patch_embed.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. """ Image to Patch Embedding using Conv2d
  2. A convolution based approach to patchifying a 2D image w/ embedding projection.
  3. Based on code in:
  4. * https://github.com/google-research/vision_transformer
  5. * https://github.com/google-research/big_vision/tree/main/big_vision
  6. Hacked together by / Copyright 2020 Ross Wightman
  7. """
  8. import logging
  9. import math
  10. from typing import Callable, Dict, List, Optional, Tuple, Union
  11. import torch
  12. from torch import nn as nn
  13. import torch.nn.functional as F
  14. from .format import Format, nchw_to
  15. from .helpers import to_2tuple
  16. from .trace_utils import _assert
  17. _logger = logging.getLogger(__name__)
  18. class PatchEmbed(nn.Module):
  19. """ 2D Image to Patch Embedding
  20. """
  21. output_fmt: Format
  22. dynamic_img_pad: torch.jit.Final[bool]
  23. def __init__(
  24. self,
  25. img_size: Optional[Union[int, Tuple[int, int]]] = 224,
  26. patch_size: int = 16,
  27. in_chans: int = 3,
  28. embed_dim: int = 768,
  29. norm_layer: Optional[Callable] = None,
  30. flatten: bool = True,
  31. output_fmt: Optional[str] = None,
  32. bias: bool = True,
  33. strict_img_size: bool = True,
  34. dynamic_img_pad: bool = False,
  35. device=None,
  36. dtype=None,
  37. ):
  38. dd = {'device': device, 'dtype': dtype}
  39. super().__init__()
  40. self.patch_size = to_2tuple(patch_size)
  41. self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
  42. if output_fmt is not None:
  43. self.flatten = False
  44. self.output_fmt = Format(output_fmt)
  45. else:
  46. # flatten spatial dim and transpose to channels last, kept for bwd compat
  47. self.flatten = flatten
  48. self.output_fmt = Format.NCHW
  49. self.strict_img_size = strict_img_size
  50. self.dynamic_img_pad = dynamic_img_pad
  51. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **dd)
  52. self.norm = norm_layer(embed_dim, **dd) if norm_layer else nn.Identity()
  53. def _init_img_size(self, img_size: Union[int, Tuple[int, int]]):
  54. assert self.patch_size
  55. if img_size is None:
  56. return None, None, None
  57. img_size = to_2tuple(img_size)
  58. grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
  59. num_patches = grid_size[0] * grid_size[1]
  60. return img_size, grid_size, num_patches
  61. def set_input_size(
  62. self,
  63. img_size: Optional[Union[int, Tuple[int, int]]] = None,
  64. patch_size: Optional[Union[int, Tuple[int, int]]] = None,
  65. ):
  66. new_patch_size = None
  67. if patch_size is not None:
  68. new_patch_size = to_2tuple(patch_size)
  69. if new_patch_size is not None and new_patch_size != self.patch_size:
  70. with torch.no_grad():
  71. new_proj = nn.Conv2d(
  72. self.proj.in_channels,
  73. self.proj.out_channels,
  74. kernel_size=new_patch_size,
  75. stride=new_patch_size,
  76. bias=self.proj.bias is not None,
  77. device=self.proj.weight.device,
  78. dtype=self.proj.weight.dtype,
  79. )
  80. new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
  81. if self.proj.bias is not None:
  82. new_proj.bias.copy_(self.proj.bias)
  83. self.proj = new_proj
  84. self.patch_size = new_patch_size
  85. img_size = img_size or self.img_size
  86. if img_size != self.img_size or new_patch_size is not None:
  87. self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
  88. def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
  89. if as_scalar:
  90. return max(self.patch_size)
  91. else:
  92. return self.patch_size
  93. def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
  94. """ Get grid (feature) size for given image size taking account of dynamic padding.
  95. NOTE: must be torchscript compatible so using fixed tuple indexing
  96. """
  97. if self.dynamic_img_pad:
  98. return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
  99. else:
  100. return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
  101. def forward(self, x):
  102. B, C, H, W = x.shape
  103. if self.img_size is not None:
  104. if self.strict_img_size:
  105. _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
  106. _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
  107. elif not self.dynamic_img_pad:
  108. _assert(
  109. H % self.patch_size[0] == 0,
  110. f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
  111. )
  112. _assert(
  113. W % self.patch_size[1] == 0,
  114. f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
  115. )
  116. if self.dynamic_img_pad:
  117. pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
  118. pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
  119. x = F.pad(x, (0, pad_w, 0, pad_h))
  120. x = self.proj(x)
  121. if self.flatten:
  122. x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
  123. elif self.output_fmt != Format.NCHW:
  124. x = nchw_to(x, self.output_fmt)
  125. x = self.norm(x)
  126. return x
  127. class PatchEmbedWithSize(PatchEmbed):
  128. """ 2D Image to Patch Embedding
  129. """
  130. output_fmt: Format
  131. def __init__(
  132. self,
  133. img_size: Optional[Union[int, Tuple[int, int]]] = 224,
  134. patch_size: int = 16,
  135. in_chans: int = 3,
  136. embed_dim: int = 768,
  137. norm_layer: Optional[Callable] = None,
  138. flatten: bool = True,
  139. output_fmt: Optional[str] = None,
  140. bias: bool = True,
  141. device=None,
  142. dtype=None,
  143. ):
  144. super().__init__(
  145. img_size=img_size,
  146. patch_size=patch_size,
  147. in_chans=in_chans,
  148. embed_dim=embed_dim,
  149. norm_layer=norm_layer,
  150. flatten=flatten,
  151. output_fmt=output_fmt,
  152. bias=bias,
  153. device=device,
  154. dtype=dtype,
  155. )
  156. def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
  157. B, C, H, W = x.shape
  158. if self.img_size is not None:
  159. _assert(H % self.patch_size[0] == 0, f"Input image height ({H}) must be divisible by patch size ({self.patch_size[0]}).")
  160. _assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")
  161. x = self.proj(x)
  162. feat_size = x.shape[-2:]
  163. if self.flatten:
  164. x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
  165. elif self.output_fmt != Format.NCHW:
  166. x = nchw_to(x, self.output_fmt)
  167. x = self.norm(x)
  168. return x, feat_size
  169. # FIXME to remove, keeping for comparison for now
  170. def resample_patch_embed_old(
  171. patch_embed,
  172. new_size: List[int],
  173. interpolation: str = 'bicubic',
  174. antialias: bool = True,
  175. verbose: bool = False,
  176. ):
  177. """Resample the weights of the patch embedding kernel to target resolution.
  178. We resample the patch embedding kernel by approximately inverting the effect
  179. of patch resizing.
  180. Code based on:
  181. https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
  182. With this resizing, we can for example load a B/8 filter into a B/16 model
  183. and, on 2x larger input image, the result will match.
  184. Args:
  185. patch_embed: original parameter to be resized.
  186. new_size (tuple(int, int): target shape (height, width)-only.
  187. interpolation (str): interpolation for resize
  188. antialias (bool): use anti-aliasing filter in resize
  189. verbose (bool): log operation
  190. Returns:
  191. Resized patch embedding kernel.
  192. """
  193. import numpy as np
  194. try:
  195. from torch import vmap
  196. except ImportError:
  197. from functorch import vmap
  198. assert len(patch_embed.shape) == 4, "Four dimensions expected"
  199. assert len(new_size) == 2, "New shape should only be hw"
  200. old_size = patch_embed.shape[-2:]
  201. if tuple(old_size) == tuple(new_size):
  202. return patch_embed
  203. if verbose:
  204. _logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.")
  205. def resize(x_np, _new_size):
  206. x_tf = torch.Tensor(x_np)[None, None, ...]
  207. x_upsampled = F.interpolate(
  208. x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy()
  209. return x_upsampled
  210. def get_resize_mat(_old_size, _new_size):
  211. mat = []
  212. for i in range(np.prod(_old_size)):
  213. basis_vec = np.zeros(_old_size)
  214. basis_vec[np.unravel_index(i, _old_size)] = 1.
  215. mat.append(resize(basis_vec, _new_size).reshape(-1))
  216. return np.stack(mat).T
  217. resize_mat = get_resize_mat(old_size, new_size)
  218. resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device)
  219. def resample_kernel(kernel):
  220. resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
  221. return resampled_kernel.reshape(new_size)
  222. v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
  223. orig_dtype = patch_embed.dtype
  224. patch_embed = patch_embed.float()
  225. patch_embed = v_resample_kernel(patch_embed)
  226. patch_embed = patch_embed.to(orig_dtype)
  227. return patch_embed
  228. DTYPE_INTERMEDIATE = torch.float32
  229. def _compute_resize_matrix(
  230. old_size: Tuple[int, int],
  231. new_size: Tuple[int, int],
  232. interpolation: str,
  233. antialias: bool,
  234. device: torch.device,
  235. dtype: torch.dtype = DTYPE_INTERMEDIATE
  236. ) -> torch.Tensor:
  237. """Computes the resize matrix basis vectors and interpolates them to new_size."""
  238. old_h, old_w = old_size
  239. new_h, new_w = new_size
  240. old_total = old_h * old_w
  241. new_total = new_h * new_w
  242. eye_matrix = torch.eye(old_total, device=device, dtype=dtype)
  243. basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w)
  244. resized_basis_vectors_batch = F.interpolate(
  245. basis_vectors_batch,
  246. size=new_size,
  247. mode=interpolation,
  248. antialias=antialias,
  249. align_corners=False
  250. ) # Output shape: (old_total, 1, new_h, new_w)
  251. resize_matrix = resized_basis_vectors_batch.squeeze(1).permute(1, 2, 0).reshape(new_total, old_total)
  252. return resize_matrix # Shape: (new_total, old_total)
  253. def _apply_resampling(
  254. patch_embed: torch.Tensor,
  255. pinv_matrix: torch.Tensor,
  256. new_size_tuple: Tuple[int, int],
  257. orig_dtype: torch.dtype,
  258. intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE
  259. ) -> torch.Tensor:
  260. """ Simplified resampling w/o vmap use.
  261. As proposed by https://github.com/stas-sl
  262. """
  263. c_out, c_in, *_ = patch_embed.shape
  264. patch_embed = patch_embed.reshape(c_out, c_in, -1).to(dtype=intermediate_dtype)
  265. pinv_matrix = pinv_matrix.to(dtype=intermediate_dtype)
  266. resampled_patch_embed = patch_embed @ pinv_matrix # (C_out, C_in, P_old * P_old) @ (P_old * P_old, P_new * P_new)
  267. resampled_patch_embed = resampled_patch_embed.reshape(c_out, c_in, *new_size_tuple).to(dtype=orig_dtype)
  268. return resampled_patch_embed
  269. def resample_patch_embed(
  270. patch_embed: torch.Tensor,
  271. new_size: List[int],
  272. interpolation: str = 'bicubic',
  273. antialias: bool = True,
  274. verbose: bool = False,
  275. ):
  276. """ Standalone function (computes matrix on each call). """
  277. assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_ch, in_ch, h, w)"
  278. assert len(new_size) == 2, "New shape should only be hw (height, width)"
  279. old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:])
  280. new_size_tuple: Tuple[int, int] = tuple(new_size)
  281. if old_size_tuple == new_size_tuple:
  282. return patch_embed
  283. device = patch_embed.device
  284. orig_dtype = patch_embed.dtype
  285. resize_mat = _compute_resize_matrix(
  286. old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE
  287. )
  288. pinv_matrix = torch.linalg.pinv(resize_mat) # Calculates the pseudoinverse matrix used for resampling
  289. resampled_patch_embed = _apply_resampling(
  290. patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE
  291. )
  292. return resampled_patch_embed
  293. class PatchEmbedResamplerFixedOrigSize(nn.Module):
  294. """
  295. Resample patch embedding weights from a fixed original size,
  296. caching the pseudoinverse matrix based on the target size.
  297. """
  298. def __init__(
  299. self,
  300. orig_size: Tuple[int, int],
  301. interpolation: str = 'bicubic',
  302. antialias: bool = True
  303. ):
  304. """
  305. Args:
  306. orig_size (Tuple[int, int]): The expected original (height, width) of input patch_embed tensors.
  307. interpolation (str): Interpolation mode.
  308. antialias (bool): Use anti-aliasing filter in resize.
  309. """
  310. super().__init__()
  311. assert isinstance(orig_size, tuple) and len(orig_size) == 2, \
  312. "`orig_size` must be a tuple of (height, width)"
  313. self.orig_size = orig_size # expected original size
  314. self.interpolation = interpolation
  315. self.antialias = antialias
  316. # Cache map key is the target new_size tuple
  317. self._pinv_cache_map: Dict[Tuple[int, int], str] = {}
  318. def _get_or_create_pinv_matrix(
  319. self,
  320. new_size: Tuple[int, int],
  321. device: torch.device,
  322. dtype: torch.dtype = DTYPE_INTERMEDIATE
  323. ) -> torch.Tensor:
  324. """Retrieves the cached pinv matrix or computes and caches it for the given new_size."""
  325. cache_key = new_size
  326. buffer_name = self._pinv_cache_map.get(cache_key)
  327. if buffer_name and hasattr(self, buffer_name):
  328. pinv_matrix = getattr(self, buffer_name)
  329. if pinv_matrix.device == device and pinv_matrix.dtype == dtype:
  330. return pinv_matrix
  331. # Calculate the matrix if not cached or needs update
  332. resize_mat = _compute_resize_matrix(
  333. self.orig_size, new_size, self.interpolation, self.antialias, device, dtype
  334. )
  335. pinv_matrix = torch.linalg.pinv(resize_mat) # Calculates the pseudoinverse matrix used for resampling
  336. # Cache using register_buffer
  337. buffer_name = f"pinv_{new_size[0]}x{new_size[1]}"
  338. if hasattr(self, buffer_name):
  339. delattr(self, buffer_name)
  340. self.register_buffer(buffer_name, pinv_matrix)
  341. self._pinv_cache_map[cache_key] = buffer_name # Map new_size key to buffer name
  342. return pinv_matrix
  343. def forward(self, patch_embed: torch.Tensor, new_size: List[int]) -> torch.Tensor:
  344. """ Resamples the patch embedding weights to new_size.
  345. Args:
  346. patch_embed (torch.Tensor): Original weights (out_ch, in_ch, H_orig, W_orig).
  347. new_size (List[int]): Target [height, width].
  348. Returns:
  349. torch.Tensor: Resampled weights.
  350. """
  351. assert len(patch_embed.shape) == 4
  352. assert len(new_size) == 2
  353. # Input Validation
  354. input_size = tuple(patch_embed.shape[-2:])
  355. assert input_size == self.orig_size, \
  356. f"Input patch_embed spatial size {input_size} does not match " \
  357. f"module's expected original size {self.orig_size}"
  358. new_size_tuple: Tuple[int, int] = tuple(new_size)
  359. # Check no-op case against self.orig_size
  360. if self.orig_size == new_size_tuple:
  361. return patch_embed
  362. device = patch_embed.device
  363. orig_dtype = patch_embed.dtype
  364. # Get or compute the required pseudoinverse matrix
  365. pinv_matrix = self._get_or_create_pinv_matrix(new_size_tuple, device)
  366. # Apply the resampling
  367. resampled_patch_embed = _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype)
  368. return resampled_patch_embed
  369. class PatchEmbedInterpolator(nn.Module):
  370. """Dynamically interpolates patch embedding weights for variable patch sizes.
  371. This module wraps patch embedding weight resampling functionality to support
  372. on-the-fly patch size variation during training. It handles both Conv2d and
  373. Linear patch embeddings.
  374. Args:
  375. base_patch_size: The original patch size the model was initialized with
  376. in_chans: Number of input channels
  377. embed_dim: Embedding dimension
  378. interpolation: Interpolation mode for resampling
  379. antialias: Whether to use antialiasing during interpolation
  380. """
  381. def __init__(
  382. self,
  383. base_patch_size: Tuple[int, int],
  384. in_chans: int = 3,
  385. embed_dim: int = 768,
  386. interpolation: str = 'bicubic',
  387. antialias: bool = True,
  388. ):
  389. super().__init__()
  390. self.base_patch_size = base_patch_size
  391. self.in_chans = in_chans
  392. self.embed_dim = embed_dim
  393. self.interpolation = interpolation
  394. self.antialias = antialias
  395. def resample_linear_weight(
  396. self,
  397. weight: torch.Tensor,
  398. target_patch_size: Tuple[int, int],
  399. ) -> torch.Tensor:
  400. """Resample linear patch embedding weights for a new patch size.
  401. Args:
  402. weight: Linear weight tensor of shape [embed_dim, patch_h * patch_w * in_chans]
  403. target_patch_size: Target (patch_h, patch_w) to resample to
  404. Returns:
  405. Resampled weight tensor
  406. """
  407. if target_patch_size == self.base_patch_size:
  408. return weight
  409. embed_dim = weight.shape[0]
  410. base_ph, base_pw = self.base_patch_size
  411. target_ph, target_pw = target_patch_size
  412. # Reshape linear weight to conv2d format
  413. # [embed_dim, ph*pw*C] -> [embed_dim, C, ph, pw]
  414. weight_conv = weight.reshape(embed_dim, base_ph, base_pw, self.in_chans)
  415. weight_conv = weight_conv.permute(0, 3, 1, 2)
  416. # Resample using existing function
  417. weight_conv_resampled = resample_patch_embed(
  418. weight_conv,
  419. new_size=[target_ph, target_pw],
  420. interpolation=self.interpolation,
  421. antialias=self.antialias,
  422. verbose=False,
  423. )
  424. # Reshape back to linear format
  425. # [embed_dim, C, ph, pw] -> [embed_dim, ph*pw*C]
  426. weight_resampled = weight_conv_resampled.permute(0, 2, 3, 1)
  427. weight_resampled = weight_resampled.reshape(embed_dim, -1)
  428. return weight_resampled
  429. def resample_conv_weight(
  430. self,
  431. weight: torch.Tensor,
  432. target_patch_size: Tuple[int, int],
  433. ) -> torch.Tensor:
  434. """Resample conv2d patch embedding weights for a new patch size.
  435. Args:
  436. weight: Conv2d weight tensor of shape [embed_dim, in_chans, patch_h, patch_w]
  437. target_patch_size: Target (patch_h, patch_w) to resample to
  438. Returns:
  439. Resampled weight tensor
  440. """
  441. if target_patch_size == self.base_patch_size:
  442. return weight
  443. # Resample using existing function
  444. weight_resampled = resample_patch_embed(
  445. weight,
  446. new_size=list(target_patch_size),
  447. interpolation=self.interpolation,
  448. antialias=self.antialias,
  449. verbose=False,
  450. )
  451. return weight_resampled
  452. def forward(
  453. self,
  454. patches: torch.Tensor,
  455. proj_weight: torch.Tensor,
  456. proj_bias: Optional[torch.Tensor] = None,
  457. patch_size: Optional[Tuple[int, int]] = None,
  458. is_linear: bool = True,
  459. ) -> torch.Tensor:
  460. """Apply patch embedding with dynamic weight resampling.
  461. Args:
  462. patches: Input patches
  463. - For linear mode with resampling: [B, N, Ph, Pw, C]
  464. - For linear mode without resampling: [B, N, Ph*Pw*C]
  465. - For conv mode: [B, C, H, W]
  466. proj_weight: Original projection weight
  467. proj_bias: Optional projection bias
  468. patch_size: Current patch size (if None, uses base_patch_size)
  469. is_linear: Whether using linear (True) or conv2d (False) projection
  470. Returns:
  471. Embedded patches
  472. """
  473. if patch_size is None:
  474. patch_size = self.base_patch_size
  475. if is_linear:
  476. if patch_size != self.base_patch_size:
  477. # Need to resample - expects unflattened patches
  478. assert patches.ndim == 5, "Patches must be [B, N, Ph, Pw, C] for resampling"
  479. B, N, Ph, Pw, C = patches.shape
  480. # Resample the weight
  481. weight_resampled = self.resample_linear_weight(proj_weight, patch_size)
  482. # Flatten patches and apply linear projection
  483. patches_flat = patches.reshape(B, N, -1)
  484. output = torch.nn.functional.linear(patches_flat, weight_resampled, proj_bias)
  485. else:
  486. # No resampling needed, patches can be pre-flattened
  487. if patches.ndim == 5:
  488. B, N, Ph, Pw, C = patches.shape
  489. patches = patches.reshape(B, N, -1)
  490. output = torch.nn.functional.linear(patches, proj_weight, proj_bias)
  491. else:
  492. # Conv mode
  493. if patch_size != self.base_patch_size:
  494. weight_resampled = self.resample_conv_weight(proj_weight, patch_size)
  495. output = torch.nn.functional.conv2d(
  496. patches, weight_resampled, proj_bias,
  497. stride=patch_size, padding=0
  498. )
  499. else:
  500. output = torch.nn.functional.conv2d(
  501. patches, proj_weight, proj_bias,
  502. stride=patch_size, padding=0
  503. )
  504. return output
  505. # def divs(n, m=None):
  506. # m = m or n // 2
  507. # if m == 1:
  508. # return [1]
  509. # if n % m == 0:
  510. # return [m] + divs(n, m - 1)
  511. # return divs(n, m - 1)
  512. #
  513. #
  514. # class FlexiPatchEmbed(nn.Module):
  515. # """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT)
  516. # FIXME WIP
  517. # """
  518. # def __init__(
  519. # self,
  520. # img_size=240,
  521. # patch_size=16,
  522. # in_chans=3,
  523. # embed_dim=768,
  524. # base_img_size=240,
  525. # base_patch_size=32,
  526. # norm_layer=None,
  527. # flatten=True,
  528. # bias=True,
  529. # ):
  530. # super().__init__()
  531. # self.img_size = to_2tuple(img_size)
  532. # self.patch_size = to_2tuple(patch_size)
  533. # self.num_patches = 0
  534. #
  535. # # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48)
  536. # self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30)
  537. #
  538. # self.base_img_size = to_2tuple(base_img_size)
  539. # self.base_patch_size = to_2tuple(base_patch_size)
  540. # self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)])
  541. # self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1]
  542. #
  543. # self.flatten = flatten
  544. # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias)
  545. # self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
  546. #
  547. # def forward(self, x):
  548. # B, C, H, W = x.shape
  549. #
  550. # if self.patch_size == self.base_patch_size:
  551. # weight = self.proj.weight
  552. # else:
  553. # weight = resample_patch_embed(self.proj.weight, self.patch_size)
  554. # patch_size = self.patch_size
  555. # x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
  556. # if self.flatten:
  557. # x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
  558. # x = self.norm(x)
  559. # return x