hieradet_sam2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700
  1. import math
  2. from copy import deepcopy
  3. from functools import partial
  4. from typing import Dict, List, Optional, Tuple, Type, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  9. from timm.layers import (
  10. PatchEmbed,
  11. Mlp,
  12. DropPath,
  13. calculate_drop_path_rates,
  14. ClNormMlpClassifierHead,
  15. LayerScale,
  16. get_norm_layer,
  17. get_act_layer,
  18. init_weight_jax,
  19. init_weight_vit,
  20. to_2tuple,
  21. use_fused_attn,
  22. )
  23. from ._builder import build_model_with_cfg
  24. from ._features import feature_take_indices
  25. from ._manipulate import named_apply, checkpoint
  26. from ._registry import generate_default_cfgs, register_model
  27. def window_partition(x, window_size: Tuple[int, int]):
  28. """
  29. Partition into non-overlapping windows with padding if needed.
  30. Args:
  31. x (tensor): input tokens with [B, H, W, C].
  32. window_size (int): window size.
  33. Returns:
  34. windows: windows after partition with [B * num_windows, window_size, window_size, C].
  35. (Hp, Wp): padded height and width before partition
  36. """
  37. B, H, W, C = x.shape
  38. x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
  39. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
  40. return windows
  41. def window_unpartition(windows: torch.Tensor, window_size: Tuple[int, int], hw: Tuple[int, int]):
  42. """
  43. Window unpartition into original sequences and removing padding.
  44. Args:
  45. x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
  46. window_size (int): window size.
  47. hw (Tuple): original height and width (H, W) before padding.
  48. Returns:
  49. x: unpartitioned sequences with [B, H, W, C].
  50. """
  51. H, W = hw
  52. B = windows.shape[0] // (H * W // window_size[0] // window_size[1])
  53. x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
  54. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  55. return x
  56. def _calc_pad(H: int, W: int, window_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
  57. pad_h = (window_size[0] - H % window_size[0]) % window_size[0]
  58. pad_w = (window_size[1] - W % window_size[1]) % window_size[1]
  59. Hp, Wp = H + pad_h, W + pad_w
  60. return Hp, Wp, pad_h, pad_w
  61. class MultiScaleAttention(nn.Module):
  62. fused_attn: torch.jit.Final[bool]
  63. def __init__(
  64. self,
  65. dim: int,
  66. dim_out: int,
  67. num_heads: int,
  68. q_pool: nn.Module = None,
  69. device=None,
  70. dtype=None,
  71. ):
  72. dd = {'device': device, 'dtype': dtype}
  73. super().__init__()
  74. self.dim = dim
  75. self.dim_out = dim_out
  76. self.num_heads = num_heads
  77. head_dim = dim_out // num_heads
  78. self.scale = head_dim ** -0.5
  79. self.fused_attn = use_fused_attn()
  80. self.q_pool = q_pool
  81. self.qkv = nn.Linear(dim, dim_out * 3, **dd)
  82. self.proj = nn.Linear(dim_out, dim_out, **dd)
  83. def forward(self, x: torch.Tensor) -> torch.Tensor:
  84. B, H, W, _ = x.shape
  85. # qkv with shape (B, H * W, 3, nHead, C)
  86. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
  87. # q, k, v with shape (B, H * W, nheads, C)
  88. q, k, v = torch.unbind(qkv, 2)
  89. # Q pooling (for downsample at stage changes)
  90. if self.q_pool is not None:
  91. q = q.reshape(B, H, W, -1).permute(0, 3, 1, 2) # to BCHW for pool
  92. q = self.q_pool(q).permute(0, 2, 3, 1)
  93. H, W = q.shape[1:3] # downsampled shape
  94. q = q.reshape(B, H * W, self.num_heads, -1)
  95. # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
  96. q = q.transpose(1, 2)
  97. k = k.transpose(1, 2)
  98. v = v.transpose(1, 2)
  99. if self.fused_attn:
  100. x = F.scaled_dot_product_attention(q, k, v)
  101. else:
  102. q = q * self.scale
  103. attn = q @ k.transpose(-1, -2)
  104. attn = attn.softmax(dim=-1)
  105. x = attn @ v
  106. # Transpose back
  107. x = x.transpose(1, 2).reshape(B, H, W, -1)
  108. x = self.proj(x)
  109. return x
  110. class MultiScaleBlock(nn.Module):
  111. def __init__(
  112. self,
  113. dim: int,
  114. dim_out: int,
  115. num_heads: int,
  116. mlp_ratio: float = 4.0,
  117. q_stride: Optional[Tuple[int, int]] = None,
  118. norm_layer: Union[Type[nn.Module], str] = "LayerNorm",
  119. act_layer: Union[Type[nn.Module], str] = "GELU",
  120. window_size: int = 0,
  121. init_values: Optional[float] = None,
  122. drop_path: float = 0.0,
  123. device=None,
  124. dtype=None,
  125. ):
  126. dd = {'device': device, 'dtype': dtype}
  127. super().__init__()
  128. norm_layer = get_norm_layer(norm_layer)
  129. act_layer = get_act_layer(act_layer)
  130. self.window_size = to_2tuple(window_size)
  131. self.is_windowed = any(self.window_size)
  132. self.dim = dim
  133. self.dim_out = dim_out
  134. self.q_stride = q_stride
  135. if dim != dim_out:
  136. self.proj = nn.Linear(dim, dim_out, **dd)
  137. else:
  138. self.proj = nn.Identity()
  139. self.pool = None
  140. if self.q_stride:
  141. # note make a different instance for this Module so that it's not shared with attn module
  142. self.pool = nn.MaxPool2d(
  143. kernel_size=q_stride,
  144. stride=q_stride,
  145. ceil_mode=False,
  146. )
  147. self.norm1 = norm_layer(dim, **dd)
  148. self.attn = MultiScaleAttention(
  149. dim,
  150. dim_out,
  151. num_heads=num_heads,
  152. q_pool=deepcopy(self.pool),
  153. **dd,
  154. )
  155. self.ls1 = LayerScale(dim_out, init_values, **dd) if init_values is not None else nn.Identity()
  156. self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  157. self.norm2 = norm_layer(dim_out, **dd)
  158. self.mlp = Mlp(
  159. dim_out,
  160. int(dim_out * mlp_ratio),
  161. act_layer=act_layer,
  162. **dd,
  163. )
  164. self.ls2 = LayerScale(dim_out, init_values, **dd) if init_values is not None else nn.Identity()
  165. self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  166. def forward(self, x: torch.Tensor) -> torch.Tensor:
  167. shortcut = x # B, H, W, C
  168. x = self.norm1(x)
  169. # Skip connection
  170. if self.dim != self.dim_out:
  171. shortcut = self.proj(x)
  172. if self.pool is not None:
  173. shortcut = shortcut.permute(0, 3, 1, 2)
  174. shortcut = self.pool(shortcut).permute(0, 2, 3, 1)
  175. # Window partition
  176. window_size = self.window_size
  177. H, W = x.shape[1:3]
  178. Hp, Wp = H, W # keep torchscript happy
  179. if self.is_windowed:
  180. Hp, Wp, pad_h, pad_w = _calc_pad(H, W, window_size)
  181. x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
  182. x = window_partition(x, window_size)
  183. # Window Attention + Q Pooling (if stage change)
  184. x = self.attn(x)
  185. if self.q_stride is not None:
  186. # Shapes have changed due to Q pooling
  187. window_size = (self.window_size[0] // self.q_stride[0], self.window_size[1] // self.q_stride[1])
  188. H, W = shortcut.shape[1:3]
  189. Hp, Wp, pad_h, pad_w = _calc_pad(H, W, window_size)
  190. # Reverse window partition
  191. if self.is_windowed:
  192. x = window_unpartition(x, window_size, (Hp, Wp))
  193. x = x[:, :H, :W, :].contiguous() # unpad
  194. x = shortcut + self.drop_path1(self.ls1(x))
  195. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  196. return x
  197. class HieraPatchEmbed(nn.Module):
  198. """
  199. Image to Patch Embedding.
  200. """
  201. def __init__(
  202. self,
  203. kernel_size: Union[int, Tuple[int, int]] = (7, 7),
  204. stride: Union[int, Tuple[int, int]] = (4, 4),
  205. padding: Union[str, int, Tuple[int, int]] = (3, 3),
  206. in_chans: int = 3,
  207. embed_dim: int = 768,
  208. device=None,
  209. dtype=None,
  210. ):
  211. """
  212. Args:
  213. kernel_size: kernel size of the projection layer.
  214. stride: stride of the projection layer.
  215. padding: padding size of the projection layer.
  216. in_chans: Number of input image channels.
  217. embed_dim: Patch embedding dimension.
  218. """
  219. super().__init__()
  220. dd = {'device': device, 'dtype': dtype}
  221. self.proj = nn.Conv2d(
  222. in_chans,
  223. embed_dim,
  224. kernel_size=kernel_size,
  225. stride=stride,
  226. padding=padding,
  227. **dd,
  228. )
  229. def forward(self, x: torch.Tensor) -> torch.Tensor:
  230. x = self.proj(x)
  231. # B C H W -> B H W C
  232. x = x.permute(0, 2, 3, 1)
  233. return x
  234. class HieraDet(nn.Module):
  235. """
  236. Reference: https://arxiv.org/abs/2306.00989
  237. """
  238. def __init__(
  239. self,
  240. in_chans: int = 3,
  241. num_classes: int = 1000,
  242. global_pool: str = 'avg',
  243. embed_dim: int = 96, # initial embed dim
  244. num_heads: int = 1, # initial number of heads
  245. patch_kernel: Tuple[int, int] = (7, 7),
  246. patch_stride: Tuple[int, int] = (4, 4),
  247. patch_padding: Tuple[int, int] = (3, 3),
  248. patch_size: Optional[Tuple[int, int]] = None,
  249. q_pool: int = 3, # number of q_pool stages
  250. q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
  251. stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
  252. dim_mul: float = 2.0, # dim_mul factor at stage shift
  253. head_mul: float = 2.0, # head_mul factor at stage shift
  254. global_pos_size: Tuple[int, int] = (7, 7),
  255. # window size per stage, when not using global att.
  256. window_spec: Tuple[int, ...] = (
  257. 8,
  258. 4,
  259. 14,
  260. 7,
  261. ),
  262. # global attn in these blocks
  263. global_att_blocks: Tuple[int, ...] = (
  264. 12,
  265. 16,
  266. 20,
  267. ),
  268. init_values: Optional[float] = None,
  269. weight_init: str = '',
  270. fix_init: bool = True,
  271. head_init_scale: float = 0.001,
  272. drop_rate: float = 0.0,
  273. drop_path_rate: float = 0.0, # stochastic depth
  274. norm_layer: Union[Type[nn.Module], str] = "LayerNorm",
  275. act_layer: Union[Type[nn.Module], str] = "GELU",
  276. device=None,
  277. dtype=None,
  278. ):
  279. super().__init__()
  280. dd = {'device': device, 'dtype': dtype}
  281. norm_layer = get_norm_layer(norm_layer)
  282. act_layer = get_act_layer(act_layer)
  283. assert len(stages) == len(window_spec)
  284. self.grad_checkpointing = False
  285. self.num_classes = num_classes
  286. self.in_chans = in_chans
  287. self.window_spec = window_spec
  288. self.output_fmt = 'NHWC'
  289. depth = sum(stages)
  290. self.q_stride = q_stride
  291. self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
  292. assert 0 <= q_pool <= len(self.stage_ends[:-1])
  293. self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
  294. if patch_size is not None:
  295. # use a non-overlapping vit style patch embed
  296. self.patch_embed = PatchEmbed(
  297. img_size=None,
  298. patch_size=patch_size,
  299. in_chans=in_chans,
  300. embed_dim=embed_dim,
  301. output_fmt='NHWC',
  302. dynamic_img_pad=True,
  303. **dd,
  304. )
  305. else:
  306. self.patch_embed = HieraPatchEmbed(
  307. kernel_size=patch_kernel,
  308. stride=patch_stride,
  309. padding=patch_padding,
  310. in_chans=in_chans,
  311. embed_dim=embed_dim,
  312. **dd,
  313. )
  314. # Which blocks have global att?
  315. self.global_att_blocks = global_att_blocks
  316. # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
  317. self.global_pos_size = global_pos_size
  318. self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size, **dd))
  319. self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0], **dd))
  320. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  321. cur_stage = 0
  322. self.blocks = nn.Sequential()
  323. self.feature_info = []
  324. for i in range(depth):
  325. dim_out = embed_dim
  326. # lags by a block, so first block of
  327. # next stage uses an initial window size
  328. # of previous stage and final window size of current stage
  329. window_size = self.window_spec[cur_stage]
  330. if self.global_att_blocks is not None:
  331. window_size = 0 if i in self.global_att_blocks else window_size
  332. if i - 1 in self.stage_ends:
  333. dim_out = int(embed_dim * dim_mul)
  334. num_heads = int(num_heads * head_mul)
  335. cur_stage += 1
  336. block = MultiScaleBlock(
  337. dim=embed_dim,
  338. dim_out=dim_out,
  339. num_heads=num_heads,
  340. drop_path=dpr[i],
  341. q_stride=self.q_stride if i in self.q_pool_blocks else None,
  342. window_size=window_size,
  343. norm_layer=norm_layer,
  344. act_layer=act_layer,
  345. init_values=init_values,
  346. **dd,
  347. )
  348. embed_dim = dim_out
  349. self.blocks.append(block)
  350. if i in self.stage_ends:
  351. self.feature_info += [
  352. dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
  353. self.num_features = self.head_hidden_size = embed_dim
  354. self.head = ClNormMlpClassifierHead(
  355. embed_dim,
  356. num_classes,
  357. pool_type=global_pool,
  358. drop_rate=drop_rate,
  359. norm_layer=norm_layer,
  360. **dd,
  361. )
  362. # Initialize everything
  363. if self.pos_embed is not None:
  364. nn.init.trunc_normal_(self.pos_embed, std=0.02)
  365. if self.pos_embed_window is not None:
  366. nn.init.trunc_normal_(self.pos_embed_window, std=0.02)
  367. if weight_init != 'skip':
  368. init_fn = init_weight_jax if weight_init == 'jax' else init_weight_vit
  369. init_fn = partial(init_fn, classifier_name='head.fc')
  370. named_apply(init_fn, self)
  371. if fix_init:
  372. self.fix_init_weight()
  373. if isinstance(self.head, ClNormMlpClassifierHead) and isinstance(self.head.fc, nn.Linear):
  374. self.head.fc.weight.data.mul_(head_init_scale)
  375. self.head.fc.bias.data.mul_(head_init_scale)
  376. def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
  377. h, w = x.shape[1:3]
  378. window_embed = self.pos_embed_window
  379. pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
  380. tile_h = pos_embed.shape[-2] // window_embed.shape[-2]
  381. tile_w = pos_embed.shape[-1] // window_embed.shape[-1]
  382. pos_embed = pos_embed + window_embed.tile((tile_h, tile_w))
  383. pos_embed = pos_embed.permute(0, 2, 3, 1)
  384. return x + pos_embed
  385. def fix_init_weight(self):
  386. def rescale(param, _layer_id):
  387. param.div_(math.sqrt(2.0 * _layer_id))
  388. for layer_id, layer in enumerate(self.blocks):
  389. rescale(layer.attn.proj.weight.data, layer_id + 1)
  390. rescale(layer.mlp.fc2.weight.data, layer_id + 1)
  391. @torch.jit.ignore
  392. def no_weight_decay(self):
  393. return ['pos_embed', 'pos_embed_window']
  394. @torch.jit.ignore
  395. def group_matcher(self, coarse: bool = False) -> Dict:
  396. return dict(
  397. stem=r'^pos_embed|pos_embed_window|patch_embed',
  398. blocks=[(r'^blocks\.(\d+)', None)]
  399. )
  400. @torch.jit.ignore
  401. def set_grad_checkpointing(self, enable: bool = True) -> None:
  402. self.grad_checkpointing = enable
  403. @torch.jit.ignore
  404. def get_classifier(self):
  405. return self.head.fc
  406. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
  407. self.num_classes = num_classes
  408. self.head.reset(num_classes, pool_type=global_pool, reset_other=reset_other)
  409. def forward_intermediates(
  410. self,
  411. x: torch.Tensor,
  412. indices: Optional[Union[int, List[int]]] = None,
  413. norm: bool = False,
  414. stop_early: bool = True,
  415. output_fmt: str = 'NCHW',
  416. intermediates_only: bool = False,
  417. coarse: bool = True,
  418. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  419. """ Forward features that returns intermediates.
  420. Args:
  421. x: Input image tensor
  422. indices: Take last n blocks if int, all if None, select matching indices if sequence
  423. norm: Apply norm layer to all intermediates
  424. stop_early: Stop iterating over blocks when last desired intermediate hit
  425. output_fmt: Shape of intermediate feature outputs
  426. intermediates_only: Only return intermediate features
  427. coarse: Take coarse features (stage ends) if true, otherwise all block featrures
  428. Returns:
  429. """
  430. assert not norm, 'normalization of features not supported'
  431. assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
  432. if coarse:
  433. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  434. take_indices = [self.stage_ends[i] for i in take_indices]
  435. max_index = self.stage_ends[max_index]
  436. else:
  437. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  438. x = self.patch_embed(x)
  439. x = self._pos_embed(x)
  440. intermediates = []
  441. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  442. blocks = self.blocks
  443. else:
  444. blocks = self.blocks[:max_index + 1]
  445. for i, blk in enumerate(blocks):
  446. if self.grad_checkpointing and not torch.jit.is_scripting():
  447. x = checkpoint(blk, x)
  448. else:
  449. x = blk(x)
  450. if i in take_indices:
  451. x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x
  452. intermediates.append(x_out)
  453. if intermediates_only:
  454. return intermediates
  455. return x, intermediates
  456. def prune_intermediate_layers(
  457. self,
  458. indices: Union[int, List[int]] = 1,
  459. prune_norm: bool = False,
  460. prune_head: bool = True,
  461. coarse: bool = True,
  462. ):
  463. """ Prune layers not required for specified intermediates.
  464. """
  465. if coarse:
  466. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  467. max_index = self.stage_ends[max_index]
  468. else:
  469. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  470. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  471. if prune_head:
  472. self.head.reset(0, reset_other=prune_norm)
  473. return take_indices
  474. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  475. x = self.patch_embed(x) # BHWC
  476. x = self._pos_embed(x)
  477. for blk in self.blocks:
  478. if self.grad_checkpointing and not torch.jit.is_scripting():
  479. x = checkpoint(blk, x)
  480. else:
  481. x = blk(x)
  482. return x
  483. def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
  484. x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  485. return x
  486. def forward(self, x: torch.Tensor) -> torch.Tensor:
  487. x = self.forward_features(x)
  488. x = self.forward_head(x)
  489. return x
  490. # NOTE sam2 appears to use 1024x1024 for all models, but T, S, & B+ have windows that fit multiples of 224.
  491. def _cfg(url='', **kwargs):
  492. return {
  493. 'url': url,
  494. 'num_classes': 0, 'input_size': (3, 896, 896), 'pool_size': (28, 28),
  495. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'min_input_size': (3, 224, 224),
  496. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  497. 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
  498. 'license': 'apache-2.0',
  499. **kwargs
  500. }
  501. default_cfgs = generate_default_cfgs({
  502. "sam2_hiera_tiny.fb_r896": _cfg(
  503. # hf_hub_id='facebook/sam2-hiera-tiny',
  504. # hf_hub_filename='sam2_hiera_tiny.pt',
  505. hf_hub_id='timm/',
  506. ),
  507. "sam2_hiera_tiny.fb_r896_2pt1": _cfg(
  508. # hf_hub_id='facebook/sam2.1-hiera-tiny',
  509. # hf_hub_filename='sam2.1_hiera_tiny.pt',
  510. hf_hub_id='timm/',
  511. ),
  512. "sam2_hiera_small.fb_r896": _cfg(
  513. # hf_hub_id='facebook/sam2-hiera-small',
  514. # hf_hub_filename='sam2_hiera_small.pt',
  515. hf_hub_id='timm/',
  516. ),
  517. "sam2_hiera_small.fb_r896_2pt1": _cfg(
  518. # hf_hub_id='facebook/sam2.1-hiera-small',
  519. # hf_hub_filename='sam2.1_hiera_small.pt',
  520. hf_hub_id='timm/',
  521. ),
  522. "sam2_hiera_base_plus.fb_r896": _cfg(
  523. # hf_hub_id='facebook/sam2-hiera-base-plus',
  524. # hf_hub_filename='sam2_hiera_base_plus.pt',
  525. hf_hub_id='timm/',
  526. ),
  527. "sam2_hiera_base_plus.fb_r896_2pt1": _cfg(
  528. # hf_hub_id='facebook/sam2.1-hiera-base-plus',
  529. # hf_hub_filename='sam2.1_hiera_base_plus.pt',
  530. hf_hub_id='timm/',
  531. ),
  532. "sam2_hiera_large.fb_r1024": _cfg(
  533. # hf_hub_id='facebook/sam2-hiera-large',
  534. # hf_hub_filename='sam2_hiera_large.pt',
  535. hf_hub_id='timm/',
  536. min_input_size=(3, 256, 256),
  537. input_size=(3, 1024, 1024), pool_size=(32, 32),
  538. ),
  539. "sam2_hiera_large.fb_r1024_2pt1": _cfg(
  540. # hf_hub_id='facebook/sam2.1-hiera-large',
  541. # hf_hub_filename='sam2.1_hiera_large.pt',
  542. hf_hub_id='timm/',
  543. min_input_size=(3, 256, 256),
  544. input_size=(3, 1024, 1024), pool_size=(32, 32),
  545. ),
  546. "hieradet_small.untrained": _cfg(
  547. num_classes=1000,
  548. input_size=(3, 256, 256), pool_size=(8, 8),
  549. ),
  550. })
  551. def checkpoint_filter_fn(state_dict, model=None, prefix=''):
  552. state_dict = state_dict.get('model', state_dict)
  553. output = {}
  554. for k, v in state_dict.items():
  555. if k.startswith(prefix):
  556. k = k.replace(prefix, '')
  557. else:
  558. continue
  559. k = k.replace('mlp.layers.0', 'mlp.fc1')
  560. k = k.replace('mlp.layers.1', 'mlp.fc2')
  561. output[k] = v
  562. return output
  563. def _create_hiera_det(variant: str, pretrained: bool = False, **kwargs) -> HieraDet:
  564. out_indices = kwargs.pop('out_indices', 4)
  565. checkpoint_prefix = ''
  566. # if 'sam2' in variant:
  567. # # SAM2 pretrained weights have no classifier or final norm-layer (`head.norm`)
  568. # # This is workaround loading with num_classes=0 w/o removing norm-layer.
  569. # kwargs.setdefault('pretrained_strict', False)
  570. # checkpoint_prefix = 'image_encoder.trunk.'
  571. return build_model_with_cfg(
  572. HieraDet,
  573. variant,
  574. pretrained,
  575. pretrained_filter_fn=partial(checkpoint_filter_fn, prefix=checkpoint_prefix),
  576. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  577. **kwargs,
  578. )
  579. @register_model
  580. def sam2_hiera_tiny(pretrained=False, **kwargs):
  581. model_args = dict(stages=(1, 2, 7, 2), global_att_blocks=(5, 7, 9))
  582. return _create_hiera_det('sam2_hiera_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
  583. @register_model
  584. def sam2_hiera_small(pretrained=False, **kwargs):
  585. model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13))
  586. return _create_hiera_det('sam2_hiera_small', pretrained=pretrained, **dict(model_args, **kwargs))
  587. @register_model
  588. def sam2_hiera_base_plus(pretrained=False, **kwargs):
  589. model_args = dict(embed_dim=112, num_heads=2, global_pos_size=(14, 14))
  590. return _create_hiera_det('sam2_hiera_base_plus', pretrained=pretrained, **dict(model_args, **kwargs))
  591. @register_model
  592. def sam2_hiera_large(pretrained=False, **kwargs):
  593. model_args = dict(
  594. embed_dim=144,
  595. num_heads=2,
  596. stages=(2, 6, 36, 4),
  597. global_att_blocks=(23, 33, 43),
  598. window_spec=(8, 4, 16, 8),
  599. )
  600. return _create_hiera_det('sam2_hiera_large', pretrained=pretrained, **dict(model_args, **kwargs))
  601. @register_model
  602. def hieradet_small(pretrained=False, **kwargs):
  603. model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8), init_values=1e-5)
  604. return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs))
  605. # @register_model
  606. # def hieradet_base(pretrained=False, **kwargs):
  607. # model_args = dict(window_spec=(8, 4, 16, 8))
  608. # return _create_hiera_det('hieradet_base', pretrained=pretrained, **dict(model_args, **kwargs))