mvit.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. import math
  2. from collections.abc import Sequence
  3. from dataclasses import dataclass
  4. from functools import partial
  5. from typing import Any, Callable, Optional
  6. import torch
  7. import torch.fx
  8. import torch.nn as nn
  9. from ...ops import MLP, StochasticDepth
  10. from ...transforms._presets import VideoClassification
  11. from ...utils import _log_api_usage_once
  12. from .._api import register_model, Weights, WeightsEnum
  13. from .._meta import _KINETICS400_CATEGORIES
  14. from .._utils import _ovewrite_named_param, handle_legacy_interface
  15. __all__ = [
  16. "MViT",
  17. "MViT_V1_B_Weights",
  18. "mvit_v1_b",
  19. "MViT_V2_S_Weights",
  20. "mvit_v2_s",
  21. ]
  22. @dataclass
  23. class MSBlockConfig:
  24. num_heads: int
  25. input_channels: int
  26. output_channels: int
  27. kernel_q: list[int]
  28. kernel_kv: list[int]
  29. stride_q: list[int]
  30. stride_kv: list[int]
  31. def _prod(s: Sequence[int]) -> int:
  32. product = 1
  33. for v in s:
  34. product *= v
  35. return product
  36. def _unsqueeze(x: torch.Tensor, target_dim: int, expand_dim: int) -> tuple[torch.Tensor, int]:
  37. tensor_dim = x.dim()
  38. if tensor_dim == target_dim - 1:
  39. x = x.unsqueeze(expand_dim)
  40. elif tensor_dim != target_dim:
  41. raise ValueError(f"Unsupported input dimension {x.shape}")
  42. return x, tensor_dim
  43. def _squeeze(x: torch.Tensor, target_dim: int, expand_dim: int, tensor_dim: int) -> torch.Tensor:
  44. if tensor_dim == target_dim - 1:
  45. x = x.squeeze(expand_dim)
  46. return x
  47. torch.fx.wrap("_unsqueeze")
  48. torch.fx.wrap("_squeeze")
  49. class Pool(nn.Module):
  50. def __init__(
  51. self,
  52. pool: nn.Module,
  53. norm: Optional[nn.Module],
  54. activation: Optional[nn.Module] = None,
  55. norm_before_pool: bool = False,
  56. ) -> None:
  57. super().__init__()
  58. self.pool = pool
  59. layers = []
  60. if norm is not None:
  61. layers.append(norm)
  62. if activation is not None:
  63. layers.append(activation)
  64. self.norm_act = nn.Sequential(*layers) if layers else None
  65. self.norm_before_pool = norm_before_pool
  66. def forward(self, x: torch.Tensor, thw: tuple[int, int, int]) -> tuple[torch.Tensor, tuple[int, int, int]]:
  67. x, tensor_dim = _unsqueeze(x, 4, 1)
  68. # Separate the class token and reshape the input
  69. class_token, x = torch.tensor_split(x, indices=(1,), dim=2)
  70. x = x.transpose(2, 3)
  71. B, N, C = x.shape[:3]
  72. x = x.reshape((B * N, C) + thw).contiguous()
  73. # normalizing prior pooling is useful when we use BN which can be absorbed to speed up inference
  74. if self.norm_before_pool and self.norm_act is not None:
  75. x = self.norm_act(x)
  76. # apply the pool on the input and add back the token
  77. x = self.pool(x)
  78. T, H, W = x.shape[2:]
  79. x = x.reshape(B, N, C, -1).transpose(2, 3)
  80. x = torch.cat((class_token, x), dim=2)
  81. if not self.norm_before_pool and self.norm_act is not None:
  82. x = self.norm_act(x)
  83. x = _squeeze(x, 4, 1, tensor_dim)
  84. return x, (T, H, W)
  85. def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor:
  86. if embedding.shape[0] == d:
  87. return embedding
  88. return (
  89. nn.functional.interpolate(
  90. embedding.permute(1, 0).unsqueeze(0),
  91. size=d,
  92. mode="linear",
  93. )
  94. .squeeze(0)
  95. .permute(1, 0)
  96. )
  97. def _add_rel_pos(
  98. attn: torch.Tensor,
  99. q: torch.Tensor,
  100. q_thw: tuple[int, int, int],
  101. k_thw: tuple[int, int, int],
  102. rel_pos_h: torch.Tensor,
  103. rel_pos_w: torch.Tensor,
  104. rel_pos_t: torch.Tensor,
  105. ) -> torch.Tensor:
  106. # Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932
  107. q_t, q_h, q_w = q_thw
  108. k_t, k_h, k_w = k_thw
  109. dh = int(2 * max(q_h, k_h) - 1)
  110. dw = int(2 * max(q_w, k_w) - 1)
  111. dt = int(2 * max(q_t, k_t) - 1)
  112. # Scale up rel pos if shapes for q and k are different.
  113. q_h_ratio = max(k_h / q_h, 1.0)
  114. k_h_ratio = max(q_h / k_h, 1.0)
  115. dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio
  116. q_w_ratio = max(k_w / q_w, 1.0)
  117. k_w_ratio = max(q_w / k_w, 1.0)
  118. dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio
  119. q_t_ratio = max(k_t / q_t, 1.0)
  120. k_t_ratio = max(q_t / k_t, 1.0)
  121. dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio
  122. # Interpolate rel pos if needed.
  123. rel_pos_h = _interpolate(rel_pos_h, dh)
  124. rel_pos_w = _interpolate(rel_pos_w, dw)
  125. rel_pos_t = _interpolate(rel_pos_t, dt)
  126. Rh = rel_pos_h[dist_h.long()]
  127. Rw = rel_pos_w[dist_w.long()]
  128. Rt = rel_pos_t[dist_t.long()]
  129. B, n_head, _, dim = q.shape
  130. r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim)
  131. rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h]
  132. rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w]
  133. # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
  134. r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim)
  135. # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
  136. rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
  137. # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
  138. rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)
  139. # Combine rel pos.
  140. rel_pos = (
  141. rel_h_q[:, :, :, :, :, None, :, None]
  142. + rel_w_q[:, :, :, :, :, None, None, :]
  143. + rel_q_t[:, :, :, :, :, :, None, None]
  144. ).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w)
  145. # Add it to attention
  146. attn[:, :, 1:, 1:] += rel_pos
  147. return attn
  148. def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_with_cls_embed: bool):
  149. if residual_with_cls_embed:
  150. x.add_(shortcut)
  151. else:
  152. x[:, :, 1:, :] += shortcut[:, :, 1:, :]
  153. return x
  154. torch.fx.wrap("_add_rel_pos")
  155. torch.fx.wrap("_add_shortcut")
  156. class MultiscaleAttention(nn.Module):
  157. def __init__(
  158. self,
  159. input_size: list[int],
  160. embed_dim: int,
  161. output_dim: int,
  162. num_heads: int,
  163. kernel_q: list[int],
  164. kernel_kv: list[int],
  165. stride_q: list[int],
  166. stride_kv: list[int],
  167. residual_pool: bool,
  168. residual_with_cls_embed: bool,
  169. rel_pos_embed: bool,
  170. dropout: float = 0.0,
  171. norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
  172. ) -> None:
  173. super().__init__()
  174. self.embed_dim = embed_dim
  175. self.output_dim = output_dim
  176. self.num_heads = num_heads
  177. self.head_dim = output_dim // num_heads
  178. self.scaler = 1.0 / math.sqrt(self.head_dim)
  179. self.residual_pool = residual_pool
  180. self.residual_with_cls_embed = residual_with_cls_embed
  181. self.qkv = nn.Linear(embed_dim, 3 * output_dim)
  182. layers: list[nn.Module] = [nn.Linear(output_dim, output_dim)]
  183. if dropout > 0.0:
  184. layers.append(nn.Dropout(dropout, inplace=True))
  185. self.project = nn.Sequential(*layers)
  186. self.pool_q: Optional[nn.Module] = None
  187. if _prod(kernel_q) > 1 or _prod(stride_q) > 1:
  188. padding_q = [int(q // 2) for q in kernel_q]
  189. self.pool_q = Pool(
  190. nn.Conv3d(
  191. self.head_dim,
  192. self.head_dim,
  193. kernel_q, # type: ignore[arg-type]
  194. stride=stride_q, # type: ignore[arg-type]
  195. padding=padding_q, # type: ignore[arg-type]
  196. groups=self.head_dim,
  197. bias=False,
  198. ),
  199. norm_layer(self.head_dim),
  200. )
  201. self.pool_k: Optional[nn.Module] = None
  202. self.pool_v: Optional[nn.Module] = None
  203. if _prod(kernel_kv) > 1 or _prod(stride_kv) > 1:
  204. padding_kv = [int(kv // 2) for kv in kernel_kv]
  205. self.pool_k = Pool(
  206. nn.Conv3d(
  207. self.head_dim,
  208. self.head_dim,
  209. kernel_kv, # type: ignore[arg-type]
  210. stride=stride_kv, # type: ignore[arg-type]
  211. padding=padding_kv, # type: ignore[arg-type]
  212. groups=self.head_dim,
  213. bias=False,
  214. ),
  215. norm_layer(self.head_dim),
  216. )
  217. self.pool_v = Pool(
  218. nn.Conv3d(
  219. self.head_dim,
  220. self.head_dim,
  221. kernel_kv, # type: ignore[arg-type]
  222. stride=stride_kv, # type: ignore[arg-type]
  223. padding=padding_kv, # type: ignore[arg-type]
  224. groups=self.head_dim,
  225. bias=False,
  226. ),
  227. norm_layer(self.head_dim),
  228. )
  229. self.rel_pos_h: Optional[nn.Parameter] = None
  230. self.rel_pos_w: Optional[nn.Parameter] = None
  231. self.rel_pos_t: Optional[nn.Parameter] = None
  232. if rel_pos_embed:
  233. size = max(input_size[1:])
  234. q_size = size // stride_q[1] if len(stride_q) > 0 else size
  235. kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
  236. spatial_dim = 2 * max(q_size, kv_size) - 1
  237. temporal_dim = 2 * input_size[0] - 1
  238. self.rel_pos_h = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
  239. self.rel_pos_w = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
  240. self.rel_pos_t = nn.Parameter(torch.zeros(temporal_dim, self.head_dim))
  241. nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
  242. nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
  243. nn.init.trunc_normal_(self.rel_pos_t, std=0.02)
  244. def forward(self, x: torch.Tensor, thw: tuple[int, int, int]) -> tuple[torch.Tensor, tuple[int, int, int]]:
  245. B, N, C = x.shape
  246. q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2)
  247. if self.pool_k is not None:
  248. k, k_thw = self.pool_k(k, thw)
  249. else:
  250. k_thw = thw
  251. if self.pool_v is not None:
  252. v = self.pool_v(v, thw)[0]
  253. if self.pool_q is not None:
  254. q, thw = self.pool_q(q, thw)
  255. attn = torch.matmul(self.scaler * q, k.transpose(2, 3))
  256. if self.rel_pos_h is not None and self.rel_pos_w is not None and self.rel_pos_t is not None:
  257. attn = _add_rel_pos(
  258. attn,
  259. q,
  260. thw,
  261. k_thw,
  262. self.rel_pos_h,
  263. self.rel_pos_w,
  264. self.rel_pos_t,
  265. )
  266. attn = attn.softmax(dim=-1)
  267. x = torch.matmul(attn, v)
  268. if self.residual_pool:
  269. _add_shortcut(x, q, self.residual_with_cls_embed)
  270. x = x.transpose(1, 2).reshape(B, -1, self.output_dim)
  271. x = self.project(x)
  272. return x, thw
  273. class MultiscaleBlock(nn.Module):
  274. def __init__(
  275. self,
  276. input_size: list[int],
  277. cnf: MSBlockConfig,
  278. residual_pool: bool,
  279. residual_with_cls_embed: bool,
  280. rel_pos_embed: bool,
  281. proj_after_attn: bool,
  282. dropout: float = 0.0,
  283. stochastic_depth_prob: float = 0.0,
  284. norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
  285. ) -> None:
  286. super().__init__()
  287. self.proj_after_attn = proj_after_attn
  288. self.pool_skip: Optional[nn.Module] = None
  289. if _prod(cnf.stride_q) > 1:
  290. kernel_skip = [s + 1 if s > 1 else s for s in cnf.stride_q]
  291. padding_skip = [int(k // 2) for k in kernel_skip]
  292. self.pool_skip = Pool(
  293. nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type]
  294. )
  295. attn_dim = cnf.output_channels if proj_after_attn else cnf.input_channels
  296. self.norm1 = norm_layer(cnf.input_channels)
  297. self.norm2 = norm_layer(attn_dim)
  298. self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d)
  299. self.attn = MultiscaleAttention(
  300. input_size,
  301. cnf.input_channels,
  302. attn_dim,
  303. cnf.num_heads,
  304. kernel_q=cnf.kernel_q,
  305. kernel_kv=cnf.kernel_kv,
  306. stride_q=cnf.stride_q,
  307. stride_kv=cnf.stride_kv,
  308. rel_pos_embed=rel_pos_embed,
  309. residual_pool=residual_pool,
  310. residual_with_cls_embed=residual_with_cls_embed,
  311. dropout=dropout,
  312. norm_layer=norm_layer,
  313. )
  314. self.mlp = MLP(
  315. attn_dim,
  316. [4 * attn_dim, cnf.output_channels],
  317. activation_layer=nn.GELU,
  318. dropout=dropout,
  319. inplace=None,
  320. )
  321. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  322. self.project: Optional[nn.Module] = None
  323. if cnf.input_channels != cnf.output_channels:
  324. self.project = nn.Linear(cnf.input_channels, cnf.output_channels)
  325. def forward(self, x: torch.Tensor, thw: tuple[int, int, int]) -> tuple[torch.Tensor, tuple[int, int, int]]:
  326. x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x)
  327. x_attn, thw_new = self.attn(x_norm1, thw)
  328. x = x if self.project is None or not self.proj_after_attn else self.project(x_norm1)
  329. x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0]
  330. x = x_skip + self.stochastic_depth(x_attn)
  331. x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x)
  332. x_proj = x if self.project is None or self.proj_after_attn else self.project(x_norm2)
  333. return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new
  334. class PositionalEncoding(nn.Module):
  335. def __init__(self, embed_size: int, spatial_size: tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None:
  336. super().__init__()
  337. self.spatial_size = spatial_size
  338. self.temporal_size = temporal_size
  339. self.class_token = nn.Parameter(torch.zeros(embed_size))
  340. self.spatial_pos: Optional[nn.Parameter] = None
  341. self.temporal_pos: Optional[nn.Parameter] = None
  342. self.class_pos: Optional[nn.Parameter] = None
  343. if not rel_pos_embed:
  344. self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size))
  345. self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size))
  346. self.class_pos = nn.Parameter(torch.zeros(embed_size))
  347. def forward(self, x: torch.Tensor) -> torch.Tensor:
  348. class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1)
  349. x = torch.cat((class_token, x), dim=1)
  350. if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None:
  351. hw_size, embed_size = self.spatial_pos.shape
  352. pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0)
  353. pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size))
  354. pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0)
  355. x.add_(pos_embedding)
  356. return x
  357. class MViT(nn.Module):
  358. def __init__(
  359. self,
  360. spatial_size: tuple[int, int],
  361. temporal_size: int,
  362. block_setting: Sequence[MSBlockConfig],
  363. residual_pool: bool,
  364. residual_with_cls_embed: bool,
  365. rel_pos_embed: bool,
  366. proj_after_attn: bool,
  367. dropout: float = 0.5,
  368. attention_dropout: float = 0.0,
  369. stochastic_depth_prob: float = 0.0,
  370. num_classes: int = 400,
  371. block: Optional[Callable[..., nn.Module]] = None,
  372. norm_layer: Optional[Callable[..., nn.Module]] = None,
  373. patch_embed_kernel: tuple[int, int, int] = (3, 7, 7),
  374. patch_embed_stride: tuple[int, int, int] = (2, 4, 4),
  375. patch_embed_padding: tuple[int, int, int] = (1, 3, 3),
  376. ) -> None:
  377. """
  378. MViT main class.
  379. Args:
  380. spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``.
  381. temporal_size (int): The temporal size ``T`` of the input.
  382. block_setting (sequence of MSBlockConfig): The Network structure.
  383. residual_pool (bool): If True, use MViTv2 pooling residual connection.
  384. residual_with_cls_embed (bool): If True, the addition on the residual connection will include
  385. the class embedding.
  386. rel_pos_embed (bool): If True, use MViTv2's relative positional embeddings.
  387. proj_after_attn (bool): If True, apply the projection after the attention.
  388. dropout (float): Dropout rate. Default: 0.0.
  389. attention_dropout (float): Attention dropout rate. Default: 0.0.
  390. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
  391. num_classes (int): The number of classes.
  392. block (callable, optional): Module specifying the layer which consists of the attention and mlp.
  393. norm_layer (callable, optional): Module specifying the normalization layer to use.
  394. patch_embed_kernel (tuple of ints): The kernel of the convolution that patchifies the input.
  395. patch_embed_stride (tuple of ints): The stride of the convolution that patchifies the input.
  396. patch_embed_padding (tuple of ints): The padding of the convolution that patchifies the input.
  397. """
  398. super().__init__()
  399. # This implementation employs a different parameterization scheme than the one used at PyTorch Video:
  400. # https://github.com/facebookresearch/pytorchvideo/blob/718d0a4/pytorchvideo/models/vision_transformers.py
  401. # We remove any experimental configuration that didn't make it to the final variants of the models. To represent
  402. # the configuration of the architecture we use the simplified form suggested at Table 1 of the paper.
  403. _log_api_usage_once(self)
  404. total_stage_blocks = len(block_setting)
  405. if total_stage_blocks == 0:
  406. raise ValueError("The configuration parameter can't be empty.")
  407. if block is None:
  408. block = MultiscaleBlock
  409. if norm_layer is None:
  410. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  411. # Patch Embedding module
  412. self.conv_proj = nn.Conv3d(
  413. in_channels=3,
  414. out_channels=block_setting[0].input_channels,
  415. kernel_size=patch_embed_kernel,
  416. stride=patch_embed_stride,
  417. padding=patch_embed_padding,
  418. )
  419. input_size = [size // stride for size, stride in zip((temporal_size,) + spatial_size, self.conv_proj.stride)]
  420. # Spatio-Temporal Class Positional Encoding
  421. self.pos_encoding = PositionalEncoding(
  422. embed_size=block_setting[0].input_channels,
  423. spatial_size=(input_size[1], input_size[2]),
  424. temporal_size=input_size[0],
  425. rel_pos_embed=rel_pos_embed,
  426. )
  427. # Encoder module
  428. self.blocks = nn.ModuleList()
  429. for stage_block_id, cnf in enumerate(block_setting):
  430. # adjust stochastic depth probability based on the depth of the stage block
  431. sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
  432. self.blocks.append(
  433. block(
  434. input_size=input_size,
  435. cnf=cnf,
  436. residual_pool=residual_pool,
  437. residual_with_cls_embed=residual_with_cls_embed,
  438. rel_pos_embed=rel_pos_embed,
  439. proj_after_attn=proj_after_attn,
  440. dropout=attention_dropout,
  441. stochastic_depth_prob=sd_prob,
  442. norm_layer=norm_layer,
  443. )
  444. )
  445. if len(cnf.stride_q) > 0:
  446. input_size = [size // stride for size, stride in zip(input_size, cnf.stride_q)]
  447. self.norm = norm_layer(block_setting[-1].output_channels)
  448. # Classifier module
  449. self.head = nn.Sequential(
  450. nn.Dropout(dropout, inplace=True),
  451. nn.Linear(block_setting[-1].output_channels, num_classes),
  452. )
  453. for m in self.modules():
  454. if isinstance(m, nn.Linear):
  455. nn.init.trunc_normal_(m.weight, std=0.02)
  456. if isinstance(m, nn.Linear) and m.bias is not None:
  457. nn.init.constant_(m.bias, 0.0)
  458. elif isinstance(m, nn.LayerNorm):
  459. if m.weight is not None:
  460. nn.init.constant_(m.weight, 1.0)
  461. if m.bias is not None:
  462. nn.init.constant_(m.bias, 0.0)
  463. elif isinstance(m, PositionalEncoding):
  464. for weights in m.parameters():
  465. nn.init.trunc_normal_(weights, std=0.02)
  466. def forward(self, x: torch.Tensor) -> torch.Tensor:
  467. # Convert if necessary (B, C, H, W) -> (B, C, 1, H, W)
  468. x = _unsqueeze(x, 5, 2)[0]
  469. # patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0])
  470. x = self.conv_proj(x)
  471. x = x.flatten(2).transpose(1, 2)
  472. # add positional encoding
  473. x = self.pos_encoding(x)
  474. # pass patches through the encoder
  475. thw = (self.pos_encoding.temporal_size,) + self.pos_encoding.spatial_size
  476. for block in self.blocks:
  477. x, thw = block(x, thw)
  478. x = self.norm(x)
  479. # classifier "token" as used by standard language architectures
  480. x = x[:, 0]
  481. x = self.head(x)
  482. return x
  483. def _mvit(
  484. block_setting: list[MSBlockConfig],
  485. stochastic_depth_prob: float,
  486. weights: Optional[WeightsEnum],
  487. progress: bool,
  488. **kwargs: Any,
  489. ) -> MViT:
  490. if weights is not None:
  491. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  492. assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
  493. _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"])
  494. _ovewrite_named_param(kwargs, "temporal_size", weights.meta["min_temporal_size"])
  495. spatial_size = kwargs.pop("spatial_size", (224, 224))
  496. temporal_size = kwargs.pop("temporal_size", 16)
  497. model = MViT(
  498. spatial_size=spatial_size,
  499. temporal_size=temporal_size,
  500. block_setting=block_setting,
  501. residual_pool=kwargs.pop("residual_pool", False),
  502. residual_with_cls_embed=kwargs.pop("residual_with_cls_embed", True),
  503. rel_pos_embed=kwargs.pop("rel_pos_embed", False),
  504. proj_after_attn=kwargs.pop("proj_after_attn", False),
  505. stochastic_depth_prob=stochastic_depth_prob,
  506. **kwargs,
  507. )
  508. if weights is not None:
  509. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  510. return model
  511. class MViT_V1_B_Weights(WeightsEnum):
  512. KINETICS400_V1 = Weights(
  513. url="https://download.pytorch.org/models/mvit_v1_b-dbeb1030.pth",
  514. transforms=partial(
  515. VideoClassification,
  516. crop_size=(224, 224),
  517. resize_size=(256,),
  518. mean=(0.45, 0.45, 0.45),
  519. std=(0.225, 0.225, 0.225),
  520. ),
  521. meta={
  522. "min_size": (224, 224),
  523. "min_temporal_size": 16,
  524. "categories": _KINETICS400_CATEGORIES,
  525. "recipe": "https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md",
  526. "_docs": (
  527. "The weights were ported from the paper. The accuracies are estimated on video-level "
  528. "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
  529. ),
  530. "num_params": 36610672,
  531. "_metrics": {
  532. "Kinetics-400": {
  533. "acc@1": 78.477,
  534. "acc@5": 93.582,
  535. }
  536. },
  537. "_ops": 70.599,
  538. "_file_size": 139.764,
  539. },
  540. )
  541. DEFAULT = KINETICS400_V1
  542. class MViT_V2_S_Weights(WeightsEnum):
  543. KINETICS400_V1 = Weights(
  544. url="https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth",
  545. transforms=partial(
  546. VideoClassification,
  547. crop_size=(224, 224),
  548. resize_size=(256,),
  549. mean=(0.45, 0.45, 0.45),
  550. std=(0.225, 0.225, 0.225),
  551. ),
  552. meta={
  553. "min_size": (224, 224),
  554. "min_temporal_size": 16,
  555. "categories": _KINETICS400_CATEGORIES,
  556. "recipe": "https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md",
  557. "_docs": (
  558. "The weights were ported from the paper. The accuracies are estimated on video-level "
  559. "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
  560. ),
  561. "num_params": 34537744,
  562. "_metrics": {
  563. "Kinetics-400": {
  564. "acc@1": 80.757,
  565. "acc@5": 94.665,
  566. }
  567. },
  568. "_ops": 64.224,
  569. "_file_size": 131.884,
  570. },
  571. )
  572. DEFAULT = KINETICS400_V1
  573. @register_model()
  574. @handle_legacy_interface(weights=("pretrained", MViT_V1_B_Weights.KINETICS400_V1))
  575. def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
  576. """
  577. Constructs a base MViTV1 architecture from
  578. `Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__.
  579. .. betastatus:: video module
  580. Args:
  581. weights (:class:`~torchvision.models.video.MViT_V1_B_Weights`, optional): The
  582. pretrained weights to use. See
  583. :class:`~torchvision.models.video.MViT_V1_B_Weights` below for
  584. more details, and possible values. By default, no pre-trained
  585. weights are used.
  586. progress (bool, optional): If True, displays a progress bar of the
  587. download to stderr. Default is True.
  588. **kwargs: parameters passed to the ``torchvision.models.video.MViT``
  589. base class. Please refer to the `source code
  590. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_
  591. for more details about this class.
  592. .. autoclass:: torchvision.models.video.MViT_V1_B_Weights
  593. :members:
  594. """
  595. weights = MViT_V1_B_Weights.verify(weights)
  596. config: dict[str, list] = {
  597. "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8],
  598. "input_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768],
  599. "output_channels": [192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768, 768],
  600. "kernel_q": [[], [3, 3, 3], [], [3, 3, 3], [], [], [], [], [], [], [], [], [], [], [3, 3, 3], []],
  601. "kernel_kv": [
  602. [3, 3, 3],
  603. [3, 3, 3],
  604. [3, 3, 3],
  605. [3, 3, 3],
  606. [3, 3, 3],
  607. [3, 3, 3],
  608. [3, 3, 3],
  609. [3, 3, 3],
  610. [3, 3, 3],
  611. [3, 3, 3],
  612. [3, 3, 3],
  613. [3, 3, 3],
  614. [3, 3, 3],
  615. [3, 3, 3],
  616. [3, 3, 3],
  617. [3, 3, 3],
  618. ],
  619. "stride_q": [[], [1, 2, 2], [], [1, 2, 2], [], [], [], [], [], [], [], [], [], [], [1, 2, 2], []],
  620. "stride_kv": [
  621. [1, 8, 8],
  622. [1, 4, 4],
  623. [1, 4, 4],
  624. [1, 2, 2],
  625. [1, 2, 2],
  626. [1, 2, 2],
  627. [1, 2, 2],
  628. [1, 2, 2],
  629. [1, 2, 2],
  630. [1, 2, 2],
  631. [1, 2, 2],
  632. [1, 2, 2],
  633. [1, 2, 2],
  634. [1, 2, 2],
  635. [1, 1, 1],
  636. [1, 1, 1],
  637. ],
  638. }
  639. block_setting = []
  640. for i in range(len(config["num_heads"])):
  641. block_setting.append(
  642. MSBlockConfig(
  643. num_heads=config["num_heads"][i],
  644. input_channels=config["input_channels"][i],
  645. output_channels=config["output_channels"][i],
  646. kernel_q=config["kernel_q"][i],
  647. kernel_kv=config["kernel_kv"][i],
  648. stride_q=config["stride_q"][i],
  649. stride_kv=config["stride_kv"][i],
  650. )
  651. )
  652. return _mvit(
  653. spatial_size=(224, 224),
  654. temporal_size=16,
  655. block_setting=block_setting,
  656. residual_pool=False,
  657. residual_with_cls_embed=False,
  658. stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
  659. weights=weights,
  660. progress=progress,
  661. **kwargs,
  662. )
  663. @register_model()
  664. @handle_legacy_interface(weights=("pretrained", MViT_V2_S_Weights.KINETICS400_V1))
  665. def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
  666. """Constructs a small MViTV2 architecture from
  667. `Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__ and
  668. `MViTv2: Improved Multiscale Vision Transformers for Classification
  669. and Detection <https://arxiv.org/abs/2112.01526>`__.
  670. .. betastatus:: video module
  671. Args:
  672. weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The
  673. pretrained weights to use. See
  674. :class:`~torchvision.models.video.MViT_V2_S_Weights` below for
  675. more details, and possible values. By default, no pre-trained
  676. weights are used.
  677. progress (bool, optional): If True, displays a progress bar of the
  678. download to stderr. Default is True.
  679. **kwargs: parameters passed to the ``torchvision.models.video.MViT``
  680. base class. Please refer to the `source code
  681. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_
  682. for more details about this class.
  683. .. autoclass:: torchvision.models.video.MViT_V2_S_Weights
  684. :members:
  685. """
  686. weights = MViT_V2_S_Weights.verify(weights)
  687. config: dict[str, list] = {
  688. "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8],
  689. "input_channels": [96, 96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768],
  690. "output_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768],
  691. "kernel_q": [
  692. [3, 3, 3],
  693. [3, 3, 3],
  694. [3, 3, 3],
  695. [3, 3, 3],
  696. [3, 3, 3],
  697. [3, 3, 3],
  698. [3, 3, 3],
  699. [3, 3, 3],
  700. [3, 3, 3],
  701. [3, 3, 3],
  702. [3, 3, 3],
  703. [3, 3, 3],
  704. [3, 3, 3],
  705. [3, 3, 3],
  706. [3, 3, 3],
  707. [3, 3, 3],
  708. ],
  709. "kernel_kv": [
  710. [3, 3, 3],
  711. [3, 3, 3],
  712. [3, 3, 3],
  713. [3, 3, 3],
  714. [3, 3, 3],
  715. [3, 3, 3],
  716. [3, 3, 3],
  717. [3, 3, 3],
  718. [3, 3, 3],
  719. [3, 3, 3],
  720. [3, 3, 3],
  721. [3, 3, 3],
  722. [3, 3, 3],
  723. [3, 3, 3],
  724. [3, 3, 3],
  725. [3, 3, 3],
  726. ],
  727. "stride_q": [
  728. [1, 1, 1],
  729. [1, 2, 2],
  730. [1, 1, 1],
  731. [1, 2, 2],
  732. [1, 1, 1],
  733. [1, 1, 1],
  734. [1, 1, 1],
  735. [1, 1, 1],
  736. [1, 1, 1],
  737. [1, 1, 1],
  738. [1, 1, 1],
  739. [1, 1, 1],
  740. [1, 1, 1],
  741. [1, 1, 1],
  742. [1, 2, 2],
  743. [1, 1, 1],
  744. ],
  745. "stride_kv": [
  746. [1, 8, 8],
  747. [1, 4, 4],
  748. [1, 4, 4],
  749. [1, 2, 2],
  750. [1, 2, 2],
  751. [1, 2, 2],
  752. [1, 2, 2],
  753. [1, 2, 2],
  754. [1, 2, 2],
  755. [1, 2, 2],
  756. [1, 2, 2],
  757. [1, 2, 2],
  758. [1, 2, 2],
  759. [1, 2, 2],
  760. [1, 1, 1],
  761. [1, 1, 1],
  762. ],
  763. }
  764. block_setting = []
  765. for i in range(len(config["num_heads"])):
  766. block_setting.append(
  767. MSBlockConfig(
  768. num_heads=config["num_heads"][i],
  769. input_channels=config["input_channels"][i],
  770. output_channels=config["output_channels"][i],
  771. kernel_q=config["kernel_q"][i],
  772. kernel_kv=config["kernel_kv"][i],
  773. stride_q=config["stride_q"][i],
  774. stride_kv=config["stride_kv"][i],
  775. )
  776. )
  777. return _mvit(
  778. spatial_size=(224, 224),
  779. temporal_size=16,
  780. block_setting=block_setting,
  781. residual_pool=True,
  782. residual_with_cls_embed=False,
  783. rel_pos_embed=True,
  784. proj_after_attn=True,
  785. stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
  786. weights=weights,
  787. progress=progress,
  788. **kwargs,
  789. )