| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- from typing import Union
- from pathlib import Path
- import os
- import shutil
- import torch
- from .roma_models import roma_model,roma_model_pad, tiny_roma_v1_model
- weight_urls = {
- "romatch": {
- "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
- "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
- },
- "tiny_roma_v1": {
- "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/tiny_roma_v1_outdoor.pth",
- },
- "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", # hopefully this doesnt change :D
- }
- def _project_root() -> Path:
- # romatch/models/model_zoo/__init__.py -> repo root is 3 levels up
- return Path(__file__).resolve().parents[3]
- def _ensure_local_torch_home_with_optional_weights() -> Path:
- """
- Make torch-hub caching point to a directory inside the current project.
- If the user places weight files directly under the project root (e.g. for offline use),
- we copy them into torch's expected hub/checkpoints directory so that subsequent
- torch.hub calls won't need to download the weights.
- """
- torch_home_env = os.environ.get("ROMATCH_TORCH_HOME")
- torch_home = Path(torch_home_env) if torch_home_env else (_project_root() / "torch_home")
- checkpoints_dir = torch_home / "hub" / "checkpoints"
- checkpoints_dir.mkdir(parents=True, exist_ok=True)
- # If present, copy user-provided weight files next to the project to where torch.hub expects them.
- local_candidates = [
- (_project_root() / "tiny_roma_v1_outdoor.pth", checkpoints_dir / "tiny_roma_v1_outdoor.pth"),
- (_project_root() / "xfeat.pt", checkpoints_dir / "xfeat.pt"),
- ]
- for src, dst in local_candidates:
- if dst.exists():
- continue
- if src.exists():
- shutil.copyfile(str(src), str(dst))
- # Also copy cached torch-hub repository code into the project torch_home.
- #
- # Tiny RoMa's XFeat is loaded via torch.hub.load("verlab/accelerated_features", ...).
- # torch.hub may still hit the network unless the repo code directory is present.
- default_torch_home = Path.home() / ".cache" / "torch"
- repo_name = "verlab_accelerated_features_main"
- hub_src = default_torch_home / "hub" / repo_name
- hub_dst = torch_home / "hub" / repo_name
- if hub_dst.exists() is False and hub_src.exists() is True:
- shutil.copytree(str(hub_src), str(hub_dst))
- trusted_src = default_torch_home / "hub" / "trusted_list"
- trusted_dst = torch_home / "hub" / "trusted_list"
- if trusted_dst.exists() is False and trusted_src.exists() is True:
- try:
- shutil.copyfile(str(trusted_src), str(trusted_dst))
- except OSError:
- # Not critical; best-effort.
- pass
- # Redirect torch-hub cache to this folder.
- os.environ.setdefault("TORCH_HOME", str(torch_home))
- return torch_home
- def tiny_roma_v1_outdoor(device, weights=None, xfeat=None):
- _ensure_local_torch_home_with_optional_weights()
- if weights is None:
- weights = torch.hub.load_state_dict_from_url(
- weight_urls["tiny_roma_v1"]["outdoor"], map_location=device
- )
- if xfeat is None:
- # Important for offline use:
- # torch.hub.load("verlab/accelerated_features", ...) may still query GitHub
- # even if repository code is cached. If we have a local cached repo directory,
- # load from that directory directly.
- torch_home = Path(os.environ.get("TORCH_HOME", _project_root() / "torch_home"))
- hub_repo_dir = torch_home / "hub" / "verlab_accelerated_features_main"
- if hub_repo_dir.exists():
- xfeat = torch.hub.load(
- str(hub_repo_dir),
- "XFeat",
- source="local",
- pretrained=True,
- top_k=4096,
- ).net
- else:
- xfeat = torch.hub.load(
- "verlab/accelerated_features",
- "XFeat",
- pretrained=True,
- top_k=4096,
- ).net
- return tiny_roma_v1_model(weights=weights, xfeat=xfeat).to(device)
- def roma_outdoor(
- device,
- weights=None,
- dinov2_weights=None,
- coarse_res: Union[int, tuple[int, int]] = 560,
- upsample_res: Union[int, tuple[int, int]] = 864,
- amp_dtype: torch.dtype = torch.float16,
- symmetric=True,
- use_custom_corr=True,
- upsample_preds=True,
- with_padding=False,
- do_compile=False,
- ):
- if torch.get_float32_matmul_precision() != "highest":
- raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35")
- if weights is None:
- weights = torch.hub.load_state_dict_from_url(
- weight_urls["romatch"]["outdoor"], map_location=device
- )
- if dinov2_weights is None:
- dinov2_weights = torch.hub.load_state_dict_from_url(
- weight_urls["dinov2"], map_location=device
- )
- model_init = roma_model if not with_padding else roma_model_pad
- model = model_init(
- resolution=coarse_res,
- upsample_preds=upsample_preds,
- weights=weights,
- dinov2_weights=dinov2_weights,
- device=device,
- amp_dtype=amp_dtype,
- symmetric=symmetric,
- use_custom_corr=use_custom_corr,
- upsample_res=upsample_res,
- )
- if do_compile:
- model.compile()
- return model
- def roma_indoor(
- device,
- weights=None,
- dinov2_weights=None,
- coarse_res: Union[int, tuple[int, int]] = 560,
- upsample_res: Union[int, tuple[int, int]] = 864,
- amp_dtype: torch.dtype = torch.float16,
- symmetric=True,
- use_custom_corr=True,
- upsample_preds=True,
- with_padding=False,
- do_compile=False,
- ):
- if torch.get_float32_matmul_precision() != "highest":
- raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35")
- if weights is None:
- weights = torch.hub.load_state_dict_from_url(
- weight_urls["romatch"]["indoor"], map_location=device
- )
- if dinov2_weights is None:
- dinov2_weights = torch.hub.load_state_dict_from_url(
- weight_urls["dinov2"], map_location=device
- )
- model_init = roma_model if not with_padding else roma_model_pad
- model = model_init(
- resolution=coarse_res,
- upsample_preds=upsample_preds,
- weights=weights,
- dinov2_weights=dinov2_weights,
- device=device,
- amp_dtype=amp_dtype,
- symmetric=symmetric,
- use_custom_corr=use_custom_corr,
- upsample_res=upsample_res,
- )
- if do_compile:
- model.compile()
- return model
|