__init__.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from typing import Union
  2. import torch
  3. from .roma_models import roma_model, tiny_roma_v1_model
  4. weight_urls = {
  5. "romatch": {
  6. "outdoor": "https://github.com/Parskatt/storage/releases/download/romatch/roma_outdoor.pth",
  7. "indoor": "https://github.com/Parskatt/storage/releases/download/romatch/roma_indoor.pth",
  8. },
  9. "tiny_roma_v1": {
  10. "outdoor": "https://github.com/Parskatt/storage/releases/download/romatch/tiny_roma_v1_outdoor.pth",
  11. },
  12. "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
  13. }
  14. def tiny_roma_v1_outdoor(device, weights = None, xfeat = None):
  15. if weights is None:
  16. weights = torch.hub.load_state_dict_from_url(
  17. weight_urls["tiny_roma_v1"]["outdoor"],
  18. map_location=device)
  19. if xfeat is None:
  20. xfeat = torch.hub.load(
  21. 'verlab/accelerated_features',
  22. 'XFeat',
  23. pretrained = True,
  24. top_k = 4096).net
  25. return tiny_roma_v1_model(weights = weights, xfeat = xfeat).to(device)
  26. def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
  27. if isinstance(coarse_res, int):
  28. coarse_res = (coarse_res, coarse_res)
  29. if isinstance(upsample_res, int):
  30. upsample_res = (upsample_res, upsample_res)
  31. assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
  32. assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
  33. if weights is None:
  34. weights = torch.hub.load_state_dict_from_url(weight_urls["romatch"]["outdoor"],
  35. map_location=device)
  36. if dinov2_weights is None:
  37. dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
  38. map_location=device)
  39. model = roma_model(resolution=coarse_res, upsample_preds=True,
  40. weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
  41. model.upsample_res = upsample_res
  42. print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
  43. return model
  44. def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
  45. if isinstance(coarse_res, int):
  46. coarse_res = (coarse_res, coarse_res)
  47. if isinstance(upsample_res, int):
  48. upsample_res = (upsample_res, upsample_res)
  49. assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
  50. assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
  51. if weights is None:
  52. weights = torch.hub.load_state_dict_from_url(weight_urls["romatch"]["indoor"],
  53. map_location=device)
  54. if dinov2_weights is None:
  55. dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
  56. map_location=device)
  57. model = roma_model(resolution=coarse_res, upsample_preds=True,
  58. weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
  59. model.upsample_res = upsample_res
  60. print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
  61. return model