modeling_dinat.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805
  1. # Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Dilated Neighborhood Attention Transformer model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from ...activations import ACT2FN
  20. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  21. from ...modeling_outputs import BackboneOutput
  22. from ...modeling_utils import PreTrainedModel
  23. from ...utils import (
  24. ModelOutput,
  25. OptionalDependencyNotAvailable,
  26. auto_docstring,
  27. is_natten_available,
  28. logging,
  29. requires_backends,
  30. )
  31. from ...utils.generic import can_return_tuple
  32. from .configuration_dinat import DinatConfig
  33. if is_natten_available():
  34. from natten.functional import natten2dav, natten2dqkrpb
  35. else:
  36. def natten2dqkrpb(*args, **kwargs):
  37. raise OptionalDependencyNotAvailable()
  38. def natten2dav(*args, **kwargs):
  39. raise OptionalDependencyNotAvailable()
  40. logger = logging.get_logger(__name__)
  41. # drop_path and DinatDropPath are from the timm library.
  42. @dataclass
  43. @auto_docstring(
  44. custom_intro="""
  45. Dinat encoder's outputs, with potential hidden states and attentions.
  46. """
  47. )
  48. class DinatEncoderOutput(ModelOutput):
  49. r"""
  50. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  51. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  52. shape `(batch_size, hidden_size, height, width)`.
  53. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  54. include the spatial dimensions.
  55. """
  56. last_hidden_state: torch.FloatTensor | None = None
  57. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  58. attentions: tuple[torch.FloatTensor, ...] | None = None
  59. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  60. @dataclass
  61. @auto_docstring(
  62. custom_intro="""
  63. Dinat model's outputs that also contains a pooling of the last hidden states.
  64. """
  65. )
  66. class DinatModelOutput(ModelOutput):
  67. r"""
  68. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
  69. Average pooling of the last layer hidden-state.
  70. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  71. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  72. shape `(batch_size, hidden_size, height, width)`.
  73. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  74. include the spatial dimensions.
  75. """
  76. last_hidden_state: torch.FloatTensor | None = None
  77. pooler_output: torch.FloatTensor | None = None
  78. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  79. attentions: tuple[torch.FloatTensor, ...] | None = None
  80. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  81. @dataclass
  82. @auto_docstring(
  83. custom_intro="""
  84. Dinat outputs for image classification.
  85. """
  86. )
  87. class DinatImageClassifierOutput(ModelOutput):
  88. r"""
  89. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  90. Classification (or regression if config.num_labels==1) loss.
  91. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  92. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  93. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  94. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  95. shape `(batch_size, hidden_size, height, width)`.
  96. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  97. include the spatial dimensions.
  98. """
  99. loss: torch.FloatTensor | None = None
  100. logits: torch.FloatTensor | None = None
  101. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  102. attentions: tuple[torch.FloatTensor, ...] | None = None
  103. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  104. class DinatEmbeddings(nn.Module):
  105. """
  106. Construct the patch and position embeddings.
  107. """
  108. def __init__(self, config):
  109. super().__init__()
  110. self.patch_embeddings = DinatPatchEmbeddings(config)
  111. self.norm = nn.LayerNorm(config.embed_dim)
  112. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  113. def forward(self, pixel_values: torch.FloatTensor | None) -> tuple[torch.Tensor]:
  114. embeddings = self.patch_embeddings(pixel_values)
  115. embeddings = self.norm(embeddings)
  116. embeddings = self.dropout(embeddings)
  117. return embeddings
  118. class DinatPatchEmbeddings(nn.Module):
  119. """
  120. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  121. `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a
  122. Transformer.
  123. """
  124. def __init__(self, config):
  125. super().__init__()
  126. patch_size = config.patch_size
  127. num_channels, hidden_size = config.num_channels, config.embed_dim
  128. self.num_channels = num_channels
  129. if patch_size == 4:
  130. pass
  131. else:
  132. # TODO: Support arbitrary patch sizes.
  133. raise ValueError("Dinat only supports patch size of 4 at the moment.")
  134. self.projection = nn.Sequential(
  135. nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
  136. nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
  137. )
  138. def forward(self, pixel_values: torch.FloatTensor | None) -> torch.Tensor:
  139. _, num_channels, height, width = pixel_values.shape
  140. if num_channels != self.num_channels:
  141. raise ValueError(
  142. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  143. )
  144. embeddings = self.projection(pixel_values)
  145. embeddings = embeddings.permute(0, 2, 3, 1)
  146. return embeddings
  147. class DinatDownsampler(nn.Module):
  148. """
  149. Convolutional Downsampling Layer.
  150. Args:
  151. dim (`int`):
  152. Number of input channels.
  153. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
  154. Normalization layer class.
  155. """
  156. def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
  157. super().__init__()
  158. self.dim = dim
  159. self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  160. self.norm = norm_layer(2 * dim)
  161. def forward(self, input_feature: torch.Tensor) -> torch.Tensor:
  162. input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
  163. input_feature = self.norm(input_feature)
  164. return input_feature
  165. # Copied from transformers.models.beit.modeling_beit.drop_path
  166. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  167. """
  168. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  169. """
  170. if drop_prob == 0.0 or not training:
  171. return input
  172. keep_prob = 1 - drop_prob
  173. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  174. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  175. random_tensor.floor_() # binarize
  176. output = input.div(keep_prob) * random_tensor
  177. return output
  178. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Dinat
  179. class DinatDropPath(nn.Module):
  180. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  181. def __init__(self, drop_prob: float | None = None) -> None:
  182. super().__init__()
  183. self.drop_prob = drop_prob
  184. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  185. return drop_path(hidden_states, self.drop_prob, self.training)
  186. def extra_repr(self) -> str:
  187. return f"p={self.drop_prob}"
  188. class NeighborhoodAttention(nn.Module):
  189. def __init__(self, config, dim, num_heads, kernel_size, dilation):
  190. super().__init__()
  191. if dim % num_heads != 0:
  192. raise ValueError(
  193. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  194. )
  195. self.num_attention_heads = num_heads
  196. self.attention_head_size = int(dim / num_heads)
  197. self.all_head_size = self.num_attention_heads * self.attention_head_size
  198. self.kernel_size = kernel_size
  199. self.dilation = dilation
  200. # rpb is learnable relative positional biases; same concept is used Swin.
  201. self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1)))
  202. self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  203. self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  204. self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  205. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  206. def forward(
  207. self,
  208. hidden_states: torch.Tensor,
  209. output_attentions: bool | None = False,
  210. ) -> tuple[torch.Tensor]:
  211. input_shape = hidden_states.shape[:-1]
  212. hidden_shape = (*input_shape, -1, self.attention_head_size)
  213. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  214. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  215. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  216. # Apply the scale factor before computing attention weights. It's usually more efficient because
  217. # attention weights are typically a bigger tensor compared to query.
  218. # It gives identical results because scalars are commutable in matrix multiplication.
  219. query_layer = query_layer / math.sqrt(self.attention_head_size)
  220. # Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases.
  221. attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, self.dilation)
  222. # Normalize the attention scores to probabilities.
  223. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  224. # This is actually dropping out entire tokens to attend to, which might
  225. # seem a bit unusual, but is taken from the original Transformer paper.
  226. attention_probs = self.dropout(attention_probs)
  227. context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, self.dilation)
  228. context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
  229. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  230. context_layer = context_layer.view(new_context_layer_shape)
  231. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  232. return outputs
  233. class NeighborhoodAttentionOutput(nn.Module):
  234. def __init__(self, config, dim):
  235. super().__init__()
  236. self.dense = nn.Linear(dim, dim)
  237. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  238. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  239. hidden_states = self.dense(hidden_states)
  240. hidden_states = self.dropout(hidden_states)
  241. return hidden_states
  242. class NeighborhoodAttentionModule(nn.Module):
  243. def __init__(self, config, dim, num_heads, kernel_size, dilation):
  244. super().__init__()
  245. self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size, dilation)
  246. self.output = NeighborhoodAttentionOutput(config, dim)
  247. def forward(
  248. self,
  249. hidden_states: torch.Tensor,
  250. output_attentions: bool | None = False,
  251. ) -> tuple[torch.Tensor]:
  252. self_outputs = self.self(hidden_states, output_attentions)
  253. attention_output = self.output(self_outputs[0], hidden_states)
  254. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  255. return outputs
  256. class DinatIntermediate(nn.Module):
  257. def __init__(self, config, dim):
  258. super().__init__()
  259. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  260. if isinstance(config.hidden_act, str):
  261. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  262. else:
  263. self.intermediate_act_fn = config.hidden_act
  264. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  265. hidden_states = self.dense(hidden_states)
  266. hidden_states = self.intermediate_act_fn(hidden_states)
  267. return hidden_states
  268. class DinatOutput(nn.Module):
  269. def __init__(self, config, dim):
  270. super().__init__()
  271. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  272. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  273. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  274. hidden_states = self.dense(hidden_states)
  275. hidden_states = self.dropout(hidden_states)
  276. return hidden_states
  277. class DinatLayer(nn.Module):
  278. def __init__(self, config, dim, num_heads, dilation, drop_path_rate=0.0):
  279. super().__init__()
  280. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  281. self.kernel_size = config.kernel_size
  282. self.dilation = dilation
  283. self.window_size = self.kernel_size * self.dilation
  284. self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  285. self.attention = NeighborhoodAttentionModule(
  286. config, dim, num_heads, kernel_size=self.kernel_size, dilation=self.dilation
  287. )
  288. self.drop_path = DinatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  289. self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  290. self.intermediate = DinatIntermediate(config, dim)
  291. self.output = DinatOutput(config, dim)
  292. self.layer_scale_parameters = (
  293. nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
  294. if config.layer_scale_init_value > 0
  295. else None
  296. )
  297. def maybe_pad(self, hidden_states, height, width):
  298. window_size = self.window_size
  299. pad_values = (0, 0, 0, 0, 0, 0)
  300. if height < window_size or width < window_size:
  301. pad_l = pad_t = 0
  302. pad_r = max(0, window_size - width)
  303. pad_b = max(0, window_size - height)
  304. pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b)
  305. hidden_states = nn.functional.pad(hidden_states, pad_values)
  306. return hidden_states, pad_values
  307. def forward(
  308. self,
  309. hidden_states: torch.Tensor,
  310. output_attentions: bool | None = False,
  311. ) -> tuple[torch.Tensor, torch.Tensor]:
  312. batch_size, height, width, channels = hidden_states.size()
  313. shortcut = hidden_states
  314. hidden_states = self.layernorm_before(hidden_states)
  315. # pad hidden_states if they are smaller than kernel size x dilation
  316. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  317. _, height_pad, width_pad, _ = hidden_states.shape
  318. attention_outputs = self.attention(hidden_states, output_attentions=output_attentions)
  319. attention_output = attention_outputs[0]
  320. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  321. if was_padded:
  322. attention_output = attention_output[:, :height, :width, :].contiguous()
  323. if self.layer_scale_parameters is not None:
  324. attention_output = self.layer_scale_parameters[0] * attention_output
  325. hidden_states = shortcut + self.drop_path(attention_output)
  326. layer_output = self.layernorm_after(hidden_states)
  327. layer_output = self.output(self.intermediate(layer_output))
  328. if self.layer_scale_parameters is not None:
  329. layer_output = self.layer_scale_parameters[1] * layer_output
  330. layer_output = hidden_states + self.drop_path(layer_output)
  331. layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
  332. return layer_outputs
  333. class DinatStage(nn.Module):
  334. def __init__(self, config, dim, depth, num_heads, dilations, drop_path_rate, downsample):
  335. super().__init__()
  336. self.config = config
  337. self.dim = dim
  338. self.layers = nn.ModuleList(
  339. [
  340. DinatLayer(
  341. config=config,
  342. dim=dim,
  343. num_heads=num_heads,
  344. dilation=dilations[i],
  345. drop_path_rate=drop_path_rate[i],
  346. )
  347. for i in range(depth)
  348. ]
  349. )
  350. # patch merging layer
  351. if downsample is not None:
  352. self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm)
  353. else:
  354. self.downsample = None
  355. self.pointing = False
  356. def forward(
  357. self,
  358. hidden_states: torch.Tensor,
  359. output_attentions: bool | None = False,
  360. ) -> tuple[torch.Tensor]:
  361. _, height, width, _ = hidden_states.size()
  362. for i, layer_module in enumerate(self.layers):
  363. layer_outputs = layer_module(hidden_states, output_attentions)
  364. hidden_states = layer_outputs[0]
  365. hidden_states_before_downsampling = hidden_states
  366. if self.downsample is not None:
  367. hidden_states = self.downsample(hidden_states_before_downsampling)
  368. stage_outputs = (hidden_states, hidden_states_before_downsampling)
  369. if output_attentions:
  370. stage_outputs += layer_outputs[1:]
  371. return stage_outputs
  372. class DinatEncoder(nn.Module):
  373. def __init__(self, config):
  374. super().__init__()
  375. self.num_levels = len(config.depths)
  376. self.config = config
  377. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  378. self.levels = nn.ModuleList(
  379. [
  380. DinatStage(
  381. config=config,
  382. dim=int(config.embed_dim * 2**i_layer),
  383. depth=config.depths[i_layer],
  384. num_heads=config.num_heads[i_layer],
  385. dilations=config.dilations[i_layer],
  386. drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
  387. downsample=DinatDownsampler if (i_layer < self.num_levels - 1) else None,
  388. )
  389. for i_layer in range(self.num_levels)
  390. ]
  391. )
  392. def forward(
  393. self,
  394. hidden_states: torch.Tensor,
  395. output_attentions: bool | None = False,
  396. output_hidden_states: bool | None = False,
  397. output_hidden_states_before_downsampling: bool | None = False,
  398. return_dict: bool | None = True,
  399. ) -> tuple | DinatEncoderOutput:
  400. all_hidden_states = () if output_hidden_states else None
  401. all_reshaped_hidden_states = () if output_hidden_states else None
  402. all_self_attentions = () if output_attentions else None
  403. if output_hidden_states:
  404. # rearrange b h w c -> b c h w
  405. reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
  406. all_hidden_states += (hidden_states,)
  407. all_reshaped_hidden_states += (reshaped_hidden_state,)
  408. for i, layer_module in enumerate(self.levels):
  409. layer_outputs = layer_module(hidden_states, output_attentions)
  410. hidden_states = layer_outputs[0]
  411. hidden_states_before_downsampling = layer_outputs[1]
  412. if output_hidden_states and output_hidden_states_before_downsampling:
  413. # rearrange b h w c -> b c h w
  414. reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
  415. all_hidden_states += (hidden_states_before_downsampling,)
  416. all_reshaped_hidden_states += (reshaped_hidden_state,)
  417. elif output_hidden_states and not output_hidden_states_before_downsampling:
  418. # rearrange b h w c -> b c h w
  419. reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
  420. all_hidden_states += (hidden_states,)
  421. all_reshaped_hidden_states += (reshaped_hidden_state,)
  422. if output_attentions:
  423. all_self_attentions += layer_outputs[2:]
  424. if not return_dict:
  425. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  426. return DinatEncoderOutput(
  427. last_hidden_state=hidden_states,
  428. hidden_states=all_hidden_states,
  429. attentions=all_self_attentions,
  430. reshaped_hidden_states=all_reshaped_hidden_states,
  431. )
  432. @auto_docstring
  433. class DinatPreTrainedModel(PreTrainedModel):
  434. config: DinatConfig
  435. base_model_prefix = "dinat"
  436. main_input_name = "pixel_values"
  437. input_modalities = ("image",)
  438. @auto_docstring
  439. class DinatModel(DinatPreTrainedModel):
  440. def __init__(self, config, add_pooling_layer=True):
  441. r"""
  442. add_pooling_layer (bool, *optional*, defaults to `True`):
  443. Whether to add a pooling layer
  444. """
  445. super().__init__(config)
  446. requires_backends(self, ["natten"])
  447. self.config = config
  448. self.num_levels = len(config.depths)
  449. self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1))
  450. self.embeddings = DinatEmbeddings(config)
  451. self.encoder = DinatEncoder(config)
  452. self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
  453. self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
  454. # Initialize weights and apply final processing
  455. self.post_init()
  456. def get_input_embeddings(self):
  457. return self.embeddings.patch_embeddings
  458. @auto_docstring
  459. def forward(
  460. self,
  461. pixel_values: torch.FloatTensor | None = None,
  462. output_attentions: bool | None = None,
  463. output_hidden_states: bool | None = None,
  464. return_dict: bool | None = None,
  465. **kwargs,
  466. ) -> tuple | DinatModelOutput:
  467. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  468. output_hidden_states = (
  469. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  470. )
  471. return_dict = return_dict if return_dict is not None else self.config.return_dict
  472. if pixel_values is None:
  473. raise ValueError("You have to specify pixel_values")
  474. embedding_output = self.embeddings(pixel_values)
  475. encoder_outputs = self.encoder(
  476. embedding_output,
  477. output_attentions=output_attentions,
  478. output_hidden_states=output_hidden_states,
  479. return_dict=return_dict,
  480. )
  481. sequence_output = encoder_outputs[0]
  482. sequence_output = self.layernorm(sequence_output)
  483. pooled_output = None
  484. if self.pooler is not None:
  485. pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2))
  486. pooled_output = torch.flatten(pooled_output, 1)
  487. if not return_dict:
  488. output = (sequence_output, pooled_output) + encoder_outputs[1:]
  489. return output
  490. return DinatModelOutput(
  491. last_hidden_state=sequence_output,
  492. pooler_output=pooled_output,
  493. hidden_states=encoder_outputs.hidden_states,
  494. attentions=encoder_outputs.attentions,
  495. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  496. )
  497. @auto_docstring(
  498. custom_intro="""
  499. Dinat Model transformer with an image classification head on top (a linear layer on top of the final hidden state
  500. of the [CLS] token) e.g. for ImageNet.
  501. """
  502. )
  503. class DinatForImageClassification(DinatPreTrainedModel):
  504. def __init__(self, config):
  505. super().__init__(config)
  506. requires_backends(self, ["natten"])
  507. self.num_labels = config.num_labels
  508. self.dinat = DinatModel(config)
  509. # Classifier head
  510. self.classifier = (
  511. nn.Linear(self.dinat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
  512. )
  513. # Initialize weights and apply final processing
  514. self.post_init()
  515. @auto_docstring
  516. def forward(
  517. self,
  518. pixel_values: torch.FloatTensor | None = None,
  519. labels: torch.LongTensor | None = None,
  520. output_attentions: bool | None = None,
  521. output_hidden_states: bool | None = None,
  522. return_dict: bool | None = None,
  523. **kwargs,
  524. ) -> tuple | DinatImageClassifierOutput:
  525. r"""
  526. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  527. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  528. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  529. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  530. """
  531. return_dict = return_dict if return_dict is not None else self.config.return_dict
  532. outputs = self.dinat(
  533. pixel_values,
  534. output_attentions=output_attentions,
  535. output_hidden_states=output_hidden_states,
  536. return_dict=return_dict,
  537. )
  538. pooled_output = outputs[1]
  539. logits = self.classifier(pooled_output)
  540. loss = None
  541. if labels is not None:
  542. loss = self.loss_function(labels, logits, self.config)
  543. if not return_dict:
  544. output = (logits,) + outputs[2:]
  545. return ((loss,) + output) if loss is not None else output
  546. return DinatImageClassifierOutput(
  547. loss=loss,
  548. logits=logits,
  549. hidden_states=outputs.hidden_states,
  550. attentions=outputs.attentions,
  551. reshaped_hidden_states=outputs.reshaped_hidden_states,
  552. )
  553. @auto_docstring(
  554. custom_intro="""
  555. NAT backbone, to be used with frameworks like DETR and MaskFormer.
  556. """
  557. )
  558. class DinatBackbone(BackboneMixin, DinatPreTrainedModel):
  559. def __init__(self, config):
  560. super().__init__(config)
  561. requires_backends(self, ["natten"])
  562. self.embeddings = DinatEmbeddings(config)
  563. self.encoder = DinatEncoder(config)
  564. self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
  565. # Add layer norms to hidden states of out_features
  566. hidden_states_norms = {}
  567. for stage, num_channels in zip(self.out_features, self.channels):
  568. hidden_states_norms[stage] = nn.LayerNorm(num_channels)
  569. self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
  570. # Initialize weights and apply final processing
  571. self.post_init()
  572. def get_input_embeddings(self):
  573. return self.embeddings.patch_embeddings
  574. @can_return_tuple
  575. @filter_output_hidden_states
  576. @auto_docstring
  577. def forward(
  578. self,
  579. pixel_values: torch.Tensor,
  580. output_hidden_states: bool | None = None,
  581. output_attentions: bool | None = None,
  582. return_dict: bool | None = None,
  583. **kwargs,
  584. ) -> BackboneOutput:
  585. r"""
  586. Examples:
  587. ```python
  588. >>> from transformers import AutoImageProcessor, AutoBackbone
  589. >>> import torch
  590. >>> from PIL import Image
  591. >>> import httpx
  592. >>> from io import BytesIO
  593. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  594. >>> with httpx.stream("GET", url) as response:
  595. ... image = Image.open(BytesIO(response.read()))
  596. >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
  597. >>> model = AutoBackbone.from_pretrained(
  598. ... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
  599. ... )
  600. >>> inputs = processor(image, return_tensors="pt")
  601. >>> outputs = model(**inputs)
  602. >>> feature_maps = outputs.feature_maps
  603. >>> list(feature_maps[-1].shape)
  604. [1, 512, 7, 7]
  605. ```"""
  606. return_dict = return_dict if return_dict is not None else self.config.return_dict
  607. output_hidden_states = (
  608. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  609. )
  610. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  611. embedding_output = self.embeddings(pixel_values)
  612. outputs = self.encoder(
  613. embedding_output,
  614. output_attentions=output_attentions,
  615. output_hidden_states=True,
  616. output_hidden_states_before_downsampling=True,
  617. return_dict=True,
  618. )
  619. hidden_states = outputs.reshaped_hidden_states
  620. feature_maps = ()
  621. for stage, hidden_state in zip(self.stage_names, hidden_states):
  622. if stage in self.out_features:
  623. batch_size, num_channels, height, width = hidden_state.shape
  624. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  625. hidden_state = hidden_state.view(batch_size, height * width, num_channels)
  626. hidden_state = self.hidden_states_norms[stage](hidden_state)
  627. hidden_state = hidden_state.view(batch_size, height, width, num_channels)
  628. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  629. feature_maps += (hidden_state,)
  630. if not return_dict:
  631. output = (feature_maps,)
  632. if output_hidden_states:
  633. output += (outputs.hidden_states,)
  634. return output
  635. return BackboneOutput(
  636. feature_maps=feature_maps,
  637. hidden_states=outputs.hidden_states if output_hidden_states else None,
  638. attentions=outputs.attentions,
  639. )
  640. __all__ = ["DinatForImageClassification", "DinatModel", "DinatPreTrainedModel", "DinatBackbone"]