hiera.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049
  1. """ An PyTorch implementation of Hiera
  2. Adapted for timm from originals at https://github.com/facebookresearch/hiera
  3. """
  4. # Copyright (c) Meta Platforms, Inc. and affiliates.
  5. # All rights reserved.
  6. # This source code is licensed under the license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. # --------------------------------------------------------
  9. #
  10. # Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
  11. #
  12. # Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan,
  13. # Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed,
  14. # Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
  15. #
  16. # Paper: https://arxiv.org/abs/2306.00989/
  17. #
  18. # References:
  19. # slowfast: https://github.com/facebookresearch/SlowFast
  20. # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
  21. # --------------------------------------------------------
  22. import math
  23. from functools import partial
  24. from typing import Dict, List, Optional, Tuple, Type, Union
  25. import torch
  26. import torch.nn as nn
  27. import torch.nn.functional as F
  28. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  29. from timm.layers import (
  30. DropPath,
  31. calculate_drop_path_rates,
  32. Mlp,
  33. LayerScale,
  34. ClNormMlpClassifierHead,
  35. use_fused_attn,
  36. _assert,
  37. get_norm_layer,
  38. to_2tuple,
  39. init_weight_vit,
  40. init_weight_jax,
  41. )
  42. from ._registry import generate_default_cfgs, register_model
  43. from ._builder import build_model_with_cfg
  44. from ._features import feature_take_indices
  45. from ._features_fx import register_notrace_function
  46. from ._manipulate import named_apply, checkpoint
  47. __all__ = ['Hiera']
  48. def conv_nd(n: int) -> Type[nn.Module]:
  49. """
  50. Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.
  51. If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises)
  52. """
  53. return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
  54. @register_notrace_function
  55. def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor:
  56. # target_size: [(T), (H), W]
  57. # (spatial) mask: [B, C, (t), (h), w]
  58. if mask is None:
  59. return mask
  60. _assert(len(mask.shape[2:]) == len(target_size), "mask spatial shape and target_size must match.")
  61. if mask.shape[2:] != target_size:
  62. return F.interpolate(mask.float(), size=target_size)
  63. return mask
  64. def undo_windowing(
  65. x: torch.Tensor,
  66. shape: List[int],
  67. mu_shape: List[int],
  68. ) -> torch.Tensor:
  69. """
  70. Restore spatial organization by undoing windowed organization of mask units.
  71. Args:
  72. x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C]
  73. shape: current spatial shape, if it were not organized into mask unit
  74. windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C].
  75. mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx]
  76. Returns:
  77. x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C]
  78. """
  79. D = len(shape)
  80. B, C = x.shape[0], x.shape[-1]
  81. # [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C]
  82. num_MUs = [s // mu for s, mu in zip(shape, mu_shape)]
  83. x = x.view(B, *num_MUs, *mu_shape, C)
  84. # [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C]
  85. permute = (
  86. [0]
  87. + sum([list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))], [])
  88. + [len(x.shape) - 1]
  89. )
  90. x = x.permute(permute).reshape(B, *shape, C)
  91. return x
  92. class Unroll(nn.Module):
  93. """
  94. Reorders the tokens such that patches are contiguous in memory.
  95. E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as
  96. [B, (Sy, Sx, H // Sy, W // Sx), C]
  97. This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1).
  98. Not only is this faster, but it also makes it easy to support inputs of arbitrary
  99. dimensions in addition to patch-wise sparsity.
  100. Performing this operation multiple times in sequence puts entire windows as contiguous
  101. in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
  102. size 8x8 would be contiguous in memory, allowing operations like mask unit attention
  103. computed easily and efficiently, while also allowing max to be applied sequentially.
  104. Note: This means that intermediate values of the model are not in HxW order, so they
  105. need to be re-rolled if you want to use the intermediate values as a HxW feature map.
  106. The last block of the network is fine though, since by then the strides are all consumed.
  107. """
  108. def __init__(
  109. self,
  110. input_size: Tuple[int, ...],
  111. patch_stride: Tuple[int, ...],
  112. unroll_schedule: List[Tuple[int, ...]],
  113. ):
  114. super().__init__()
  115. self.size = [i // s for i, s in zip(input_size, patch_stride)]
  116. self.schedule = unroll_schedule
  117. def forward(self, x: torch.Tensor) -> torch.Tensor:
  118. """
  119. Input: Flattened patch embeddings [B, N, C]
  120. Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
  121. """
  122. B, _, C = x.shape
  123. cur_size = self.size
  124. x = x.view(*([B] + cur_size + [C]))
  125. for strides in self.schedule:
  126. # Move patches with the given strides to the batch dimension
  127. # Create a view of the tensor with the patch stride as separate dims
  128. # For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C]
  129. cur_size = [i // s for i, s in zip(cur_size, strides)]
  130. new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C]
  131. x = x.view(new_shape)
  132. # Move the patch stride into the batch dimension
  133. # For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C]
  134. L = len(new_shape)
  135. permute = [0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1]
  136. x = x.permute(permute)
  137. # Now finally flatten the relevant dims into the batch dimension
  138. x = x.flatten(0, len(strides))
  139. B *= math.prod(strides)
  140. x = x.reshape(-1, math.prod(self.size), C)
  141. return x
  142. class Reroll(nn.Module):
  143. """
  144. Undos the "unroll" operation so that you can use intermediate features.
  145. """
  146. def __init__(
  147. self,
  148. input_size: Tuple[int, ...],
  149. patch_stride: Tuple[int, ...],
  150. unroll_schedule: List[Tuple[int, ...]],
  151. stage_ends: List[int],
  152. q_pool: int,
  153. ):
  154. super().__init__()
  155. self.size = [i // s for i, s in zip(input_size, patch_stride)]
  156. # The first stage has to reverse everything
  157. # The next stage has to reverse all but the first unroll, etc.
  158. self.schedule = {}
  159. size = self.size
  160. for i in range(stage_ends[-1] + 1):
  161. self.schedule[i] = unroll_schedule, size
  162. # schedule unchanged if no pooling at a stage end
  163. if i in stage_ends[:q_pool]:
  164. if len(unroll_schedule) > 0:
  165. size = [n // s for n, s in zip(size, unroll_schedule[0])]
  166. unroll_schedule = unroll_schedule[1:]
  167. def forward(
  168. self,
  169. x: torch.Tensor,
  170. block_idx: int,
  171. mask: torch.Tensor = None
  172. ) -> torch.Tensor:
  173. """
  174. Roll the given tensor back up to spatial order assuming it's from the given block.
  175. If no mask is provided:
  176. - Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc.
  177. If a mask is provided:
  178. - Returns [B, #MUs, MUy, MUx, C] for 2d, etc.
  179. """
  180. schedule, size = self.schedule[block_idx]
  181. B, N, C = x.shape
  182. D = len(size)
  183. cur_mu_shape = [1] * D
  184. for strides in schedule:
  185. # Extract the current patch from N
  186. x = x.view(B, *strides, N // math.prod(strides), *cur_mu_shape, C)
  187. # Move that patch into the current MU
  188. # Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C]
  189. L = len(x.shape)
  190. permute = (
  191. [0, 1 + D]
  192. + sum([list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))], [])
  193. + [L - 1]
  194. )
  195. x = x.permute(permute)
  196. # Reshape to [B, N//(Sy*Sx), *MU, C]
  197. for i in range(D):
  198. cur_mu_shape[i] *= strides[i]
  199. x = x.reshape(B, -1, *cur_mu_shape, C)
  200. N = x.shape[1]
  201. # Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C])
  202. x = x.view(B, N, *cur_mu_shape, C)
  203. # If masked, return [B, #MUs, MUy, MUx, C]
  204. if mask is not None:
  205. return x
  206. # If not masked, we can return [B, H, W, C]
  207. x = undo_windowing(x, size, cur_mu_shape)
  208. return x
  209. class MaskUnitAttention(nn.Module):
  210. """
  211. Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
  212. Note: this assumes the tokens have already been flattened and unrolled into mask units.
  213. See `Unroll` for more details.
  214. """
  215. fused_attn: torch.jit.Final[bool]
  216. def __init__(
  217. self,
  218. dim: int,
  219. dim_out: int,
  220. heads: int,
  221. q_stride: int = 1,
  222. window_size: int = 0,
  223. use_mask_unit_attn: bool = False,
  224. device=None,
  225. dtype=None,
  226. ):
  227. """
  228. Args:
  229. - dim, dim_out: The input and output feature dimensions.
  230. - heads: The number of attention heads.
  231. - q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4).
  232. - window_size: The current (flattened) size of a mask unit *after* pooling (if any).
  233. - use_mask_unit_attn: Use Mask Unit or Global Attention.
  234. """
  235. dd = {'device': device, 'dtype': dtype}
  236. super().__init__()
  237. self.dim = dim
  238. self.dim_out = dim_out
  239. self.heads = heads
  240. self.q_stride = q_stride
  241. self.head_dim = dim_out // heads
  242. self.scale = self.head_dim ** -0.5
  243. self.fused_attn = use_fused_attn()
  244. self.qkv = nn.Linear(dim, 3 * dim_out, **dd)
  245. self.proj = nn.Linear(dim_out, dim_out, **dd)
  246. self.window_size = window_size
  247. self.use_mask_unit_attn = use_mask_unit_attn
  248. def forward(self, x: torch.Tensor) -> torch.Tensor:
  249. """ Input should be of shape [batch, tokens, channels]. """
  250. B, N, _ = x.shape
  251. if self.use_mask_unit_attn:
  252. # Windowed attention: 5D path [B, heads, num_windows, tokens_per_window, head_dim]
  253. num_windows = N // (self.q_stride * self.window_size)
  254. qkv = self.qkv(x).reshape(
  255. B, -1, num_windows, 3, self.heads, self.head_dim,
  256. ).permute(3, 0, 4, 2, 1, 5)
  257. q, k, v = qkv.unbind(0)
  258. if self.q_stride > 1:
  259. # Refer to Unroll to see how this performs a maxpool-Nd
  260. q = q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim).amax(dim=3)
  261. else:
  262. # Global attention: 4D path [B, heads, N, head_dim]
  263. # Avoids the dummy num_windows=1 dimension that prevents FlashAttention dispatch.
  264. qkv = self.qkv(x).reshape(B, N, 3, self.heads, self.head_dim).permute(2, 0, 3, 1, 4)
  265. q, k, v = qkv.unbind(0)
  266. if self.q_stride > 1:
  267. # dim=2 instead of dim=3 because num_windows dimension is absent
  268. q = q.view(B, self.heads, self.q_stride, -1, self.head_dim).amax(dim=2)
  269. # Enforce contiguous memory layout so SDPA dispatches to FlashAttention
  270. # instead of silently falling back to the O(N^2) math backend.
  271. q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
  272. if self.fused_attn:
  273. # Note: the original paper did *not* use SDPA, it's a free boost!
  274. x = F.scaled_dot_product_attention(q, k, v)
  275. else:
  276. attn = (q * self.scale) @ k.transpose(-1, -2)
  277. attn = attn.softmax(dim=-1)
  278. x = attn @ v
  279. # Output transpose adapts to 5D (windowed) vs 4D (global) layout
  280. if self.use_mask_unit_attn:
  281. x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
  282. else:
  283. x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
  284. x = self.proj(x)
  285. return x
  286. class HieraBlock(nn.Module):
  287. def __init__(
  288. self,
  289. dim: int,
  290. dim_out: int,
  291. heads: int,
  292. mlp_ratio: float = 4.0,
  293. drop_path: float = 0.0,
  294. init_values: Optional[float] = None,
  295. norm_layer: Type[nn.Module] = nn.LayerNorm,
  296. act_layer: Type[nn.Module] = nn.GELU,
  297. q_stride: int = 1,
  298. window_size: int = 0,
  299. use_expand_proj: bool = True,
  300. use_mask_unit_attn: bool = False,
  301. device=None,
  302. dtype=None,
  303. ):
  304. dd = {'device': device, 'dtype': dtype}
  305. super().__init__()
  306. self.dim = dim
  307. self.dim_out = dim_out
  308. self.norm1 = norm_layer(dim, **dd)
  309. if dim != dim_out:
  310. self.do_expand = True
  311. if use_expand_proj:
  312. self.proj = nn.Linear(dim, dim_out, **dd)
  313. else:
  314. assert dim_out == dim * 2
  315. self.proj = None
  316. else:
  317. self.do_expand = False
  318. self.proj = None
  319. self.attn = MaskUnitAttention(
  320. dim,
  321. dim_out,
  322. heads,
  323. q_stride,
  324. window_size,
  325. use_mask_unit_attn,
  326. **dd
  327. )
  328. self.ls1 = LayerScale(dim_out, init_values=init_values, **dd) if init_values is not None else nn.Identity()
  329. self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
  330. self.norm2 = norm_layer(dim_out, **dd)
  331. self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer, **dd)
  332. self.ls2 = LayerScale(dim_out, init_values=init_values, **dd) if init_values is not None else nn.Identity()
  333. self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
  334. def forward(self, x: torch.Tensor) -> torch.Tensor:
  335. # Attention + Q Pooling
  336. x_norm = self.norm1(x)
  337. if self.do_expand:
  338. if self.proj is not None:
  339. x = self.proj(x_norm)
  340. x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool
  341. else:
  342. x = torch.cat([
  343. x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1), # max-pool
  344. x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).mean(dim=1), # avg-pool
  345. ],
  346. dim=-1,
  347. )
  348. x = x + self.drop_path1(self.ls1(self.attn(x_norm)))
  349. # MLP
  350. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  351. return x
  352. class PatchEmbed(nn.Module):
  353. """Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
  354. def __init__(
  355. self,
  356. dim_in: int,
  357. dim_out: int,
  358. kernel: Tuple[int, ...],
  359. stride: Tuple[int, ...],
  360. padding: Tuple[int, ...],
  361. reshape: bool = True,
  362. device=None,
  363. dtype=None,
  364. ):
  365. dd = {'device': device, 'dtype': dtype}
  366. super().__init__()
  367. # Support any number of spatial dimensions
  368. self.spatial_dims = len(kernel)
  369. self.reshape = reshape
  370. self.proj = conv_nd(self.spatial_dims)(
  371. dim_in,
  372. dim_out,
  373. kernel_size=kernel,
  374. stride=stride,
  375. padding=padding,
  376. **dd,
  377. )
  378. def forward(
  379. self,
  380. x: torch.Tensor,
  381. mask: Optional[torch.Tensor] = None,
  382. ) -> torch.Tensor:
  383. if mask is not None:
  384. mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
  385. x = self.proj(x * mask.to(torch.bool))
  386. else:
  387. x = self.proj(x)
  388. if self.reshape:
  389. x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1)
  390. return x
  391. class Hiera(nn.Module):
  392. def __init__(
  393. self,
  394. img_size: Tuple[int, ...] = (224, 224),
  395. in_chans: int = 3,
  396. embed_dim: int = 96, # initial embed dim
  397. num_heads: int = 1, # initial number of heads
  398. num_classes: int = 1000,
  399. global_pool: str = 'avg',
  400. stages: Tuple[int, ...] = (2, 3, 16, 3),
  401. q_pool: int = 3, # number of q_pool stages
  402. q_stride: Tuple[int, ...] = (2, 2),
  403. mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1)
  404. # mask_unit_attn: which stages use mask unit attention?
  405. mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
  406. use_expand_proj: bool = True,
  407. dim_mul: float = 2.0,
  408. head_mul: float = 2.0,
  409. patch_kernel: Tuple[int, ...] = (7, 7),
  410. patch_stride: Tuple[int, ...] = (4, 4),
  411. patch_padding: Tuple[int, ...] = (3, 3),
  412. mlp_ratio: float = 4.0,
  413. drop_path_rate: float = 0.0,
  414. init_values: Optional[float] = None,
  415. fix_init: bool = True,
  416. weight_init: str = '',
  417. norm_layer: Union[str, Type[nn.Module]] = "LayerNorm",
  418. drop_rate: float = 0.0,
  419. patch_drop_rate: float = 0.0,
  420. head_init_scale: float = 0.001,
  421. sep_pos_embed: bool = False,
  422. abs_win_pos_embed: bool = False,
  423. global_pos_size: Tuple[int, int] = (14, 14),
  424. device=None,
  425. dtype=None,
  426. ):
  427. super().__init__()
  428. dd = {'device': device, 'dtype': dtype}
  429. self.num_classes = num_classes
  430. self.in_chans = in_chans
  431. self.grad_checkpointing = False
  432. norm_layer = get_norm_layer(norm_layer)
  433. if isinstance(img_size, int):
  434. img_size = to_2tuple(img_size)
  435. self.patch_stride = patch_stride
  436. self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
  437. num_tokens = math.prod(self.tokens_spatial_shape)
  438. flat_mu_size = math.prod(mask_unit_size)
  439. flat_q_stride = math.prod(q_stride)
  440. assert q_pool < len(stages)
  441. self.q_pool, self.q_stride = q_pool, q_stride
  442. self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
  443. self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size)]
  444. self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
  445. self.patch_drop_rate = patch_drop_rate
  446. self.patch_embed = PatchEmbed(
  447. in_chans,
  448. embed_dim,
  449. patch_kernel,
  450. patch_stride,
  451. patch_padding,
  452. **dd,
  453. )
  454. self.pos_embed: Optional[nn.Parameter] = None
  455. self.pos_embed_win: Optional[nn.Parameter] = None
  456. self.pos_embed_spatial: Optional[nn.Parameter] = None
  457. self.pos_embed_temporal: Optional[nn.Parameter] = None
  458. if sep_pos_embed:
  459. self.pos_embed_spatial = nn.Parameter(
  460. torch.zeros(1, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], embed_dim, **dd)
  461. )
  462. self.pos_embed_temporal = nn.Parameter(
  463. torch.zeros(1, self.tokens_spatial_shape[0], embed_dim, **dd)
  464. )
  465. else:
  466. if abs_win_pos_embed:
  467. # absolute win, params NCHW to make tile & interpolate more natural before add & reshape
  468. self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *global_pos_size, **dd))
  469. self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size, **dd))
  470. else:
  471. self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim, **dd))
  472. # Setup roll and reroll modules
  473. self.unroll = Unroll(
  474. img_size,
  475. patch_stride,
  476. [q_stride] * len(self.stage_ends[:-1])
  477. )
  478. self.reroll = Reroll(
  479. img_size,
  480. patch_stride,
  481. [q_stride] * len(self.stage_ends[:-1]),
  482. self.stage_ends,
  483. q_pool,
  484. )
  485. # q_pool locations
  486. q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
  487. # Transformer blocks
  488. cur_stage = 0
  489. depth = sum(stages)
  490. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  491. self.blocks = nn.ModuleList()
  492. self.feature_info = []
  493. for i in range(depth):
  494. dim_out = embed_dim
  495. # Mask unit or global attention.
  496. # Lag by 1 block, so that global attention,
  497. # applied post pooling on lower resolution
  498. use_mask_unit_attn = mask_unit_attn[cur_stage]
  499. if i - 1 in self.stage_ends:
  500. dim_out = int(embed_dim * dim_mul)
  501. num_heads = int(num_heads * head_mul)
  502. cur_stage += 1
  503. if i in q_pool_blocks:
  504. flat_mu_size //= flat_q_stride
  505. block = HieraBlock(
  506. dim=embed_dim,
  507. dim_out=dim_out,
  508. heads=num_heads,
  509. mlp_ratio=mlp_ratio,
  510. drop_path=dpr[i],
  511. init_values=init_values,
  512. norm_layer=norm_layer,
  513. q_stride=(flat_q_stride if i in q_pool_blocks else 1),
  514. window_size=flat_mu_size,
  515. use_expand_proj=use_expand_proj,
  516. use_mask_unit_attn=use_mask_unit_attn,
  517. **dd,
  518. )
  519. embed_dim = dim_out
  520. if i in self.stage_ends:
  521. self.feature_info += [
  522. dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
  523. self.blocks.append(block)
  524. self.num_features = self.head_hidden_size = embed_dim
  525. self.head = ClNormMlpClassifierHead(
  526. embed_dim,
  527. num_classes,
  528. pool_type=global_pool,
  529. drop_rate=drop_rate,
  530. norm_layer=norm_layer,
  531. input_fmt='NLC',
  532. **dd,
  533. )
  534. # Initialize everything
  535. if sep_pos_embed:
  536. nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
  537. nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
  538. else:
  539. if self.pos_embed is not None:
  540. nn.init.trunc_normal_(self.pos_embed, std=0.02)
  541. if self.pos_embed_win is not None:
  542. nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
  543. if weight_init != 'skip':
  544. init_fn = init_weight_jax if weight_init == 'jax' else init_weight_vit
  545. init_fn = partial(init_fn, classifier_name='head.fc')
  546. named_apply(init_fn, self)
  547. if fix_init:
  548. self.fix_init_weight()
  549. if isinstance(self.head.fc, nn.Linear):
  550. self.head.fc.weight.data.mul_(head_init_scale)
  551. self.head.fc.bias.data.mul_(head_init_scale)
  552. def fix_init_weight(self):
  553. def rescale(param, _layer_id):
  554. param.div_(math.sqrt(2.0 * _layer_id))
  555. for layer_id, layer in enumerate(self.blocks):
  556. rescale(layer.attn.proj.weight.data, layer_id + 1)
  557. rescale(layer.mlp.fc2.weight.data, layer_id + 1)
  558. @torch.jit.ignore
  559. def no_weight_decay(self):
  560. if self.pos_embed is not None:
  561. return ["pos_embed"]
  562. elif self.pos_embed_abs is not None:
  563. return ['pos_embed_abs', 'pos_embed_win']
  564. else:
  565. return ["pos_embed_spatial", "pos_embed_temporal"]
  566. @torch.jit.ignore
  567. def group_matcher(self, coarse: bool = False) -> Dict:
  568. return dict(
  569. stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|pos_embed_abs|pos_embed_win|patch_embed',
  570. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  571. )
  572. @torch.jit.ignore
  573. def set_grad_checkpointing(self, enable: bool = True) -> None:
  574. self.grad_checkpointing = enable
  575. @torch.jit.ignore
  576. def get_classifier(self):
  577. return self.head.fc
  578. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
  579. self.num_classes = num_classes
  580. self.head.reset(num_classes, global_pool, reset_other=reset_other)
  581. def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
  582. """
  583. Generates a random mask, mask_ratio fraction are dropped.
  584. 1 is *keep*, 0 is *remove*. Useful for MAE, FLIP, etc.
  585. """
  586. B = x.shape[0]
  587. # Tokens selected for masking at mask unit level
  588. num_windows = math.prod(self.mask_spatial_shape) # num_mask_units
  589. len_keep = int(num_windows * (1 - mask_ratio))
  590. noise = torch.rand(B, num_windows, device=x.device)
  591. # Sort noise for each sample
  592. ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
  593. ids_restore = torch.argsort(ids_shuffle, dim=1)
  594. # Generate the binary mask: 1 is *keep*, 0 is *remove*
  595. # Note this is opposite to original MAE
  596. mask = torch.zeros([B, num_windows], device=x.device)
  597. mask[:, :len_keep] = 1
  598. # Unshuffle to get the binary mask
  599. mask = torch.gather(mask, dim=1, index=ids_restore)
  600. return mask.bool()
  601. def _pos_embed(self, x) -> torch.Tensor:
  602. if self.pos_embed_win is not None:
  603. # absolute win position embedding, from
  604. # Window Attention is Bugged: How not to Interpolate Position Embeddings (https://arxiv.org/abs/2311.05613)
  605. pos_embed_win = self.pos_embed_win.tile(self.mask_spatial_shape)
  606. pos_embed = F.interpolate(
  607. self.pos_embed,
  608. size=pos_embed_win.shape[-2:],
  609. mode='bicubic',
  610. antialias=True,
  611. )
  612. pos_embed = pos_embed + pos_embed_win
  613. pos_embed = pos_embed.flatten(2).transpose(1, 2)
  614. elif self.pos_embed is not None:
  615. pos_embed = self.pos_embed
  616. else:
  617. pos_embed = (
  618. self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1)
  619. +
  620. torch.repeat_interleave(
  621. self.pos_embed_temporal,
  622. self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
  623. dim=1,
  624. )
  625. )
  626. x = x + pos_embed
  627. return x
  628. def forward_intermediates(
  629. self,
  630. x: torch.Tensor,
  631. mask: Optional[torch.Tensor] = None,
  632. indices: Optional[Union[int, List[int]]] = None,
  633. norm: bool = False,
  634. stop_early: bool = True,
  635. output_fmt: str = 'NCHW',
  636. intermediates_only: bool = False,
  637. coarse: bool = True,
  638. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  639. """ Forward features that returns intermediates.
  640. Args:
  641. x: Input image tensor
  642. indices: Take last n blocks if int, all if None, select matching indices if sequence
  643. norm: Apply norm layer to all intermediates
  644. stop_early: Stop iterating over blocks when last desired intermediate hit
  645. output_fmt: Shape of intermediate feature outputs
  646. intermediates_only: Only return intermediate features
  647. Returns:
  648. """
  649. assert not norm, 'normalization of features not supported'
  650. assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
  651. if coarse:
  652. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  653. take_indices = [self.stage_ends[i] for i in take_indices]
  654. max_index = self.stage_ends[max_index]
  655. else:
  656. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  657. if mask is not None:
  658. patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
  659. else:
  660. patch_mask = None
  661. x = self.patch_embed(x, mask=patch_mask)
  662. x = self._pos_embed(x)
  663. x = self.unroll(x)
  664. # Discard masked tokens
  665. if mask is not None:
  666. x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1])
  667. intermediates = []
  668. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  669. blocks = self.blocks
  670. else:
  671. blocks = self.blocks[:max_index + 1]
  672. for i, blk in enumerate(blocks):
  673. if self.grad_checkpointing and not torch.jit.is_scripting():
  674. x = checkpoint(blk, x)
  675. else:
  676. x = blk(x)
  677. if i in take_indices:
  678. x_int = self.reroll(x, i, mask=mask)
  679. intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int)
  680. if intermediates_only:
  681. return intermediates
  682. return x, intermediates
  683. def prune_intermediate_layers(
  684. self,
  685. indices: Union[int, List[int]] = 1,
  686. prune_norm: bool = False,
  687. prune_head: bool = True,
  688. coarse: bool = True,
  689. ):
  690. """ Prune layers not required for specified intermediates.
  691. """
  692. if coarse:
  693. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  694. max_index = self.stage_ends[max_index]
  695. else:
  696. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  697. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  698. if prune_head:
  699. self.head.reset(0, reset_other=True)
  700. return take_indices
  701. def forward_features(
  702. self,
  703. x: torch.Tensor,
  704. mask: Optional[torch.Tensor] = None,
  705. return_intermediates: bool = False,
  706. ) -> torch.Tensor:
  707. """
  708. mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
  709. Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
  710. """
  711. if self.training and self.patch_drop_rate > 0:
  712. # using mask for something like 'patch dropout' via mask-units in supervised train / fine-tune
  713. assert mask is None
  714. mask = self.get_random_mask(x, mask_ratio=self.patch_drop_rate)
  715. if mask is not None:
  716. patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
  717. else:
  718. patch_mask = None
  719. x = self.patch_embed(x, mask=patch_mask)
  720. x = self._pos_embed(x)
  721. x = self.unroll(x)
  722. # Discard masked tokens
  723. if mask is not None:
  724. x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1])
  725. intermediates = []
  726. for i, blk in enumerate(self.blocks):
  727. if self.grad_checkpointing and not torch.jit.is_scripting():
  728. x = checkpoint(blk, x)
  729. else:
  730. x = blk(x)
  731. if return_intermediates and i in self.stage_ends:
  732. intermediates.append(self.reroll(x, i, mask=mask))
  733. # x may not always be in spatial order here.
  734. # e.g. if q_pool = 2, mask_unit_size = (8, 8), and
  735. # q_stride = (2, 2), not all unrolls were consumed,
  736. # intermediates[-1] is x in spatial order
  737. if return_intermediates:
  738. return x, intermediates
  739. return x
  740. def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
  741. x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  742. return x
  743. def forward(
  744. self,
  745. x: torch.Tensor,
  746. mask: Optional[torch.Tensor] = None,
  747. ) -> torch.Tensor:
  748. x = self.forward_features(x, mask=mask)
  749. if mask is None:
  750. x = self.forward_head(x)
  751. return x
  752. def _cfg(url='', **kwargs):
  753. return {
  754. 'url': url,
  755. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  756. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  757. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  758. 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
  759. 'license': 'apache-2.0',
  760. **kwargs
  761. }
  762. default_cfgs = generate_default_cfgs({
  763. "hiera_tiny_224.mae_in1k_ft_in1k": _cfg(
  764. hf_hub_id='timm/',
  765. license='cc-by-nc-4.0',
  766. ),
  767. "hiera_tiny_224.mae": _cfg(
  768. hf_hub_id='timm/',
  769. license='cc-by-nc-4.0',
  770. num_classes=0,
  771. ),
  772. "hiera_small_224.mae_in1k_ft_in1k": _cfg(
  773. hf_hub_id='timm/',
  774. license='cc-by-nc-4.0',
  775. ),
  776. "hiera_small_224.mae": _cfg(
  777. hf_hub_id='timm/',
  778. license='cc-by-nc-4.0',
  779. num_classes=0,
  780. ),
  781. "hiera_base_224.mae_in1k_ft_in1k": _cfg(
  782. hf_hub_id='timm/',
  783. license='cc-by-nc-4.0',
  784. ),
  785. "hiera_base_224.mae": _cfg(
  786. hf_hub_id='timm/',
  787. license='cc-by-nc-4.0',
  788. num_classes=0,
  789. ),
  790. "hiera_base_plus_224.mae_in1k_ft_in1k": _cfg(
  791. hf_hub_id='timm/',
  792. license='cc-by-nc-4.0',
  793. ),
  794. "hiera_base_plus_224.mae": _cfg(
  795. hf_hub_id='timm/',
  796. license='cc-by-nc-4.0',
  797. num_classes=0,
  798. ),
  799. "hiera_large_224.mae_in1k_ft_in1k": _cfg(
  800. hf_hub_id='timm/',
  801. license='cc-by-nc-4.0',
  802. ),
  803. "hiera_large_224.mae": _cfg(
  804. hf_hub_id='timm/',
  805. license='cc-by-nc-4.0',
  806. num_classes=0,
  807. ),
  808. "hiera_huge_224.mae_in1k_ft_in1k": _cfg(
  809. hf_hub_id='timm/',
  810. license='cc-by-nc-4.0',
  811. ),
  812. "hiera_huge_224.mae": _cfg(
  813. hf_hub_id='timm/',
  814. license='cc-by-nc-4.0',
  815. num_classes=0,
  816. ),
  817. "hiera_small_abswin_256.sbb2_e200_in12k_ft_in1k": _cfg(
  818. hf_hub_id='timm/',
  819. input_size=(3, 256, 256), crop_pct=0.95,
  820. ),
  821. "hiera_small_abswin_256.sbb2_pd_e200_in12k_ft_in1k": _cfg(
  822. hf_hub_id='timm/',
  823. input_size=(3, 256, 256), crop_pct=0.95,
  824. ),
  825. "hiera_small_abswin_256.sbb2_e200_in12k": _cfg(
  826. hf_hub_id='timm/',
  827. num_classes=11821,
  828. input_size=(3, 256, 256), crop_pct=0.95,
  829. ),
  830. "hiera_small_abswin_256.sbb2_pd_e200_in12k": _cfg(
  831. hf_hub_id='timm/',
  832. num_classes=11821,
  833. input_size=(3, 256, 256), crop_pct=0.95,
  834. ),
  835. "hiera_base_abswin_256.untrained": _cfg(
  836. # hf_hub_id='timm/',
  837. input_size=(3, 256, 256), crop_pct=0.95,
  838. ),
  839. })
  840. def checkpoint_filter_fn(state_dict, model=None):
  841. state_dict = state_dict.get('model_state', state_dict)
  842. output = {}
  843. for k, v in state_dict.items():
  844. # if k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
  845. # # To resize pos embedding when using model at different size from pretrained weights
  846. # from timm.layers import resample_abs_pos_embed
  847. # v = resample_abs_pos_embed(
  848. # v,
  849. # new_size=(64, 64),
  850. # num_prefix_tokens=0,
  851. # verbose=True,
  852. # )
  853. if 'head.projection.' in k:
  854. k = k.replace('head.projection.', 'head.fc.')
  855. if k.startswith('encoder_norm.'):
  856. k = k.replace('encoder_norm.', 'head.norm.')
  857. elif k.startswith('norm.'):
  858. k = k.replace('norm.', 'head.norm.')
  859. if k == 'pos_embed_abs':
  860. k = 'pos_embed'
  861. output[k] = v
  862. return output
  863. def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
  864. out_indices = kwargs.pop('out_indices', 4)
  865. return build_model_with_cfg(
  866. Hiera,
  867. variant,
  868. pretrained,
  869. pretrained_filter_fn=checkpoint_filter_fn,
  870. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  871. **kwargs,
  872. )
  873. @register_model
  874. def hiera_tiny_224(pretrained=False, **kwargs):
  875. model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))
  876. return _create_hiera('hiera_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs))
  877. @register_model
  878. def hiera_small_224(pretrained=False, **kwargs):
  879. model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2))
  880. return _create_hiera('hiera_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
  881. @register_model
  882. def hiera_base_224(pretrained=False, **kwargs):
  883. model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
  884. return _create_hiera('hiera_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
  885. @register_model
  886. def hiera_base_plus_224(pretrained=False, **kwargs):
  887. model_args = dict(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3))
  888. return _create_hiera('hiera_base_plus_224', pretrained=pretrained, **dict(model_args, **kwargs))
  889. @register_model
  890. def hiera_large_224(pretrained=False, **kwargs):
  891. model_args = dict(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4))
  892. return _create_hiera('hiera_large_224', pretrained=pretrained, **dict(model_args, **kwargs))
  893. @register_model
  894. def hiera_huge_224(pretrained=False, **kwargs):
  895. model_args = dict(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4))
  896. return _create_hiera('hiera_huge_224', pretrained=pretrained, **dict(model_args, **kwargs))
  897. @register_model
  898. def hiera_small_abswin_256(pretrained=False, **kwargs):
  899. model_args = dict(
  900. embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, global_pos_size=(16, 16),
  901. init_values=1e-5, weight_init='jax', use_expand_proj=False,
  902. )
  903. return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
  904. @register_model
  905. def hiera_base_abswin_256(pretrained=False, **kwargs):
  906. model_args = dict(
  907. embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, init_values=1e-5, weight_init='jax')
  908. return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))