efficientnet.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132
  1. import copy
  2. import math
  3. from collections.abc import Sequence
  4. from dataclasses import dataclass
  5. from functools import partial
  6. from typing import Any, Callable, Optional, Union
  7. import torch
  8. from torch import nn, Tensor
  9. from torchvision.ops import StochasticDepth
  10. from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
  11. from ..transforms._presets import ImageClassification, InterpolationMode
  12. from ..utils import _log_api_usage_once
  13. from ._api import register_model, Weights, WeightsEnum
  14. from ._meta import _IMAGENET_CATEGORIES
  15. from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
  16. __all__ = [
  17. "EfficientNet",
  18. "EfficientNet_B0_Weights",
  19. "EfficientNet_B1_Weights",
  20. "EfficientNet_B2_Weights",
  21. "EfficientNet_B3_Weights",
  22. "EfficientNet_B4_Weights",
  23. "EfficientNet_B5_Weights",
  24. "EfficientNet_B6_Weights",
  25. "EfficientNet_B7_Weights",
  26. "EfficientNet_V2_S_Weights",
  27. "EfficientNet_V2_M_Weights",
  28. "EfficientNet_V2_L_Weights",
  29. "efficientnet_b0",
  30. "efficientnet_b1",
  31. "efficientnet_b2",
  32. "efficientnet_b3",
  33. "efficientnet_b4",
  34. "efficientnet_b5",
  35. "efficientnet_b6",
  36. "efficientnet_b7",
  37. "efficientnet_v2_s",
  38. "efficientnet_v2_m",
  39. "efficientnet_v2_l",
  40. ]
  41. @dataclass
  42. class _MBConvConfig:
  43. expand_ratio: float
  44. kernel: int
  45. stride: int
  46. input_channels: int
  47. out_channels: int
  48. num_layers: int
  49. block: Callable[..., nn.Module]
  50. @staticmethod
  51. def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
  52. return _make_divisible(channels * width_mult, 8, min_value)
  53. class MBConvConfig(_MBConvConfig):
  54. # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
  55. def __init__(
  56. self,
  57. expand_ratio: float,
  58. kernel: int,
  59. stride: int,
  60. input_channels: int,
  61. out_channels: int,
  62. num_layers: int,
  63. width_mult: float = 1.0,
  64. depth_mult: float = 1.0,
  65. block: Optional[Callable[..., nn.Module]] = None,
  66. ) -> None:
  67. input_channels = self.adjust_channels(input_channels, width_mult)
  68. out_channels = self.adjust_channels(out_channels, width_mult)
  69. num_layers = self.adjust_depth(num_layers, depth_mult)
  70. if block is None:
  71. block = MBConv
  72. super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
  73. @staticmethod
  74. def adjust_depth(num_layers: int, depth_mult: float):
  75. return int(math.ceil(num_layers * depth_mult))
  76. class FusedMBConvConfig(_MBConvConfig):
  77. # Stores information listed at Table 4 of the EfficientNetV2 paper
  78. def __init__(
  79. self,
  80. expand_ratio: float,
  81. kernel: int,
  82. stride: int,
  83. input_channels: int,
  84. out_channels: int,
  85. num_layers: int,
  86. block: Optional[Callable[..., nn.Module]] = None,
  87. ) -> None:
  88. if block is None:
  89. block = FusedMBConv
  90. super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
  91. class MBConv(nn.Module):
  92. def __init__(
  93. self,
  94. cnf: MBConvConfig,
  95. stochastic_depth_prob: float,
  96. norm_layer: Callable[..., nn.Module],
  97. se_layer: Callable[..., nn.Module] = SqueezeExcitation,
  98. ) -> None:
  99. super().__init__()
  100. if not (1 <= cnf.stride <= 2):
  101. raise ValueError("illegal stride value")
  102. self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
  103. layers: list[nn.Module] = []
  104. activation_layer = nn.SiLU
  105. # expand
  106. expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
  107. if expanded_channels != cnf.input_channels:
  108. layers.append(
  109. Conv2dNormActivation(
  110. cnf.input_channels,
  111. expanded_channels,
  112. kernel_size=1,
  113. norm_layer=norm_layer,
  114. activation_layer=activation_layer,
  115. )
  116. )
  117. # depthwise
  118. layers.append(
  119. Conv2dNormActivation(
  120. expanded_channels,
  121. expanded_channels,
  122. kernel_size=cnf.kernel,
  123. stride=cnf.stride,
  124. groups=expanded_channels,
  125. norm_layer=norm_layer,
  126. activation_layer=activation_layer,
  127. )
  128. )
  129. # squeeze and excitation
  130. squeeze_channels = max(1, cnf.input_channels // 4)
  131. layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
  132. # project
  133. layers.append(
  134. Conv2dNormActivation(
  135. expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
  136. )
  137. )
  138. self.block = nn.Sequential(*layers)
  139. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  140. self.out_channels = cnf.out_channels
  141. def forward(self, input: Tensor) -> Tensor:
  142. result = self.block(input)
  143. if self.use_res_connect:
  144. result = self.stochastic_depth(result)
  145. result += input
  146. return result
  147. class FusedMBConv(nn.Module):
  148. def __init__(
  149. self,
  150. cnf: FusedMBConvConfig,
  151. stochastic_depth_prob: float,
  152. norm_layer: Callable[..., nn.Module],
  153. ) -> None:
  154. super().__init__()
  155. if not (1 <= cnf.stride <= 2):
  156. raise ValueError("illegal stride value")
  157. self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
  158. layers: list[nn.Module] = []
  159. activation_layer = nn.SiLU
  160. expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
  161. if expanded_channels != cnf.input_channels:
  162. # fused expand
  163. layers.append(
  164. Conv2dNormActivation(
  165. cnf.input_channels,
  166. expanded_channels,
  167. kernel_size=cnf.kernel,
  168. stride=cnf.stride,
  169. norm_layer=norm_layer,
  170. activation_layer=activation_layer,
  171. )
  172. )
  173. # project
  174. layers.append(
  175. Conv2dNormActivation(
  176. expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
  177. )
  178. )
  179. else:
  180. layers.append(
  181. Conv2dNormActivation(
  182. cnf.input_channels,
  183. cnf.out_channels,
  184. kernel_size=cnf.kernel,
  185. stride=cnf.stride,
  186. norm_layer=norm_layer,
  187. activation_layer=activation_layer,
  188. )
  189. )
  190. self.block = nn.Sequential(*layers)
  191. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  192. self.out_channels = cnf.out_channels
  193. def forward(self, input: Tensor) -> Tensor:
  194. result = self.block(input)
  195. if self.use_res_connect:
  196. result = self.stochastic_depth(result)
  197. result += input
  198. return result
  199. class EfficientNet(nn.Module):
  200. def __init__(
  201. self,
  202. inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
  203. dropout: float,
  204. stochastic_depth_prob: float = 0.2,
  205. num_classes: int = 1000,
  206. norm_layer: Optional[Callable[..., nn.Module]] = None,
  207. last_channel: Optional[int] = None,
  208. ) -> None:
  209. """
  210. EfficientNet V1 and V2 main class
  211. Args:
  212. inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
  213. dropout (float): The droupout probability
  214. stochastic_depth_prob (float): The stochastic depth probability
  215. num_classes (int): Number of classes
  216. norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
  217. last_channel (int): The number of channels on the penultimate layer
  218. """
  219. super().__init__()
  220. _log_api_usage_once(self)
  221. if not inverted_residual_setting:
  222. raise ValueError("The inverted_residual_setting should not be empty")
  223. elif not (
  224. isinstance(inverted_residual_setting, Sequence)
  225. and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
  226. ):
  227. raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
  228. if norm_layer is None:
  229. norm_layer = nn.BatchNorm2d
  230. layers: list[nn.Module] = []
  231. # building first layer
  232. firstconv_output_channels = inverted_residual_setting[0].input_channels
  233. layers.append(
  234. Conv2dNormActivation(
  235. 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
  236. )
  237. )
  238. # building inverted residual blocks
  239. total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
  240. stage_block_id = 0
  241. for cnf in inverted_residual_setting:
  242. stage: list[nn.Module] = []
  243. for _ in range(cnf.num_layers):
  244. # copy to avoid modifications. shallow copy is enough
  245. block_cnf = copy.copy(cnf)
  246. # overwrite info if not the first conv in the stage
  247. if stage:
  248. block_cnf.input_channels = block_cnf.out_channels
  249. block_cnf.stride = 1
  250. # adjust stochastic depth probability based on the depth of the stage block
  251. sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks
  252. stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
  253. stage_block_id += 1
  254. layers.append(nn.Sequential(*stage))
  255. # building last several layers
  256. lastconv_input_channels = inverted_residual_setting[-1].out_channels
  257. lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
  258. layers.append(
  259. Conv2dNormActivation(
  260. lastconv_input_channels,
  261. lastconv_output_channels,
  262. kernel_size=1,
  263. norm_layer=norm_layer,
  264. activation_layer=nn.SiLU,
  265. )
  266. )
  267. self.features = nn.Sequential(*layers)
  268. self.avgpool = nn.AdaptiveAvgPool2d(1)
  269. self.classifier = nn.Sequential(
  270. nn.Dropout(p=dropout, inplace=True),
  271. nn.Linear(lastconv_output_channels, num_classes),
  272. )
  273. for m in self.modules():
  274. if isinstance(m, nn.Conv2d):
  275. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  276. if m.bias is not None:
  277. nn.init.zeros_(m.bias)
  278. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  279. nn.init.ones_(m.weight)
  280. nn.init.zeros_(m.bias)
  281. elif isinstance(m, nn.Linear):
  282. init_range = 1.0 / math.sqrt(m.out_features)
  283. nn.init.uniform_(m.weight, -init_range, init_range)
  284. nn.init.zeros_(m.bias)
  285. def _forward_impl(self, x: Tensor) -> Tensor:
  286. x = self.features(x)
  287. x = self.avgpool(x)
  288. x = torch.flatten(x, 1)
  289. x = self.classifier(x)
  290. return x
  291. def forward(self, x: Tensor) -> Tensor:
  292. return self._forward_impl(x)
  293. def _efficientnet(
  294. inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
  295. dropout: float,
  296. last_channel: Optional[int],
  297. weights: Optional[WeightsEnum],
  298. progress: bool,
  299. **kwargs: Any,
  300. ) -> EfficientNet:
  301. if weights is not None:
  302. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  303. model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
  304. if weights is not None:
  305. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  306. return model
  307. def _efficientnet_conf(
  308. arch: str,
  309. **kwargs: Any,
  310. ) -> tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
  311. inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
  312. if arch.startswith("efficientnet_b"):
  313. bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
  314. inverted_residual_setting = [
  315. bneck_conf(1, 3, 1, 32, 16, 1),
  316. bneck_conf(6, 3, 2, 16, 24, 2),
  317. bneck_conf(6, 5, 2, 24, 40, 2),
  318. bneck_conf(6, 3, 2, 40, 80, 3),
  319. bneck_conf(6, 5, 1, 80, 112, 3),
  320. bneck_conf(6, 5, 2, 112, 192, 4),
  321. bneck_conf(6, 3, 1, 192, 320, 1),
  322. ]
  323. last_channel = None
  324. elif arch.startswith("efficientnet_v2_s"):
  325. inverted_residual_setting = [
  326. FusedMBConvConfig(1, 3, 1, 24, 24, 2),
  327. FusedMBConvConfig(4, 3, 2, 24, 48, 4),
  328. FusedMBConvConfig(4, 3, 2, 48, 64, 4),
  329. MBConvConfig(4, 3, 2, 64, 128, 6),
  330. MBConvConfig(6, 3, 1, 128, 160, 9),
  331. MBConvConfig(6, 3, 2, 160, 256, 15),
  332. ]
  333. last_channel = 1280
  334. elif arch.startswith("efficientnet_v2_m"):
  335. inverted_residual_setting = [
  336. FusedMBConvConfig(1, 3, 1, 24, 24, 3),
  337. FusedMBConvConfig(4, 3, 2, 24, 48, 5),
  338. FusedMBConvConfig(4, 3, 2, 48, 80, 5),
  339. MBConvConfig(4, 3, 2, 80, 160, 7),
  340. MBConvConfig(6, 3, 1, 160, 176, 14),
  341. MBConvConfig(6, 3, 2, 176, 304, 18),
  342. MBConvConfig(6, 3, 1, 304, 512, 5),
  343. ]
  344. last_channel = 1280
  345. elif arch.startswith("efficientnet_v2_l"):
  346. inverted_residual_setting = [
  347. FusedMBConvConfig(1, 3, 1, 32, 32, 4),
  348. FusedMBConvConfig(4, 3, 2, 32, 64, 7),
  349. FusedMBConvConfig(4, 3, 2, 64, 96, 7),
  350. MBConvConfig(4, 3, 2, 96, 192, 10),
  351. MBConvConfig(6, 3, 1, 192, 224, 19),
  352. MBConvConfig(6, 3, 2, 224, 384, 25),
  353. MBConvConfig(6, 3, 1, 384, 640, 7),
  354. ]
  355. last_channel = 1280
  356. else:
  357. raise ValueError(f"Unsupported model type {arch}")
  358. return inverted_residual_setting, last_channel
  359. _COMMON_META: dict[str, Any] = {
  360. "categories": _IMAGENET_CATEGORIES,
  361. }
  362. _COMMON_META_V1 = {
  363. **_COMMON_META,
  364. "min_size": (1, 1),
  365. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v1",
  366. }
  367. _COMMON_META_V2 = {
  368. **_COMMON_META,
  369. "min_size": (33, 33),
  370. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v2",
  371. }
  372. class EfficientNet_B0_Weights(WeightsEnum):
  373. IMAGENET1K_V1 = Weights(
  374. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  375. url="https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth",
  376. transforms=partial(
  377. ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
  378. ),
  379. meta={
  380. **_COMMON_META_V1,
  381. "num_params": 5288548,
  382. "_metrics": {
  383. "ImageNet-1K": {
  384. "acc@1": 77.692,
  385. "acc@5": 93.532,
  386. }
  387. },
  388. "_ops": 0.386,
  389. "_file_size": 20.451,
  390. "_docs": """These weights are ported from the original paper.""",
  391. },
  392. )
  393. DEFAULT = IMAGENET1K_V1
  394. class EfficientNet_B1_Weights(WeightsEnum):
  395. IMAGENET1K_V1 = Weights(
  396. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  397. url="https://download.pytorch.org/models/efficientnet_b1_rwightman-bac287d4.pth",
  398. transforms=partial(
  399. ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
  400. ),
  401. meta={
  402. **_COMMON_META_V1,
  403. "num_params": 7794184,
  404. "_metrics": {
  405. "ImageNet-1K": {
  406. "acc@1": 78.642,
  407. "acc@5": 94.186,
  408. }
  409. },
  410. "_ops": 0.687,
  411. "_file_size": 30.134,
  412. "_docs": """These weights are ported from the original paper.""",
  413. },
  414. )
  415. IMAGENET1K_V2 = Weights(
  416. url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
  417. transforms=partial(
  418. ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
  419. ),
  420. meta={
  421. **_COMMON_META_V1,
  422. "num_params": 7794184,
  423. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
  424. "_metrics": {
  425. "ImageNet-1K": {
  426. "acc@1": 79.838,
  427. "acc@5": 94.934,
  428. }
  429. },
  430. "_ops": 0.687,
  431. "_file_size": 30.136,
  432. "_docs": """
  433. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  434. `new training recipe
  435. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  436. """,
  437. },
  438. )
  439. DEFAULT = IMAGENET1K_V2
  440. class EfficientNet_B2_Weights(WeightsEnum):
  441. IMAGENET1K_V1 = Weights(
  442. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  443. url="https://download.pytorch.org/models/efficientnet_b2_rwightman-c35c1473.pth",
  444. transforms=partial(
  445. ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
  446. ),
  447. meta={
  448. **_COMMON_META_V1,
  449. "num_params": 9109994,
  450. "_metrics": {
  451. "ImageNet-1K": {
  452. "acc@1": 80.608,
  453. "acc@5": 95.310,
  454. }
  455. },
  456. "_ops": 1.088,
  457. "_file_size": 35.174,
  458. "_docs": """These weights are ported from the original paper.""",
  459. },
  460. )
  461. DEFAULT = IMAGENET1K_V1
  462. class EfficientNet_B3_Weights(WeightsEnum):
  463. IMAGENET1K_V1 = Weights(
  464. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  465. url="https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth",
  466. transforms=partial(
  467. ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
  468. ),
  469. meta={
  470. **_COMMON_META_V1,
  471. "num_params": 12233232,
  472. "_metrics": {
  473. "ImageNet-1K": {
  474. "acc@1": 82.008,
  475. "acc@5": 96.054,
  476. }
  477. },
  478. "_ops": 1.827,
  479. "_file_size": 47.184,
  480. "_docs": """These weights are ported from the original paper.""",
  481. },
  482. )
  483. DEFAULT = IMAGENET1K_V1
  484. class EfficientNet_B4_Weights(WeightsEnum):
  485. IMAGENET1K_V1 = Weights(
  486. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  487. url="https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth",
  488. transforms=partial(
  489. ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
  490. ),
  491. meta={
  492. **_COMMON_META_V1,
  493. "num_params": 19341616,
  494. "_metrics": {
  495. "ImageNet-1K": {
  496. "acc@1": 83.384,
  497. "acc@5": 96.594,
  498. }
  499. },
  500. "_ops": 4.394,
  501. "_file_size": 74.489,
  502. "_docs": """These weights are ported from the original paper.""",
  503. },
  504. )
  505. DEFAULT = IMAGENET1K_V1
  506. class EfficientNet_B5_Weights(WeightsEnum):
  507. IMAGENET1K_V1 = Weights(
  508. # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
  509. url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-1a07897c.pth",
  510. transforms=partial(
  511. ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
  512. ),
  513. meta={
  514. **_COMMON_META_V1,
  515. "num_params": 30389784,
  516. "_metrics": {
  517. "ImageNet-1K": {
  518. "acc@1": 83.444,
  519. "acc@5": 96.628,
  520. }
  521. },
  522. "_ops": 10.266,
  523. "_file_size": 116.864,
  524. "_docs": """These weights are ported from the original paper.""",
  525. },
  526. )
  527. DEFAULT = IMAGENET1K_V1
  528. class EfficientNet_B6_Weights(WeightsEnum):
  529. IMAGENET1K_V1 = Weights(
  530. # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
  531. url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-24a108a5.pth",
  532. transforms=partial(
  533. ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
  534. ),
  535. meta={
  536. **_COMMON_META_V1,
  537. "num_params": 43040704,
  538. "_metrics": {
  539. "ImageNet-1K": {
  540. "acc@1": 84.008,
  541. "acc@5": 96.916,
  542. }
  543. },
  544. "_ops": 19.068,
  545. "_file_size": 165.362,
  546. "_docs": """These weights are ported from the original paper.""",
  547. },
  548. )
  549. DEFAULT = IMAGENET1K_V1
  550. class EfficientNet_B7_Weights(WeightsEnum):
  551. IMAGENET1K_V1 = Weights(
  552. # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
  553. url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-c5b4e57e.pth",
  554. transforms=partial(
  555. ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
  556. ),
  557. meta={
  558. **_COMMON_META_V1,
  559. "num_params": 66347960,
  560. "_metrics": {
  561. "ImageNet-1K": {
  562. "acc@1": 84.122,
  563. "acc@5": 96.908,
  564. }
  565. },
  566. "_ops": 37.746,
  567. "_file_size": 254.675,
  568. "_docs": """These weights are ported from the original paper.""",
  569. },
  570. )
  571. DEFAULT = IMAGENET1K_V1
  572. class EfficientNet_V2_S_Weights(WeightsEnum):
  573. IMAGENET1K_V1 = Weights(
  574. url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
  575. transforms=partial(
  576. ImageClassification,
  577. crop_size=384,
  578. resize_size=384,
  579. interpolation=InterpolationMode.BILINEAR,
  580. ),
  581. meta={
  582. **_COMMON_META_V2,
  583. "num_params": 21458488,
  584. "_metrics": {
  585. "ImageNet-1K": {
  586. "acc@1": 84.228,
  587. "acc@5": 96.878,
  588. }
  589. },
  590. "_ops": 8.366,
  591. "_file_size": 82.704,
  592. "_docs": """
  593. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  594. `new training recipe
  595. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  596. """,
  597. },
  598. )
  599. DEFAULT = IMAGENET1K_V1
  600. class EfficientNet_V2_M_Weights(WeightsEnum):
  601. IMAGENET1K_V1 = Weights(
  602. url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
  603. transforms=partial(
  604. ImageClassification,
  605. crop_size=480,
  606. resize_size=480,
  607. interpolation=InterpolationMode.BILINEAR,
  608. ),
  609. meta={
  610. **_COMMON_META_V2,
  611. "num_params": 54139356,
  612. "_metrics": {
  613. "ImageNet-1K": {
  614. "acc@1": 85.112,
  615. "acc@5": 97.156,
  616. }
  617. },
  618. "_ops": 24.582,
  619. "_file_size": 208.01,
  620. "_docs": """
  621. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  622. `new training recipe
  623. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  624. """,
  625. },
  626. )
  627. DEFAULT = IMAGENET1K_V1
  628. class EfficientNet_V2_L_Weights(WeightsEnum):
  629. # Weights ported from https://github.com/google/automl/tree/master/efficientnetv2
  630. IMAGENET1K_V1 = Weights(
  631. url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
  632. transforms=partial(
  633. ImageClassification,
  634. crop_size=480,
  635. resize_size=480,
  636. interpolation=InterpolationMode.BICUBIC,
  637. mean=(0.5, 0.5, 0.5),
  638. std=(0.5, 0.5, 0.5),
  639. ),
  640. meta={
  641. **_COMMON_META_V2,
  642. "num_params": 118515272,
  643. "_metrics": {
  644. "ImageNet-1K": {
  645. "acc@1": 85.808,
  646. "acc@5": 97.788,
  647. }
  648. },
  649. "_ops": 56.08,
  650. "_file_size": 454.573,
  651. "_docs": """These weights are ported from the original paper.""",
  652. },
  653. )
  654. DEFAULT = IMAGENET1K_V1
  655. @register_model()
  656. @handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
  657. def efficientnet_b0(
  658. *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
  659. ) -> EfficientNet:
  660. """EfficientNet B0 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  661. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  662. Args:
  663. weights (:class:`~torchvision.models.EfficientNet_B0_Weights`, optional): The
  664. pretrained weights to use. See
  665. :class:`~torchvision.models.EfficientNet_B0_Weights` below for
  666. more details, and possible values. By default, no pre-trained
  667. weights are used.
  668. progress (bool, optional): If True, displays a progress bar of the
  669. download to stderr. Default is True.
  670. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  671. base class. Please refer to the `source code
  672. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  673. for more details about this class.
  674. .. autoclass:: torchvision.models.EfficientNet_B0_Weights
  675. :members:
  676. """
  677. weights = EfficientNet_B0_Weights.verify(weights)
  678. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
  679. return _efficientnet(
  680. inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
  681. )
  682. @register_model()
  683. @handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
  684. def efficientnet_b1(
  685. *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
  686. ) -> EfficientNet:
  687. """EfficientNet B1 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  688. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  689. Args:
  690. weights (:class:`~torchvision.models.EfficientNet_B1_Weights`, optional): The
  691. pretrained weights to use. See
  692. :class:`~torchvision.models.EfficientNet_B1_Weights` below for
  693. more details, and possible values. By default, no pre-trained
  694. weights are used.
  695. progress (bool, optional): If True, displays a progress bar of the
  696. download to stderr. Default is True.
  697. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  698. base class. Please refer to the `source code
  699. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  700. for more details about this class.
  701. .. autoclass:: torchvision.models.EfficientNet_B1_Weights
  702. :members:
  703. """
  704. weights = EfficientNet_B1_Weights.verify(weights)
  705. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
  706. return _efficientnet(
  707. inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
  708. )
  709. @register_model()
  710. @handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
  711. def efficientnet_b2(
  712. *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
  713. ) -> EfficientNet:
  714. """EfficientNet B2 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  715. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  716. Args:
  717. weights (:class:`~torchvision.models.EfficientNet_B2_Weights`, optional): The
  718. pretrained weights to use. See
  719. :class:`~torchvision.models.EfficientNet_B2_Weights` below for
  720. more details, and possible values. By default, no pre-trained
  721. weights are used.
  722. progress (bool, optional): If True, displays a progress bar of the
  723. download to stderr. Default is True.
  724. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  725. base class. Please refer to the `source code
  726. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  727. for more details about this class.
  728. .. autoclass:: torchvision.models.EfficientNet_B2_Weights
  729. :members:
  730. """
  731. weights = EfficientNet_B2_Weights.verify(weights)
  732. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
  733. return _efficientnet(
  734. inverted_residual_setting, kwargs.pop("dropout", 0.3), last_channel, weights, progress, **kwargs
  735. )
  736. @register_model()
  737. @handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
  738. def efficientnet_b3(
  739. *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
  740. ) -> EfficientNet:
  741. """EfficientNet B3 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  742. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  743. Args:
  744. weights (:class:`~torchvision.models.EfficientNet_B3_Weights`, optional): The
  745. pretrained weights to use. See
  746. :class:`~torchvision.models.EfficientNet_B3_Weights` below for
  747. more details, and possible values. By default, no pre-trained
  748. weights are used.
  749. progress (bool, optional): If True, displays a progress bar of the
  750. download to stderr. Default is True.
  751. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  752. base class. Please refer to the `source code
  753. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  754. for more details about this class.
  755. .. autoclass:: torchvision.models.EfficientNet_B3_Weights
  756. :members:
  757. """
  758. weights = EfficientNet_B3_Weights.verify(weights)
  759. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
  760. return _efficientnet(
  761. inverted_residual_setting,
  762. kwargs.pop("dropout", 0.3),
  763. last_channel,
  764. weights,
  765. progress,
  766. **kwargs,
  767. )
  768. @register_model()
  769. @handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
  770. def efficientnet_b4(
  771. *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
  772. ) -> EfficientNet:
  773. """EfficientNet B4 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  774. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  775. Args:
  776. weights (:class:`~torchvision.models.EfficientNet_B4_Weights`, optional): The
  777. pretrained weights to use. See
  778. :class:`~torchvision.models.EfficientNet_B4_Weights` below for
  779. more details, and possible values. By default, no pre-trained
  780. weights are used.
  781. progress (bool, optional): If True, displays a progress bar of the
  782. download to stderr. Default is True.
  783. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  784. base class. Please refer to the `source code
  785. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  786. for more details about this class.
  787. .. autoclass:: torchvision.models.EfficientNet_B4_Weights
  788. :members:
  789. """
  790. weights = EfficientNet_B4_Weights.verify(weights)
  791. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
  792. return _efficientnet(
  793. inverted_residual_setting,
  794. kwargs.pop("dropout", 0.4),
  795. last_channel,
  796. weights,
  797. progress,
  798. **kwargs,
  799. )
  800. @register_model()
  801. @handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
  802. def efficientnet_b5(
  803. *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
  804. ) -> EfficientNet:
  805. """EfficientNet B5 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  806. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  807. Args:
  808. weights (:class:`~torchvision.models.EfficientNet_B5_Weights`, optional): The
  809. pretrained weights to use. See
  810. :class:`~torchvision.models.EfficientNet_B5_Weights` below for
  811. more details, and possible values. By default, no pre-trained
  812. weights are used.
  813. progress (bool, optional): If True, displays a progress bar of the
  814. download to stderr. Default is True.
  815. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  816. base class. Please refer to the `source code
  817. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  818. for more details about this class.
  819. .. autoclass:: torchvision.models.EfficientNet_B5_Weights
  820. :members:
  821. """
  822. weights = EfficientNet_B5_Weights.verify(weights)
  823. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
  824. return _efficientnet(
  825. inverted_residual_setting,
  826. kwargs.pop("dropout", 0.4),
  827. last_channel,
  828. weights,
  829. progress,
  830. norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
  831. **kwargs,
  832. )
  833. @register_model()
  834. @handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
  835. def efficientnet_b6(
  836. *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
  837. ) -> EfficientNet:
  838. """EfficientNet B6 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  839. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  840. Args:
  841. weights (:class:`~torchvision.models.EfficientNet_B6_Weights`, optional): The
  842. pretrained weights to use. See
  843. :class:`~torchvision.models.EfficientNet_B6_Weights` below for
  844. more details, and possible values. By default, no pre-trained
  845. weights are used.
  846. progress (bool, optional): If True, displays a progress bar of the
  847. download to stderr. Default is True.
  848. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  849. base class. Please refer to the `source code
  850. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  851. for more details about this class.
  852. .. autoclass:: torchvision.models.EfficientNet_B6_Weights
  853. :members:
  854. """
  855. weights = EfficientNet_B6_Weights.verify(weights)
  856. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
  857. return _efficientnet(
  858. inverted_residual_setting,
  859. kwargs.pop("dropout", 0.5),
  860. last_channel,
  861. weights,
  862. progress,
  863. norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
  864. **kwargs,
  865. )
  866. @register_model()
  867. @handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
  868. def efficientnet_b7(
  869. *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
  870. ) -> EfficientNet:
  871. """EfficientNet B7 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  872. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  873. Args:
  874. weights (:class:`~torchvision.models.EfficientNet_B7_Weights`, optional): The
  875. pretrained weights to use. See
  876. :class:`~torchvision.models.EfficientNet_B7_Weights` below for
  877. more details, and possible values. By default, no pre-trained
  878. weights are used.
  879. progress (bool, optional): If True, displays a progress bar of the
  880. download to stderr. Default is True.
  881. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  882. base class. Please refer to the `source code
  883. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  884. for more details about this class.
  885. .. autoclass:: torchvision.models.EfficientNet_B7_Weights
  886. :members:
  887. """
  888. weights = EfficientNet_B7_Weights.verify(weights)
  889. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
  890. return _efficientnet(
  891. inverted_residual_setting,
  892. kwargs.pop("dropout", 0.5),
  893. last_channel,
  894. weights,
  895. progress,
  896. norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
  897. **kwargs,
  898. )
  899. @register_model()
  900. @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
  901. def efficientnet_v2_s(
  902. *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
  903. ) -> EfficientNet:
  904. """
  905. Constructs an EfficientNetV2-S architecture from
  906. `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
  907. Args:
  908. weights (:class:`~torchvision.models.EfficientNet_V2_S_Weights`, optional): The
  909. pretrained weights to use. See
  910. :class:`~torchvision.models.EfficientNet_V2_S_Weights` below for
  911. more details, and possible values. By default, no pre-trained
  912. weights are used.
  913. progress (bool, optional): If True, displays a progress bar of the
  914. download to stderr. Default is True.
  915. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  916. base class. Please refer to the `source code
  917. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  918. for more details about this class.
  919. .. autoclass:: torchvision.models.EfficientNet_V2_S_Weights
  920. :members:
  921. """
  922. weights = EfficientNet_V2_S_Weights.verify(weights)
  923. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
  924. return _efficientnet(
  925. inverted_residual_setting,
  926. kwargs.pop("dropout", 0.2),
  927. last_channel,
  928. weights,
  929. progress,
  930. norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
  931. **kwargs,
  932. )
  933. @register_model()
  934. @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
  935. def efficientnet_v2_m(
  936. *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
  937. ) -> EfficientNet:
  938. """
  939. Constructs an EfficientNetV2-M architecture from
  940. `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
  941. Args:
  942. weights (:class:`~torchvision.models.EfficientNet_V2_M_Weights`, optional): The
  943. pretrained weights to use. See
  944. :class:`~torchvision.models.EfficientNet_V2_M_Weights` below for
  945. more details, and possible values. By default, no pre-trained
  946. weights are used.
  947. progress (bool, optional): If True, displays a progress bar of the
  948. download to stderr. Default is True.
  949. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  950. base class. Please refer to the `source code
  951. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  952. for more details about this class.
  953. .. autoclass:: torchvision.models.EfficientNet_V2_M_Weights
  954. :members:
  955. """
  956. weights = EfficientNet_V2_M_Weights.verify(weights)
  957. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
  958. return _efficientnet(
  959. inverted_residual_setting,
  960. kwargs.pop("dropout", 0.3),
  961. last_channel,
  962. weights,
  963. progress,
  964. norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
  965. **kwargs,
  966. )
  967. @register_model()
  968. @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
  969. def efficientnet_v2_l(
  970. *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
  971. ) -> EfficientNet:
  972. """
  973. Constructs an EfficientNetV2-L architecture from
  974. `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
  975. Args:
  976. weights (:class:`~torchvision.models.EfficientNet_V2_L_Weights`, optional): The
  977. pretrained weights to use. See
  978. :class:`~torchvision.models.EfficientNet_V2_L_Weights` below for
  979. more details, and possible values. By default, no pre-trained
  980. weights are used.
  981. progress (bool, optional): If True, displays a progress bar of the
  982. download to stderr. Default is True.
  983. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  984. base class. Please refer to the `source code
  985. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  986. for more details about this class.
  987. .. autoclass:: torchvision.models.EfficientNet_V2_L_Weights
  988. :members:
  989. """
  990. weights = EfficientNet_V2_L_Weights.verify(weights)
  991. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
  992. return _efficientnet(
  993. inverted_residual_setting,
  994. kwargs.pop("dropout", 0.4),
  995. last_channel,
  996. weights,
  997. progress,
  998. norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
  999. **kwargs,
  1000. )