resnet.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. from functools import partial
  2. from typing import Any, Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. from torch import Tensor
  6. from torchvision.models.resnet import (
  7. BasicBlock,
  8. Bottleneck,
  9. ResNet,
  10. ResNet18_Weights,
  11. ResNet50_Weights,
  12. ResNeXt101_32X8D_Weights,
  13. ResNeXt101_64X4D_Weights,
  14. )
  15. from ...transforms._presets import ImageClassification
  16. from .._api import register_model, Weights, WeightsEnum
  17. from .._meta import _IMAGENET_CATEGORIES
  18. from .._utils import _ovewrite_named_param, handle_legacy_interface
  19. from .utils import _fuse_modules, _replace_relu, quantize_model
  20. __all__ = [
  21. "QuantizableResNet",
  22. "ResNet18_QuantizedWeights",
  23. "ResNet50_QuantizedWeights",
  24. "ResNeXt101_32X8D_QuantizedWeights",
  25. "ResNeXt101_64X4D_QuantizedWeights",
  26. "resnet18",
  27. "resnet50",
  28. "resnext101_32x8d",
  29. "resnext101_64x4d",
  30. ]
  31. class QuantizableBasicBlock(BasicBlock):
  32. def __init__(self, *args: Any, **kwargs: Any) -> None:
  33. super().__init__(*args, **kwargs)
  34. self.add_relu = torch.nn.quantized.FloatFunctional()
  35. def forward(self, x: Tensor) -> Tensor:
  36. identity = x
  37. out = self.conv1(x)
  38. out = self.bn1(out)
  39. out = self.relu(out)
  40. out = self.conv2(out)
  41. out = self.bn2(out)
  42. if self.downsample is not None:
  43. identity = self.downsample(x)
  44. out = self.add_relu.add_relu(out, identity)
  45. return out
  46. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  47. _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
  48. if self.downsample:
  49. _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
  50. class QuantizableBottleneck(Bottleneck):
  51. def __init__(self, *args: Any, **kwargs: Any) -> None:
  52. super().__init__(*args, **kwargs)
  53. self.skip_add_relu = nn.quantized.FloatFunctional()
  54. self.relu1 = nn.ReLU(inplace=False)
  55. self.relu2 = nn.ReLU(inplace=False)
  56. def forward(self, x: Tensor) -> Tensor:
  57. identity = x
  58. out = self.conv1(x)
  59. out = self.bn1(out)
  60. out = self.relu1(out)
  61. out = self.conv2(out)
  62. out = self.bn2(out)
  63. out = self.relu2(out)
  64. out = self.conv3(out)
  65. out = self.bn3(out)
  66. if self.downsample is not None:
  67. identity = self.downsample(x)
  68. out = self.skip_add_relu.add_relu(out, identity)
  69. return out
  70. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  71. _fuse_modules(
  72. self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
  73. )
  74. if self.downsample:
  75. _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
  76. class QuantizableResNet(ResNet):
  77. def __init__(self, *args: Any, **kwargs: Any) -> None:
  78. super().__init__(*args, **kwargs)
  79. self.quant = torch.ao.quantization.QuantStub()
  80. self.dequant = torch.ao.quantization.DeQuantStub()
  81. def forward(self, x: Tensor) -> Tensor:
  82. x = self.quant(x)
  83. # Ensure scriptability
  84. # super(QuantizableResNet,self).forward(x)
  85. # is not scriptable
  86. x = self._forward_impl(x)
  87. x = self.dequant(x)
  88. return x
  89. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  90. r"""Fuse conv/bn/relu modules in resnet models
  91. Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
  92. Model is modified in place. Note that this operation does not change numerics
  93. and the model after modification is in floating point
  94. """
  95. _fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
  96. for m in self.modules():
  97. if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
  98. m.fuse_model(is_qat)
  99. def _resnet(
  100. block: type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
  101. layers: list[int],
  102. weights: Optional[WeightsEnum],
  103. progress: bool,
  104. quantize: bool,
  105. **kwargs: Any,
  106. ) -> QuantizableResNet:
  107. if weights is not None:
  108. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  109. if "backend" in weights.meta:
  110. _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
  111. backend = kwargs.pop("backend", "fbgemm")
  112. model = QuantizableResNet(block, layers, **kwargs)
  113. _replace_relu(model)
  114. if quantize:
  115. quantize_model(model, backend)
  116. if weights is not None:
  117. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  118. return model
  119. _COMMON_META = {
  120. "min_size": (1, 1),
  121. "categories": _IMAGENET_CATEGORIES,
  122. "backend": "fbgemm",
  123. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
  124. "_docs": """
  125. These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized
  126. weights listed below.
  127. """,
  128. }
  129. class ResNet18_QuantizedWeights(WeightsEnum):
  130. IMAGENET1K_FBGEMM_V1 = Weights(
  131. url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
  132. transforms=partial(ImageClassification, crop_size=224),
  133. meta={
  134. **_COMMON_META,
  135. "num_params": 11689512,
  136. "unquantized": ResNet18_Weights.IMAGENET1K_V1,
  137. "_metrics": {
  138. "ImageNet-1K": {
  139. "acc@1": 69.494,
  140. "acc@5": 88.882,
  141. }
  142. },
  143. "_ops": 1.814,
  144. "_file_size": 11.238,
  145. },
  146. )
  147. DEFAULT = IMAGENET1K_FBGEMM_V1
  148. class ResNet50_QuantizedWeights(WeightsEnum):
  149. IMAGENET1K_FBGEMM_V1 = Weights(
  150. url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
  151. transforms=partial(ImageClassification, crop_size=224),
  152. meta={
  153. **_COMMON_META,
  154. "num_params": 25557032,
  155. "unquantized": ResNet50_Weights.IMAGENET1K_V1,
  156. "_metrics": {
  157. "ImageNet-1K": {
  158. "acc@1": 75.920,
  159. "acc@5": 92.814,
  160. }
  161. },
  162. "_ops": 4.089,
  163. "_file_size": 24.759,
  164. },
  165. )
  166. IMAGENET1K_FBGEMM_V2 = Weights(
  167. url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
  168. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  169. meta={
  170. **_COMMON_META,
  171. "num_params": 25557032,
  172. "unquantized": ResNet50_Weights.IMAGENET1K_V2,
  173. "_metrics": {
  174. "ImageNet-1K": {
  175. "acc@1": 80.282,
  176. "acc@5": 94.976,
  177. }
  178. },
  179. "_ops": 4.089,
  180. "_file_size": 24.953,
  181. },
  182. )
  183. DEFAULT = IMAGENET1K_FBGEMM_V2
  184. class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
  185. IMAGENET1K_FBGEMM_V1 = Weights(
  186. url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
  187. transforms=partial(ImageClassification, crop_size=224),
  188. meta={
  189. **_COMMON_META,
  190. "num_params": 88791336,
  191. "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
  192. "_metrics": {
  193. "ImageNet-1K": {
  194. "acc@1": 78.986,
  195. "acc@5": 94.480,
  196. }
  197. },
  198. "_ops": 16.414,
  199. "_file_size": 86.034,
  200. },
  201. )
  202. IMAGENET1K_FBGEMM_V2 = Weights(
  203. url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
  204. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  205. meta={
  206. **_COMMON_META,
  207. "num_params": 88791336,
  208. "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
  209. "_metrics": {
  210. "ImageNet-1K": {
  211. "acc@1": 82.574,
  212. "acc@5": 96.132,
  213. }
  214. },
  215. "_ops": 16.414,
  216. "_file_size": 86.645,
  217. },
  218. )
  219. DEFAULT = IMAGENET1K_FBGEMM_V2
  220. class ResNeXt101_64X4D_QuantizedWeights(WeightsEnum):
  221. IMAGENET1K_FBGEMM_V1 = Weights(
  222. url="https://download.pytorch.org/models/quantized/resnext101_64x4d_fbgemm-605a1cb3.pth",
  223. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  224. meta={
  225. **_COMMON_META,
  226. "num_params": 83455272,
  227. "recipe": "https://github.com/pytorch/vision/pull/5935",
  228. "unquantized": ResNeXt101_64X4D_Weights.IMAGENET1K_V1,
  229. "_metrics": {
  230. "ImageNet-1K": {
  231. "acc@1": 82.898,
  232. "acc@5": 96.326,
  233. }
  234. },
  235. "_ops": 15.46,
  236. "_file_size": 81.556,
  237. },
  238. )
  239. DEFAULT = IMAGENET1K_FBGEMM_V1
  240. @register_model(name="quantized_resnet18")
  241. @handle_legacy_interface(
  242. weights=(
  243. "pretrained",
  244. lambda kwargs: (
  245. ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  246. if kwargs.get("quantize", False)
  247. else ResNet18_Weights.IMAGENET1K_V1
  248. ),
  249. )
  250. )
  251. def resnet18(
  252. *,
  253. weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
  254. progress: bool = True,
  255. quantize: bool = False,
  256. **kwargs: Any,
  257. ) -> QuantizableResNet:
  258. """ResNet-18 model from
  259. `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_
  260. .. note::
  261. Note that ``quantize = True`` returns a quantized model with 8 bit
  262. weights. Quantized models only support inference and run on CPUs.
  263. GPU inference is not yet supported.
  264. Args:
  265. weights (:class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` or :class:`~torchvision.models.ResNet18_Weights`, optional): The
  266. pretrained weights for the model. See
  267. :class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` below for
  268. more details, and possible values. By default, no pre-trained
  269. weights are used.
  270. progress (bool, optional): If True, displays a progress bar of the
  271. download to stderr. Default is True.
  272. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  273. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  274. base class. Please refer to the `source code
  275. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  276. for more details about this class.
  277. .. autoclass:: torchvision.models.quantization.ResNet18_QuantizedWeights
  278. :members:
  279. .. autoclass:: torchvision.models.ResNet18_Weights
  280. :members:
  281. :noindex:
  282. """
  283. weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)
  284. return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
  285. @register_model(name="quantized_resnet50")
  286. @handle_legacy_interface(
  287. weights=(
  288. "pretrained",
  289. lambda kwargs: (
  290. ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  291. if kwargs.get("quantize", False)
  292. else ResNet50_Weights.IMAGENET1K_V1
  293. ),
  294. )
  295. )
  296. def resnet50(
  297. *,
  298. weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
  299. progress: bool = True,
  300. quantize: bool = False,
  301. **kwargs: Any,
  302. ) -> QuantizableResNet:
  303. """ResNet-50 model from
  304. `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_
  305. .. note::
  306. Note that ``quantize = True`` returns a quantized model with 8 bit
  307. weights. Quantized models only support inference and run on CPUs.
  308. GPU inference is not yet supported.
  309. Args:
  310. weights (:class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` or :class:`~torchvision.models.ResNet50_Weights`, optional): The
  311. pretrained weights for the model. See
  312. :class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` below for
  313. more details, and possible values. By default, no pre-trained
  314. weights are used.
  315. progress (bool, optional): If True, displays a progress bar of the
  316. download to stderr. Default is True.
  317. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  318. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  319. base class. Please refer to the `source code
  320. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  321. for more details about this class.
  322. .. autoclass:: torchvision.models.quantization.ResNet50_QuantizedWeights
  323. :members:
  324. .. autoclass:: torchvision.models.ResNet50_Weights
  325. :members:
  326. :noindex:
  327. """
  328. weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
  329. return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
  330. @register_model(name="quantized_resnext101_32x8d")
  331. @handle_legacy_interface(
  332. weights=(
  333. "pretrained",
  334. lambda kwargs: (
  335. ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  336. if kwargs.get("quantize", False)
  337. else ResNeXt101_32X8D_Weights.IMAGENET1K_V1
  338. ),
  339. )
  340. )
  341. def resnext101_32x8d(
  342. *,
  343. weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
  344. progress: bool = True,
  345. quantize: bool = False,
  346. **kwargs: Any,
  347. ) -> QuantizableResNet:
  348. """ResNeXt-101 32x8d model from
  349. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_
  350. .. note::
  351. Note that ``quantize = True`` returns a quantized model with 8 bit
  352. weights. Quantized models only support inference and run on CPUs.
  353. GPU inference is not yet supported.
  354. Args:
  355. weights (:class:`~torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The
  356. pretrained weights for the model. See
  357. :class:`~torchvision.models.quantization.ResNet101_32X8D_QuantizedWeights` below for
  358. more details, and possible values. By default, no pre-trained
  359. weights are used.
  360. progress (bool, optional): If True, displays a progress bar of the
  361. download to stderr. Default is True.
  362. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  363. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  364. base class. Please refer to the `source code
  365. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  366. for more details about this class.
  367. .. autoclass:: torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights
  368. :members:
  369. .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights
  370. :members:
  371. :noindex:
  372. """
  373. weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)
  374. _ovewrite_named_param(kwargs, "groups", 32)
  375. _ovewrite_named_param(kwargs, "width_per_group", 8)
  376. return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
  377. @register_model(name="quantized_resnext101_64x4d")
  378. @handle_legacy_interface(
  379. weights=(
  380. "pretrained",
  381. lambda kwargs: (
  382. ResNeXt101_64X4D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  383. if kwargs.get("quantize", False)
  384. else ResNeXt101_64X4D_Weights.IMAGENET1K_V1
  385. ),
  386. )
  387. )
  388. def resnext101_64x4d(
  389. *,
  390. weights: Optional[Union[ResNeXt101_64X4D_QuantizedWeights, ResNeXt101_64X4D_Weights]] = None,
  391. progress: bool = True,
  392. quantize: bool = False,
  393. **kwargs: Any,
  394. ) -> QuantizableResNet:
  395. """ResNeXt-101 64x4d model from
  396. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_
  397. .. note::
  398. Note that ``quantize = True`` returns a quantized model with 8 bit
  399. weights. Quantized models only support inference and run on CPUs.
  400. GPU inference is not yet supported.
  401. Args:
  402. weights (:class:`~torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The
  403. pretrained weights for the model. See
  404. :class:`~torchvision.models.quantization.ResNet101_64X4D_QuantizedWeights` below for
  405. more details, and possible values. By default, no pre-trained
  406. weights are used.
  407. progress (bool, optional): If True, displays a progress bar of the
  408. download to stderr. Default is True.
  409. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  410. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  411. base class. Please refer to the `source code
  412. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  413. for more details about this class.
  414. .. autoclass:: torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights
  415. :members:
  416. .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights
  417. :members:
  418. :noindex:
  419. """
  420. weights = (ResNeXt101_64X4D_QuantizedWeights if quantize else ResNeXt101_64X4D_Weights).verify(weights)
  421. _ovewrite_named_param(kwargs, "groups", 64)
  422. _ovewrite_named_param(kwargs, "width_per_group", 4)
  423. return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)