modeling_mobilevit.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963
  1. # Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
  16. """PyTorch MobileViT model."""
  17. import math
  18. import torch
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutputWithNoAttention,
  26. BaseModelOutputWithPoolingAndNoAttention,
  27. ImageClassifierOutputWithNoAttention,
  28. SemanticSegmenterOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...utils import auto_docstring, logging, torch_int
  32. from .configuration_mobilevit import MobileViTConfig
  33. logger = logging.get_logger(__name__)
  34. def make_divisible(value: int, divisor: int = 8, min_value: int | None = None) -> int:
  35. """
  36. Ensure that all layers have a channel count that is divisible by `divisor`.
  37. """
  38. if min_value is None:
  39. min_value = divisor
  40. new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
  41. # Make sure that round down does not go down by more than 10%.
  42. if new_value < 0.9 * value:
  43. new_value += divisor
  44. return int(new_value)
  45. class MobileViTConvLayer(nn.Module):
  46. def __init__(
  47. self,
  48. config: MobileViTConfig,
  49. in_channels: int,
  50. out_channels: int,
  51. kernel_size: int,
  52. stride: int = 1,
  53. groups: int = 1,
  54. bias: bool = False,
  55. dilation: int = 1,
  56. use_normalization: bool = True,
  57. use_activation: bool | str = True,
  58. ) -> None:
  59. super().__init__()
  60. padding = int((kernel_size - 1) / 2) * dilation
  61. if in_channels % groups != 0:
  62. raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
  63. if out_channels % groups != 0:
  64. raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
  65. self.convolution = nn.Conv2d(
  66. in_channels=in_channels,
  67. out_channels=out_channels,
  68. kernel_size=kernel_size,
  69. stride=stride,
  70. padding=padding,
  71. dilation=dilation,
  72. groups=groups,
  73. bias=bias,
  74. padding_mode="zeros",
  75. )
  76. if use_normalization:
  77. self.normalization = nn.BatchNorm2d(
  78. num_features=out_channels,
  79. eps=1e-5,
  80. momentum=0.1,
  81. affine=True,
  82. track_running_stats=True,
  83. )
  84. else:
  85. self.normalization = None
  86. if use_activation:
  87. if isinstance(use_activation, str):
  88. self.activation = ACT2FN[use_activation]
  89. elif isinstance(config.hidden_act, str):
  90. self.activation = ACT2FN[config.hidden_act]
  91. else:
  92. self.activation = config.hidden_act
  93. else:
  94. self.activation = None
  95. def forward(self, features: torch.Tensor) -> torch.Tensor:
  96. features = self.convolution(features)
  97. if self.normalization is not None:
  98. features = self.normalization(features)
  99. if self.activation is not None:
  100. features = self.activation(features)
  101. return features
  102. class MobileViTInvertedResidual(nn.Module):
  103. """
  104. Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381
  105. """
  106. def __init__(
  107. self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1
  108. ) -> None:
  109. super().__init__()
  110. expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
  111. if stride not in [1, 2]:
  112. raise ValueError(f"Invalid stride {stride}.")
  113. self.use_residual = (stride == 1) and (in_channels == out_channels)
  114. self.expand_1x1 = MobileViTConvLayer(
  115. config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
  116. )
  117. self.conv_3x3 = MobileViTConvLayer(
  118. config,
  119. in_channels=expanded_channels,
  120. out_channels=expanded_channels,
  121. kernel_size=3,
  122. stride=stride,
  123. groups=expanded_channels,
  124. dilation=dilation,
  125. )
  126. self.reduce_1x1 = MobileViTConvLayer(
  127. config,
  128. in_channels=expanded_channels,
  129. out_channels=out_channels,
  130. kernel_size=1,
  131. use_activation=False,
  132. )
  133. def forward(self, features: torch.Tensor) -> torch.Tensor:
  134. residual = features
  135. features = self.expand_1x1(features)
  136. features = self.conv_3x3(features)
  137. features = self.reduce_1x1(features)
  138. return residual + features if self.use_residual else features
  139. class MobileViTMobileNetLayer(nn.Module):
  140. def __init__(
  141. self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1
  142. ) -> None:
  143. super().__init__()
  144. self.layer = nn.ModuleList()
  145. for i in range(num_stages):
  146. layer = MobileViTInvertedResidual(
  147. config,
  148. in_channels=in_channels,
  149. out_channels=out_channels,
  150. stride=stride if i == 0 else 1,
  151. )
  152. self.layer.append(layer)
  153. in_channels = out_channels
  154. def forward(self, features: torch.Tensor) -> torch.Tensor:
  155. for layer_module in self.layer:
  156. features = layer_module(features)
  157. return features
  158. class MobileViTSelfAttention(nn.Module):
  159. def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
  160. super().__init__()
  161. if hidden_size % config.num_attention_heads != 0:
  162. raise ValueError(
  163. f"The hidden size {hidden_size} is not a multiple of the number of attention "
  164. f"heads {config.num_attention_heads}."
  165. )
  166. self.num_attention_heads = config.num_attention_heads
  167. self.attention_head_size = int(hidden_size / config.num_attention_heads)
  168. self.all_head_size = self.num_attention_heads * self.attention_head_size
  169. self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  170. self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  171. self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  172. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  173. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  174. input_shape = hidden_states.shape[:-1]
  175. hidden_shape = (*input_shape, -1, self.attention_head_size)
  176. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  177. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  178. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  179. # Take the dot product between "query" and "key" to get the raw attention scores.
  180. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  181. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  182. # Normalize the attention scores to probabilities.
  183. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  184. # This is actually dropping out entire tokens to attend to, which might
  185. # seem a bit unusual, but is taken from the original Transformer paper.
  186. attention_probs = self.dropout(attention_probs)
  187. context_layer = torch.matmul(attention_probs, value_layer)
  188. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  189. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  190. context_layer = context_layer.view(*new_context_layer_shape)
  191. return context_layer
  192. class MobileViTSelfOutput(nn.Module):
  193. def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
  194. super().__init__()
  195. self.dense = nn.Linear(hidden_size, hidden_size)
  196. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  197. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  198. hidden_states = self.dense(hidden_states)
  199. hidden_states = self.dropout(hidden_states)
  200. return hidden_states
  201. class MobileViTAttention(nn.Module):
  202. def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
  203. super().__init__()
  204. self.attention = MobileViTSelfAttention(config, hidden_size)
  205. self.output = MobileViTSelfOutput(config, hidden_size)
  206. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  207. self_outputs = self.attention(hidden_states)
  208. attention_output = self.output(self_outputs)
  209. return attention_output
  210. class MobileViTIntermediate(nn.Module):
  211. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
  212. super().__init__()
  213. self.dense = nn.Linear(hidden_size, intermediate_size)
  214. if isinstance(config.hidden_act, str):
  215. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  216. else:
  217. self.intermediate_act_fn = config.hidden_act
  218. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  219. hidden_states = self.dense(hidden_states)
  220. hidden_states = self.intermediate_act_fn(hidden_states)
  221. return hidden_states
  222. class MobileViTOutput(nn.Module):
  223. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
  224. super().__init__()
  225. self.dense = nn.Linear(intermediate_size, hidden_size)
  226. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  227. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  228. hidden_states = self.dense(hidden_states)
  229. hidden_states = self.dropout(hidden_states)
  230. hidden_states = hidden_states + input_tensor
  231. return hidden_states
  232. class MobileViTTransformerLayer(nn.Module):
  233. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
  234. super().__init__()
  235. self.attention = MobileViTAttention(config, hidden_size)
  236. self.intermediate = MobileViTIntermediate(config, hidden_size, intermediate_size)
  237. self.output = MobileViTOutput(config, hidden_size, intermediate_size)
  238. self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  239. self.layernorm_after = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  240. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  241. attention_output = self.attention(self.layernorm_before(hidden_states))
  242. hidden_states = attention_output + hidden_states
  243. layer_output = self.layernorm_after(hidden_states)
  244. layer_output = self.intermediate(layer_output)
  245. layer_output = self.output(layer_output, hidden_states)
  246. return layer_output
  247. class MobileViTTransformer(nn.Module):
  248. def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int) -> None:
  249. super().__init__()
  250. self.layer = nn.ModuleList()
  251. for _ in range(num_stages):
  252. transformer_layer = MobileViTTransformerLayer(
  253. config,
  254. hidden_size=hidden_size,
  255. intermediate_size=int(hidden_size * config.mlp_ratio),
  256. )
  257. self.layer.append(transformer_layer)
  258. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  259. for layer_module in self.layer:
  260. hidden_states = layer_module(hidden_states)
  261. return hidden_states
  262. class MobileViTLayer(GradientCheckpointingLayer):
  263. """
  264. MobileViT block: https://huggingface.co/papers/2110.02178
  265. """
  266. def __init__(
  267. self,
  268. config: MobileViTConfig,
  269. in_channels: int,
  270. out_channels: int,
  271. stride: int,
  272. hidden_size: int,
  273. num_stages: int,
  274. dilation: int = 1,
  275. ) -> None:
  276. super().__init__()
  277. self.patch_width = config.patch_size
  278. self.patch_height = config.patch_size
  279. if stride == 2:
  280. self.downsampling_layer = MobileViTInvertedResidual(
  281. config,
  282. in_channels=in_channels,
  283. out_channels=out_channels,
  284. stride=stride if dilation == 1 else 1,
  285. dilation=dilation // 2 if dilation > 1 else 1,
  286. )
  287. in_channels = out_channels
  288. else:
  289. self.downsampling_layer = None
  290. self.conv_kxk = MobileViTConvLayer(
  291. config,
  292. in_channels=in_channels,
  293. out_channels=in_channels,
  294. kernel_size=config.conv_kernel_size,
  295. )
  296. self.conv_1x1 = MobileViTConvLayer(
  297. config,
  298. in_channels=in_channels,
  299. out_channels=hidden_size,
  300. kernel_size=1,
  301. use_normalization=False,
  302. use_activation=False,
  303. )
  304. self.transformer = MobileViTTransformer(
  305. config,
  306. hidden_size=hidden_size,
  307. num_stages=num_stages,
  308. )
  309. self.layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  310. self.conv_projection = MobileViTConvLayer(
  311. config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1
  312. )
  313. self.fusion = MobileViTConvLayer(
  314. config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size
  315. )
  316. def unfolding(self, features: torch.Tensor) -> tuple[torch.Tensor, dict]:
  317. patch_width, patch_height = self.patch_width, self.patch_height
  318. patch_area = int(patch_width * patch_height)
  319. batch_size, channels, orig_height, orig_width = features.shape
  320. new_height = (
  321. torch_int(torch.ceil(orig_height / patch_height) * patch_height)
  322. if torch.jit.is_tracing()
  323. else int(math.ceil(orig_height / patch_height) * patch_height)
  324. )
  325. new_width = (
  326. torch_int(torch.ceil(orig_width / patch_width) * patch_width)
  327. if torch.jit.is_tracing()
  328. else int(math.ceil(orig_width / patch_width) * patch_width)
  329. )
  330. interpolate = False
  331. if new_width != orig_width or new_height != orig_height:
  332. # Note: Padding can be done, but then it needs to be handled in attention function.
  333. features = nn.functional.interpolate(
  334. features, size=(new_height, new_width), mode="bilinear", align_corners=False
  335. )
  336. interpolate = True
  337. # number of patches along width and height
  338. num_patch_width = new_width // patch_width
  339. num_patch_height = new_height // patch_height
  340. num_patches = num_patch_height * num_patch_width
  341. # convert from shape (batch_size, channels, orig_height, orig_width)
  342. # to the shape (batch_size * patch_area, num_patches, channels)
  343. patches = features.reshape(
  344. batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width
  345. )
  346. patches = patches.transpose(1, 2)
  347. patches = patches.reshape(batch_size, channels, num_patches, patch_area)
  348. patches = patches.transpose(1, 3)
  349. patches = patches.reshape(batch_size * patch_area, num_patches, -1)
  350. info_dict = {
  351. "orig_size": (orig_height, orig_width),
  352. "batch_size": batch_size,
  353. "channels": channels,
  354. "interpolate": interpolate,
  355. "num_patches": num_patches,
  356. "num_patches_width": num_patch_width,
  357. "num_patches_height": num_patch_height,
  358. }
  359. return patches, info_dict
  360. def folding(self, patches: torch.Tensor, info_dict: dict) -> torch.Tensor:
  361. patch_width, patch_height = self.patch_width, self.patch_height
  362. patch_area = int(patch_width * patch_height)
  363. batch_size = info_dict["batch_size"]
  364. channels = info_dict["channels"]
  365. num_patches = info_dict["num_patches"]
  366. num_patch_height = info_dict["num_patches_height"]
  367. num_patch_width = info_dict["num_patches_width"]
  368. # convert from shape (batch_size * patch_area, num_patches, channels)
  369. # back to shape (batch_size, channels, orig_height, orig_width)
  370. features = patches.contiguous().view(batch_size, patch_area, num_patches, -1)
  371. features = features.transpose(1, 3)
  372. features = features.reshape(
  373. batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width
  374. )
  375. features = features.transpose(1, 2)
  376. features = features.reshape(
  377. batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width
  378. )
  379. if info_dict["interpolate"]:
  380. features = nn.functional.interpolate(
  381. features, size=info_dict["orig_size"], mode="bilinear", align_corners=False
  382. )
  383. return features
  384. def forward(self, features: torch.Tensor) -> torch.Tensor:
  385. # reduce spatial dimensions if needed
  386. if self.downsampling_layer:
  387. features = self.downsampling_layer(features)
  388. residual = features
  389. # local representation
  390. features = self.conv_kxk(features)
  391. features = self.conv_1x1(features)
  392. # convert feature map to patches
  393. patches, info_dict = self.unfolding(features)
  394. # learn global representations
  395. patches = self.transformer(patches)
  396. patches = self.layernorm(patches)
  397. # convert patches back to feature maps
  398. features = self.folding(patches, info_dict)
  399. features = self.conv_projection(features)
  400. features = self.fusion(torch.cat((residual, features), dim=1))
  401. return features
  402. class MobileViTEncoder(nn.Module):
  403. def __init__(self, config: MobileViTConfig) -> None:
  404. super().__init__()
  405. self.config = config
  406. self.layer = nn.ModuleList()
  407. self.gradient_checkpointing = False
  408. # segmentation architectures like DeepLab and PSPNet modify the strides
  409. # of the classification backbones
  410. dilate_layer_4 = dilate_layer_5 = False
  411. if config.output_stride == 8:
  412. dilate_layer_4 = True
  413. dilate_layer_5 = True
  414. elif config.output_stride == 16:
  415. dilate_layer_5 = True
  416. dilation = 1
  417. layer_1 = MobileViTMobileNetLayer(
  418. config,
  419. in_channels=config.neck_hidden_sizes[0],
  420. out_channels=config.neck_hidden_sizes[1],
  421. stride=1,
  422. num_stages=1,
  423. )
  424. self.layer.append(layer_1)
  425. layer_2 = MobileViTMobileNetLayer(
  426. config,
  427. in_channels=config.neck_hidden_sizes[1],
  428. out_channels=config.neck_hidden_sizes[2],
  429. stride=2,
  430. num_stages=3,
  431. )
  432. self.layer.append(layer_2)
  433. layer_3 = MobileViTLayer(
  434. config,
  435. in_channels=config.neck_hidden_sizes[2],
  436. out_channels=config.neck_hidden_sizes[3],
  437. stride=2,
  438. hidden_size=config.hidden_sizes[0],
  439. num_stages=2,
  440. )
  441. self.layer.append(layer_3)
  442. if dilate_layer_4:
  443. dilation *= 2
  444. layer_4 = MobileViTLayer(
  445. config,
  446. in_channels=config.neck_hidden_sizes[3],
  447. out_channels=config.neck_hidden_sizes[4],
  448. stride=2,
  449. hidden_size=config.hidden_sizes[1],
  450. num_stages=4,
  451. dilation=dilation,
  452. )
  453. self.layer.append(layer_4)
  454. if dilate_layer_5:
  455. dilation *= 2
  456. layer_5 = MobileViTLayer(
  457. config,
  458. in_channels=config.neck_hidden_sizes[4],
  459. out_channels=config.neck_hidden_sizes[5],
  460. stride=2,
  461. hidden_size=config.hidden_sizes[2],
  462. num_stages=3,
  463. dilation=dilation,
  464. )
  465. self.layer.append(layer_5)
  466. def forward(
  467. self,
  468. hidden_states: torch.Tensor,
  469. output_hidden_states: bool = False,
  470. return_dict: bool = True,
  471. ) -> tuple | BaseModelOutputWithNoAttention:
  472. all_hidden_states = () if output_hidden_states else None
  473. for i, layer_module in enumerate(self.layer):
  474. hidden_states = layer_module(hidden_states)
  475. if output_hidden_states:
  476. all_hidden_states = all_hidden_states + (hidden_states,)
  477. if not return_dict:
  478. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  479. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  480. @auto_docstring
  481. class MobileViTPreTrainedModel(PreTrainedModel):
  482. config: MobileViTConfig
  483. base_model_prefix = "mobilevit"
  484. main_input_name = "pixel_values"
  485. input_modalities = ("image",)
  486. supports_gradient_checkpointing = True
  487. _no_split_modules = ["MobileViTLayer"]
  488. @torch.no_grad()
  489. def _init_weights(self, module: nn.Module) -> None:
  490. """Initialize the weights"""
  491. if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
  492. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  493. if module.bias is not None:
  494. init.zeros_(module.bias)
  495. if getattr(module, "running_mean", None) is not None:
  496. init.zeros_(module.running_mean)
  497. init.ones_(module.running_var)
  498. init.zeros_(module.num_batches_tracked)
  499. elif isinstance(module, nn.LayerNorm):
  500. init.zeros_(module.bias)
  501. init.ones_(module.weight)
  502. @auto_docstring
  503. class MobileViTModel(MobileViTPreTrainedModel):
  504. def __init__(self, config: MobileViTConfig, expand_output: bool = True):
  505. r"""
  506. expand_output (`bool`, *optional*, defaults to `True`):
  507. Whether to expand the output of the model using a 1x1 convolution. If `True`, the model will apply an additional
  508. 1x1 convolution to expand the output channels from `config.neck_hidden_sizes[5]` to `config.neck_hidden_sizes[6]`.
  509. """
  510. super().__init__(config)
  511. self.config = config
  512. self.expand_output = expand_output
  513. self.conv_stem = MobileViTConvLayer(
  514. config,
  515. in_channels=config.num_channels,
  516. out_channels=config.neck_hidden_sizes[0],
  517. kernel_size=3,
  518. stride=2,
  519. )
  520. self.encoder = MobileViTEncoder(config)
  521. if self.expand_output:
  522. self.conv_1x1_exp = MobileViTConvLayer(
  523. config,
  524. in_channels=config.neck_hidden_sizes[5],
  525. out_channels=config.neck_hidden_sizes[6],
  526. kernel_size=1,
  527. )
  528. # Initialize weights and apply final processing
  529. self.post_init()
  530. @auto_docstring
  531. def forward(
  532. self,
  533. pixel_values: torch.Tensor | None = None,
  534. output_hidden_states: bool | None = None,
  535. return_dict: bool | None = None,
  536. **kwargs,
  537. ) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
  538. output_hidden_states = (
  539. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  540. )
  541. return_dict = return_dict if return_dict is not None else self.config.return_dict
  542. if pixel_values is None:
  543. raise ValueError("You have to specify pixel_values")
  544. embedding_output = self.conv_stem(pixel_values)
  545. encoder_outputs = self.encoder(
  546. embedding_output,
  547. output_hidden_states=output_hidden_states,
  548. return_dict=return_dict,
  549. )
  550. if self.expand_output:
  551. last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
  552. # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
  553. pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False)
  554. else:
  555. last_hidden_state = encoder_outputs[0]
  556. pooled_output = None
  557. if not return_dict:
  558. output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
  559. return output + encoder_outputs[1:]
  560. return BaseModelOutputWithPoolingAndNoAttention(
  561. last_hidden_state=last_hidden_state,
  562. pooler_output=pooled_output,
  563. hidden_states=encoder_outputs.hidden_states,
  564. )
  565. @auto_docstring(
  566. custom_intro="""
  567. MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  568. ImageNet.
  569. """
  570. )
  571. class MobileViTForImageClassification(MobileViTPreTrainedModel):
  572. def __init__(self, config: MobileViTConfig) -> None:
  573. super().__init__(config)
  574. self.num_labels = config.num_labels
  575. self.mobilevit = MobileViTModel(config)
  576. # Classifier head
  577. self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True)
  578. self.classifier = (
  579. nn.Linear(config.neck_hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
  580. )
  581. # Initialize weights and apply final processing
  582. self.post_init()
  583. @auto_docstring
  584. def forward(
  585. self,
  586. pixel_values: torch.Tensor | None = None,
  587. output_hidden_states: bool | None = None,
  588. labels: torch.Tensor | None = None,
  589. return_dict: bool | None = None,
  590. **kwargs,
  591. ) -> tuple | ImageClassifierOutputWithNoAttention:
  592. r"""
  593. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  594. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  595. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
  596. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  597. """
  598. return_dict = return_dict if return_dict is not None else self.config.return_dict
  599. outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  600. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  601. logits = self.classifier(self.dropout(pooled_output))
  602. loss = None
  603. if labels is not None:
  604. loss = self.loss_function(labels, logits, self.config)
  605. if not return_dict:
  606. output = (logits,) + outputs[2:]
  607. return ((loss,) + output) if loss is not None else output
  608. return ImageClassifierOutputWithNoAttention(
  609. loss=loss,
  610. logits=logits,
  611. hidden_states=outputs.hidden_states,
  612. )
  613. class MobileViTASPPPooling(nn.Module):
  614. def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int) -> None:
  615. super().__init__()
  616. self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
  617. self.conv_1x1 = MobileViTConvLayer(
  618. config,
  619. in_channels=in_channels,
  620. out_channels=out_channels,
  621. kernel_size=1,
  622. stride=1,
  623. use_normalization=True,
  624. use_activation="relu",
  625. )
  626. def forward(self, features: torch.Tensor) -> torch.Tensor:
  627. spatial_size = features.shape[-2:]
  628. features = self.global_pool(features)
  629. features = self.conv_1x1(features)
  630. features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False)
  631. return features
  632. class MobileViTASPP(nn.Module):
  633. """
  634. ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587
  635. """
  636. def __init__(self, config: MobileViTConfig) -> None:
  637. super().__init__()
  638. in_channels = config.neck_hidden_sizes[-2]
  639. out_channels = config.aspp_out_channels
  640. if len(config.atrous_rates) != 3:
  641. raise ValueError("Expected 3 values for atrous_rates")
  642. self.convs = nn.ModuleList()
  643. in_projection = MobileViTConvLayer(
  644. config,
  645. in_channels=in_channels,
  646. out_channels=out_channels,
  647. kernel_size=1,
  648. use_activation="relu",
  649. )
  650. self.convs.append(in_projection)
  651. self.convs.extend(
  652. [
  653. MobileViTConvLayer(
  654. config,
  655. in_channels=in_channels,
  656. out_channels=out_channels,
  657. kernel_size=3,
  658. dilation=rate,
  659. use_activation="relu",
  660. )
  661. for rate in config.atrous_rates
  662. ]
  663. )
  664. pool_layer = MobileViTASPPPooling(config, in_channels, out_channels)
  665. self.convs.append(pool_layer)
  666. self.project = MobileViTConvLayer(
  667. config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu"
  668. )
  669. self.dropout = nn.Dropout(p=config.aspp_dropout_prob)
  670. def forward(self, features: torch.Tensor) -> torch.Tensor:
  671. pyramid = []
  672. for conv in self.convs:
  673. pyramid.append(conv(features))
  674. pyramid = torch.cat(pyramid, dim=1)
  675. pooled_features = self.project(pyramid)
  676. pooled_features = self.dropout(pooled_features)
  677. return pooled_features
  678. class MobileViTDeepLabV3(nn.Module):
  679. """
  680. DeepLabv3 architecture: https://huggingface.co/papers/1706.05587
  681. """
  682. def __init__(self, config: MobileViTConfig) -> None:
  683. super().__init__()
  684. self.aspp = MobileViTASPP(config)
  685. self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
  686. self.classifier = MobileViTConvLayer(
  687. config,
  688. in_channels=config.aspp_out_channels,
  689. out_channels=config.num_labels,
  690. kernel_size=1,
  691. use_normalization=False,
  692. use_activation=False,
  693. bias=True,
  694. )
  695. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  696. features = self.aspp(hidden_states[-1])
  697. features = self.dropout(features)
  698. features = self.classifier(features)
  699. return features
  700. @auto_docstring(
  701. custom_intro="""
  702. MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
  703. """
  704. )
  705. class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel):
  706. def __init__(self, config: MobileViTConfig) -> None:
  707. super().__init__(config)
  708. self.num_labels = config.num_labels
  709. self.mobilevit = MobileViTModel(config, expand_output=False)
  710. self.segmentation_head = MobileViTDeepLabV3(config)
  711. # Initialize weights and apply final processing
  712. self.post_init()
  713. @auto_docstring
  714. def forward(
  715. self,
  716. pixel_values: torch.Tensor | None = None,
  717. labels: torch.Tensor | None = None,
  718. output_hidden_states: bool | None = None,
  719. return_dict: bool | None = None,
  720. **kwargs,
  721. ) -> tuple | SemanticSegmenterOutput:
  722. r"""
  723. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  724. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  725. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  726. Examples:
  727. ```python
  728. >>> import httpx
  729. >>> from io import BytesIO
  730. >>> import torch
  731. >>> from PIL import Image
  732. >>> from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation
  733. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  734. >>> with httpx.stream("GET", url) as response:
  735. ... image = Image.open(BytesIO(response.read()))
  736. >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small")
  737. >>> model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
  738. >>> inputs = image_processor(images=image, return_tensors="pt")
  739. >>> with torch.no_grad():
  740. ... outputs = model(**inputs)
  741. >>> # logits are of shape (batch_size, num_labels, height, width)
  742. >>> logits = outputs.logits
  743. ```"""
  744. output_hidden_states = (
  745. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  746. )
  747. return_dict = return_dict if return_dict is not None else self.config.return_dict
  748. if labels is not None and self.config.num_labels == 1:
  749. raise ValueError("The number of labels should be greater than one")
  750. outputs = self.mobilevit(
  751. pixel_values,
  752. output_hidden_states=True, # we need the intermediate hidden states
  753. return_dict=return_dict,
  754. )
  755. encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
  756. logits = self.segmentation_head(encoder_hidden_states)
  757. loss = None
  758. if labels is not None:
  759. # upsample logits to the images' original size
  760. upsampled_logits = nn.functional.interpolate(
  761. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  762. )
  763. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  764. loss = loss_fct(upsampled_logits, labels)
  765. if not return_dict:
  766. if output_hidden_states:
  767. output = (logits,) + outputs[1:]
  768. else:
  769. output = (logits,) + outputs[2:]
  770. return ((loss,) + output) if loss is not None else output
  771. return SemanticSegmenterOutput(
  772. loss=loss,
  773. logits=logits,
  774. hidden_states=outputs.hidden_states if output_hidden_states else None,
  775. attentions=None,
  776. )
  777. __all__ = [
  778. "MobileViTForImageClassification",
  779. "MobileViTForSemanticSegmentation",
  780. "MobileViTModel",
  781. "MobileViTPreTrainedModel",
  782. ]