modeling_mobilevitv2.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942
  1. # Copyright 2023 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
  16. """PyTorch MobileViTV2 model."""
  17. import torch
  18. from torch import nn
  19. from torch.nn import CrossEntropyLoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithNoAttention,
  25. BaseModelOutputWithPoolingAndNoAttention,
  26. ImageClassifierOutputWithNoAttention,
  27. SemanticSegmenterOutput,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...utils import auto_docstring, logging
  31. from .configuration_mobilevitv2 import MobileViTV2Config
  32. logger = logging.get_logger(__name__)
  33. # Copied from transformers.models.mobilevit.modeling_mobilevit.make_divisible
  34. def make_divisible(value: int, divisor: int = 8, min_value: int | None = None) -> int:
  35. """
  36. Ensure that all layers have a channel count that is divisible by `divisor`.
  37. """
  38. if min_value is None:
  39. min_value = divisor
  40. new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
  41. # Make sure that round down does not go down by more than 10%.
  42. if new_value < 0.9 * value:
  43. new_value += divisor
  44. return int(new_value)
  45. def clip(value: float, min_val: float = float("-inf"), max_val: float = float("inf")) -> float:
  46. return max(min_val, min(max_val, value))
  47. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTConvLayer with MobileViT->MobileViTV2
  48. class MobileViTV2ConvLayer(nn.Module):
  49. def __init__(
  50. self,
  51. config: MobileViTV2Config,
  52. in_channels: int,
  53. out_channels: int,
  54. kernel_size: int,
  55. stride: int = 1,
  56. groups: int = 1,
  57. bias: bool = False,
  58. dilation: int = 1,
  59. use_normalization: bool = True,
  60. use_activation: bool | str = True,
  61. ) -> None:
  62. super().__init__()
  63. padding = int((kernel_size - 1) / 2) * dilation
  64. if in_channels % groups != 0:
  65. raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
  66. if out_channels % groups != 0:
  67. raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
  68. self.convolution = nn.Conv2d(
  69. in_channels=in_channels,
  70. out_channels=out_channels,
  71. kernel_size=kernel_size,
  72. stride=stride,
  73. padding=padding,
  74. dilation=dilation,
  75. groups=groups,
  76. bias=bias,
  77. padding_mode="zeros",
  78. )
  79. if use_normalization:
  80. self.normalization = nn.BatchNorm2d(
  81. num_features=out_channels,
  82. eps=1e-5,
  83. momentum=0.1,
  84. affine=True,
  85. track_running_stats=True,
  86. )
  87. else:
  88. self.normalization = None
  89. if use_activation:
  90. if isinstance(use_activation, str):
  91. self.activation = ACT2FN[use_activation]
  92. elif isinstance(config.hidden_act, str):
  93. self.activation = ACT2FN[config.hidden_act]
  94. else:
  95. self.activation = config.hidden_act
  96. else:
  97. self.activation = None
  98. def forward(self, features: torch.Tensor) -> torch.Tensor:
  99. features = self.convolution(features)
  100. if self.normalization is not None:
  101. features = self.normalization(features)
  102. if self.activation is not None:
  103. features = self.activation(features)
  104. return features
  105. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTInvertedResidual with MobileViT->MobileViTV2
  106. class MobileViTV2InvertedResidual(nn.Module):
  107. """
  108. Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381
  109. """
  110. def __init__(
  111. self, config: MobileViTV2Config, in_channels: int, out_channels: int, stride: int, dilation: int = 1
  112. ) -> None:
  113. super().__init__()
  114. expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
  115. if stride not in [1, 2]:
  116. raise ValueError(f"Invalid stride {stride}.")
  117. self.use_residual = (stride == 1) and (in_channels == out_channels)
  118. self.expand_1x1 = MobileViTV2ConvLayer(
  119. config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
  120. )
  121. self.conv_3x3 = MobileViTV2ConvLayer(
  122. config,
  123. in_channels=expanded_channels,
  124. out_channels=expanded_channels,
  125. kernel_size=3,
  126. stride=stride,
  127. groups=expanded_channels,
  128. dilation=dilation,
  129. )
  130. self.reduce_1x1 = MobileViTV2ConvLayer(
  131. config,
  132. in_channels=expanded_channels,
  133. out_channels=out_channels,
  134. kernel_size=1,
  135. use_activation=False,
  136. )
  137. def forward(self, features: torch.Tensor) -> torch.Tensor:
  138. residual = features
  139. features = self.expand_1x1(features)
  140. features = self.conv_3x3(features)
  141. features = self.reduce_1x1(features)
  142. return residual + features if self.use_residual else features
  143. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTMobileNetLayer with MobileViT->MobileViTV2
  144. class MobileViTV2MobileNetLayer(nn.Module):
  145. def __init__(
  146. self, config: MobileViTV2Config, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1
  147. ) -> None:
  148. super().__init__()
  149. self.layer = nn.ModuleList()
  150. for i in range(num_stages):
  151. layer = MobileViTV2InvertedResidual(
  152. config,
  153. in_channels=in_channels,
  154. out_channels=out_channels,
  155. stride=stride if i == 0 else 1,
  156. )
  157. self.layer.append(layer)
  158. in_channels = out_channels
  159. def forward(self, features: torch.Tensor) -> torch.Tensor:
  160. for layer_module in self.layer:
  161. features = layer_module(features)
  162. return features
  163. class MobileViTV2LinearSelfAttention(nn.Module):
  164. """
  165. This layer applies a self-attention with linear complexity, as described in MobileViTV2 paper:
  166. https://huggingface.co/papers/2206.02680
  167. Args:
  168. config (`MobileVitv2Config`):
  169. Model configuration object
  170. embed_dim (`int`):
  171. `input_channels` from an expected input of size :math:`(batch_size, input_channels, height, width)`
  172. """
  173. def __init__(self, config: MobileViTV2Config, embed_dim: int) -> None:
  174. super().__init__()
  175. self.qkv_proj = MobileViTV2ConvLayer(
  176. config=config,
  177. in_channels=embed_dim,
  178. out_channels=1 + (2 * embed_dim),
  179. bias=True,
  180. kernel_size=1,
  181. use_normalization=False,
  182. use_activation=False,
  183. )
  184. self.attn_dropout = nn.Dropout(p=config.attn_dropout)
  185. self.out_proj = MobileViTV2ConvLayer(
  186. config=config,
  187. in_channels=embed_dim,
  188. out_channels=embed_dim,
  189. bias=True,
  190. kernel_size=1,
  191. use_normalization=False,
  192. use_activation=False,
  193. )
  194. self.embed_dim = embed_dim
  195. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  196. # (batch_size, embed_dim, num_pixels_in_patch, num_patches) --> (batch_size, 1+2*embed_dim, num_pixels_in_patch, num_patches)
  197. qkv = self.qkv_proj(hidden_states)
  198. # Project hidden_states into query, key and value
  199. # Query --> [batch_size, 1, num_pixels_in_patch, num_patches]
  200. # value, key --> [batch_size, embed_dim, num_pixels_in_patch, num_patches]
  201. query, key, value = torch.split(qkv, split_size_or_sections=[1, self.embed_dim, self.embed_dim], dim=1)
  202. # apply softmax along num_patches dimension
  203. context_scores = torch.nn.functional.softmax(query, dim=-1)
  204. context_scores = self.attn_dropout(context_scores)
  205. # Compute context vector
  206. # [batch_size, embed_dim, num_pixels_in_patch, num_patches] x [batch_size, 1, num_pixels_in_patch, num_patches] -> [batch_size, embed_dim, num_pixels_in_patch, num_patches]
  207. context_vector = key * context_scores
  208. # [batch_size, embed_dim, num_pixels_in_patch, num_patches] --> [batch_size, embed_dim, num_pixels_in_patch, 1]
  209. context_vector = torch.sum(context_vector, dim=-1, keepdim=True)
  210. # combine context vector with values
  211. # [batch_size, embed_dim, num_pixels_in_patch, num_patches] * [batch_size, embed_dim, num_pixels_in_patch, 1] --> [batch_size, embed_dim, num_pixels_in_patch, num_patches]
  212. out = torch.nn.functional.relu(value) * context_vector.expand_as(value)
  213. out = self.out_proj(out)
  214. return out
  215. class MobileViTV2FFN(nn.Module):
  216. def __init__(
  217. self,
  218. config: MobileViTV2Config,
  219. embed_dim: int,
  220. ffn_latent_dim: int,
  221. ffn_dropout: float = 0.0,
  222. ) -> None:
  223. super().__init__()
  224. self.conv1 = MobileViTV2ConvLayer(
  225. config=config,
  226. in_channels=embed_dim,
  227. out_channels=ffn_latent_dim,
  228. kernel_size=1,
  229. stride=1,
  230. bias=True,
  231. use_normalization=False,
  232. use_activation=True,
  233. )
  234. self.dropout1 = nn.Dropout(ffn_dropout)
  235. self.conv2 = MobileViTV2ConvLayer(
  236. config=config,
  237. in_channels=ffn_latent_dim,
  238. out_channels=embed_dim,
  239. kernel_size=1,
  240. stride=1,
  241. bias=True,
  242. use_normalization=False,
  243. use_activation=False,
  244. )
  245. self.dropout2 = nn.Dropout(ffn_dropout)
  246. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  247. hidden_states = self.conv1(hidden_states)
  248. hidden_states = self.dropout1(hidden_states)
  249. hidden_states = self.conv2(hidden_states)
  250. hidden_states = self.dropout2(hidden_states)
  251. return hidden_states
  252. class MobileViTV2TransformerLayer(nn.Module):
  253. def __init__(
  254. self,
  255. config: MobileViTV2Config,
  256. embed_dim: int,
  257. ffn_latent_dim: int,
  258. dropout: float = 0.0,
  259. ) -> None:
  260. super().__init__()
  261. self.layernorm_before = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=config.layer_norm_eps)
  262. self.attention = MobileViTV2LinearSelfAttention(config, embed_dim)
  263. self.dropout1 = nn.Dropout(p=dropout)
  264. self.layernorm_after = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=config.layer_norm_eps)
  265. self.ffn = MobileViTV2FFN(config, embed_dim, ffn_latent_dim, config.ffn_dropout)
  266. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  267. layernorm_1_out = self.layernorm_before(hidden_states)
  268. attention_output = self.attention(layernorm_1_out)
  269. hidden_states = attention_output + hidden_states
  270. layer_output = self.layernorm_after(hidden_states)
  271. layer_output = self.ffn(layer_output)
  272. layer_output = layer_output + hidden_states
  273. return layer_output
  274. class MobileViTV2Transformer(nn.Module):
  275. def __init__(self, config: MobileViTV2Config, n_layers: int, d_model: int) -> None:
  276. super().__init__()
  277. ffn_multiplier = config.ffn_multiplier
  278. ffn_dims = [ffn_multiplier * d_model] * n_layers
  279. # ensure that dims are multiple of 16
  280. ffn_dims = [int((d // 16) * 16) for d in ffn_dims]
  281. self.layer = nn.ModuleList()
  282. for block_idx in range(n_layers):
  283. transformer_layer = MobileViTV2TransformerLayer(
  284. config, embed_dim=d_model, ffn_latent_dim=ffn_dims[block_idx]
  285. )
  286. self.layer.append(transformer_layer)
  287. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  288. for layer_module in self.layer:
  289. hidden_states = layer_module(hidden_states)
  290. return hidden_states
  291. class MobileViTV2Layer(GradientCheckpointingLayer):
  292. """
  293. MobileViTV2 layer: https://huggingface.co/papers/2206.02680
  294. """
  295. def __init__(
  296. self,
  297. config: MobileViTV2Config,
  298. in_channels: int,
  299. out_channels: int,
  300. attn_unit_dim: int,
  301. n_attn_blocks: int = 2,
  302. dilation: int = 1,
  303. stride: int = 2,
  304. ) -> None:
  305. super().__init__()
  306. self.patch_width = config.patch_size
  307. self.patch_height = config.patch_size
  308. cnn_out_dim = attn_unit_dim
  309. if stride == 2:
  310. self.downsampling_layer = MobileViTV2InvertedResidual(
  311. config,
  312. in_channels=in_channels,
  313. out_channels=out_channels,
  314. stride=stride if dilation == 1 else 1,
  315. dilation=dilation // 2 if dilation > 1 else 1,
  316. )
  317. in_channels = out_channels
  318. else:
  319. self.downsampling_layer = None
  320. # Local representations
  321. self.conv_kxk = MobileViTV2ConvLayer(
  322. config,
  323. in_channels=in_channels,
  324. out_channels=in_channels,
  325. kernel_size=config.conv_kernel_size,
  326. groups=in_channels,
  327. )
  328. self.conv_1x1 = MobileViTV2ConvLayer(
  329. config,
  330. in_channels=in_channels,
  331. out_channels=cnn_out_dim,
  332. kernel_size=1,
  333. use_normalization=False,
  334. use_activation=False,
  335. )
  336. # Global representations
  337. self.transformer = MobileViTV2Transformer(config, d_model=attn_unit_dim, n_layers=n_attn_blocks)
  338. # self.layernorm = MobileViTV2LayerNorm2D(attn_unit_dim, eps=config.layer_norm_eps)
  339. self.layernorm = nn.GroupNorm(num_groups=1, num_channels=attn_unit_dim, eps=config.layer_norm_eps)
  340. # Fusion
  341. self.conv_projection = MobileViTV2ConvLayer(
  342. config,
  343. in_channels=cnn_out_dim,
  344. out_channels=in_channels,
  345. kernel_size=1,
  346. use_normalization=True,
  347. use_activation=False,
  348. )
  349. def unfolding(self, feature_map: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
  350. batch_size, in_channels, img_height, img_width = feature_map.shape
  351. patches = nn.functional.unfold(
  352. feature_map,
  353. kernel_size=(self.patch_height, self.patch_width),
  354. stride=(self.patch_height, self.patch_width),
  355. )
  356. patches = patches.reshape(batch_size, in_channels, self.patch_height * self.patch_width, -1)
  357. return patches, (img_height, img_width)
  358. def folding(self, patches: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
  359. batch_size, in_dim, patch_size, n_patches = patches.shape
  360. patches = patches.reshape(batch_size, in_dim * patch_size, n_patches)
  361. feature_map = nn.functional.fold(
  362. patches,
  363. output_size=output_size,
  364. kernel_size=(self.patch_height, self.patch_width),
  365. stride=(self.patch_height, self.patch_width),
  366. )
  367. return feature_map
  368. def forward(self, features: torch.Tensor) -> torch.Tensor:
  369. # reduce spatial dimensions if needed
  370. if self.downsampling_layer:
  371. features = self.downsampling_layer(features)
  372. # local representation
  373. features = self.conv_kxk(features)
  374. features = self.conv_1x1(features)
  375. # convert feature map to patches
  376. patches, output_size = self.unfolding(features)
  377. # learn global representations
  378. patches = self.transformer(patches)
  379. patches = self.layernorm(patches)
  380. # convert patches back to feature maps
  381. # [batch_size, patch_height, patch_width, input_dim] --> [batch_size, input_dim, patch_height, patch_width]
  382. features = self.folding(patches, output_size)
  383. features = self.conv_projection(features)
  384. return features
  385. class MobileViTV2Encoder(nn.Module):
  386. def __init__(self, config: MobileViTV2Config) -> None:
  387. super().__init__()
  388. self.config = config
  389. self.layer = nn.ModuleList()
  390. self.gradient_checkpointing = False
  391. # segmentation architectures like DeepLab and PSPNet modify the strides
  392. # of the classification backbones
  393. dilate_layer_4 = dilate_layer_5 = False
  394. if config.output_stride == 8:
  395. dilate_layer_4 = True
  396. dilate_layer_5 = True
  397. elif config.output_stride == 16:
  398. dilate_layer_5 = True
  399. dilation = 1
  400. layer_0_dim = make_divisible(
  401. clip(value=32 * config.width_multiplier, min_val=16, max_val=64), divisor=8, min_value=16
  402. )
  403. layer_1_dim = make_divisible(64 * config.width_multiplier, divisor=16)
  404. layer_2_dim = make_divisible(128 * config.width_multiplier, divisor=8)
  405. layer_3_dim = make_divisible(256 * config.width_multiplier, divisor=8)
  406. layer_4_dim = make_divisible(384 * config.width_multiplier, divisor=8)
  407. layer_5_dim = make_divisible(512 * config.width_multiplier, divisor=8)
  408. layer_1 = MobileViTV2MobileNetLayer(
  409. config,
  410. in_channels=layer_0_dim,
  411. out_channels=layer_1_dim,
  412. stride=1,
  413. num_stages=1,
  414. )
  415. self.layer.append(layer_1)
  416. layer_2 = MobileViTV2MobileNetLayer(
  417. config,
  418. in_channels=layer_1_dim,
  419. out_channels=layer_2_dim,
  420. stride=2,
  421. num_stages=2,
  422. )
  423. self.layer.append(layer_2)
  424. layer_3 = MobileViTV2Layer(
  425. config,
  426. in_channels=layer_2_dim,
  427. out_channels=layer_3_dim,
  428. attn_unit_dim=make_divisible(config.base_attn_unit_dims[0] * config.width_multiplier, divisor=8),
  429. n_attn_blocks=config.n_attn_blocks[0],
  430. )
  431. self.layer.append(layer_3)
  432. if dilate_layer_4:
  433. dilation *= 2
  434. layer_4 = MobileViTV2Layer(
  435. config,
  436. in_channels=layer_3_dim,
  437. out_channels=layer_4_dim,
  438. attn_unit_dim=make_divisible(config.base_attn_unit_dims[1] * config.width_multiplier, divisor=8),
  439. n_attn_blocks=config.n_attn_blocks[1],
  440. dilation=dilation,
  441. )
  442. self.layer.append(layer_4)
  443. if dilate_layer_5:
  444. dilation *= 2
  445. layer_5 = MobileViTV2Layer(
  446. config,
  447. in_channels=layer_4_dim,
  448. out_channels=layer_5_dim,
  449. attn_unit_dim=make_divisible(config.base_attn_unit_dims[2] * config.width_multiplier, divisor=8),
  450. n_attn_blocks=config.n_attn_blocks[2],
  451. dilation=dilation,
  452. )
  453. self.layer.append(layer_5)
  454. def forward(
  455. self,
  456. hidden_states: torch.Tensor,
  457. output_hidden_states: bool = False,
  458. return_dict: bool = True,
  459. ) -> tuple | BaseModelOutputWithNoAttention:
  460. all_hidden_states = () if output_hidden_states else None
  461. for i, layer_module in enumerate(self.layer):
  462. hidden_states = layer_module(hidden_states)
  463. if output_hidden_states:
  464. all_hidden_states = all_hidden_states + (hidden_states,)
  465. if not return_dict:
  466. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  467. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  468. @auto_docstring
  469. class MobileViTV2PreTrainedModel(PreTrainedModel):
  470. config: MobileViTV2Config
  471. base_model_prefix = "mobilevitv2"
  472. main_input_name = "pixel_values"
  473. input_modalities = ("image",)
  474. supports_gradient_checkpointing = True
  475. _no_split_modules = ["MobileViTV2Layer"]
  476. @torch.no_grad()
  477. def _init_weights(self, module: nn.Module) -> None:
  478. """Initialize the weights"""
  479. if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
  480. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  481. if module.bias is not None:
  482. init.zeros_(module.bias)
  483. if getattr(module, "running_mean", None) is not None:
  484. init.zeros_(module.running_mean)
  485. init.ones_(module.running_var)
  486. init.zeros_(module.num_batches_tracked)
  487. elif isinstance(module, nn.GroupNorm):
  488. init.zeros_(module.bias)
  489. init.ones_(module.weight)
  490. @auto_docstring
  491. class MobileViTV2Model(MobileViTV2PreTrainedModel):
  492. def __init__(self, config: MobileViTV2Config, expand_output: bool = True):
  493. r"""
  494. expand_output (`bool`, *optional*, defaults to `True`):
  495. Whether to expand the output of the model. If `True`, the model will output pooled features in addition to
  496. hidden states. If `False`, only the hidden states will be returned.
  497. """
  498. super().__init__(config)
  499. self.config = config
  500. self.expand_output = expand_output
  501. layer_0_dim = make_divisible(
  502. clip(value=32 * config.width_multiplier, min_val=16, max_val=64), divisor=8, min_value=16
  503. )
  504. self.conv_stem = MobileViTV2ConvLayer(
  505. config,
  506. in_channels=config.num_channels,
  507. out_channels=layer_0_dim,
  508. kernel_size=3,
  509. stride=2,
  510. use_normalization=True,
  511. use_activation=True,
  512. )
  513. self.encoder = MobileViTV2Encoder(config)
  514. # Initialize weights and apply final processing
  515. self.post_init()
  516. @auto_docstring
  517. def forward(
  518. self,
  519. pixel_values: torch.Tensor | None = None,
  520. output_hidden_states: bool | None = None,
  521. return_dict: bool | None = None,
  522. **kwargs,
  523. ) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
  524. output_hidden_states = (
  525. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  526. )
  527. return_dict = return_dict if return_dict is not None else self.config.return_dict
  528. if pixel_values is None:
  529. raise ValueError("You have to specify pixel_values")
  530. embedding_output = self.conv_stem(pixel_values)
  531. encoder_outputs = self.encoder(
  532. embedding_output,
  533. output_hidden_states=output_hidden_states,
  534. return_dict=return_dict,
  535. )
  536. if self.expand_output:
  537. last_hidden_state = encoder_outputs[0]
  538. # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
  539. pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False)
  540. else:
  541. last_hidden_state = encoder_outputs[0]
  542. pooled_output = None
  543. if not return_dict:
  544. output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
  545. return output + encoder_outputs[1:]
  546. return BaseModelOutputWithPoolingAndNoAttention(
  547. last_hidden_state=last_hidden_state,
  548. pooler_output=pooled_output,
  549. hidden_states=encoder_outputs.hidden_states,
  550. )
  551. @auto_docstring(
  552. custom_intro="""
  553. MobileViTV2 model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  554. ImageNet.
  555. """
  556. )
  557. class MobileViTV2ForImageClassification(MobileViTV2PreTrainedModel):
  558. def __init__(self, config: MobileViTV2Config) -> None:
  559. super().__init__(config)
  560. self.num_labels = config.num_labels
  561. self.mobilevitv2 = MobileViTV2Model(config)
  562. out_channels = make_divisible(512 * config.width_multiplier, divisor=8) # layer 5 output dimension
  563. # Classifier head
  564. self.classifier = (
  565. nn.Linear(in_features=out_channels, out_features=config.num_labels)
  566. if config.num_labels > 0
  567. else nn.Identity()
  568. )
  569. # Initialize weights and apply final processing
  570. self.post_init()
  571. @auto_docstring
  572. def forward(
  573. self,
  574. pixel_values: torch.Tensor | None = None,
  575. output_hidden_states: bool | None = None,
  576. labels: torch.Tensor | None = None,
  577. return_dict: bool | None = None,
  578. **kwargs,
  579. ) -> tuple | ImageClassifierOutputWithNoAttention:
  580. r"""
  581. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  582. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  583. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
  584. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  585. """
  586. return_dict = return_dict if return_dict is not None else self.config.return_dict
  587. outputs = self.mobilevitv2(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  588. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  589. logits = self.classifier(pooled_output)
  590. loss = None
  591. if labels is not None:
  592. loss = self.loss_function(labels, logits, self.config)
  593. if not return_dict:
  594. output = (logits,) + outputs[2:]
  595. return ((loss,) + output) if loss is not None else output
  596. return ImageClassifierOutputWithNoAttention(
  597. loss=loss,
  598. logits=logits,
  599. hidden_states=outputs.hidden_states,
  600. )
  601. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTASPPPooling with MobileViT->MobileViTV2
  602. class MobileViTV2ASPPPooling(nn.Module):
  603. def __init__(self, config: MobileViTV2Config, in_channels: int, out_channels: int) -> None:
  604. super().__init__()
  605. self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
  606. self.conv_1x1 = MobileViTV2ConvLayer(
  607. config,
  608. in_channels=in_channels,
  609. out_channels=out_channels,
  610. kernel_size=1,
  611. stride=1,
  612. use_normalization=True,
  613. use_activation="relu",
  614. )
  615. def forward(self, features: torch.Tensor) -> torch.Tensor:
  616. spatial_size = features.shape[-2:]
  617. features = self.global_pool(features)
  618. features = self.conv_1x1(features)
  619. features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False)
  620. return features
  621. class MobileViTV2ASPP(nn.Module):
  622. """
  623. ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587
  624. """
  625. def __init__(self, config: MobileViTV2Config) -> None:
  626. super().__init__()
  627. encoder_out_channels = make_divisible(512 * config.width_multiplier, divisor=8) # layer 5 output dimension
  628. in_channels = encoder_out_channels
  629. out_channels = config.aspp_out_channels
  630. if len(config.atrous_rates) != 3:
  631. raise ValueError("Expected 3 values for atrous_rates")
  632. self.convs = nn.ModuleList()
  633. in_projection = MobileViTV2ConvLayer(
  634. config,
  635. in_channels=in_channels,
  636. out_channels=out_channels,
  637. kernel_size=1,
  638. use_activation="relu",
  639. )
  640. self.convs.append(in_projection)
  641. self.convs.extend(
  642. [
  643. MobileViTV2ConvLayer(
  644. config,
  645. in_channels=in_channels,
  646. out_channels=out_channels,
  647. kernel_size=3,
  648. dilation=rate,
  649. use_activation="relu",
  650. )
  651. for rate in config.atrous_rates
  652. ]
  653. )
  654. pool_layer = MobileViTV2ASPPPooling(config, in_channels, out_channels)
  655. self.convs.append(pool_layer)
  656. self.project = MobileViTV2ConvLayer(
  657. config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu"
  658. )
  659. self.dropout = nn.Dropout(p=config.aspp_dropout_prob)
  660. def forward(self, features: torch.Tensor) -> torch.Tensor:
  661. pyramid = []
  662. for conv in self.convs:
  663. pyramid.append(conv(features))
  664. pyramid = torch.cat(pyramid, dim=1)
  665. pooled_features = self.project(pyramid)
  666. pooled_features = self.dropout(pooled_features)
  667. return pooled_features
  668. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTDeepLabV3 with MobileViT->MobileViTV2
  669. class MobileViTV2DeepLabV3(nn.Module):
  670. """
  671. DeepLabv3 architecture: https://huggingface.co/papers/1706.05587
  672. """
  673. def __init__(self, config: MobileViTV2Config) -> None:
  674. super().__init__()
  675. self.aspp = MobileViTV2ASPP(config)
  676. self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
  677. self.classifier = MobileViTV2ConvLayer(
  678. config,
  679. in_channels=config.aspp_out_channels,
  680. out_channels=config.num_labels,
  681. kernel_size=1,
  682. use_normalization=False,
  683. use_activation=False,
  684. bias=True,
  685. )
  686. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  687. features = self.aspp(hidden_states[-1])
  688. features = self.dropout(features)
  689. features = self.classifier(features)
  690. return features
  691. @auto_docstring(
  692. custom_intro="""
  693. MobileViTV2 model with a semantic segmentation head on top, e.g. for Pascal VOC.
  694. """
  695. )
  696. class MobileViTV2ForSemanticSegmentation(MobileViTV2PreTrainedModel):
  697. def __init__(self, config: MobileViTV2Config) -> None:
  698. super().__init__(config)
  699. self.num_labels = config.num_labels
  700. self.mobilevitv2 = MobileViTV2Model(config, expand_output=False)
  701. self.segmentation_head = MobileViTV2DeepLabV3(config)
  702. # Initialize weights and apply final processing
  703. self.post_init()
  704. @auto_docstring
  705. def forward(
  706. self,
  707. pixel_values: torch.Tensor | None = None,
  708. labels: torch.Tensor | None = None,
  709. output_hidden_states: bool | None = None,
  710. return_dict: bool | None = None,
  711. **kwargs,
  712. ) -> tuple | SemanticSegmenterOutput:
  713. r"""
  714. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  715. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  716. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  717. Examples:
  718. ```python
  719. >>> import httpx
  720. >>> from io import BytesIO
  721. >>> import torch
  722. >>> from PIL import Image
  723. >>> from transformers import AutoImageProcessor, MobileViTV2ForSemanticSegmentation
  724. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  725. >>> with httpx.stream("GET", url) as response:
  726. ... image = Image.open(BytesIO(response.read()))
  727. >>> image_processor = AutoImageProcessor.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")
  728. >>> model = MobileViTV2ForSemanticSegmentation.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")
  729. >>> inputs = image_processor(images=image, return_tensors="pt")
  730. >>> with torch.no_grad():
  731. ... outputs = model(**inputs)
  732. >>> # logits are of shape (batch_size, num_labels, height, width)
  733. >>> logits = outputs.logits
  734. ```"""
  735. output_hidden_states = (
  736. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  737. )
  738. return_dict = return_dict if return_dict is not None else self.config.return_dict
  739. if labels is not None and self.config.num_labels == 1:
  740. raise ValueError("The number of labels should be greater than one")
  741. outputs = self.mobilevitv2(
  742. pixel_values,
  743. output_hidden_states=True, # we need the intermediate hidden states
  744. return_dict=return_dict,
  745. )
  746. encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
  747. logits = self.segmentation_head(encoder_hidden_states)
  748. loss = None
  749. if labels is not None:
  750. # upsample logits to the images' original size
  751. upsampled_logits = nn.functional.interpolate(
  752. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  753. )
  754. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  755. loss = loss_fct(upsampled_logits, labels)
  756. if not return_dict:
  757. if output_hidden_states:
  758. output = (logits,) + outputs[1:]
  759. else:
  760. output = (logits,) + outputs[2:]
  761. return ((loss,) + output) if loss is not None else output
  762. return SemanticSegmenterOutput(
  763. loss=loss,
  764. logits=logits,
  765. hidden_states=outputs.hidden_states if output_hidden_states else None,
  766. attentions=None,
  767. )
  768. __all__ = [
  769. "MobileViTV2ForImageClassification",
  770. "MobileViTV2ForSemanticSegmentation",
  771. "MobileViTV2Model",
  772. "MobileViTV2PreTrainedModel",
  773. ]