__init__.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from typing import Union
  2. from pathlib import Path
  3. import os
  4. import shutil
  5. import torch
  6. from .roma_models import roma_model,roma_model_pad, tiny_roma_v1_model
  7. weight_urls = {
  8. "romatch": {
  9. "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
  10. "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
  11. },
  12. "tiny_roma_v1": {
  13. "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/tiny_roma_v1_outdoor.pth",
  14. },
  15. "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", # hopefully this doesnt change :D
  16. }
  17. def _project_root() -> Path:
  18. # romatch/models/model_zoo/__init__.py -> repo root is 3 levels up
  19. return Path(__file__).resolve().parents[3]
  20. def _ensure_local_torch_home_with_optional_weights() -> Path:
  21. """
  22. Make torch-hub caching point to a directory inside the current project.
  23. If the user places weight files directly under the project root (e.g. for offline use),
  24. we copy them into torch's expected hub/checkpoints directory so that subsequent
  25. torch.hub calls won't need to download the weights.
  26. """
  27. torch_home_env = os.environ.get("ROMATCH_TORCH_HOME")
  28. torch_home = Path(torch_home_env) if torch_home_env else (_project_root() / "torch_home")
  29. checkpoints_dir = torch_home / "hub" / "checkpoints"
  30. checkpoints_dir.mkdir(parents=True, exist_ok=True)
  31. # If present, copy user-provided weight files next to the project to where torch.hub expects them.
  32. local_candidates = [
  33. (_project_root() / "tiny_roma_v1_outdoor.pth", checkpoints_dir / "tiny_roma_v1_outdoor.pth"),
  34. (_project_root() / "xfeat.pt", checkpoints_dir / "xfeat.pt"),
  35. ]
  36. for src, dst in local_candidates:
  37. if dst.exists():
  38. continue
  39. if src.exists():
  40. shutil.copyfile(str(src), str(dst))
  41. # Also copy cached torch-hub repository code into the project torch_home.
  42. #
  43. # Tiny RoMa's XFeat is loaded via torch.hub.load("verlab/accelerated_features", ...).
  44. # torch.hub may still hit the network unless the repo code directory is present.
  45. default_torch_home = Path.home() / ".cache" / "torch"
  46. repo_name = "verlab_accelerated_features_main"
  47. hub_src = default_torch_home / "hub" / repo_name
  48. hub_dst = torch_home / "hub" / repo_name
  49. if hub_dst.exists() is False and hub_src.exists() is True:
  50. shutil.copytree(str(hub_src), str(hub_dst))
  51. trusted_src = default_torch_home / "hub" / "trusted_list"
  52. trusted_dst = torch_home / "hub" / "trusted_list"
  53. if trusted_dst.exists() is False and trusted_src.exists() is True:
  54. try:
  55. shutil.copyfile(str(trusted_src), str(trusted_dst))
  56. except OSError:
  57. # Not critical; best-effort.
  58. pass
  59. # Redirect torch-hub cache to this folder.
  60. os.environ.setdefault("TORCH_HOME", str(torch_home))
  61. return torch_home
  62. def tiny_roma_v1_outdoor(device, weights=None, xfeat=None):
  63. _ensure_local_torch_home_with_optional_weights()
  64. if weights is None:
  65. weights = torch.hub.load_state_dict_from_url(
  66. weight_urls["tiny_roma_v1"]["outdoor"], map_location=device
  67. )
  68. if xfeat is None:
  69. # Important for offline use:
  70. # torch.hub.load("verlab/accelerated_features", ...) may still query GitHub
  71. # even if repository code is cached. If we have a local cached repo directory,
  72. # load from that directory directly.
  73. torch_home = Path(os.environ.get("TORCH_HOME", _project_root() / "torch_home"))
  74. hub_repo_dir = torch_home / "hub" / "verlab_accelerated_features_main"
  75. if hub_repo_dir.exists():
  76. xfeat = torch.hub.load(
  77. str(hub_repo_dir),
  78. "XFeat",
  79. source="local",
  80. pretrained=True,
  81. top_k=4096,
  82. ).net
  83. else:
  84. xfeat = torch.hub.load(
  85. "verlab/accelerated_features",
  86. "XFeat",
  87. pretrained=True,
  88. top_k=4096,
  89. ).net
  90. return tiny_roma_v1_model(weights=weights, xfeat=xfeat).to(device)
  91. def roma_outdoor(
  92. device,
  93. weights=None,
  94. dinov2_weights=None,
  95. coarse_res: Union[int, tuple[int, int]] = 560,
  96. upsample_res: Union[int, tuple[int, int]] = 864,
  97. amp_dtype: torch.dtype = torch.float16,
  98. symmetric=True,
  99. use_custom_corr=True,
  100. upsample_preds=True,
  101. with_padding=False,
  102. do_compile=False,
  103. ):
  104. if torch.get_float32_matmul_precision() != "highest":
  105. raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35")
  106. if weights is None:
  107. weights = torch.hub.load_state_dict_from_url(
  108. weight_urls["romatch"]["outdoor"], map_location=device
  109. )
  110. if dinov2_weights is None:
  111. dinov2_weights = torch.hub.load_state_dict_from_url(
  112. weight_urls["dinov2"], map_location=device
  113. )
  114. model_init = roma_model if not with_padding else roma_model_pad
  115. model = model_init(
  116. resolution=coarse_res,
  117. upsample_preds=upsample_preds,
  118. weights=weights,
  119. dinov2_weights=dinov2_weights,
  120. device=device,
  121. amp_dtype=amp_dtype,
  122. symmetric=symmetric,
  123. use_custom_corr=use_custom_corr,
  124. upsample_res=upsample_res,
  125. )
  126. if do_compile:
  127. model.compile()
  128. return model
  129. def roma_indoor(
  130. device,
  131. weights=None,
  132. dinov2_weights=None,
  133. coarse_res: Union[int, tuple[int, int]] = 560,
  134. upsample_res: Union[int, tuple[int, int]] = 864,
  135. amp_dtype: torch.dtype = torch.float16,
  136. symmetric=True,
  137. use_custom_corr=True,
  138. upsample_preds=True,
  139. with_padding=False,
  140. do_compile=False,
  141. ):
  142. if torch.get_float32_matmul_precision() != "highest":
  143. raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35")
  144. if weights is None:
  145. weights = torch.hub.load_state_dict_from_url(
  146. weight_urls["romatch"]["indoor"], map_location=device
  147. )
  148. if dinov2_weights is None:
  149. dinov2_weights = torch.hub.load_state_dict_from_url(
  150. weight_urls["dinov2"], map_location=device
  151. )
  152. model_init = roma_model if not with_padding else roma_model_pad
  153. model = model_init(
  154. resolution=coarse_res,
  155. upsample_preds=upsample_preds,
  156. weights=weights,
  157. dinov2_weights=dinov2_weights,
  158. device=device,
  159. amp_dtype=amp_dtype,
  160. symmetric=symmetric,
  161. use_custom_corr=use_custom_corr,
  162. upsample_res=upsample_res,
  163. )
  164. if do_compile:
  165. model.compile()
  166. return model