__init__.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from typing import Union
  2. import torch
  3. from .roma_models import roma_model,roma_model_pad, tiny_roma_v1_model
  4. weight_urls = {
  5. "romatch": {
  6. "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
  7. "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
  8. },
  9. "tiny_roma_v1": {
  10. "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/tiny_roma_v1_outdoor.pth",
  11. },
  12. "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", # hopefully this doesnt change :D
  13. }
  14. def tiny_roma_v1_outdoor(device, weights=None, xfeat=None):
  15. if weights is None:
  16. weights = torch.hub.load_state_dict_from_url(
  17. weight_urls["tiny_roma_v1"]["outdoor"], map_location=device
  18. )
  19. if xfeat is None:
  20. xfeat = torch.hub.load(
  21. "verlab/accelerated_features", "XFeat", pretrained=True, top_k=4096
  22. ).net
  23. return tiny_roma_v1_model(weights=weights, xfeat=xfeat).to(device)
  24. def roma_outdoor(
  25. device,
  26. weights=None,
  27. dinov2_weights=None,
  28. coarse_res: Union[int, tuple[int, int]] = 560,
  29. upsample_res: Union[int, tuple[int, int]] = 864,
  30. amp_dtype: torch.dtype = torch.float16,
  31. symmetric=True,
  32. use_custom_corr=True,
  33. upsample_preds=True,
  34. with_padding=False,
  35. do_compile=False,
  36. ):
  37. if torch.get_float32_matmul_precision() != "highest":
  38. raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35")
  39. if weights is None:
  40. weights = torch.hub.load_state_dict_from_url(
  41. weight_urls["romatch"]["outdoor"], map_location=device
  42. )
  43. if dinov2_weights is None:
  44. dinov2_weights = torch.hub.load_state_dict_from_url(
  45. weight_urls["dinov2"], map_location=device
  46. )
  47. model_init = roma_model if not with_padding else roma_model_pad
  48. model = model_init(
  49. resolution=coarse_res,
  50. upsample_preds=upsample_preds,
  51. weights=weights,
  52. dinov2_weights=dinov2_weights,
  53. device=device,
  54. amp_dtype=amp_dtype,
  55. symmetric=symmetric,
  56. use_custom_corr=use_custom_corr,
  57. upsample_res=upsample_res,
  58. )
  59. if do_compile:
  60. model.compile()
  61. return model
  62. def roma_indoor(
  63. device,
  64. weights=None,
  65. dinov2_weights=None,
  66. coarse_res: Union[int, tuple[int, int]] = 560,
  67. upsample_res: Union[int, tuple[int, int]] = 864,
  68. amp_dtype: torch.dtype = torch.float16,
  69. symmetric=True,
  70. use_custom_corr=True,
  71. upsample_preds=True,
  72. with_padding=False,
  73. do_compile=False,
  74. ):
  75. if torch.get_float32_matmul_precision() != "highest":
  76. raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35")
  77. if weights is None:
  78. weights = torch.hub.load_state_dict_from_url(
  79. weight_urls["romatch"]["indoor"], map_location=device
  80. )
  81. if dinov2_weights is None:
  82. dinov2_weights = torch.hub.load_state_dict_from_url(
  83. weight_urls["dinov2"], map_location=device
  84. )
  85. model_init = roma_model if not with_padding else roma_model_pad
  86. model = model_init(
  87. resolution=coarse_res,
  88. upsample_preds=upsample_preds,
  89. weights=weights,
  90. dinov2_weights=dinov2_weights,
  91. device=device,
  92. amp_dtype=amp_dtype,
  93. symmetric=symmetric,
  94. use_custom_corr=use_custom_corr,
  95. upsample_res=upsample_res,
  96. )
  97. if do_compile:
  98. model.compile()
  99. return model