resnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. from collections.abc import Sequence
  2. from functools import partial
  3. from typing import Any, Callable, Optional, Union
  4. import torch.nn as nn
  5. from torch import Tensor
  6. from ...transforms._presets import VideoClassification
  7. from ...utils import _log_api_usage_once
  8. from .._api import register_model, Weights, WeightsEnum
  9. from .._meta import _KINETICS400_CATEGORIES
  10. from .._utils import _ovewrite_named_param, handle_legacy_interface
  11. __all__ = [
  12. "VideoResNet",
  13. "R3D_18_Weights",
  14. "MC3_18_Weights",
  15. "R2Plus1D_18_Weights",
  16. "r3d_18",
  17. "mc3_18",
  18. "r2plus1d_18",
  19. ]
  20. class Conv3DSimple(nn.Conv3d):
  21. def __init__(
  22. self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
  23. ) -> None:
  24. super().__init__(
  25. in_channels=in_planes,
  26. out_channels=out_planes,
  27. kernel_size=(3, 3, 3),
  28. stride=stride,
  29. padding=padding,
  30. bias=False,
  31. )
  32. @staticmethod
  33. def get_downsample_stride(stride: int) -> tuple[int, int, int]:
  34. return stride, stride, stride
  35. class Conv2Plus1D(nn.Sequential):
  36. def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None:
  37. super().__init__(
  38. nn.Conv3d(
  39. in_planes,
  40. midplanes,
  41. kernel_size=(1, 3, 3),
  42. stride=(1, stride, stride),
  43. padding=(0, padding, padding),
  44. bias=False,
  45. ),
  46. nn.BatchNorm3d(midplanes),
  47. nn.ReLU(inplace=True),
  48. nn.Conv3d(
  49. midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False
  50. ),
  51. )
  52. @staticmethod
  53. def get_downsample_stride(stride: int) -> tuple[int, int, int]:
  54. return stride, stride, stride
  55. class Conv3DNoTemporal(nn.Conv3d):
  56. def __init__(
  57. self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
  58. ) -> None:
  59. super().__init__(
  60. in_channels=in_planes,
  61. out_channels=out_planes,
  62. kernel_size=(1, 3, 3),
  63. stride=(1, stride, stride),
  64. padding=(0, padding, padding),
  65. bias=False,
  66. )
  67. @staticmethod
  68. def get_downsample_stride(stride: int) -> tuple[int, int, int]:
  69. return 1, stride, stride
  70. class BasicBlock(nn.Module):
  71. expansion = 1
  72. def __init__(
  73. self,
  74. inplanes: int,
  75. planes: int,
  76. conv_builder: Callable[..., nn.Module],
  77. stride: int = 1,
  78. downsample: Optional[nn.Module] = None,
  79. ) -> None:
  80. midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
  81. super().__init__()
  82. self.conv1 = nn.Sequential(
  83. conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
  84. )
  85. self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes))
  86. self.relu = nn.ReLU(inplace=True)
  87. self.downsample = downsample
  88. self.stride = stride
  89. def forward(self, x: Tensor) -> Tensor:
  90. residual = x
  91. out = self.conv1(x)
  92. out = self.conv2(out)
  93. if self.downsample is not None:
  94. residual = self.downsample(x)
  95. out += residual
  96. out = self.relu(out)
  97. return out
  98. class Bottleneck(nn.Module):
  99. expansion = 4
  100. def __init__(
  101. self,
  102. inplanes: int,
  103. planes: int,
  104. conv_builder: Callable[..., nn.Module],
  105. stride: int = 1,
  106. downsample: Optional[nn.Module] = None,
  107. ) -> None:
  108. super().__init__()
  109. midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
  110. # 1x1x1
  111. self.conv1 = nn.Sequential(
  112. nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
  113. )
  114. # Second kernel
  115. self.conv2 = nn.Sequential(
  116. conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
  117. )
  118. # 1x1x1
  119. self.conv3 = nn.Sequential(
  120. nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
  121. nn.BatchNorm3d(planes * self.expansion),
  122. )
  123. self.relu = nn.ReLU(inplace=True)
  124. self.downsample = downsample
  125. self.stride = stride
  126. def forward(self, x: Tensor) -> Tensor:
  127. residual = x
  128. out = self.conv1(x)
  129. out = self.conv2(out)
  130. out = self.conv3(out)
  131. if self.downsample is not None:
  132. residual = self.downsample(x)
  133. out += residual
  134. out = self.relu(out)
  135. return out
  136. class BasicStem(nn.Sequential):
  137. """The default conv-batchnorm-relu stem"""
  138. def __init__(self) -> None:
  139. super().__init__(
  140. nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False),
  141. nn.BatchNorm3d(64),
  142. nn.ReLU(inplace=True),
  143. )
  144. class R2Plus1dStem(nn.Sequential):
  145. """R(2+1)D stem is different than the default one as it uses separated 3D convolution"""
  146. def __init__(self) -> None:
  147. super().__init__(
  148. nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False),
  149. nn.BatchNorm3d(45),
  150. nn.ReLU(inplace=True),
  151. nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False),
  152. nn.BatchNorm3d(64),
  153. nn.ReLU(inplace=True),
  154. )
  155. class VideoResNet(nn.Module):
  156. def __init__(
  157. self,
  158. block: type[Union[BasicBlock, Bottleneck]],
  159. conv_makers: Sequence[type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
  160. layers: list[int],
  161. stem: Callable[..., nn.Module],
  162. num_classes: int = 400,
  163. zero_init_residual: bool = False,
  164. ) -> None:
  165. """Generic resnet video generator.
  166. Args:
  167. block (Type[Union[BasicBlock, Bottleneck]]): resnet building block
  168. conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator
  169. function for each layer
  170. layers (List[int]): number of blocks per layer
  171. stem (Callable[..., nn.Module]): module specifying the ResNet stem.
  172. num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
  173. zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
  174. """
  175. super().__init__()
  176. _log_api_usage_once(self)
  177. self.inplanes = 64
  178. self.stem = stem()
  179. self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
  180. self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
  181. self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)
  182. self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)
  183. self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
  184. self.fc = nn.Linear(512 * block.expansion, num_classes)
  185. # init weights
  186. for m in self.modules():
  187. if isinstance(m, nn.Conv3d):
  188. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  189. if m.bias is not None:
  190. nn.init.constant_(m.bias, 0)
  191. elif isinstance(m, nn.BatchNorm3d):
  192. nn.init.constant_(m.weight, 1)
  193. nn.init.constant_(m.bias, 0)
  194. elif isinstance(m, nn.Linear):
  195. nn.init.normal_(m.weight, 0, 0.01)
  196. nn.init.constant_(m.bias, 0)
  197. if zero_init_residual:
  198. for m in self.modules():
  199. if isinstance(m, Bottleneck):
  200. nn.init.constant_(m.bn3.weight, 0) # type: ignore[union-attr, arg-type]
  201. def forward(self, x: Tensor) -> Tensor:
  202. x = self.stem(x)
  203. x = self.layer1(x)
  204. x = self.layer2(x)
  205. x = self.layer3(x)
  206. x = self.layer4(x)
  207. x = self.avgpool(x)
  208. # Flatten the layer to fc
  209. x = x.flatten(1)
  210. x = self.fc(x)
  211. return x
  212. def _make_layer(
  213. self,
  214. block: type[Union[BasicBlock, Bottleneck]],
  215. conv_builder: type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]],
  216. planes: int,
  217. blocks: int,
  218. stride: int = 1,
  219. ) -> nn.Sequential:
  220. downsample = None
  221. if stride != 1 or self.inplanes != planes * block.expansion:
  222. ds_stride = conv_builder.get_downsample_stride(stride)
  223. downsample = nn.Sequential(
  224. nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False),
  225. nn.BatchNorm3d(planes * block.expansion),
  226. )
  227. layers = []
  228. layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
  229. self.inplanes = planes * block.expansion
  230. for i in range(1, blocks):
  231. layers.append(block(self.inplanes, planes, conv_builder))
  232. return nn.Sequential(*layers)
  233. def _video_resnet(
  234. block: type[Union[BasicBlock, Bottleneck]],
  235. conv_makers: Sequence[type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
  236. layers: list[int],
  237. stem: Callable[..., nn.Module],
  238. weights: Optional[WeightsEnum],
  239. progress: bool,
  240. **kwargs: Any,
  241. ) -> VideoResNet:
  242. if weights is not None:
  243. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  244. model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
  245. if weights is not None:
  246. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  247. return model
  248. _COMMON_META = {
  249. "min_size": (1, 1),
  250. "categories": _KINETICS400_CATEGORIES,
  251. "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
  252. "_docs": (
  253. "The weights reproduce closely the accuracy of the paper. The accuracies are estimated on video-level "
  254. "with parameters `frame_rate=15`, `clips_per_video=5`, and `clip_len=16`."
  255. ),
  256. }
  257. class R3D_18_Weights(WeightsEnum):
  258. KINETICS400_V1 = Weights(
  259. url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
  260. transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
  261. meta={
  262. **_COMMON_META,
  263. "num_params": 33371472,
  264. "_metrics": {
  265. "Kinetics-400": {
  266. "acc@1": 63.200,
  267. "acc@5": 83.479,
  268. }
  269. },
  270. "_ops": 40.697,
  271. "_file_size": 127.359,
  272. },
  273. )
  274. DEFAULT = KINETICS400_V1
  275. class MC3_18_Weights(WeightsEnum):
  276. KINETICS400_V1 = Weights(
  277. url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
  278. transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
  279. meta={
  280. **_COMMON_META,
  281. "num_params": 11695440,
  282. "_metrics": {
  283. "Kinetics-400": {
  284. "acc@1": 63.960,
  285. "acc@5": 84.130,
  286. }
  287. },
  288. "_ops": 43.343,
  289. "_file_size": 44.672,
  290. },
  291. )
  292. DEFAULT = KINETICS400_V1
  293. class R2Plus1D_18_Weights(WeightsEnum):
  294. KINETICS400_V1 = Weights(
  295. url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
  296. transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
  297. meta={
  298. **_COMMON_META,
  299. "num_params": 31505325,
  300. "_metrics": {
  301. "Kinetics-400": {
  302. "acc@1": 67.463,
  303. "acc@5": 86.175,
  304. }
  305. },
  306. "_ops": 40.519,
  307. "_file_size": 120.318,
  308. },
  309. )
  310. DEFAULT = KINETICS400_V1
  311. @register_model()
  312. @handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
  313. def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
  314. """Construct 18 layer Resnet3D model.
  315. .. betastatus:: video module
  316. Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
  317. Args:
  318. weights (:class:`~torchvision.models.video.R3D_18_Weights`, optional): The
  319. pretrained weights to use. See
  320. :class:`~torchvision.models.video.R3D_18_Weights`
  321. below for more details, and possible values. By default, no
  322. pre-trained weights are used.
  323. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  324. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
  325. Please refer to the `source code
  326. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
  327. for more details about this class.
  328. .. autoclass:: torchvision.models.video.R3D_18_Weights
  329. :members:
  330. """
  331. weights = R3D_18_Weights.verify(weights)
  332. return _video_resnet(
  333. BasicBlock,
  334. [Conv3DSimple] * 4,
  335. [2, 2, 2, 2],
  336. BasicStem,
  337. weights,
  338. progress,
  339. **kwargs,
  340. )
  341. @register_model()
  342. @handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
  343. def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
  344. """Construct 18 layer Mixed Convolution network as in
  345. .. betastatus:: video module
  346. Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
  347. Args:
  348. weights (:class:`~torchvision.models.video.MC3_18_Weights`, optional): The
  349. pretrained weights to use. See
  350. :class:`~torchvision.models.video.MC3_18_Weights`
  351. below for more details, and possible values. By default, no
  352. pre-trained weights are used.
  353. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  354. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
  355. Please refer to the `source code
  356. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
  357. for more details about this class.
  358. .. autoclass:: torchvision.models.video.MC3_18_Weights
  359. :members:
  360. """
  361. weights = MC3_18_Weights.verify(weights)
  362. return _video_resnet(
  363. BasicBlock,
  364. [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item]
  365. [2, 2, 2, 2],
  366. BasicStem,
  367. weights,
  368. progress,
  369. **kwargs,
  370. )
  371. @register_model()
  372. @handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
  373. def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
  374. """Construct 18 layer deep R(2+1)D network as in
  375. .. betastatus:: video module
  376. Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
  377. Args:
  378. weights (:class:`~torchvision.models.video.R2Plus1D_18_Weights`, optional): The
  379. pretrained weights to use. See
  380. :class:`~torchvision.models.video.R2Plus1D_18_Weights`
  381. below for more details, and possible values. By default, no
  382. pre-trained weights are used.
  383. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  384. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
  385. Please refer to the `source code
  386. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
  387. for more details about this class.
  388. .. autoclass:: torchvision.models.video.R2Plus1D_18_Weights
  389. :members:
  390. """
  391. weights = R2Plus1D_18_Weights.verify(weights)
  392. return _video_resnet(
  393. BasicBlock,
  394. [Conv2Plus1D] * 4,
  395. [2, 2, 2, 2],
  396. R2Plus1dStem,
  397. weights,
  398. progress,
  399. **kwargs,
  400. )
  401. # The dictionary below is internal implementation detail and will be removed in v0.15
  402. from .._utils import _ModelURLs
  403. model_urls = _ModelURLs(
  404. {
  405. "r3d_18": R3D_18_Weights.KINETICS400_V1.url,
  406. "mc3_18": MC3_18_Weights.KINETICS400_V1.url,
  407. "r2plus1d_18": R2Plus1D_18_Weights.KINETICS400_V1.url,
  408. }
  409. )