roma_models.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import sys
  2. import warnings
  3. from functools import partial
  4. import torch
  5. import torch.nn as nn
  6. from loguru import logger
  7. from romatch.models.encoders import CNNandDinov2
  8. from romatch.models.matcher import (
  9. GP,
  10. ConvRefiner,
  11. CosKernel,
  12. Decoder,
  13. RegressionMatcher,
  14. )
  15. from romatch.models.tiny import TinyRoMa
  16. from romatch.models.transformer import Block, MemEffAttention, TransformerDecoder
  17. def tiny_roma_v1_model(
  18. weights=None, freeze_xfeat=False, exact_softmax=False, xfeat=None
  19. ):
  20. model = TinyRoMa(
  21. xfeat=xfeat, freeze_xfeat=freeze_xfeat, exact_softmax=exact_softmax
  22. )
  23. if weights is not None:
  24. model.load_state_dict(weights)
  25. return model
  26. def pad_refiner_state_dict(state_dict_old,state_dict_pad):
  27. for key in state_dict_pad.keys():
  28. if key.startswith('decoder.conv_refiner'):
  29. param = state_dict_old[key]
  30. shape_old = param.shape
  31. shape_pad = state_dict_pad[key].shape
  32. if shape_old != shape_pad:
  33. new_param = torch.zeros(shape_pad, device=param.device, dtype=param.dtype)
  34. slices = tuple(slice(0, s) for s in shape_old)
  35. new_param[slices] = param
  36. state_dict_old[key] = new_param
  37. return state_dict_old
  38. def roma_model_pad(
  39. resolution,
  40. upsample_preds,
  41. device=None,
  42. weights=None,
  43. dinov2_weights=None,
  44. amp_dtype: torch.dtype = torch.float16,
  45. use_custom_corr=True,
  46. symmetric=True,
  47. upsample_res=None,
  48. sample_thresh=0.05,
  49. sample_mode="threshold_balanced",
  50. attenuate_cert = True,
  51. refiner_channels= [1384, 1144, 576, 144, 24],
  52. **kwargs,
  53. ):
  54. if sys.platform != "linux":
  55. use_custom_corr = False
  56. warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
  57. if isinstance(resolution, int):
  58. resolution = (resolution, resolution)
  59. if isinstance(upsample_res, int):
  60. upsample_res = (upsample_res, upsample_res)
  61. if str(device) == "cpu":
  62. amp_dtype = torch.float32
  63. assert resolution[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
  64. assert resolution[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
  65. logger.info(
  66. f"Using coarse resolution {resolution}, and upsample res {upsample_res}"
  67. )
  68. if sys.platform != "linux":
  69. use_custom_corr = False
  70. warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
  71. warnings.filterwarnings(
  72. "ignore", category=UserWarning, message="TypedStorage is deprecated"
  73. )
  74. gp_dim = 512
  75. feat_dim = 512
  76. decoder_dim = gp_dim + feat_dim
  77. cls_to_coord_res = 64
  78. coordinate_decoder = TransformerDecoder(
  79. nn.Sequential(
  80. *[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]
  81. ),
  82. decoder_dim,
  83. cls_to_coord_res**2 + 1,
  84. is_classifier=True,
  85. amp=True,
  86. pos_enc=False,
  87. )
  88. dw = True
  89. hidden_blocks = 8
  90. kernel_size = 5
  91. displacement_emb = "linear"
  92. disable_local_corr_grad = True
  93. partial_conv_refiner = partial(
  94. ConvRefiner,
  95. kernel_size=kernel_size,
  96. dw=dw,
  97. hidden_blocks=hidden_blocks,
  98. displacement_emb=displacement_emb,
  99. corr_in_other=True,
  100. amp=True,
  101. disable_local_corr_grad=disable_local_corr_grad,
  102. bn_momentum=0.01,
  103. use_custom_corr=use_custom_corr,
  104. )
  105. conv_refiner = nn.ModuleDict(
  106. {
  107. "16": partial_conv_refiner(
  108. refiner_channels[0],
  109. refiner_channels[0],
  110. 2 + 1,
  111. displacement_emb_dim=128,
  112. local_corr_radius=7,
  113. ),
  114. "8": partial_conv_refiner(
  115. refiner_channels[1],
  116. refiner_channels[1],
  117. 2 + 1,
  118. displacement_emb_dim=64,
  119. local_corr_radius=3,
  120. ),
  121. "4": partial_conv_refiner(
  122. refiner_channels[2],
  123. refiner_channels[2],
  124. 2 + 1,
  125. displacement_emb_dim=32,
  126. local_corr_radius=2,
  127. ),
  128. "2": partial_conv_refiner(
  129. refiner_channels[3],
  130. refiner_channels[3],
  131. 2 + 1,
  132. displacement_emb_dim=16,
  133. ),
  134. "1": partial_conv_refiner(
  135. refiner_channels[4],
  136. refiner_channels[4],
  137. 2 + 1,
  138. displacement_emb_dim=6,
  139. ),
  140. }
  141. )
  142. kernel_temperature = 0.2
  143. learn_temperature = False
  144. no_cov = True
  145. kernel = CosKernel
  146. only_attention = False
  147. basis = "fourier"
  148. gp16 = GP(
  149. kernel,
  150. T=kernel_temperature,
  151. learn_temperature=learn_temperature,
  152. only_attention=only_attention,
  153. gp_dim=gp_dim,
  154. basis=basis,
  155. no_cov=no_cov,
  156. )
  157. gps = nn.ModuleDict({"16": gp16})
  158. proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
  159. proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
  160. proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
  161. proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
  162. proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
  163. proj = nn.ModuleDict(
  164. {
  165. "16": proj16,
  166. "8": proj8,
  167. "4": proj4,
  168. "2": proj2,
  169. "1": proj1,
  170. }
  171. )
  172. displacement_dropout_p = 0.0
  173. gm_warp_dropout_p = 0.0
  174. decoder = Decoder(
  175. coordinate_decoder,
  176. gps,
  177. proj,
  178. conv_refiner,
  179. detach=True,
  180. scales=["16", "8", "4", "2", "1"],
  181. displacement_dropout_p=displacement_dropout_p,
  182. gm_warp_dropout_p=gm_warp_dropout_p,
  183. )
  184. encoder = CNNandDinov2(
  185. cnn_kwargs=dict(pretrained=False, amp=True),
  186. amp=True,
  187. dinov2_weights=dinov2_weights,
  188. amp_dtype=amp_dtype,
  189. )
  190. h, w = resolution
  191. matcher = RegressionMatcher(
  192. encoder,
  193. decoder,
  194. h=h,
  195. w=w,
  196. upsample_preds=upsample_preds,
  197. upsample_res=upsample_res,
  198. symmetric=symmetric,
  199. attenuate_cert=attenuate_cert,
  200. sample_mode=sample_mode,
  201. sample_thresh=sample_thresh,
  202. **kwargs,
  203. ).to(device)
  204. if weights is not None:
  205. state_dict_pad = matcher.state_dict()
  206. weights = pad_refiner_state_dict(weights,state_dict_pad)
  207. del state_dict_pad
  208. matcher.load_state_dict(weights)
  209. return matcher
  210. def roma_model(
  211. resolution,
  212. upsample_preds,
  213. device=None,
  214. weights=None,
  215. dinov2_weights=None,
  216. amp_dtype: torch.dtype = torch.float16,
  217. use_custom_corr=True,
  218. symmetric=True,
  219. upsample_res=None,
  220. sample_thresh=0.05,
  221. sample_mode="threshold_balanced",
  222. attenuate_cert = True,
  223. **kwargs,
  224. ):
  225. if sys.platform != "linux":
  226. use_custom_corr = False
  227. warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
  228. if isinstance(resolution, int):
  229. resolution = (resolution, resolution)
  230. if isinstance(upsample_res, int):
  231. upsample_res = (upsample_res, upsample_res)
  232. if str(device) == "cpu":
  233. amp_dtype = torch.float32
  234. assert resolution[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
  235. assert resolution[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
  236. logger.info(
  237. f"Using coarse resolution {resolution}, and upsample res {upsample_res}"
  238. )
  239. if sys.platform != "linux":
  240. use_custom_corr = False
  241. warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
  242. warnings.filterwarnings(
  243. "ignore", category=UserWarning, message="TypedStorage is deprecated"
  244. )
  245. gp_dim = 512
  246. feat_dim = 512
  247. decoder_dim = gp_dim + feat_dim
  248. cls_to_coord_res = 64
  249. coordinate_decoder = TransformerDecoder(
  250. nn.Sequential(
  251. *[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]
  252. ),
  253. decoder_dim,
  254. cls_to_coord_res**2 + 1,
  255. is_classifier=True,
  256. amp=True,
  257. pos_enc=False,
  258. )
  259. dw = True
  260. hidden_blocks = 8
  261. kernel_size = 5
  262. displacement_emb = "linear"
  263. disable_local_corr_grad = True
  264. partial_conv_refiner = partial(
  265. ConvRefiner,
  266. kernel_size=kernel_size,
  267. dw=dw,
  268. hidden_blocks=hidden_blocks,
  269. displacement_emb=displacement_emb,
  270. corr_in_other=True,
  271. amp=True,
  272. disable_local_corr_grad=disable_local_corr_grad,
  273. bn_momentum=0.01,
  274. use_custom_corr=use_custom_corr,
  275. )
  276. conv_refiner = nn.ModuleDict(
  277. {
  278. "16": partial_conv_refiner(
  279. 2 * 512 + 128 + (2 * 7 + 1) ** 2,
  280. 2 * 512 + 128 + (2 * 7 + 1) ** 2,
  281. 2 + 1,
  282. displacement_emb_dim=128,
  283. local_corr_radius=7,
  284. ),
  285. "8": partial_conv_refiner(
  286. 2 * 512 + 64 + (2 * 3 + 1) ** 2,
  287. 2 * 512 + 64 + (2 * 3 + 1) ** 2,
  288. 2 + 1,
  289. displacement_emb_dim=64,
  290. local_corr_radius=3,
  291. ),
  292. "4": partial_conv_refiner(
  293. 2 * 256 + 32 + (2 * 2 + 1) ** 2,
  294. 2 * 256 + 32 + (2 * 2 + 1) ** 2,
  295. 2 + 1,
  296. displacement_emb_dim=32,
  297. local_corr_radius=2,
  298. ),
  299. "2": partial_conv_refiner(
  300. 2 * 64 + 16,
  301. 128 + 16,
  302. 2 + 1,
  303. displacement_emb_dim=16,
  304. ),
  305. "1": partial_conv_refiner(
  306. 2 * 9 + 6,
  307. 24,
  308. 2 + 1,
  309. displacement_emb_dim=6,
  310. ),
  311. }
  312. )
  313. kernel_temperature = 0.2
  314. learn_temperature = False
  315. no_cov = True
  316. kernel = CosKernel
  317. only_attention = False
  318. basis = "fourier"
  319. gp16 = GP(
  320. kernel,
  321. T=kernel_temperature,
  322. learn_temperature=learn_temperature,
  323. only_attention=only_attention,
  324. gp_dim=gp_dim,
  325. basis=basis,
  326. no_cov=no_cov,
  327. )
  328. gps = nn.ModuleDict({"16": gp16})
  329. proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
  330. proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
  331. proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
  332. proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
  333. proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
  334. proj = nn.ModuleDict(
  335. {
  336. "16": proj16,
  337. "8": proj8,
  338. "4": proj4,
  339. "2": proj2,
  340. "1": proj1,
  341. }
  342. )
  343. displacement_dropout_p = 0.0
  344. gm_warp_dropout_p = 0.0
  345. decoder = Decoder(
  346. coordinate_decoder,
  347. gps,
  348. proj,
  349. conv_refiner,
  350. detach=True,
  351. scales=["16", "8", "4", "2", "1"],
  352. displacement_dropout_p=displacement_dropout_p,
  353. gm_warp_dropout_p=gm_warp_dropout_p,
  354. )
  355. encoder = CNNandDinov2(
  356. cnn_kwargs=dict(pretrained=False, amp=True),
  357. amp=True,
  358. dinov2_weights=dinov2_weights,
  359. amp_dtype=amp_dtype,
  360. )
  361. h, w = resolution
  362. matcher = RegressionMatcher(
  363. encoder,
  364. decoder,
  365. h=h,
  366. w=w,
  367. upsample_preds=upsample_preds,
  368. upsample_res=upsample_res,
  369. symmetric=symmetric,
  370. attenuate_cert=attenuate_cert,
  371. sample_mode=sample_mode,
  372. sample_thresh=sample_thresh,
  373. **kwargs,
  374. ).to(device)
  375. matcher.load_state_dict(weights)
  376. return matcher