vision_transformer_sam.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. """ Vision Transformer (ViT) in PyTorch
  2. A PyTorch implement of Vision Transformers as described in:
  3. 'Exploring Plain Vision Transformer Backbones for Object Detection'
  4. - https://arxiv.org/abs/2203.16527
  5. 'Segment Anything Model (SAM)'
  6. - https://github.com/facebookresearch/segment-anything/
  7. """
  8. import logging
  9. from functools import partial
  10. from typing import Callable, List, Optional, Tuple, Type, Union
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  15. from timm.layers import (
  16. PatchEmbed,
  17. Mlp,
  18. DropPath,
  19. calculate_drop_path_rates,
  20. PatchDropout,
  21. LayerNorm2d,
  22. LayerScale,
  23. ClassifierHead,
  24. NormMlpClassifierHead,
  25. Format,
  26. resample_abs_pos_embed_nhwc,
  27. RotaryEmbeddingCat,
  28. apply_rot_embed_cat,
  29. to_2tuple,
  30. use_fused_attn,
  31. )
  32. from torch.jit import Final
  33. from ._builder import build_model_with_cfg
  34. from ._features import feature_take_indices
  35. from ._features_fx import register_notrace_function
  36. from ._manipulate import checkpoint, checkpoint_seq
  37. from ._registry import generate_default_cfgs, register_model
  38. # model_registry will add each entrypoint fn to this
  39. __all__ = ['VisionTransformerSAM']
  40. _logger = logging.getLogger(__name__)
  41. def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
  42. """
  43. Get relative positional embeddings according to the relative positions of
  44. query and key sizes.
  45. Args:
  46. q_size (int): size of query q.
  47. k_size (int): size of key k.
  48. rel_pos (Tensor): relative position embeddings (L, C).
  49. Returns:
  50. Extracted positional embeddings according to relative positions.
  51. """
  52. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  53. # Interpolate rel pos if needed.
  54. if rel_pos.shape[0] != max_rel_dist:
  55. # Interpolate rel pos.
  56. rel_pos_resized = F.interpolate(
  57. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  58. size=max_rel_dist,
  59. mode="linear",
  60. )
  61. rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
  62. else:
  63. rel_pos_resized = rel_pos
  64. # Scale the coords with short length if shapes for q and k are different.
  65. q_coords = torch.arange(q_size, dtype=torch.float32)[:, None] * max(k_size / q_size, 1.0)
  66. k_coords = torch.arange(k_size, dtype=torch.float32)[None, :] * max(q_size / k_size, 1.0)
  67. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  68. return rel_pos_resized[relative_coords.long()]
  69. register_notrace_function(get_rel_pos)
  70. def get_decomposed_rel_pos_bias(
  71. q: torch.Tensor,
  72. rel_pos_h: torch.Tensor,
  73. rel_pos_w: torch.Tensor,
  74. q_size: Tuple[int, int],
  75. k_size: Tuple[int, int],
  76. ) -> torch.Tensor:
  77. """
  78. Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
  79. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
  80. Args:
  81. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
  82. rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
  83. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
  84. q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
  85. k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
  86. Returns:
  87. bias (Tensor): attention bias to add to attention map
  88. """
  89. q_h, q_w = q_size
  90. k_h, k_w = k_size
  91. Rh = get_rel_pos(q_h, k_h, rel_pos_h)
  92. Rw = get_rel_pos(q_w, k_w, rel_pos_w)
  93. B, _, dim = q.shape
  94. r_q = q.reshape(B, q_h, q_w, dim)
  95. rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
  96. rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
  97. attn_bias = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
  98. return attn_bias.reshape(-1, q_h * q_w, k_h * k_w)
  99. class Attention(nn.Module):
  100. fused_attn: Final[bool]
  101. def __init__(
  102. self,
  103. dim: int,
  104. num_heads: int = 8,
  105. qkv_bias: bool = True,
  106. qk_norm: bool = False,
  107. attn_drop: float = 0.,
  108. proj_drop: float = 0.,
  109. norm_layer: Type[nn.Module] = nn.LayerNorm,
  110. use_rel_pos: bool = False,
  111. input_size: Optional[Tuple[int, int]] = None,
  112. rope: Optional[nn.Module] = None,
  113. device=None,
  114. dtype=None,
  115. ):
  116. dd = {'device': device, 'dtype': dtype}
  117. super().__init__()
  118. assert dim % num_heads == 0, 'dim should be divisible by num_heads'
  119. self.num_heads = num_heads
  120. self.head_dim = dim // num_heads
  121. self.scale = self.head_dim ** -0.5
  122. self.fused_attn = use_fused_attn()
  123. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  124. self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  125. self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  126. self.attn_drop = nn.Dropout(attn_drop)
  127. self.proj = nn.Linear(dim, dim, **dd)
  128. self.proj_drop = nn.Dropout(proj_drop)
  129. self.use_rel_pos = use_rel_pos
  130. if self.use_rel_pos:
  131. assert rope is None
  132. assert (
  133. input_size is not None
  134. ), "Input size must be provided if using relative positional encoding."
  135. # initialize relative positional embeddings
  136. self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim, **dd))
  137. self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim, **dd))
  138. self.rope = rope
  139. def forward(self, x):
  140. B, H, W, _ = x.shape
  141. N = H * W
  142. x = x.reshape(B, N, -1)
  143. qkv = self.qkv(x).view(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  144. # qkv with shape (3, B, nHead, H * W, C)
  145. q, k, v = qkv.reshape(3, B * self.num_heads, N, -1).unbind(0)
  146. # q, k, v with shape (B * nHead, H * W, C)
  147. q, k = self.q_norm(q), self.k_norm(k)
  148. if self.use_rel_pos:
  149. attn_bias = get_decomposed_rel_pos_bias(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
  150. else:
  151. attn_bias = None
  152. if self.rope is not None:
  153. rope = self.rope.get_embed()
  154. q = apply_rot_embed_cat(q, rope).type_as(v)
  155. k = apply_rot_embed_cat(k, rope).type_as(v)
  156. if self.fused_attn:
  157. x = torch.nn.functional.scaled_dot_product_attention(
  158. q, k, v,
  159. attn_mask=attn_bias,
  160. dropout_p=self.attn_drop.p if self.training else 0.,
  161. )
  162. else:
  163. q = q * self.scale
  164. attn = q @ k.transpose(-2, -1)
  165. if attn_bias is not None:
  166. attn = attn + attn_bias
  167. attn = attn.softmax(dim=-1)
  168. attn = self.attn_drop(attn)
  169. x = attn @ v
  170. x = x.view(B, self.num_heads, N, -1).transpose(1, 2).reshape(B, N, -1)
  171. x = self.proj(x)
  172. x = self.proj_drop(x)
  173. x = x.view(B, H, W, -1)
  174. return x
  175. class Block(nn.Module):
  176. def __init__(
  177. self,
  178. dim: int,
  179. num_heads: int,
  180. mlp_ratio: float = 4.,
  181. qkv_bias: bool = True,
  182. qk_norm: bool = False,
  183. proj_drop: float = 0.,
  184. attn_drop: float = 0.,
  185. init_values: Optional[float] = None,
  186. drop_path: float = 0.,
  187. act_layer: Type[nn.Module] = nn.GELU,
  188. norm_layer: Type[nn.Module] = nn.LayerNorm,
  189. mlp_layer: Type[nn.Module] = Mlp,
  190. use_rel_pos: bool = False,
  191. window_size: int = 0,
  192. input_size=None,
  193. rope=None,
  194. device=None,
  195. dtype=None,
  196. ):
  197. dd = {'device': device, 'dtype': dtype}
  198. super().__init__()
  199. self.window_size = window_size
  200. self.norm1 = norm_layer(dim, **dd)
  201. self.attn = Attention(
  202. dim,
  203. num_heads=num_heads,
  204. qkv_bias=qkv_bias,
  205. qk_norm=qk_norm,
  206. attn_drop=attn_drop,
  207. proj_drop=proj_drop,
  208. norm_layer=norm_layer,
  209. use_rel_pos=use_rel_pos,
  210. input_size=input_size if window_size == 0 else (window_size, window_size),
  211. rope=rope,
  212. **dd,
  213. )
  214. self.ls1 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()
  215. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  216. self.norm2 = norm_layer(dim, **dd)
  217. self.mlp = mlp_layer(
  218. in_features=dim,
  219. hidden_features=int(dim * mlp_ratio),
  220. act_layer=act_layer,
  221. drop=proj_drop,
  222. **dd,
  223. )
  224. self.ls2 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()
  225. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  226. def forward(self, x):
  227. B, H, W, _ = x.shape
  228. shortcut = x
  229. x = self.norm1(x)
  230. # Window partition
  231. pad_hw: Optional[Tuple[int, int]] = None
  232. if self.window_size > 0:
  233. x, pad_hw = window_partition(x, self.window_size)
  234. x = self.drop_path1(self.ls1(self.attn(x)))
  235. # Reverse window partition
  236. if self.window_size > 0:
  237. x = window_unpartition(x, self.window_size, (H, W), pad_hw)
  238. x = shortcut + x
  239. x = x.reshape(B, H * W, -1) # MLP is faster for N, L, C tensor
  240. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  241. x = x.reshape(B, H, W, -1)
  242. return x
  243. def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
  244. """
  245. Partition into non-overlapping windows with padding if needed.
  246. Args:
  247. x (tensor): input tokens with [B, H, W, C].
  248. window_size (int): window size.
  249. Returns:
  250. windows: windows after partition with [B * num_windows, window_size, window_size, C].
  251. (Hp, Wp): padded height and width before partition
  252. """
  253. B, H, W, C = x.shape
  254. pad_h = (window_size - H % window_size) % window_size
  255. pad_w = (window_size - W % window_size) % window_size
  256. x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
  257. Hp, Wp = H + pad_h, W + pad_w
  258. x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
  259. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  260. return windows, (Hp, Wp)
  261. def window_unpartition(
  262. windows: torch.Tensor, window_size: int, hw: Tuple[int, int], pad_hw: Optional[Tuple[int, int]] = None,
  263. ) -> torch.Tensor:
  264. """
  265. Window unpartition into original sequences and removing padding.
  266. Args:
  267. windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
  268. window_size (int): window size.
  269. pad_hw (Tuple): padded height and width (Hp, Wp).
  270. hw (Tuple): original height and width (H, W) before padding.
  271. Returns:
  272. x: unpartitioned sequences with [B, H, W, C].
  273. """
  274. Hp, Wp = pad_hw if pad_hw is not None else hw
  275. H, W = hw
  276. B = windows.shape[0] // (Hp * Wp // window_size // window_size)
  277. x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
  278. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
  279. x = x[:, :H, :W, :].contiguous()
  280. return x
  281. class VisionTransformerSAM(nn.Module):
  282. """ Vision Transformer for Segment-Anything Model(SAM)
  283. A PyTorch impl of : `Exploring Plain Vision Transformer Backbones for Object Detection` or `Segment Anything Model (SAM)`
  284. - https://arxiv.org/abs/2010.11929
  285. """
  286. def __init__(
  287. self,
  288. img_size: int = 1024,
  289. patch_size: int = 16,
  290. in_chans: int = 3,
  291. num_classes: int = 768,
  292. embed_dim: int = 768,
  293. depth: int = 12,
  294. num_heads: int = 12,
  295. mlp_ratio: float = 4.,
  296. qkv_bias: bool = True,
  297. qk_norm: bool = False,
  298. init_values: Optional[float] = None,
  299. pre_norm: bool = False,
  300. drop_rate: float = 0.,
  301. pos_drop_rate: float = 0.,
  302. patch_drop_rate: float = 0.,
  303. proj_drop_rate: float = 0.,
  304. attn_drop_rate: float = 0.,
  305. drop_path_rate: float = 0.,
  306. weight_init: str = '',
  307. embed_layer: Type[nn.Module] = partial(PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
  308. norm_layer: Optional[Type[nn.Module]] = nn.LayerNorm,
  309. act_layer: Optional[Type[nn.Module]] = nn.GELU,
  310. block_fn: Type[nn.Module] = Block,
  311. mlp_layer: Type[nn.Module] = Mlp,
  312. use_abs_pos: bool = True,
  313. use_rel_pos: bool = False,
  314. use_rope: bool = False,
  315. window_size: int = 14,
  316. global_attn_indexes: Tuple[int, ...] = (),
  317. neck_chans: int = 256,
  318. global_pool: str = 'avg',
  319. head_hidden_size: Optional[int] = None,
  320. ref_feat_shape: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None,
  321. device=None,
  322. dtype=None,
  323. ):
  324. """
  325. Args:
  326. img_size: Input image size.
  327. patch_size: Patch size.
  328. in_chans: Number of image input channels.
  329. num_classes: Number of classes for classification head.
  330. global_pool: Type of global pooling for final sequence (default: 'token').
  331. embed_dim: Transformer embedding dimension.
  332. depth: Depth of transformer.
  333. num_heads: Number of attention heads.
  334. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  335. qkv_bias: Enable bias for qkv projections if True.
  336. init_values: Layer-scale init values (layer-scale enabled if not None).
  337. drop_rate: Head dropout rate.
  338. pos_drop_rate: Position embedding dropout rate.
  339. attn_drop_rate: Attention dropout rate.
  340. drop_path_rate: Stochastic depth rate.
  341. weight_init: Weight initialization scheme.
  342. embed_layer: Patch embedding layer.
  343. norm_layer: Normalization layer.
  344. act_layer: MLP activation layer.
  345. block_fn: Transformer block layer.
  346. use_abs_pos: If True, use absolute positional embeddings.
  347. use_rel_pos: If True, add relative positional embeddings to the attention map.
  348. use_rope: If True, add rotary position embeddings to q/k in attention block.
  349. window_size: Window size for window attention blocks. If 0, not use window attention.
  350. global_attn_indexes: Indexes for blocks using global attention. Used when window_size > 0.
  351. global_pool: Global pooling type.
  352. head_hidden_size: If set, use NormMlpHead
  353. ref_feat_shape: Tuple of reference feature shapes for ROPE, (global, local)
  354. """
  355. super().__init__()
  356. dd = {'device': device, 'dtype': dtype}
  357. norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
  358. act_layer = act_layer or nn.GELU
  359. self.num_classes = num_classes
  360. self.in_chans = in_chans
  361. self.global_pool = global_pool
  362. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
  363. self.grad_checkpointing = False
  364. self.patch_embed = embed_layer(
  365. img_size=img_size,
  366. patch_size=patch_size,
  367. in_chans=in_chans,
  368. embed_dim=embed_dim,
  369. bias=not pre_norm, # disable bias if pre-norm is used
  370. **dd,
  371. )
  372. grid_size = self.patch_embed.grid_size
  373. r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
  374. if use_abs_pos:
  375. # Initialize absolute positional embedding with pretrain image size.
  376. self.pos_embed = nn.Parameter(torch.zeros(1, grid_size[0], grid_size[1], embed_dim, **dd))
  377. else:
  378. self.pos_embed = None
  379. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  380. if patch_drop_rate > 0:
  381. self.patch_drop = PatchDropout(
  382. patch_drop_rate,
  383. num_prefix_tokens=0,
  384. )
  385. else:
  386. self.patch_drop = nn.Identity()
  387. self.norm_pre = norm_layer(embed_dim, **dd) if pre_norm else nn.Identity()
  388. if use_rope:
  389. assert not use_rel_pos, "ROPE and relative pos embeddings should not be enabled at same time"
  390. if ref_feat_shape is not None:
  391. assert len(ref_feat_shape) == 2
  392. ref_feat_shape_global = to_2tuple(ref_feat_shape[0])
  393. ref_feat_shape_window = to_2tuple(ref_feat_shape[1])
  394. else:
  395. ref_feat_shape_global = ref_feat_shape_window = None
  396. self.rope_global = RotaryEmbeddingCat(
  397. embed_dim // num_heads,
  398. in_pixels=False,
  399. feat_shape=grid_size,
  400. ref_feat_shape=ref_feat_shape_global,
  401. )
  402. self.rope_window = RotaryEmbeddingCat(
  403. embed_dim // num_heads,
  404. in_pixels=False,
  405. feat_shape=to_2tuple(window_size),
  406. ref_feat_shape=ref_feat_shape_window,
  407. )
  408. else:
  409. self.rope_global = None
  410. self.rope_window = None
  411. # stochastic depth decay rule
  412. dpr = calculate_drop_path_rates(drop_path_rate, depth)
  413. self.blocks = nn.Sequential(*[
  414. block_fn(
  415. dim=embed_dim,
  416. num_heads=num_heads,
  417. mlp_ratio=mlp_ratio,
  418. qkv_bias=qkv_bias,
  419. qk_norm=qk_norm,
  420. init_values=init_values,
  421. proj_drop=proj_drop_rate,
  422. attn_drop=attn_drop_rate,
  423. drop_path=dpr[i],
  424. norm_layer=norm_layer,
  425. act_layer=act_layer,
  426. mlp_layer=mlp_layer,
  427. use_rel_pos=use_rel_pos,
  428. window_size=window_size if i not in global_attn_indexes else 0,
  429. input_size=grid_size,
  430. rope=self.rope_window if i not in global_attn_indexes else self.rope_global,
  431. **dd,
  432. )
  433. for i in range(depth)])
  434. self.feature_info = [
  435. dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
  436. if neck_chans:
  437. self.neck = nn.Sequential(
  438. nn.Conv2d(
  439. embed_dim,
  440. neck_chans,
  441. kernel_size=1,
  442. bias=False,
  443. **dd,
  444. ),
  445. LayerNorm2d(neck_chans, **dd),
  446. nn.Conv2d(
  447. neck_chans,
  448. neck_chans,
  449. kernel_size=3,
  450. padding=1,
  451. bias=False,
  452. **dd,
  453. ),
  454. LayerNorm2d(neck_chans, **dd),
  455. )
  456. self.num_features = neck_chans
  457. else:
  458. if head_hidden_size:
  459. self.neck = nn.Identity()
  460. else:
  461. # should have a final norm with standard ClassifierHead
  462. self.neck = LayerNorm2d(embed_dim, **dd)
  463. neck_chans = embed_dim
  464. # Classifier Head
  465. if head_hidden_size:
  466. self.head = NormMlpClassifierHead(
  467. neck_chans,
  468. num_classes,
  469. hidden_size=head_hidden_size,
  470. pool_type=global_pool,
  471. drop_rate=drop_rate,
  472. **dd,
  473. )
  474. else:
  475. self.head = ClassifierHead(
  476. neck_chans,
  477. num_classes,
  478. pool_type=global_pool,
  479. drop_rate=drop_rate,
  480. **dd,
  481. )
  482. @torch.jit.ignore
  483. def no_weight_decay(self):
  484. return {'pos_embed', 'dist_token'}
  485. @torch.jit.ignore
  486. def group_matcher(self, coarse=False):
  487. return dict(
  488. stem=r'^pos_embed|patch_embed', # stem and embed
  489. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  490. )
  491. @torch.jit.ignore
  492. def set_grad_checkpointing(self, enable=True):
  493. self.grad_checkpointing = enable
  494. @torch.jit.ignore
  495. def get_classifier(self) -> nn.Module:
  496. return self.head
  497. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  498. self.num_classes = num_classes
  499. self.head.reset(num_classes, global_pool)
  500. def forward_intermediates(
  501. self,
  502. x: torch.Tensor,
  503. indices: Optional[Union[int, List[int]]] = None,
  504. norm: bool = False,
  505. stop_early: bool = False,
  506. output_fmt: str = 'NCHW',
  507. intermediates_only: bool = False,
  508. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  509. """ Forward features that returns intermediates.
  510. Args:
  511. x: Input image tensor
  512. indices: Take last n blocks if int, all if None, select matching indices if sequence
  513. norm: Apply norm layer to all intermediates
  514. stop_early: Stop iterating over blocks when last desired intermediate hit
  515. output_fmt: Shape of intermediate feature outputs
  516. intermediates_only: Only return intermediate features
  517. Returns:
  518. """
  519. assert output_fmt == 'NCHW', 'Output shape for ViT-SAM must be NCHW.'
  520. intermediates = []
  521. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  522. # forward pass, collect intermediates
  523. x = self.patch_embed(x)
  524. if self.pos_embed is not None:
  525. # dynamically resize abs pos embedding if needed
  526. x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3])
  527. x = self.pos_drop(x)
  528. x = self.patch_drop(x)
  529. x = self.norm_pre(x)
  530. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  531. blocks = self.blocks
  532. else:
  533. blocks = self.blocks[:max_index + 1]
  534. for i, blk in enumerate(blocks):
  535. if self.grad_checkpointing and not torch.jit.is_scripting():
  536. x = checkpoint(blk, x)
  537. else:
  538. x = blk(x)
  539. if i in take_indices:
  540. # make output BCHW
  541. if norm:
  542. # norm is intertwined with neck convs so apply both, changes the dim
  543. # FIXME only apply to final? Need experiments
  544. intermediates.append(self.neck(x.permute(0, 3, 1, 2)))
  545. else:
  546. intermediates.append(x.permute(0, 3, 1, 2))
  547. if intermediates_only:
  548. return intermediates
  549. x = self.neck(x.permute(0, 3, 1, 2))
  550. return x, intermediates
  551. def prune_intermediate_layers(
  552. self,
  553. indices: Optional[Union[int, List[int]]] = None,
  554. prune_norm: bool = False,
  555. prune_head: bool = True,
  556. ):
  557. """ Prune layers not required for specified intermediates.
  558. """
  559. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  560. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  561. if prune_norm:
  562. # neck is being treated as equivalent to final norm here
  563. self.neck = nn.Identity()
  564. if prune_head:
  565. self.reset_classifier(0, '')
  566. return take_indices
  567. def forward_features(self, x):
  568. x = self.patch_embed(x)
  569. if self.pos_embed is not None:
  570. # dynamically resize abs pos embedding if needed
  571. x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3])
  572. x = self.pos_drop(x)
  573. x = self.patch_drop(x)
  574. x = self.norm_pre(x)
  575. if self.grad_checkpointing and not torch.jit.is_scripting():
  576. x = checkpoint_seq(self.blocks, x)
  577. else:
  578. x = self.blocks(x)
  579. x = self.neck(x.permute(0, 3, 1, 2))
  580. return x
  581. def forward_head(self, x, pre_logits: bool = False):
  582. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  583. def forward(self, x):
  584. x = self.forward_features(x)
  585. x = self.forward_head(x)
  586. return x
  587. def checkpoint_filter_fn(
  588. state_dict,
  589. model,
  590. ):
  591. """ Remap SAM checkpoints -> timm """
  592. sam_checkpoint = 'image_encoder.patch_embed.proj.weight' in state_dict
  593. out_dict = {}
  594. for k, v in state_dict.items():
  595. if k.startswith('image_encoder.'):
  596. k = k[14:]
  597. k = k.replace('mlp.lin', 'mlp.fc')
  598. else:
  599. if sam_checkpoint:
  600. continue
  601. out_dict[k] = v
  602. return out_dict
  603. def _cfg(url='', **kwargs):
  604. return {
  605. 'url': url,
  606. 'num_classes': 1000, 'input_size': (3, 1024, 1024), 'pool_size': None,
  607. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  608. 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
  609. 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
  610. **kwargs
  611. }
  612. default_cfgs = generate_default_cfgs({
  613. # Segment-Anything Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only)
  614. 'samvit_base_patch16.sa1b': _cfg(
  615. url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
  616. hf_hub_id='timm/',
  617. license='apache-2.0',
  618. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  619. input_size=(3, 1024, 1024), crop_pct=1.0),
  620. 'samvit_large_patch16.sa1b': _cfg(
  621. url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
  622. hf_hub_id='timm/',
  623. license='apache-2.0',
  624. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  625. input_size=(3, 1024, 1024), crop_pct=1.0),
  626. 'samvit_huge_patch16.sa1b': _cfg(
  627. url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
  628. hf_hub_id='timm/',
  629. license='apache-2.0',
  630. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  631. input_size=(3, 1024, 1024), crop_pct=1.0),
  632. 'samvit_base_patch16_224': _cfg(
  633. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=1000,
  634. input_size=(3, 224, 224), crop_pct=0.9),
  635. })
  636. def _create_vision_transformer(variant, pretrained=False, **kwargs):
  637. out_indices = kwargs.pop('out_indices', 3)
  638. return build_model_with_cfg(
  639. VisionTransformerSAM,
  640. variant,
  641. pretrained,
  642. pretrained_filter_fn=checkpoint_filter_fn,
  643. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  644. **kwargs,
  645. )
  646. @register_model
  647. def samvit_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
  648. """ ViT-B/16 for Segment-Anything
  649. """
  650. model_args = dict(
  651. patch_size=16, embed_dim=768, depth=12, num_heads=12, global_attn_indexes=[2, 5, 8, 11],
  652. window_size=14, use_rel_pos=True, img_size=1024,
  653. )
  654. model = _create_vision_transformer(
  655. 'samvit_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
  656. return model
  657. @register_model
  658. def samvit_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
  659. """ ViT-L/16 for Segment-Anything
  660. """
  661. model_args = dict(
  662. patch_size=16, embed_dim=1024, depth=24, num_heads=16, global_attn_indexes=[5, 11, 17, 23],
  663. window_size=14, use_rel_pos=True, img_size=1024,
  664. )
  665. model = _create_vision_transformer(
  666. 'samvit_large_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
  667. return model
  668. @register_model
  669. def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
  670. """ ViT-H/16 for Segment-Anything
  671. """
  672. model_args = dict(
  673. patch_size=16, embed_dim=1280, depth=32, num_heads=16, global_attn_indexes=[7, 15, 23, 31],
  674. window_size=14, use_rel_pos=True, img_size=1024,
  675. )
  676. model = _create_vision_transformer(
  677. 'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
  678. return model
  679. @register_model
  680. def samvit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformerSAM:
  681. """ ViT-B/16 based on samvit arch
  682. """
  683. model_args = dict(
  684. patch_size=16, embed_dim=768, depth=12, num_heads=12, global_attn_indexes=[2, 5, 8, 11],
  685. window_size=14, use_rel_pos=True, use_abs_pos=False, img_size=224, neck_chans=None,
  686. )
  687. model = _create_vision_transformer(
  688. 'samvit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  689. return model