modeling_levit.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch LeViT model."""
  15. import itertools
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...modeling_outputs import (
  21. BaseModelOutputWithNoAttention,
  22. BaseModelOutputWithPoolingAndNoAttention,
  23. ImageClassifierOutputWithNoAttention,
  24. ModelOutput,
  25. )
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import auto_docstring, logging
  28. from .configuration_levit import LevitConfig
  29. logger = logging.get_logger(__name__)
  30. @dataclass
  31. @auto_docstring(
  32. custom_intro="""
  33. Output type of [`LevitForImageClassificationWithTeacher`].
  34. """
  35. )
  36. class LevitForImageClassificationWithTeacherOutput(ModelOutput):
  37. r"""
  38. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  39. Prediction scores as the average of the `cls_logits` and `distillation_logits`.
  40. cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  41. Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
  42. class token).
  43. distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  44. Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
  45. distillation token).
  46. """
  47. logits: torch.FloatTensor | None = None
  48. cls_logits: torch.FloatTensor | None = None
  49. distillation_logits: torch.FloatTensor | None = None
  50. hidden_states: tuple[torch.FloatTensor] | None = None
  51. class LevitConvEmbeddings(nn.Module):
  52. """
  53. LeViT Conv Embeddings with Batch Norm, used in the initial patch embedding layer.
  54. """
  55. def __init__(
  56. self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bn_weight_init=1
  57. ):
  58. super().__init__()
  59. self.convolution = nn.Conv2d(
  60. in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=False
  61. )
  62. self.batch_norm = nn.BatchNorm2d(out_channels)
  63. def forward(self, embeddings):
  64. embeddings = self.convolution(embeddings)
  65. embeddings = self.batch_norm(embeddings)
  66. return embeddings
  67. class LevitPatchEmbeddings(nn.Module):
  68. """
  69. LeViT patch embeddings, for final embeddings to be passed to transformer blocks. It consists of multiple
  70. `LevitConvEmbeddings`.
  71. """
  72. def __init__(self, config):
  73. super().__init__()
  74. self.embedding_layer_1 = LevitConvEmbeddings(
  75. config.num_channels, config.hidden_sizes[0] // 8, config.kernel_size, config.stride, config.padding
  76. )
  77. self.activation_layer_1 = nn.Hardswish()
  78. self.embedding_layer_2 = LevitConvEmbeddings(
  79. config.hidden_sizes[0] // 8, config.hidden_sizes[0] // 4, config.kernel_size, config.stride, config.padding
  80. )
  81. self.activation_layer_2 = nn.Hardswish()
  82. self.embedding_layer_3 = LevitConvEmbeddings(
  83. config.hidden_sizes[0] // 4, config.hidden_sizes[0] // 2, config.kernel_size, config.stride, config.padding
  84. )
  85. self.activation_layer_3 = nn.Hardswish()
  86. self.embedding_layer_4 = LevitConvEmbeddings(
  87. config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding
  88. )
  89. self.num_channels = config.num_channels
  90. def forward(self, pixel_values):
  91. num_channels = pixel_values.shape[1]
  92. if num_channels != self.num_channels:
  93. raise ValueError(
  94. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  95. )
  96. embeddings = self.embedding_layer_1(pixel_values)
  97. embeddings = self.activation_layer_1(embeddings)
  98. embeddings = self.embedding_layer_2(embeddings)
  99. embeddings = self.activation_layer_2(embeddings)
  100. embeddings = self.embedding_layer_3(embeddings)
  101. embeddings = self.activation_layer_3(embeddings)
  102. embeddings = self.embedding_layer_4(embeddings)
  103. return embeddings.flatten(2).transpose(1, 2)
  104. class MLPLayerWithBN(nn.Module):
  105. def __init__(self, input_dim, output_dim, bn_weight_init=1):
  106. super().__init__()
  107. self.linear = nn.Linear(in_features=input_dim, out_features=output_dim, bias=False)
  108. self.batch_norm = nn.BatchNorm1d(output_dim)
  109. def forward(self, hidden_state):
  110. hidden_state = self.linear(hidden_state)
  111. hidden_state = self.batch_norm(hidden_state.flatten(0, 1)).reshape_as(hidden_state)
  112. return hidden_state
  113. class LevitSubsample(nn.Module):
  114. def __init__(self, stride, resolution):
  115. super().__init__()
  116. self.stride = stride
  117. self.resolution = resolution
  118. def forward(self, hidden_state):
  119. batch_size, _, channels = hidden_state.shape
  120. hidden_state = hidden_state.view(batch_size, self.resolution, self.resolution, channels)[
  121. :, :: self.stride, :: self.stride
  122. ].reshape(batch_size, -1, channels)
  123. return hidden_state
  124. class LevitAttention(nn.Module):
  125. def __init__(self, hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution):
  126. super().__init__()
  127. self.num_attention_heads = num_attention_heads
  128. self.scale = key_dim**-0.5
  129. self.key_dim = key_dim
  130. self.attention_ratio = attention_ratio
  131. self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2
  132. self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
  133. self.queries_keys_values = MLPLayerWithBN(hidden_sizes, self.out_dim_keys_values)
  134. self.activation = nn.Hardswish()
  135. self.projection = MLPLayerWithBN(self.out_dim_projection, hidden_sizes, bn_weight_init=0)
  136. points = list(itertools.product(range(resolution), range(resolution)))
  137. len_points = len(points)
  138. self.len_points = len_points
  139. attention_offsets, indices = {}, []
  140. for p1 in points:
  141. for p2 in points:
  142. offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
  143. if offset not in attention_offsets:
  144. attention_offsets[offset] = len(attention_offsets)
  145. indices.append(attention_offsets[offset])
  146. self.indices = indices
  147. self.attention_bias_cache = {}
  148. self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
  149. self.register_buffer(
  150. "attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points), persistent=False
  151. )
  152. @torch.no_grad()
  153. def train(self, mode=True):
  154. super().train(mode)
  155. if mode and self.attention_bias_cache:
  156. self.attention_bias_cache = {} # clear ab cache
  157. def get_attention_biases(self, device):
  158. if self.training:
  159. return self.attention_biases[:, self.attention_bias_idxs]
  160. else:
  161. device_key = str(device)
  162. if device_key not in self.attention_bias_cache:
  163. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  164. return self.attention_bias_cache[device_key]
  165. def forward(self, hidden_state):
  166. batch_size, seq_length, _ = hidden_state.shape
  167. queries_keys_values = self.queries_keys_values(hidden_state)
  168. query, key, value = queries_keys_values.view(batch_size, seq_length, self.num_attention_heads, -1).split(
  169. [self.key_dim, self.key_dim, self.attention_ratio * self.key_dim], dim=3
  170. )
  171. query = query.permute(0, 2, 1, 3)
  172. key = key.permute(0, 2, 1, 3)
  173. value = value.permute(0, 2, 1, 3)
  174. attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
  175. attention = attention.softmax(dim=-1)
  176. hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, seq_length, self.out_dim_projection)
  177. hidden_state = self.projection(self.activation(hidden_state))
  178. return hidden_state
  179. class LevitAttentionSubsample(nn.Module):
  180. def __init__(
  181. self,
  182. input_dim,
  183. output_dim,
  184. key_dim,
  185. num_attention_heads,
  186. attention_ratio,
  187. stride,
  188. resolution_in,
  189. resolution_out,
  190. ):
  191. super().__init__()
  192. self.num_attention_heads = num_attention_heads
  193. self.scale = key_dim**-0.5
  194. self.key_dim = key_dim
  195. self.attention_ratio = attention_ratio
  196. self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads
  197. self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
  198. self.resolution_out = resolution_out
  199. # resolution_in is the initial resolution, resolution_out is final resolution after downsampling
  200. self.keys_values = MLPLayerWithBN(input_dim, self.out_dim_keys_values)
  201. self.queries_subsample = LevitSubsample(stride, resolution_in)
  202. self.queries = MLPLayerWithBN(input_dim, key_dim * num_attention_heads)
  203. self.activation = nn.Hardswish()
  204. self.projection = MLPLayerWithBN(self.out_dim_projection, output_dim)
  205. self.attention_bias_cache = {}
  206. points = list(itertools.product(range(resolution_in), range(resolution_in)))
  207. points_ = list(itertools.product(range(resolution_out), range(resolution_out)))
  208. len_points, len_points_ = len(points), len(points_)
  209. self.len_points_ = len_points_
  210. self.len_points = len_points
  211. attention_offsets, indices = {}, []
  212. for p1 in points_:
  213. for p2 in points:
  214. size = 1
  215. offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2))
  216. if offset not in attention_offsets:
  217. attention_offsets[offset] = len(attention_offsets)
  218. indices.append(attention_offsets[offset])
  219. self.indices = indices
  220. self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
  221. self.register_buffer(
  222. "attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points), persistent=False
  223. )
  224. @torch.no_grad()
  225. def train(self, mode=True):
  226. super().train(mode)
  227. if mode and self.attention_bias_cache:
  228. self.attention_bias_cache = {} # clear ab cache
  229. def get_attention_biases(self, device):
  230. if self.training:
  231. return self.attention_biases[:, self.attention_bias_idxs]
  232. else:
  233. device_key = str(device)
  234. if device_key not in self.attention_bias_cache:
  235. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  236. return self.attention_bias_cache[device_key]
  237. def forward(self, hidden_state):
  238. batch_size, seq_length, _ = hidden_state.shape
  239. key, value = (
  240. self.keys_values(hidden_state)
  241. .view(batch_size, seq_length, self.num_attention_heads, -1)
  242. .split([self.key_dim, self.attention_ratio * self.key_dim], dim=3)
  243. )
  244. key = key.permute(0, 2, 1, 3)
  245. value = value.permute(0, 2, 1, 3)
  246. query = self.queries(self.queries_subsample(hidden_state))
  247. query = query.view(batch_size, self.resolution_out**2, self.num_attention_heads, self.key_dim).permute(
  248. 0, 2, 1, 3
  249. )
  250. attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
  251. attention = attention.softmax(dim=-1)
  252. hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, -1, self.out_dim_projection)
  253. hidden_state = self.projection(self.activation(hidden_state))
  254. return hidden_state
  255. class LevitMLPLayer(nn.Module):
  256. """
  257. MLP Layer with `2X` expansion in contrast to ViT with `4X`.
  258. """
  259. def __init__(self, input_dim, hidden_dim):
  260. super().__init__()
  261. self.linear_up = MLPLayerWithBN(input_dim, hidden_dim)
  262. self.activation = nn.Hardswish()
  263. self.linear_down = MLPLayerWithBN(hidden_dim, input_dim)
  264. def forward(self, hidden_state):
  265. hidden_state = self.linear_up(hidden_state)
  266. hidden_state = self.activation(hidden_state)
  267. hidden_state = self.linear_down(hidden_state)
  268. return hidden_state
  269. class LevitResidualLayer(nn.Module):
  270. """
  271. Residual Block for LeViT
  272. """
  273. def __init__(self, module, drop_rate):
  274. super().__init__()
  275. self.module = module
  276. self.drop_rate = drop_rate
  277. def forward(self, hidden_state):
  278. if self.training and self.drop_rate > 0:
  279. rnd = torch.rand(hidden_state.size(0), 1, 1, device=hidden_state.device)
  280. rnd = rnd.ge_(self.drop_rate).div(1 - self.drop_rate).detach()
  281. hidden_state = hidden_state + self.module(hidden_state) * rnd
  282. return hidden_state
  283. else:
  284. hidden_state = hidden_state + self.module(hidden_state)
  285. return hidden_state
  286. class LevitStage(nn.Module):
  287. """
  288. LeViT Stage consisting of `LevitMLPLayer` and `LevitAttention` layers.
  289. """
  290. def __init__(
  291. self,
  292. config,
  293. idx,
  294. hidden_sizes,
  295. key_dim,
  296. depths,
  297. num_attention_heads,
  298. attention_ratio,
  299. mlp_ratio,
  300. down_ops,
  301. resolution_in,
  302. ):
  303. super().__init__()
  304. self.layers = []
  305. self.config = config
  306. self.resolution_in = resolution_in
  307. # resolution_in is the initial resolution, resolution_out is final resolution after downsampling
  308. for _ in range(depths):
  309. self.layers.append(
  310. LevitResidualLayer(
  311. LevitAttention(hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution_in),
  312. self.config.drop_path_rate,
  313. )
  314. )
  315. if mlp_ratio > 0:
  316. hidden_dim = hidden_sizes * mlp_ratio
  317. self.layers.append(
  318. LevitResidualLayer(LevitMLPLayer(hidden_sizes, hidden_dim), self.config.drop_path_rate)
  319. )
  320. if down_ops[0] == "Subsample":
  321. self.resolution_out = (self.resolution_in - 1) // down_ops[5] + 1
  322. self.layers.append(
  323. LevitAttentionSubsample(
  324. *self.config.hidden_sizes[idx : idx + 2],
  325. key_dim=down_ops[1],
  326. num_attention_heads=down_ops[2],
  327. attention_ratio=down_ops[3],
  328. stride=down_ops[5],
  329. resolution_in=resolution_in,
  330. resolution_out=self.resolution_out,
  331. )
  332. )
  333. self.resolution_in = self.resolution_out
  334. if down_ops[4] > 0:
  335. hidden_dim = self.config.hidden_sizes[idx + 1] * down_ops[4]
  336. self.layers.append(
  337. LevitResidualLayer(
  338. LevitMLPLayer(self.config.hidden_sizes[idx + 1], hidden_dim), self.config.drop_path_rate
  339. )
  340. )
  341. self.layers = nn.ModuleList(self.layers)
  342. def get_resolution(self):
  343. return self.resolution_in
  344. def forward(self, hidden_state):
  345. for layer in self.layers:
  346. hidden_state = layer(hidden_state)
  347. return hidden_state
  348. class LevitEncoder(nn.Module):
  349. """
  350. LeViT Encoder consisting of multiple `LevitStage` stages.
  351. """
  352. def __init__(self, config):
  353. super().__init__()
  354. self.config = config
  355. resolution = self.config.image_size // self.config.patch_size
  356. self.stages = []
  357. self.config.down_ops.append([""])
  358. for stage_idx in range(len(config.depths)):
  359. stage = LevitStage(
  360. config,
  361. stage_idx,
  362. config.hidden_sizes[stage_idx],
  363. config.key_dim[stage_idx],
  364. config.depths[stage_idx],
  365. config.num_attention_heads[stage_idx],
  366. config.attention_ratio[stage_idx],
  367. config.mlp_ratio[stage_idx],
  368. config.down_ops[stage_idx],
  369. resolution,
  370. )
  371. resolution = stage.get_resolution()
  372. self.stages.append(stage)
  373. self.stages = nn.ModuleList(self.stages)
  374. def forward(self, hidden_state, output_hidden_states=False, return_dict=True):
  375. all_hidden_states = () if output_hidden_states else None
  376. for stage in self.stages:
  377. if output_hidden_states:
  378. all_hidden_states = all_hidden_states + (hidden_state,)
  379. hidden_state = stage(hidden_state)
  380. if output_hidden_states:
  381. all_hidden_states = all_hidden_states + (hidden_state,)
  382. if not return_dict:
  383. return tuple(v for v in [hidden_state, all_hidden_states] if v is not None)
  384. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states)
  385. class LevitClassificationLayer(nn.Module):
  386. """
  387. LeViT Classification Layer
  388. """
  389. def __init__(self, input_dim, output_dim):
  390. super().__init__()
  391. self.batch_norm = nn.BatchNorm1d(input_dim)
  392. self.linear = nn.Linear(input_dim, output_dim)
  393. def forward(self, hidden_state):
  394. hidden_state = self.batch_norm(hidden_state)
  395. logits = self.linear(hidden_state)
  396. return logits
  397. @auto_docstring
  398. class LevitPreTrainedModel(PreTrainedModel):
  399. config: LevitConfig
  400. base_model_prefix = "levit"
  401. main_input_name = "pixel_values"
  402. input_modalities = ("image",)
  403. _no_split_modules = ["LevitResidualLayer"]
  404. def _init_weights(self, module):
  405. super()._init_weights(module)
  406. if isinstance(module, LevitAttention):
  407. init.copy_(
  408. module.attention_bias_idxs, torch.LongTensor(module.indices).view(module.len_points, module.len_points)
  409. )
  410. elif isinstance(module, LevitAttentionSubsample):
  411. init.copy_(
  412. module.attention_bias_idxs,
  413. torch.LongTensor(module.indices).view(module.len_points_, module.len_points),
  414. )
  415. @auto_docstring
  416. class LevitModel(LevitPreTrainedModel):
  417. def __init__(self, config):
  418. super().__init__(config)
  419. self.config = config
  420. self.patch_embeddings = LevitPatchEmbeddings(config)
  421. self.encoder = LevitEncoder(config)
  422. # Initialize weights and apply final processing
  423. self.post_init()
  424. @auto_docstring
  425. def forward(
  426. self,
  427. pixel_values: torch.FloatTensor | None = None,
  428. output_hidden_states: bool | None = None,
  429. return_dict: bool | None = None,
  430. **kwargs,
  431. ) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
  432. output_hidden_states = (
  433. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  434. )
  435. return_dict = return_dict if return_dict is not None else self.config.return_dict
  436. if pixel_values is None:
  437. raise ValueError("You have to specify pixel_values")
  438. embeddings = self.patch_embeddings(pixel_values)
  439. encoder_outputs = self.encoder(
  440. embeddings,
  441. output_hidden_states=output_hidden_states,
  442. return_dict=return_dict,
  443. )
  444. last_hidden_state = encoder_outputs[0]
  445. # global average pooling, (batch_size, seq_length, hidden_sizes) -> (batch_size, hidden_sizes)
  446. pooled_output = last_hidden_state.mean(dim=1)
  447. if not return_dict:
  448. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  449. return BaseModelOutputWithPoolingAndNoAttention(
  450. last_hidden_state=last_hidden_state,
  451. pooler_output=pooled_output,
  452. hidden_states=encoder_outputs.hidden_states,
  453. )
  454. @auto_docstring(
  455. custom_intro="""
  456. Levit Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  457. ImageNet.
  458. """
  459. )
  460. class LevitForImageClassification(LevitPreTrainedModel):
  461. def __init__(self, config):
  462. super().__init__(config)
  463. self.config = config
  464. self.num_labels = config.num_labels
  465. self.levit = LevitModel(config)
  466. # Classifier head
  467. self.classifier = (
  468. LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
  469. if config.num_labels > 0
  470. else torch.nn.Identity()
  471. )
  472. # Initialize weights and apply final processing
  473. self.post_init()
  474. @auto_docstring
  475. def forward(
  476. self,
  477. pixel_values: torch.FloatTensor | None = None,
  478. labels: torch.LongTensor | None = None,
  479. output_hidden_states: bool | None = None,
  480. return_dict: bool | None = None,
  481. **kwargs,
  482. ) -> tuple | ImageClassifierOutputWithNoAttention:
  483. r"""
  484. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  485. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  486. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  487. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  488. """
  489. return_dict = return_dict if return_dict is not None else self.config.return_dict
  490. outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  491. sequence_output = outputs[0]
  492. sequence_output = sequence_output.mean(1)
  493. logits = self.classifier(sequence_output)
  494. loss = None
  495. if labels is not None:
  496. loss = self.loss_function(labels, logits, self.config)
  497. if not return_dict:
  498. output = (logits,) + outputs[2:]
  499. return ((loss,) + output) if loss is not None else output
  500. return ImageClassifierOutputWithNoAttention(
  501. loss=loss,
  502. logits=logits,
  503. hidden_states=outputs.hidden_states,
  504. )
  505. @auto_docstring(
  506. custom_intro="""
  507. LeViT Model transformer with image classification heads on top (a linear layer on top of the final hidden state and
  508. a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. .. warning::
  509. This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
  510. supported.
  511. """
  512. )
  513. class LevitForImageClassificationWithTeacher(LevitPreTrainedModel):
  514. def __init__(self, config):
  515. super().__init__(config)
  516. self.config = config
  517. self.num_labels = config.num_labels
  518. self.levit = LevitModel(config)
  519. # Classifier head
  520. self.classifier = (
  521. LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
  522. if config.num_labels > 0
  523. else torch.nn.Identity()
  524. )
  525. self.classifier_distill = (
  526. LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
  527. if config.num_labels > 0
  528. else torch.nn.Identity()
  529. )
  530. # Initialize weights and apply final processing
  531. self.post_init()
  532. @auto_docstring
  533. def forward(
  534. self,
  535. pixel_values: torch.FloatTensor | None = None,
  536. output_hidden_states: bool | None = None,
  537. return_dict: bool | None = None,
  538. **kwargs,
  539. ) -> tuple | LevitForImageClassificationWithTeacherOutput:
  540. return_dict = return_dict if return_dict is not None else self.config.return_dict
  541. outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  542. sequence_output = outputs[0]
  543. sequence_output = sequence_output.mean(1)
  544. cls_logits, distill_logits = self.classifier(sequence_output), self.classifier_distill(sequence_output)
  545. logits = (cls_logits + distill_logits) / 2
  546. if not return_dict:
  547. output = (logits, cls_logits, distill_logits) + outputs[2:]
  548. return output
  549. return LevitForImageClassificationWithTeacherOutput(
  550. logits=logits,
  551. cls_logits=cls_logits,
  552. distillation_logits=distill_logits,
  553. hidden_states=outputs.hidden_states,
  554. )
  555. __all__ = [
  556. "LevitForImageClassification",
  557. "LevitForImageClassificationWithTeacher",
  558. "LevitModel",
  559. "LevitPreTrainedModel",
  560. ]