dedode_models.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import torch
  18. from torch import nn
  19. from .decoder import ConvRefiner, Decoder
  20. from .descriptor import DeDoDeDescriptor
  21. from .detector import DeDoDeDetector
  22. from .encoder import VGG19, VGG_DINOv2
  23. def dedode_detector_L(amp_dtype: torch.dtype = torch.float16) -> DeDoDeDetector:
  24. """Get DeDoDe descriptor of type L."""
  25. NUM_PROTOTYPES = 1
  26. residual = True
  27. hidden_blocks = 8
  28. amp = True
  29. conv_refiner = nn.ModuleDict(
  30. {
  31. "8": ConvRefiner(
  32. 512,
  33. 512,
  34. 256 + NUM_PROTOTYPES,
  35. hidden_blocks=hidden_blocks,
  36. residual=residual,
  37. amp=amp,
  38. amp_dtype=amp_dtype,
  39. ),
  40. "4": ConvRefiner(
  41. 256 + 256,
  42. 256,
  43. 128 + NUM_PROTOTYPES,
  44. hidden_blocks=hidden_blocks,
  45. residual=residual,
  46. amp=amp,
  47. amp_dtype=amp_dtype,
  48. ),
  49. "2": ConvRefiner(
  50. 128 + 128,
  51. 128,
  52. 64 + NUM_PROTOTYPES,
  53. hidden_blocks=hidden_blocks,
  54. residual=residual,
  55. amp=amp,
  56. amp_dtype=amp_dtype,
  57. ),
  58. "1": ConvRefiner(
  59. 64 + 64,
  60. 64,
  61. 1 + NUM_PROTOTYPES,
  62. hidden_blocks=hidden_blocks,
  63. residual=residual,
  64. amp=amp,
  65. amp_dtype=amp_dtype,
  66. ),
  67. }
  68. )
  69. encoder = VGG19(amp=amp, amp_dtype=amp_dtype)
  70. decoder = Decoder(conv_refiner)
  71. model = DeDoDeDetector(encoder=encoder, decoder=decoder)
  72. return model
  73. def dedode_descriptor_B(amp_dtype: torch.dtype = torch.float16) -> DeDoDeDescriptor:
  74. """Get DeDoDe descriptor of type B."""
  75. NUM_PROTOTYPES = 256 # == descriptor size
  76. residual = True
  77. hidden_blocks = 5
  78. amp = True
  79. conv_refiner = nn.ModuleDict(
  80. {
  81. "8": ConvRefiner(
  82. 512,
  83. 512,
  84. 256 + NUM_PROTOTYPES,
  85. hidden_blocks=hidden_blocks,
  86. residual=residual,
  87. amp=amp,
  88. amp_dtype=amp_dtype,
  89. ),
  90. "4": ConvRefiner(
  91. 256 + 256,
  92. 256,
  93. 128 + NUM_PROTOTYPES,
  94. hidden_blocks=hidden_blocks,
  95. residual=residual,
  96. amp=amp,
  97. amp_dtype=amp_dtype,
  98. ),
  99. "2": ConvRefiner(
  100. 128 + 128,
  101. 64,
  102. 32 + NUM_PROTOTYPES,
  103. hidden_blocks=hidden_blocks,
  104. residual=residual,
  105. amp=amp,
  106. amp_dtype=amp_dtype,
  107. ),
  108. "1": ConvRefiner(
  109. 64 + 32,
  110. 32,
  111. 1 + NUM_PROTOTYPES,
  112. hidden_blocks=hidden_blocks,
  113. residual=residual,
  114. amp=amp,
  115. amp_dtype=amp_dtype,
  116. ),
  117. }
  118. )
  119. encoder = VGG19(amp=amp, amp_dtype=amp_dtype)
  120. decoder = Decoder(conv_refiner, num_prototypes=NUM_PROTOTYPES)
  121. model = DeDoDeDescriptor(encoder=encoder, decoder=decoder)
  122. return model
  123. def dedode_descriptor_G(amp_dtype: torch.dtype = torch.float16) -> DeDoDeDescriptor:
  124. """Get DeDoDe descriptor of type G."""
  125. NUM_PROTOTYPES = 256 # == descriptor size
  126. residual = True
  127. hidden_blocks = 5
  128. amp = True
  129. conv_refiner = nn.ModuleDict(
  130. {
  131. "14": ConvRefiner(
  132. 1024,
  133. 768,
  134. 512 + NUM_PROTOTYPES,
  135. hidden_blocks=hidden_blocks,
  136. residual=residual,
  137. amp=amp,
  138. amp_dtype=amp_dtype,
  139. ),
  140. "8": ConvRefiner(
  141. 512 + 512,
  142. 512,
  143. 256 + NUM_PROTOTYPES,
  144. hidden_blocks=hidden_blocks,
  145. residual=residual,
  146. amp=amp,
  147. amp_dtype=amp_dtype,
  148. ),
  149. "4": ConvRefiner(
  150. 256 + 256,
  151. 256,
  152. 128 + NUM_PROTOTYPES,
  153. hidden_blocks=hidden_blocks,
  154. residual=residual,
  155. amp=amp,
  156. amp_dtype=amp_dtype,
  157. ),
  158. "2": ConvRefiner(
  159. 128 + 128,
  160. 64,
  161. 32 + NUM_PROTOTYPES,
  162. hidden_blocks=hidden_blocks,
  163. residual=residual,
  164. amp=amp,
  165. amp_dtype=amp_dtype,
  166. ),
  167. "1": ConvRefiner(
  168. 64 + 32,
  169. 32,
  170. 1 + NUM_PROTOTYPES,
  171. hidden_blocks=hidden_blocks,
  172. residual=residual,
  173. amp=amp,
  174. amp_dtype=amp_dtype,
  175. ),
  176. }
  177. )
  178. vgg_kwargs = {"amp": amp, "amp_dtype": amp_dtype}
  179. dinov2_kwargs = {"amp": amp, "amp_dtype": amp_dtype, "dinov2_weights": None}
  180. encoder = VGG_DINOv2(vgg_kwargs=vgg_kwargs, dinov2_kwargs=dinov2_kwargs)
  181. decoder = Decoder(conv_refiner, num_prototypes=NUM_PROTOTYPES)
  182. model = DeDoDeDescriptor(encoder=encoder, decoder=decoder)
  183. return model
  184. def get_detector(kind: str = "L", amp_dtype: torch.dtype = torch.float16) -> DeDoDeDetector:
  185. """Get DeDoDe detector."""
  186. if kind == "L":
  187. return dedode_detector_L(amp_dtype)
  188. raise ValueError(f"Unknown detector kind: {kind}")
  189. def get_descriptor(kind: str = "B", amp_dtype: torch.dtype = torch.float16) -> DeDoDeDescriptor:
  190. """Get DeDoDe descriptor."""
  191. if kind == "B":
  192. return dedode_descriptor_B(amp_dtype)
  193. if kind == "G":
  194. return dedode_descriptor_G(amp_dtype)
  195. raise ValueError(f"Unknown descriptor kind: {kind}")