roma_models.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from functools import partial
  2. import warnings
  3. import torch.nn as nn
  4. import torch
  5. from romatch.models.matcher import (
  6. ConvRefiner,
  7. CosKernel,
  8. GP,
  9. Decoder,
  10. RegressionMatcher,
  11. )
  12. from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
  13. from romatch.models.encoders import CNNandDinov2
  14. from romatch.models.tiny import TinyRoMa
  15. def tiny_roma_v1_model(
  16. weights=None, freeze_xfeat=False, exact_softmax=False, xfeat=None
  17. ):
  18. model = TinyRoMa(
  19. xfeat=xfeat, freeze_xfeat=freeze_xfeat, exact_softmax=exact_softmax
  20. )
  21. if weights is not None:
  22. model.load_state_dict(weights)
  23. return model
  24. def roma_model(
  25. resolution,
  26. upsample_preds,
  27. device=None,
  28. weights=None,
  29. dinov2_weights=None,
  30. amp_dtype: torch.dtype = torch.float16,
  31. use_custom_corr=False,
  32. symmetric=True,
  33. **kwargs,
  34. ):
  35. warnings.filterwarnings(
  36. "ignore", category=UserWarning, message="TypedStorage is deprecated"
  37. )
  38. gp_dim = 512
  39. feat_dim = 512
  40. decoder_dim = gp_dim + feat_dim
  41. cls_to_coord_res = 64
  42. coordinate_decoder = TransformerDecoder(
  43. nn.Sequential(
  44. *[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]
  45. ),
  46. decoder_dim,
  47. cls_to_coord_res**2 + 1,
  48. is_classifier=True,
  49. amp=True,
  50. pos_enc=False,
  51. )
  52. dw = True
  53. hidden_blocks = 8
  54. kernel_size = 5
  55. displacement_emb = "linear"
  56. disable_local_corr_grad = True
  57. partial_conv_refiner = partial(
  58. ConvRefiner,
  59. kernel_size=kernel_size,
  60. dw=dw,
  61. hidden_blocks=hidden_blocks,
  62. displacement_emb=displacement_emb,
  63. corr_in_other=True,
  64. amp=True,
  65. disable_local_corr_grad=disable_local_corr_grad,
  66. bn_momentum=0.01,
  67. use_custom_corr=use_custom_corr,
  68. )
  69. conv_refiner = nn.ModuleDict(
  70. {
  71. "16": partial_conv_refiner(
  72. 2 * 512 + 128 + (2 * 7 + 1) ** 2,
  73. 2 * 512 + 128 + (2 * 7 + 1) ** 2,
  74. 2 + 1,
  75. displacement_emb_dim=128,
  76. local_corr_radius=7,
  77. ),
  78. "8": partial_conv_refiner(
  79. 2 * 512 + 64 + (2 * 3 + 1) ** 2,
  80. 2 * 512 + 64 + (2 * 3 + 1) ** 2,
  81. 2 + 1,
  82. displacement_emb_dim=64,
  83. local_corr_radius=3,
  84. ),
  85. "4": partial_conv_refiner(
  86. 2 * 256 + 32 + (2 * 2 + 1) ** 2,
  87. 2 * 256 + 32 + (2 * 2 + 1) ** 2,
  88. 2 + 1,
  89. displacement_emb_dim=32,
  90. local_corr_radius=2,
  91. ),
  92. "2": partial_conv_refiner(
  93. 2 * 64 + 16,
  94. 128 + 16,
  95. 2 + 1,
  96. displacement_emb_dim=16,
  97. ),
  98. "1": partial_conv_refiner(
  99. 2 * 9 + 6,
  100. 24,
  101. 2 + 1,
  102. displacement_emb_dim=6,
  103. ),
  104. }
  105. )
  106. kernel_temperature = 0.2
  107. learn_temperature = False
  108. no_cov = True
  109. kernel = CosKernel
  110. only_attention = False
  111. basis = "fourier"
  112. gp16 = GP(
  113. kernel,
  114. T=kernel_temperature,
  115. learn_temperature=learn_temperature,
  116. only_attention=only_attention,
  117. gp_dim=gp_dim,
  118. basis=basis,
  119. no_cov=no_cov,
  120. )
  121. gps = nn.ModuleDict({"16": gp16})
  122. proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
  123. proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
  124. proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
  125. proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
  126. proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
  127. proj = nn.ModuleDict(
  128. {
  129. "16": proj16,
  130. "8": proj8,
  131. "4": proj4,
  132. "2": proj2,
  133. "1": proj1,
  134. }
  135. )
  136. displacement_dropout_p = 0.0
  137. gm_warp_dropout_p = 0.0
  138. decoder = Decoder(
  139. coordinate_decoder,
  140. gps,
  141. proj,
  142. conv_refiner,
  143. detach=True,
  144. scales=["16", "8", "4", "2", "1"],
  145. displacement_dropout_p=displacement_dropout_p,
  146. gm_warp_dropout_p=gm_warp_dropout_p,
  147. )
  148. encoder = CNNandDinov2(
  149. cnn_kwargs=dict(pretrained=False, amp=True),
  150. amp=True,
  151. dinov2_weights=dinov2_weights,
  152. amp_dtype=amp_dtype,
  153. )
  154. h, w = resolution
  155. attenuate_cert = True
  156. sample_mode = "threshold_balanced"
  157. matcher = RegressionMatcher(
  158. encoder,
  159. decoder,
  160. h=h,
  161. w=w,
  162. upsample_preds=upsample_preds,
  163. symmetric=symmetric,
  164. attenuate_cert=attenuate_cert,
  165. sample_mode=sample_mode,
  166. **kwargs,
  167. ).to(device)
  168. matcher.load_state_dict(weights)
  169. return matcher