modeling_cvt.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch CvT model."""
  15. import collections.abc
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput
  22. from ...modeling_utils import PreTrainedModel
  23. from ...utils import auto_docstring, logging
  24. from .configuration_cvt import CvtConfig
  25. logger = logging.get_logger(__name__)
  26. @dataclass
  27. @auto_docstring(
  28. custom_intro="""
  29. Base class for model's outputs, with potential hidden states and attentions.
  30. """
  31. )
  32. class BaseModelOutputWithCLSToken(ModelOutput):
  33. r"""
  34. cls_token_value (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`):
  35. Classification token at the output of the last layer of the model.
  36. """
  37. last_hidden_state: torch.FloatTensor | None = None
  38. cls_token_value: torch.FloatTensor | None = None
  39. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  40. # Copied from transformers.models.beit.modeling_beit.drop_path
  41. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  42. """
  43. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  44. """
  45. if drop_prob == 0.0 or not training:
  46. return input
  47. keep_prob = 1 - drop_prob
  48. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  49. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  50. random_tensor.floor_() # binarize
  51. output = input.div(keep_prob) * random_tensor
  52. return output
  53. # Copied from transformers.models.beit.modeling_beit.BeitDropPath
  54. class CvtDropPath(nn.Module):
  55. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  56. def __init__(self, drop_prob: float | None = None) -> None:
  57. super().__init__()
  58. self.drop_prob = drop_prob
  59. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  60. return drop_path(hidden_states, self.drop_prob, self.training)
  61. def extra_repr(self) -> str:
  62. return f"p={self.drop_prob}"
  63. class CvtEmbeddings(nn.Module):
  64. """
  65. Construct the CvT embeddings.
  66. """
  67. def __init__(self, patch_size, num_channels, embed_dim, stride, padding, dropout_rate):
  68. super().__init__()
  69. self.convolution_embeddings = CvtConvEmbeddings(
  70. patch_size=patch_size, num_channels=num_channels, embed_dim=embed_dim, stride=stride, padding=padding
  71. )
  72. self.dropout = nn.Dropout(dropout_rate)
  73. def forward(self, pixel_values):
  74. hidden_state = self.convolution_embeddings(pixel_values)
  75. hidden_state = self.dropout(hidden_state)
  76. return hidden_state
  77. class CvtConvEmbeddings(nn.Module):
  78. """
  79. Image to Conv Embedding.
  80. """
  81. def __init__(self, patch_size, num_channels, embed_dim, stride, padding):
  82. super().__init__()
  83. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  84. self.patch_size = patch_size
  85. self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=padding)
  86. self.normalization = nn.LayerNorm(embed_dim)
  87. def forward(self, pixel_values):
  88. pixel_values = self.projection(pixel_values)
  89. batch_size, num_channels, height, width = pixel_values.shape
  90. hidden_size = height * width
  91. # rearrange "b c h w -> b (h w) c"
  92. pixel_values = pixel_values.view(batch_size, num_channels, hidden_size).permute(0, 2, 1)
  93. if self.normalization:
  94. pixel_values = self.normalization(pixel_values)
  95. # rearrange "b (h w) c" -> b c h w"
  96. pixel_values = pixel_values.permute(0, 2, 1).view(batch_size, num_channels, height, width)
  97. return pixel_values
  98. class CvtSelfAttentionConvProjection(nn.Module):
  99. def __init__(self, embed_dim, kernel_size, padding, stride):
  100. super().__init__()
  101. self.convolution = nn.Conv2d(
  102. embed_dim,
  103. embed_dim,
  104. kernel_size=kernel_size,
  105. padding=padding,
  106. stride=stride,
  107. bias=False,
  108. groups=embed_dim,
  109. )
  110. self.normalization = nn.BatchNorm2d(embed_dim)
  111. def forward(self, hidden_state):
  112. hidden_state = self.convolution(hidden_state)
  113. hidden_state = self.normalization(hidden_state)
  114. return hidden_state
  115. class CvtSelfAttentionLinearProjection(nn.Module):
  116. def forward(self, hidden_state):
  117. batch_size, num_channels, height, width = hidden_state.shape
  118. hidden_size = height * width
  119. # rearrange " b c h w -> b (h w) c"
  120. hidden_state = hidden_state.view(batch_size, num_channels, hidden_size).permute(0, 2, 1)
  121. return hidden_state
  122. class CvtSelfAttentionProjection(nn.Module):
  123. def __init__(self, embed_dim, kernel_size, padding, stride, projection_method="dw_bn"):
  124. super().__init__()
  125. if projection_method == "dw_bn":
  126. self.convolution_projection = CvtSelfAttentionConvProjection(embed_dim, kernel_size, padding, stride)
  127. self.linear_projection = CvtSelfAttentionLinearProjection()
  128. def forward(self, hidden_state):
  129. hidden_state = self.convolution_projection(hidden_state)
  130. hidden_state = self.linear_projection(hidden_state)
  131. return hidden_state
  132. class CvtSelfAttention(nn.Module):
  133. def __init__(
  134. self,
  135. num_heads,
  136. embed_dim,
  137. kernel_size,
  138. padding_q,
  139. padding_kv,
  140. stride_q,
  141. stride_kv,
  142. qkv_projection_method,
  143. qkv_bias,
  144. attention_drop_rate,
  145. with_cls_token=True,
  146. **kwargs,
  147. ):
  148. super().__init__()
  149. self.scale = embed_dim**-0.5
  150. self.with_cls_token = with_cls_token
  151. self.embed_dim = embed_dim
  152. self.num_heads = num_heads
  153. self.convolution_projection_query = CvtSelfAttentionProjection(
  154. embed_dim,
  155. kernel_size,
  156. padding_q,
  157. stride_q,
  158. projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method,
  159. )
  160. self.convolution_projection_key = CvtSelfAttentionProjection(
  161. embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method
  162. )
  163. self.convolution_projection_value = CvtSelfAttentionProjection(
  164. embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method
  165. )
  166. self.projection_query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
  167. self.projection_key = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
  168. self.projection_value = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
  169. self.dropout = nn.Dropout(attention_drop_rate)
  170. def rearrange_for_multi_head_attention(self, hidden_state):
  171. batch_size, hidden_size, _ = hidden_state.shape
  172. head_dim = self.embed_dim // self.num_heads
  173. # rearrange 'b t (h d) -> b h t d'
  174. return hidden_state.view(batch_size, hidden_size, self.num_heads, head_dim).permute(0, 2, 1, 3)
  175. def forward(self, hidden_state, height, width):
  176. if self.with_cls_token:
  177. cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
  178. batch_size, hidden_size, num_channels = hidden_state.shape
  179. # rearrange "b (h w) c -> b c h w"
  180. hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width)
  181. key = self.convolution_projection_key(hidden_state)
  182. query = self.convolution_projection_query(hidden_state)
  183. value = self.convolution_projection_value(hidden_state)
  184. if self.with_cls_token:
  185. query = torch.cat((cls_token, query), dim=1)
  186. key = torch.cat((cls_token, key), dim=1)
  187. value = torch.cat((cls_token, value), dim=1)
  188. head_dim = self.embed_dim // self.num_heads
  189. query = self.rearrange_for_multi_head_attention(self.projection_query(query))
  190. key = self.rearrange_for_multi_head_attention(self.projection_key(key))
  191. value = self.rearrange_for_multi_head_attention(self.projection_value(value))
  192. attention_score = torch.einsum("bhlk,bhtk->bhlt", [query, key]) * self.scale
  193. attention_probs = torch.nn.functional.softmax(attention_score, dim=-1)
  194. attention_probs = self.dropout(attention_probs)
  195. context = torch.einsum("bhlt,bhtv->bhlv", [attention_probs, value])
  196. # rearrange"b h t d -> b t (h d)"
  197. _, _, hidden_size, _ = context.shape
  198. context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, hidden_size, self.num_heads * head_dim)
  199. return context
  200. class CvtSelfOutput(nn.Module):
  201. """
  202. The residual connection is defined in CvtLayer instead of here (as is the case with other models), due to the
  203. layernorm applied before each block.
  204. """
  205. def __init__(self, embed_dim, drop_rate):
  206. super().__init__()
  207. self.dense = nn.Linear(embed_dim, embed_dim)
  208. self.dropout = nn.Dropout(drop_rate)
  209. def forward(self, hidden_state, input_tensor):
  210. hidden_state = self.dense(hidden_state)
  211. hidden_state = self.dropout(hidden_state)
  212. return hidden_state
  213. class CvtAttention(nn.Module):
  214. def __init__(
  215. self,
  216. num_heads,
  217. embed_dim,
  218. kernel_size,
  219. padding_q,
  220. padding_kv,
  221. stride_q,
  222. stride_kv,
  223. qkv_projection_method,
  224. qkv_bias,
  225. attention_drop_rate,
  226. drop_rate,
  227. with_cls_token=True,
  228. ):
  229. super().__init__()
  230. self.attention = CvtSelfAttention(
  231. num_heads,
  232. embed_dim,
  233. kernel_size,
  234. padding_q,
  235. padding_kv,
  236. stride_q,
  237. stride_kv,
  238. qkv_projection_method,
  239. qkv_bias,
  240. attention_drop_rate,
  241. with_cls_token,
  242. )
  243. self.output = CvtSelfOutput(embed_dim, drop_rate)
  244. def forward(self, hidden_state, height, width):
  245. self_output = self.attention(hidden_state, height, width)
  246. attention_output = self.output(self_output, hidden_state)
  247. return attention_output
  248. class CvtIntermediate(nn.Module):
  249. def __init__(self, embed_dim, mlp_ratio):
  250. super().__init__()
  251. self.dense = nn.Linear(embed_dim, int(embed_dim * mlp_ratio))
  252. self.activation = nn.GELU()
  253. def forward(self, hidden_state):
  254. hidden_state = self.dense(hidden_state)
  255. hidden_state = self.activation(hidden_state)
  256. return hidden_state
  257. class CvtOutput(nn.Module):
  258. def __init__(self, embed_dim, mlp_ratio, drop_rate):
  259. super().__init__()
  260. self.dense = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
  261. self.dropout = nn.Dropout(drop_rate)
  262. def forward(self, hidden_state, input_tensor):
  263. hidden_state = self.dense(hidden_state)
  264. hidden_state = self.dropout(hidden_state)
  265. hidden_state = hidden_state + input_tensor
  266. return hidden_state
  267. class CvtLayer(nn.Module):
  268. """
  269. CvtLayer composed by attention layers, normalization and multi-layer perceptrons (mlps).
  270. """
  271. def __init__(
  272. self,
  273. num_heads,
  274. embed_dim,
  275. kernel_size,
  276. padding_q,
  277. padding_kv,
  278. stride_q,
  279. stride_kv,
  280. qkv_projection_method,
  281. qkv_bias,
  282. attention_drop_rate,
  283. drop_rate,
  284. mlp_ratio,
  285. drop_path_rate,
  286. with_cls_token=True,
  287. ):
  288. super().__init__()
  289. self.attention = CvtAttention(
  290. num_heads,
  291. embed_dim,
  292. kernel_size,
  293. padding_q,
  294. padding_kv,
  295. stride_q,
  296. stride_kv,
  297. qkv_projection_method,
  298. qkv_bias,
  299. attention_drop_rate,
  300. drop_rate,
  301. with_cls_token,
  302. )
  303. self.intermediate = CvtIntermediate(embed_dim, mlp_ratio)
  304. self.output = CvtOutput(embed_dim, mlp_ratio, drop_rate)
  305. self.drop_path = CvtDropPath(drop_prob=drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  306. self.layernorm_before = nn.LayerNorm(embed_dim)
  307. self.layernorm_after = nn.LayerNorm(embed_dim)
  308. def forward(self, hidden_state, height, width):
  309. self_attention_output = self.attention(
  310. self.layernorm_before(hidden_state), # in Cvt, layernorm is applied before self-attention
  311. height,
  312. width,
  313. )
  314. attention_output = self_attention_output
  315. attention_output = self.drop_path(attention_output)
  316. # first residual connection
  317. hidden_state = attention_output + hidden_state
  318. # in Cvt, layernorm is also applied after self-attention
  319. layer_output = self.layernorm_after(hidden_state)
  320. layer_output = self.intermediate(layer_output)
  321. # second residual connection is done here
  322. layer_output = self.output(layer_output, hidden_state)
  323. layer_output = self.drop_path(layer_output)
  324. return layer_output
  325. class CvtStage(nn.Module):
  326. def __init__(self, config, stage):
  327. super().__init__()
  328. self.config = config
  329. self.stage = stage
  330. if self.config.cls_token[self.stage]:
  331. self.cls_token = nn.Parameter(torch.randn(1, 1, self.config.embed_dim[-1]))
  332. self.embedding = CvtEmbeddings(
  333. patch_size=config.patch_sizes[self.stage],
  334. stride=config.patch_stride[self.stage],
  335. num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1],
  336. embed_dim=config.embed_dim[self.stage],
  337. padding=config.patch_padding[self.stage],
  338. dropout_rate=config.drop_rate[self.stage],
  339. )
  340. drop_path_rates = [
  341. x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage], device="cpu")
  342. ]
  343. self.layers = nn.Sequential(
  344. *[
  345. CvtLayer(
  346. num_heads=config.num_heads[self.stage],
  347. embed_dim=config.embed_dim[self.stage],
  348. kernel_size=config.kernel_qkv[self.stage],
  349. padding_q=config.padding_q[self.stage],
  350. padding_kv=config.padding_kv[self.stage],
  351. stride_kv=config.stride_kv[self.stage],
  352. stride_q=config.stride_q[self.stage],
  353. qkv_projection_method=config.qkv_projection_method[self.stage],
  354. qkv_bias=config.qkv_bias[self.stage],
  355. attention_drop_rate=config.attention_drop_rate[self.stage],
  356. drop_rate=config.drop_rate[self.stage],
  357. drop_path_rate=drop_path_rates[self.stage],
  358. mlp_ratio=config.mlp_ratio[self.stage],
  359. with_cls_token=config.cls_token[self.stage],
  360. )
  361. for _ in range(config.depth[self.stage])
  362. ]
  363. )
  364. def forward(self, hidden_state):
  365. cls_token = None
  366. hidden_state = self.embedding(hidden_state)
  367. batch_size, num_channels, height, width = hidden_state.shape
  368. # rearrange b c h w -> b (h w) c"
  369. hidden_state = hidden_state.view(batch_size, num_channels, height * width).permute(0, 2, 1)
  370. if self.config.cls_token[self.stage]:
  371. cls_token = self.cls_token.expand(batch_size, -1, -1)
  372. hidden_state = torch.cat((cls_token, hidden_state), dim=1)
  373. for layer in self.layers:
  374. layer_outputs = layer(hidden_state, height, width)
  375. hidden_state = layer_outputs
  376. if self.config.cls_token[self.stage]:
  377. cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
  378. hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width)
  379. return hidden_state, cls_token
  380. class CvtEncoder(nn.Module):
  381. def __init__(self, config):
  382. super().__init__()
  383. self.config = config
  384. self.stages = nn.ModuleList([])
  385. for stage_idx in range(len(config.depth)):
  386. self.stages.append(CvtStage(config, stage_idx))
  387. def forward(self, pixel_values, output_hidden_states=False, return_dict=True):
  388. all_hidden_states = () if output_hidden_states else None
  389. hidden_state = pixel_values
  390. cls_token = None
  391. for _, (stage_module) in enumerate(self.stages):
  392. hidden_state, cls_token = stage_module(hidden_state)
  393. if output_hidden_states:
  394. all_hidden_states = all_hidden_states + (hidden_state,)
  395. if not return_dict:
  396. return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)
  397. return BaseModelOutputWithCLSToken(
  398. last_hidden_state=hidden_state,
  399. cls_token_value=cls_token,
  400. hidden_states=all_hidden_states,
  401. )
  402. @auto_docstring
  403. class CvtPreTrainedModel(PreTrainedModel):
  404. config: CvtConfig
  405. base_model_prefix = "cvt"
  406. main_input_name = "pixel_values"
  407. _no_split_modules = ["CvtLayer"]
  408. @torch.no_grad()
  409. def _init_weights(self, module):
  410. """Initialize the weights"""
  411. if isinstance(module, (nn.Linear, nn.Conv2d)):
  412. init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  413. if module.bias is not None:
  414. init.zeros_(module.bias)
  415. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
  416. init.zeros_(module.bias)
  417. init.ones_(module.weight)
  418. if getattr(module, "running_mean", None) is not None:
  419. init.zeros_(module.running_mean)
  420. init.ones_(module.running_var)
  421. init.zeros_(module.num_batches_tracked)
  422. elif isinstance(module, CvtStage):
  423. if self.config.cls_token[module.stage]:
  424. init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
  425. @auto_docstring
  426. class CvtModel(CvtPreTrainedModel):
  427. def __init__(self, config, add_pooling_layer=True):
  428. r"""
  429. add_pooling_layer (bool, *optional*, defaults to `True`):
  430. Whether to add a pooling layer
  431. """
  432. super().__init__(config)
  433. self.config = config
  434. self.encoder = CvtEncoder(config)
  435. self.post_init()
  436. @auto_docstring
  437. def forward(
  438. self,
  439. pixel_values: torch.Tensor | None = None,
  440. output_hidden_states: bool | None = None,
  441. return_dict: bool | None = None,
  442. **kwargs,
  443. ) -> tuple | BaseModelOutputWithCLSToken:
  444. output_hidden_states = (
  445. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  446. )
  447. return_dict = return_dict if return_dict is not None else self.config.return_dict
  448. if pixel_values is None:
  449. raise ValueError("You have to specify pixel_values")
  450. encoder_outputs = self.encoder(
  451. pixel_values,
  452. output_hidden_states=output_hidden_states,
  453. return_dict=return_dict,
  454. )
  455. sequence_output = encoder_outputs[0]
  456. if not return_dict:
  457. return (sequence_output,) + encoder_outputs[1:]
  458. return BaseModelOutputWithCLSToken(
  459. last_hidden_state=sequence_output,
  460. cls_token_value=encoder_outputs.cls_token_value,
  461. hidden_states=encoder_outputs.hidden_states,
  462. )
  463. @auto_docstring(
  464. custom_intro="""
  465. Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
  466. the [CLS] token) e.g. for ImageNet.
  467. """
  468. )
  469. class CvtForImageClassification(CvtPreTrainedModel):
  470. def __init__(self, config):
  471. super().__init__(config)
  472. self.num_labels = config.num_labels
  473. self.cvt = CvtModel(config, add_pooling_layer=False)
  474. self.layernorm = nn.LayerNorm(config.embed_dim[-1])
  475. # Classifier head
  476. self.classifier = (
  477. nn.Linear(config.embed_dim[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
  478. )
  479. # Initialize weights and apply final processing
  480. self.post_init()
  481. @auto_docstring
  482. def forward(
  483. self,
  484. pixel_values: torch.Tensor | None = None,
  485. labels: torch.Tensor | None = None,
  486. output_hidden_states: bool | None = None,
  487. return_dict: bool | None = None,
  488. **kwargs,
  489. ) -> tuple | ImageClassifierOutputWithNoAttention:
  490. r"""
  491. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  492. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  493. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  494. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  495. """
  496. return_dict = return_dict if return_dict is not None else self.config.return_dict
  497. outputs = self.cvt(
  498. pixel_values,
  499. output_hidden_states=output_hidden_states,
  500. return_dict=return_dict,
  501. )
  502. sequence_output = outputs[0]
  503. cls_token = outputs[1]
  504. if self.config.cls_token[-1]:
  505. sequence_output = self.layernorm(cls_token)
  506. else:
  507. batch_size, num_channels, height, width = sequence_output.shape
  508. # rearrange "b c h w -> b (h w) c"
  509. sequence_output = sequence_output.view(batch_size, num_channels, height * width).permute(0, 2, 1)
  510. sequence_output = self.layernorm(sequence_output)
  511. sequence_output_mean = sequence_output.mean(dim=1)
  512. logits = self.classifier(sequence_output_mean)
  513. loss = None
  514. if labels is not None:
  515. if self.config.problem_type is None:
  516. if self.config.num_labels == 1:
  517. self.config.problem_type = "regression"
  518. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  519. self.config.problem_type = "single_label_classification"
  520. else:
  521. self.config.problem_type = "multi_label_classification"
  522. if self.config.problem_type == "regression":
  523. loss_fct = MSELoss()
  524. if self.config.num_labels == 1:
  525. loss = loss_fct(logits.squeeze(), labels.squeeze())
  526. else:
  527. loss = loss_fct(logits, labels)
  528. elif self.config.problem_type == "single_label_classification":
  529. loss_fct = CrossEntropyLoss()
  530. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  531. elif self.config.problem_type == "multi_label_classification":
  532. loss_fct = BCEWithLogitsLoss()
  533. loss = loss_fct(logits, labels)
  534. if not return_dict:
  535. output = (logits,) + outputs[2:]
  536. return ((loss,) + output) if loss is not None else output
  537. return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  538. __all__ = ["CvtForImageClassification", "CvtModel", "CvtPreTrainedModel"]