modeling_efficientnet.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. # Copyright 2023 Google Research, Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch EfficientNet model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...activations import ACT2FN
  20. from ...modeling_outputs import (
  21. BaseModelOutputWithNoAttention,
  22. BaseModelOutputWithPoolingAndNoAttention,
  23. ImageClassifierOutputWithNoAttention,
  24. )
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import auto_docstring, logging
  27. from .configuration_efficientnet import EfficientNetConfig
  28. logger = logging.get_logger(__name__)
  29. def round_filters(config: EfficientNetConfig, num_channels: int):
  30. r"""
  31. Round number of filters based on depth multiplier.
  32. """
  33. divisor = config.depth_divisor
  34. num_channels *= config.width_coefficient
  35. new_dim = max(divisor, int(num_channels + divisor / 2) // divisor * divisor)
  36. # Make sure that round down does not go down by more than 10%.
  37. if new_dim < 0.9 * num_channels:
  38. new_dim += divisor
  39. return int(new_dim)
  40. def correct_pad(kernel_size: int | tuple, adjust: bool = True):
  41. r"""
  42. Utility function to get the tuple padding value for the depthwise convolution.
  43. Args:
  44. kernel_size (`int` or `tuple`):
  45. Kernel size of the convolution layers.
  46. adjust (`bool`, *optional*, defaults to `True`):
  47. Adjusts padding value to apply to right and bottom sides of the input.
  48. """
  49. if isinstance(kernel_size, int):
  50. kernel_size = (kernel_size, kernel_size)
  51. correct = (kernel_size[0] // 2, kernel_size[1] // 2)
  52. if adjust:
  53. return (correct[1] - 1, correct[1], correct[0] - 1, correct[0])
  54. else:
  55. return (correct[1], correct[1], correct[0], correct[0])
  56. class EfficientNetEmbeddings(nn.Module):
  57. r"""
  58. A module that corresponds to the stem module of the original work.
  59. """
  60. def __init__(self, config: EfficientNetConfig):
  61. super().__init__()
  62. self.out_dim = round_filters(config, 32)
  63. self.padding = nn.ZeroPad2d(padding=(0, 1, 0, 1))
  64. self.convolution = nn.Conv2d(
  65. config.num_channels, self.out_dim, kernel_size=3, stride=2, padding="valid", bias=False
  66. )
  67. self.batchnorm = nn.BatchNorm2d(self.out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum)
  68. self.activation = ACT2FN[config.hidden_act]
  69. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  70. features = self.padding(pixel_values)
  71. features = self.convolution(features)
  72. features = self.batchnorm(features)
  73. features = self.activation(features)
  74. return features
  75. class EfficientNetDepthwiseConv2d(nn.Conv2d):
  76. def __init__(
  77. self,
  78. in_channels,
  79. depth_multiplier=1,
  80. kernel_size=3,
  81. stride=1,
  82. padding=0,
  83. dilation=1,
  84. bias=True,
  85. padding_mode="zeros",
  86. ):
  87. out_channels = in_channels * depth_multiplier
  88. super().__init__(
  89. in_channels=in_channels,
  90. out_channels=out_channels,
  91. kernel_size=kernel_size,
  92. stride=stride,
  93. padding=padding,
  94. dilation=dilation,
  95. groups=in_channels,
  96. bias=bias,
  97. padding_mode=padding_mode,
  98. )
  99. class EfficientNetExpansionLayer(nn.Module):
  100. r"""
  101. This corresponds to the expansion phase of each block in the original implementation.
  102. """
  103. def __init__(self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int):
  104. super().__init__()
  105. self.expand_conv = nn.Conv2d(
  106. in_channels=in_dim,
  107. out_channels=out_dim,
  108. kernel_size=1,
  109. padding="same",
  110. bias=False,
  111. )
  112. self.expand_bn = nn.BatchNorm2d(num_features=out_dim, eps=config.batch_norm_eps)
  113. self.expand_act = ACT2FN[config.hidden_act]
  114. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  115. # Expand phase
  116. hidden_states = self.expand_conv(hidden_states)
  117. hidden_states = self.expand_bn(hidden_states)
  118. hidden_states = self.expand_act(hidden_states)
  119. return hidden_states
  120. class EfficientNetDepthwiseLayer(nn.Module):
  121. r"""
  122. This corresponds to the depthwise convolution phase of each block in the original implementation.
  123. """
  124. def __init__(
  125. self,
  126. config: EfficientNetConfig,
  127. in_dim: int,
  128. stride: int,
  129. kernel_size: int,
  130. adjust_padding: bool,
  131. ):
  132. super().__init__()
  133. self.stride = stride
  134. conv_pad = "valid" if self.stride == 2 else "same"
  135. padding = correct_pad(kernel_size, adjust=adjust_padding)
  136. self.depthwise_conv_pad = nn.ZeroPad2d(padding=padding)
  137. self.depthwise_conv = EfficientNetDepthwiseConv2d(
  138. in_dim, kernel_size=kernel_size, stride=stride, padding=conv_pad, bias=False
  139. )
  140. self.depthwise_norm = nn.BatchNorm2d(
  141. num_features=in_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  142. )
  143. self.depthwise_act = ACT2FN[config.hidden_act]
  144. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  145. # Depthwise convolution
  146. if self.stride == 2:
  147. hidden_states = self.depthwise_conv_pad(hidden_states)
  148. hidden_states = self.depthwise_conv(hidden_states)
  149. hidden_states = self.depthwise_norm(hidden_states)
  150. hidden_states = self.depthwise_act(hidden_states)
  151. return hidden_states
  152. class EfficientNetSqueezeExciteLayer(nn.Module):
  153. r"""
  154. This corresponds to the Squeeze and Excitement phase of each block in the original implementation.
  155. """
  156. def __init__(self, config: EfficientNetConfig, in_dim: int, expand_dim: int, expand: bool = False):
  157. super().__init__()
  158. self.dim = expand_dim if expand else in_dim
  159. self.dim_se = max(1, int(in_dim * config.squeeze_expansion_ratio))
  160. self.squeeze = nn.AdaptiveAvgPool2d(output_size=1)
  161. self.reduce = nn.Conv2d(
  162. in_channels=self.dim,
  163. out_channels=self.dim_se,
  164. kernel_size=1,
  165. padding="same",
  166. )
  167. self.expand = nn.Conv2d(
  168. in_channels=self.dim_se,
  169. out_channels=self.dim,
  170. kernel_size=1,
  171. padding="same",
  172. )
  173. self.act_reduce = ACT2FN[config.hidden_act]
  174. self.act_expand = nn.Sigmoid()
  175. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  176. inputs = hidden_states
  177. hidden_states = self.squeeze(hidden_states)
  178. hidden_states = self.reduce(hidden_states)
  179. hidden_states = self.act_reduce(hidden_states)
  180. hidden_states = self.expand(hidden_states)
  181. hidden_states = self.act_expand(hidden_states)
  182. hidden_states = torch.mul(inputs, hidden_states)
  183. return hidden_states
  184. class EfficientNetFinalBlockLayer(nn.Module):
  185. r"""
  186. This corresponds to the final phase of each block in the original implementation.
  187. """
  188. def __init__(
  189. self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int, drop_rate: float, id_skip: bool
  190. ):
  191. super().__init__()
  192. self.apply_dropout = stride == 1 and not id_skip
  193. self.project_conv = nn.Conv2d(
  194. in_channels=in_dim,
  195. out_channels=out_dim,
  196. kernel_size=1,
  197. padding="same",
  198. bias=False,
  199. )
  200. self.project_bn = nn.BatchNorm2d(
  201. num_features=out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  202. )
  203. self.dropout = nn.Dropout(p=drop_rate)
  204. def forward(self, embeddings: torch.FloatTensor, hidden_states: torch.FloatTensor) -> torch.Tensor:
  205. hidden_states = self.project_conv(hidden_states)
  206. hidden_states = self.project_bn(hidden_states)
  207. if self.apply_dropout:
  208. hidden_states = self.dropout(hidden_states)
  209. hidden_states = hidden_states + embeddings
  210. return hidden_states
  211. class EfficientNetBlock(nn.Module):
  212. r"""
  213. This corresponds to the expansion and depthwise convolution phase of each block in the original implementation.
  214. Args:
  215. config ([`EfficientNetConfig`]):
  216. Model configuration class.
  217. in_dim (`int`):
  218. Number of input channels.
  219. out_dim (`int`):
  220. Number of output channels.
  221. stride (`int`):
  222. Stride size to be used in convolution layers.
  223. expand_ratio (`int`):
  224. Expand ratio to set the output dimensions for the expansion and squeeze-excite layers.
  225. kernel_size (`int`):
  226. Kernel size for the depthwise convolution layer.
  227. drop_rate (`float`):
  228. Dropout rate to be used in the final phase of each block.
  229. id_skip (`bool`):
  230. Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase
  231. of each block. Set to `True` for the first block of each stage.
  232. adjust_padding (`bool`):
  233. Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution
  234. operation, set to `True` for inputs with odd input sizes.
  235. """
  236. def __init__(
  237. self,
  238. config: EfficientNetConfig,
  239. in_dim: int,
  240. out_dim: int,
  241. stride: int,
  242. expand_ratio: int,
  243. kernel_size: int,
  244. drop_rate: float,
  245. id_skip: bool,
  246. adjust_padding: bool,
  247. ):
  248. super().__init__()
  249. self.expand_ratio = expand_ratio
  250. self.expand = self.expand_ratio != 1
  251. expand_in_dim = in_dim * expand_ratio
  252. if self.expand:
  253. self.expansion = EfficientNetExpansionLayer(
  254. config=config, in_dim=in_dim, out_dim=expand_in_dim, stride=stride
  255. )
  256. self.depthwise_conv = EfficientNetDepthwiseLayer(
  257. config=config,
  258. in_dim=expand_in_dim if self.expand else in_dim,
  259. stride=stride,
  260. kernel_size=kernel_size,
  261. adjust_padding=adjust_padding,
  262. )
  263. self.squeeze_excite = EfficientNetSqueezeExciteLayer(
  264. config=config, in_dim=in_dim, expand_dim=expand_in_dim, expand=self.expand
  265. )
  266. self.projection = EfficientNetFinalBlockLayer(
  267. config=config,
  268. in_dim=expand_in_dim if self.expand else in_dim,
  269. out_dim=out_dim,
  270. stride=stride,
  271. drop_rate=drop_rate,
  272. id_skip=id_skip,
  273. )
  274. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  275. embeddings = hidden_states
  276. # Expansion and depthwise convolution phase
  277. if self.expand_ratio != 1:
  278. hidden_states = self.expansion(hidden_states)
  279. hidden_states = self.depthwise_conv(hidden_states)
  280. # Squeeze and excite phase
  281. hidden_states = self.squeeze_excite(hidden_states)
  282. hidden_states = self.projection(embeddings, hidden_states)
  283. return hidden_states
  284. class EfficientNetEncoder(nn.Module):
  285. r"""
  286. Forward propagates the embeddings through each EfficientNet block.
  287. Args:
  288. config ([`EfficientNetConfig`]):
  289. Model configuration class.
  290. """
  291. def __init__(self, config: EfficientNetConfig):
  292. super().__init__()
  293. self.config = config
  294. self.depth_coefficient = config.depth_coefficient
  295. def round_repeats(repeats):
  296. # Round number of block repeats based on depth multiplier.
  297. return int(math.ceil(self.depth_coefficient * repeats))
  298. num_base_blocks = len(config.in_channels)
  299. num_blocks = sum(round_repeats(n) for n in config.num_block_repeats)
  300. curr_block_num = 0
  301. blocks = []
  302. for i in range(num_base_blocks):
  303. in_dim = round_filters(config, config.in_channels[i])
  304. out_dim = round_filters(config, config.out_channels[i])
  305. stride = config.strides[i]
  306. kernel_size = config.kernel_sizes[i]
  307. expand_ratio = config.expand_ratios[i]
  308. for j in range(round_repeats(config.num_block_repeats[i])):
  309. id_skip = j == 0
  310. stride = 1 if j > 0 else stride
  311. in_dim = out_dim if j > 0 else in_dim
  312. adjust_padding = curr_block_num not in config.depthwise_padding
  313. drop_rate = config.drop_connect_rate * curr_block_num / num_blocks
  314. block = EfficientNetBlock(
  315. config=config,
  316. in_dim=in_dim,
  317. out_dim=out_dim,
  318. stride=stride,
  319. kernel_size=kernel_size,
  320. expand_ratio=expand_ratio,
  321. drop_rate=drop_rate,
  322. id_skip=id_skip,
  323. adjust_padding=adjust_padding,
  324. )
  325. blocks.append(block)
  326. curr_block_num += 1
  327. self.blocks = nn.ModuleList(blocks)
  328. self.top_conv = nn.Conv2d(
  329. in_channels=out_dim,
  330. out_channels=round_filters(config, 1280),
  331. kernel_size=1,
  332. padding="same",
  333. bias=False,
  334. )
  335. self.top_bn = nn.BatchNorm2d(
  336. num_features=config.hidden_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  337. )
  338. self.top_activation = ACT2FN[config.hidden_act]
  339. def forward(
  340. self,
  341. hidden_states: torch.FloatTensor,
  342. output_hidden_states: bool | None = False,
  343. return_dict: bool | None = True,
  344. ) -> BaseModelOutputWithNoAttention:
  345. all_hidden_states = (hidden_states,) if output_hidden_states else None
  346. for block in self.blocks:
  347. hidden_states = block(hidden_states)
  348. if output_hidden_states:
  349. all_hidden_states += (hidden_states,)
  350. hidden_states = self.top_conv(hidden_states)
  351. hidden_states = self.top_bn(hidden_states)
  352. hidden_states = self.top_activation(hidden_states)
  353. if not return_dict:
  354. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  355. return BaseModelOutputWithNoAttention(
  356. last_hidden_state=hidden_states,
  357. hidden_states=all_hidden_states,
  358. )
  359. @auto_docstring
  360. class EfficientNetPreTrainedModel(PreTrainedModel):
  361. config: EfficientNetConfig
  362. base_model_prefix = "efficientnet"
  363. main_input_name = "pixel_values"
  364. input_modalities = ("image",)
  365. _no_split_modules = ["EfficientNetBlock"]
  366. @torch.no_grad()
  367. def _init_weights(self, module: nn.Module):
  368. """Initialize the weights"""
  369. if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
  370. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  371. if module.bias is not None:
  372. init.zeros_(module.bias)
  373. if getattr(module, "running_mean", None) is not None:
  374. init.zeros_(module.running_mean)
  375. init.ones_(module.running_var)
  376. init.zeros_(module.num_batches_tracked)
  377. @auto_docstring
  378. class EfficientNetModel(EfficientNetPreTrainedModel):
  379. def __init__(self, config: EfficientNetConfig):
  380. super().__init__(config)
  381. self.config = config
  382. self.embeddings = EfficientNetEmbeddings(config)
  383. self.encoder = EfficientNetEncoder(config)
  384. # Final pooling layer
  385. if config.pooling_type == "mean":
  386. self.pooler = nn.AvgPool2d(config.hidden_dim, ceil_mode=True)
  387. elif config.pooling_type == "max":
  388. self.pooler = nn.MaxPool2d(config.hidden_dim, ceil_mode=True)
  389. else:
  390. raise ValueError(f"config.pooling must be one of ['mean', 'max'] got {config.pooling}")
  391. # Initialize weights and apply final processing
  392. self.post_init()
  393. @auto_docstring
  394. def forward(
  395. self,
  396. pixel_values: torch.FloatTensor | None = None,
  397. output_hidden_states: bool | None = None,
  398. return_dict: bool | None = None,
  399. **kwargs,
  400. ) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
  401. output_hidden_states = (
  402. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  403. )
  404. return_dict = return_dict if return_dict is not None else self.config.return_dict
  405. if pixel_values is None:
  406. raise ValueError("You have to specify pixel_values")
  407. embedding_output = self.embeddings(pixel_values)
  408. encoder_outputs = self.encoder(
  409. embedding_output,
  410. output_hidden_states=output_hidden_states,
  411. return_dict=return_dict,
  412. )
  413. # Apply pooling
  414. last_hidden_state = encoder_outputs[0]
  415. pooled_output = self.pooler(last_hidden_state)
  416. # Reshape (batch_size, 1280, 1 , 1) -> (batch_size, 1280)
  417. pooled_output = pooled_output.reshape(pooled_output.shape[:2])
  418. if not return_dict:
  419. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  420. return BaseModelOutputWithPoolingAndNoAttention(
  421. last_hidden_state=last_hidden_state,
  422. pooler_output=pooled_output,
  423. hidden_states=encoder_outputs.hidden_states,
  424. )
  425. @auto_docstring(
  426. custom_intro="""
  427. EfficientNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g.
  428. for ImageNet.
  429. """
  430. )
  431. class EfficientNetForImageClassification(EfficientNetPreTrainedModel):
  432. def __init__(self, config):
  433. super().__init__(config)
  434. self.num_labels = config.num_labels
  435. self.config = config
  436. self.efficientnet = EfficientNetModel(config)
  437. # Classifier head
  438. self.dropout = nn.Dropout(p=config.dropout_rate)
  439. self.classifier = nn.Linear(config.hidden_dim, self.num_labels) if self.num_labels > 0 else nn.Identity()
  440. # Initialize weights and apply final processing
  441. self.post_init()
  442. @auto_docstring
  443. def forward(
  444. self,
  445. pixel_values: torch.FloatTensor | None = None,
  446. labels: torch.LongTensor | None = None,
  447. output_hidden_states: bool | None = None,
  448. return_dict: bool | None = None,
  449. **kwargs,
  450. ) -> tuple | ImageClassifierOutputWithNoAttention:
  451. r"""
  452. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  453. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  454. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  455. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  456. """
  457. return_dict = return_dict if return_dict is not None else self.config.return_dict
  458. outputs = self.efficientnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  459. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  460. pooled_output = self.dropout(pooled_output)
  461. logits = self.classifier(pooled_output)
  462. loss = None
  463. if labels is not None:
  464. loss = self.loss_function(labels, logits, self.config)
  465. if not return_dict:
  466. output = (logits,) + outputs[2:]
  467. return ((loss,) + output) if loss is not None else output
  468. return ImageClassifierOutputWithNoAttention(
  469. loss=loss,
  470. logits=logits,
  471. hidden_states=outputs.hidden_states,
  472. )
  473. __all__ = ["EfficientNetForImageClassification", "EfficientNetModel", "EfficientNetPreTrainedModel"]