hybrid_embed.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. """ Image to Patch Hybird Embedding Layer
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import logging
  5. import math
  6. from typing import List, Optional, Tuple, Union
  7. import torch
  8. from torch import nn as nn
  9. import torch.nn.functional as F
  10. from .format import Format, nchw_to
  11. from .helpers import to_2tuple
  12. from .patch_embed import resample_patch_embed
  13. _logger = logging.getLogger(__name__)
  14. class HybridEmbed(nn.Module):
  15. """ CNN Feature Map Embedding
  16. Extract feature map from CNN, flatten, project to embedding dim.
  17. """
  18. output_fmt: Format
  19. dynamic_img_pad: torch.jit.Final[bool]
  20. def __init__(
  21. self,
  22. backbone: nn.Module,
  23. img_size: Union[int, Tuple[int, int]] = 224,
  24. patch_size: Union[int, Tuple[int, int]] = 1,
  25. feature_size: Optional[Union[int, Tuple[int, int]]] = None,
  26. feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
  27. in_chans: int = 3,
  28. embed_dim: int = 768,
  29. bias: bool = True,
  30. proj: bool = True,
  31. flatten: bool = True,
  32. output_fmt: Optional[str] = None,
  33. strict_img_size: bool = True,
  34. dynamic_img_pad: bool = False,
  35. device=None,
  36. dtype=None,
  37. ):
  38. dd = {'device': device, 'dtype': dtype}
  39. super().__init__()
  40. assert isinstance(backbone, nn.Module)
  41. self.backbone = backbone
  42. self.in_chans = in_chans
  43. (
  44. self.img_size,
  45. self.patch_size,
  46. self.feature_size,
  47. self.feature_ratio,
  48. self.feature_dim,
  49. self.grid_size,
  50. self.num_patches,
  51. ) = self._init_backbone(
  52. img_size=img_size,
  53. patch_size=patch_size,
  54. feature_size=feature_size,
  55. feature_ratio=feature_ratio,
  56. **dd,
  57. )
  58. if output_fmt is not None:
  59. self.flatten = False
  60. self.output_fmt = Format(output_fmt)
  61. else:
  62. # flatten spatial dim and transpose to channels last, kept for bwd compat
  63. self.flatten = flatten
  64. self.output_fmt = Format.NCHW
  65. self.strict_img_size = strict_img_size
  66. self.dynamic_img_pad = dynamic_img_pad
  67. if not dynamic_img_pad:
  68. assert self.feature_size[0] % self.patch_size[0] == 0 and self.feature_size[1] % self.patch_size[1] == 0
  69. if proj:
  70. self.proj = nn.Conv2d(
  71. self.feature_dim,
  72. embed_dim,
  73. kernel_size=patch_size,
  74. stride=patch_size,
  75. bias=bias,
  76. **dd,
  77. )
  78. else:
  79. assert self.feature_dim == embed_dim, \
  80. f'The feature dim ({self.feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
  81. self.proj = nn.Identity()
  82. def _init_backbone(
  83. self,
  84. img_size: Union[int, Tuple[int, int]] = 224,
  85. patch_size: Union[int, Tuple[int, int]] = 1,
  86. feature_size: Optional[Union[int, Tuple[int, int]]] = None,
  87. feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
  88. feature_dim: Optional[int] = None,
  89. device=None,
  90. dtype=None,
  91. ):
  92. img_size = to_2tuple(img_size)
  93. patch_size = to_2tuple(patch_size)
  94. if feature_size is None:
  95. with torch.no_grad():
  96. # NOTE Most reliable way of determining output dims is to run forward pass
  97. training = self.backbone.training
  98. if training:
  99. self.backbone.eval()
  100. # FIXME whatif meta device?
  101. o = self.backbone(torch.zeros(1, self.in_chans, img_size[0], img_size[1], device=device, dtype=dtype))
  102. if isinstance(o, (list, tuple)):
  103. o = o[-1] # last feature if backbone outputs list/tuple of features
  104. feature_size = o.shape[-2:]
  105. feature_dim = o.shape[1]
  106. self.backbone.train(training)
  107. feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)])
  108. else:
  109. feature_size = to_2tuple(feature_size)
  110. feature_ratio = to_2tuple(feature_ratio or 16)
  111. if feature_dim is None:
  112. if hasattr(self.backbone, 'feature_info'):
  113. feature_dim = self.backbone.feature_info.channels()[-1]
  114. else:
  115. feature_dim = self.backbone.num_features
  116. grid_size = tuple([f // p for f, p in zip(feature_size, patch_size)])
  117. num_patches = grid_size[0] * grid_size[1]
  118. return img_size, patch_size, feature_size, feature_ratio, feature_dim, grid_size, num_patches
  119. def set_input_size(
  120. self,
  121. img_size: Optional[Union[int, Tuple[int, int]]] = None,
  122. patch_size: Optional[Union[int, Tuple[int, int]]] = None,
  123. feature_size: Optional[Union[int, Tuple[int, int]]] = None,
  124. feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
  125. feature_dim: Optional[int] = None,
  126. ):
  127. assert img_size is not None or patch_size is not None
  128. img_size = img_size or self.img_size
  129. new_patch_size = None
  130. if patch_size is not None:
  131. new_patch_size = to_2tuple(patch_size)
  132. if new_patch_size is not None and new_patch_size != self.patch_size:
  133. assert isinstance(self.proj, nn.Conv2d), 'HybridEmbed must have a projection layer to change patch size.'
  134. with torch.no_grad():
  135. new_proj = nn.Conv2d(
  136. self.proj.in_channels,
  137. self.proj.out_channels,
  138. kernel_size=new_patch_size,
  139. stride=new_patch_size,
  140. bias=self.proj.bias is not None,
  141. device=self.proj.device,
  142. dtype=self.proj.dtype,
  143. )
  144. new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
  145. if self.proj.bias is not None:
  146. new_proj.bias.copy_(self.proj.bias)
  147. self.proj = new_proj
  148. patch_size = new_patch_size
  149. patch_size = patch_size or self.patch_size
  150. if img_size != self.img_size or patch_size != self.patch_size:
  151. (
  152. self.img_size,
  153. self.patch_size,
  154. self.feature_size,
  155. self.feature_ratio,
  156. self.feature_dim,
  157. self.grid_size,
  158. self.num_patches,
  159. ) = self._init_backbone(
  160. img_size=img_size,
  161. patch_size=patch_size,
  162. feature_size=feature_size,
  163. feature_ratio=feature_ratio,
  164. feature_dim=feature_dim,
  165. # FIXME device/dtype?
  166. )
  167. def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
  168. total_reduction = (
  169. self.feature_ratio[0] * self.patch_size[0],
  170. self.feature_ratio[1] * self.patch_size[1]
  171. )
  172. if as_scalar:
  173. return max(total_reduction)
  174. else:
  175. return total_reduction
  176. def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
  177. """ Get feature grid size taking account dynamic padding and backbone network feat reduction
  178. """
  179. feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1])
  180. if self.dynamic_img_pad:
  181. return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1])
  182. else:
  183. return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
  184. @torch.jit.ignore
  185. def set_grad_checkpointing(self, enable: bool = True):
  186. if hasattr(self.backbone, 'set_grad_checkpointing'):
  187. self.backbone.set_grad_checkpointing(enable=enable)
  188. elif hasattr(self.backbone, 'grad_checkpointing'):
  189. self.backbone.grad_checkpointing = enable
  190. def forward(self, x):
  191. x = self.backbone(x)
  192. if isinstance(x, (list, tuple)):
  193. x = x[-1] # last feature if backbone outputs list/tuple of features
  194. _, _, H, W = x.shape
  195. if self.dynamic_img_pad:
  196. pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
  197. pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
  198. x = F.pad(x, (0, pad_w, 0, pad_h))
  199. x = self.proj(x)
  200. if self.flatten:
  201. x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
  202. elif self.output_fmt != Format.NCHW:
  203. x = nchw_to(x, self.output_fmt)
  204. return x
  205. class HybridEmbedWithSize(HybridEmbed):
  206. """ CNN Feature Map Embedding
  207. Extract feature map from CNN, flatten, project to embedding dim.
  208. """
  209. def __init__(
  210. self,
  211. backbone: nn.Module,
  212. img_size: Union[int, Tuple[int, int]] = 224,
  213. patch_size: Union[int, Tuple[int, int]] = 1,
  214. feature_size: Optional[Union[int, Tuple[int, int]]] = None,
  215. feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
  216. in_chans: int = 3,
  217. embed_dim: int = 768,
  218. bias=True,
  219. proj=True,
  220. device=None,
  221. dtype=None,
  222. ):
  223. super().__init__(
  224. backbone=backbone,
  225. img_size=img_size,
  226. patch_size=patch_size,
  227. feature_size=feature_size,
  228. feature_ratio=feature_ratio,
  229. in_chans=in_chans,
  230. embed_dim=embed_dim,
  231. bias=bias,
  232. proj=proj,
  233. device=device,
  234. dtype=dtype,
  235. )
  236. @torch.jit.ignore
  237. def set_grad_checkpointing(self, enable: bool = True):
  238. if hasattr(self.backbone, 'set_grad_checkpointing'):
  239. self.backbone.set_grad_checkpointing(enable=enable)
  240. elif hasattr(self.backbone, 'grad_checkpointing'):
  241. self.backbone.grad_checkpointing = enable
  242. def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
  243. x = self.backbone(x)
  244. if isinstance(x, (list, tuple)):
  245. x = x[-1] # last feature if backbone outputs list/tuple of features
  246. x = self.proj(x)
  247. return x.flatten(2).transpose(1, 2), x.shape[-2:]