modeling_dpt.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185
  1. # Copyright 2022 Intel Labs, OpenMMLab and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch DPT (Dense Prediction Transformers) model.
  15. This implementation is heavily inspired by OpenMMLab's implementation, found here:
  16. https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/dpt_head.py.
  17. """
  18. import collections.abc
  19. from collections.abc import Callable
  20. from dataclasses import dataclass
  21. import torch
  22. from torch import nn
  23. from torch.nn import CrossEntropyLoss
  24. from ... import initialization as init
  25. from ...activations import ACT2FN
  26. from ...backbone_utils import load_backbone
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging, torch_int
  32. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  33. from ...utils.output_capturing import capture_outputs
  34. from .configuration_dpt import DPTConfig
  35. logger = logging.get_logger(__name__)
  36. @dataclass
  37. @auto_docstring(
  38. custom_intro="""
  39. Base class for model's outputs that also contains intermediate activations that can be used at later stages. Useful
  40. in the context of Vision models.:
  41. """
  42. )
  43. class BaseModelOutputWithIntermediateActivations(ModelOutput):
  44. r"""
  45. last_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  46. Sequence of hidden-states at the output of the last layer of the model.
  47. intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
  48. Intermediate activations that can be used to compute hidden states of the model at various layers.
  49. """
  50. last_hidden_states: torch.FloatTensor | None = None
  51. intermediate_activations: tuple[torch.FloatTensor, ...] | None = None
  52. @dataclass
  53. @auto_docstring(
  54. custom_intro="""
  55. Base class for model's outputs that also contains a pooling of the last hidden states as well as intermediate
  56. activations that can be used by the model at later stages.
  57. """
  58. )
  59. class BaseModelOutputWithPoolingAndIntermediateActivations(ModelOutput):
  60. r"""
  61. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  62. Last layer hidden-state of the first token of the sequence (classification token) after further processing
  63. through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
  64. the classification token after processing through a linear layer and a tanh activation function. The linear
  65. layer weights are trained from the next sentence prediction (classification) objective during pretraining.
  66. intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
  67. Intermediate activations that can be used to compute hidden states of the model at various layers.
  68. """
  69. last_hidden_state: torch.FloatTensor | None = None
  70. pooler_output: torch.FloatTensor | None = None
  71. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  72. attentions: tuple[torch.FloatTensor, ...] | None = None
  73. intermediate_activations: tuple[torch.FloatTensor, ...] | None = None
  74. class DPTViTHybridEmbeddings(nn.Module):
  75. """
  76. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  77. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  78. Transformer.
  79. """
  80. def __init__(self, config: DPTConfig, feature_size: tuple[int, int] | None = None):
  81. super().__init__()
  82. image_size, patch_size = config.image_size, config.patch_size
  83. num_channels, hidden_size = config.num_channels, config.hidden_size
  84. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  85. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  86. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  87. self.backbone = load_backbone(config)
  88. feature_dim = self.backbone.channels[-1]
  89. if len(self.backbone.channels) != 3:
  90. raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}")
  91. self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage
  92. if feature_size is None:
  93. feat_map_shape = config.backbone_featmap_shape
  94. feature_size = feat_map_shape[-2:]
  95. feature_dim = feat_map_shape[1]
  96. else:
  97. feature_size = (
  98. feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size)
  99. )
  100. feature_dim = self.backbone.channels[-1]
  101. self.image_size = image_size
  102. self.patch_size = patch_size[0]
  103. self.num_channels = num_channels
  104. self.projection = nn.Conv2d(feature_dim, hidden_size, kernel_size=1)
  105. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  106. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  107. def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1):
  108. posemb_tok = posemb[:, :start_index]
  109. posemb_grid = posemb[0, start_index:]
  110. old_grid_size = torch_int(len(posemb_grid) ** 0.5)
  111. posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
  112. posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
  113. posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1)
  114. posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
  115. return posemb
  116. def forward(
  117. self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
  118. ) -> BaseModelOutputWithIntermediateActivations:
  119. batch_size, num_channels, height, width = pixel_values.shape
  120. if num_channels != self.num_channels:
  121. raise ValueError(
  122. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  123. )
  124. if not interpolate_pos_encoding:
  125. if height != self.image_size[0] or width != self.image_size[1]:
  126. raise ValueError(
  127. f"Input image size ({height}*{width}) doesn't match model"
  128. f" ({self.image_size[0]}*{self.image_size[1]})."
  129. )
  130. position_embeddings = self._resize_pos_embed(
  131. self.position_embeddings, height // self.patch_size, width // self.patch_size
  132. )
  133. backbone_output = self.backbone(pixel_values)
  134. features = backbone_output.feature_maps[-1]
  135. # Retrieve also the intermediate activations to use them at later stages
  136. output_hidden_states = [backbone_output.feature_maps[index] for index in self.residual_feature_map_index]
  137. embeddings = self.projection(features).flatten(2).transpose(1, 2)
  138. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  139. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  140. # add positional encoding to each token
  141. embeddings = embeddings + position_embeddings
  142. # Return hidden states and intermediate activations
  143. return BaseModelOutputWithIntermediateActivations(
  144. last_hidden_states=embeddings,
  145. intermediate_activations=output_hidden_states,
  146. )
  147. class DPTViTEmbeddings(nn.Module):
  148. """
  149. Construct the CLS token, position and patch embeddings.
  150. """
  151. def __init__(self, config):
  152. super().__init__()
  153. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  154. self.patch_embeddings = DPTViTPatchEmbeddings(config)
  155. num_patches = self.patch_embeddings.num_patches
  156. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  157. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  158. self.config = config
  159. def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1):
  160. posemb_tok = posemb[:, :start_index]
  161. posemb_grid = posemb[0, start_index:]
  162. old_grid_size = torch_int(posemb_grid.size(0) ** 0.5)
  163. posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
  164. posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
  165. posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1)
  166. posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
  167. return posemb
  168. def forward(self, pixel_values: torch.Tensor) -> BaseModelOutputWithIntermediateActivations:
  169. batch_size, num_channels, height, width = pixel_values.shape
  170. # possibly interpolate position encodings to handle varying image sizes
  171. patch_size = self.config.patch_size
  172. position_embeddings = self._resize_pos_embed(
  173. self.position_embeddings, height // patch_size, width // patch_size
  174. )
  175. embeddings = self.patch_embeddings(pixel_values)
  176. batch_size, seq_len, _ = embeddings.size()
  177. # add the [CLS] token to the embedded patch tokens
  178. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  179. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  180. # add positional encoding to each token
  181. embeddings = embeddings + position_embeddings
  182. embeddings = self.dropout(embeddings)
  183. return BaseModelOutputWithIntermediateActivations(last_hidden_states=embeddings)
  184. class DPTViTPatchEmbeddings(nn.Module):
  185. """
  186. Image to Patch Embedding.
  187. """
  188. def __init__(self, config: DPTConfig):
  189. super().__init__()
  190. image_size, patch_size = config.image_size, config.patch_size
  191. num_channels, hidden_size = config.num_channels, config.hidden_size
  192. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  193. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  194. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  195. self.image_size = image_size
  196. self.patch_size = patch_size
  197. self.num_channels = num_channels
  198. self.num_patches = num_patches
  199. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  200. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  201. batch_size, num_channels, height, width = pixel_values.shape
  202. if num_channels != self.num_channels:
  203. raise ValueError(
  204. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  205. )
  206. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  207. return embeddings
  208. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  209. def eager_attention_forward(
  210. module: nn.Module,
  211. query: torch.Tensor,
  212. key: torch.Tensor,
  213. value: torch.Tensor,
  214. attention_mask: torch.Tensor | None,
  215. scaling: float | None = None,
  216. dropout: float = 0.0,
  217. **kwargs: Unpack[TransformersKwargs],
  218. ):
  219. if scaling is None:
  220. scaling = query.size(-1) ** -0.5
  221. # Take the dot product between "query" and "key" to get the raw attention scores.
  222. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  223. if attention_mask is not None:
  224. attn_weights = attn_weights + attention_mask
  225. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  226. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  227. attn_output = torch.matmul(attn_weights, value)
  228. attn_output = attn_output.transpose(1, 2).contiguous()
  229. return attn_output, attn_weights
  230. # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DPT
  231. class DPTSelfAttention(nn.Module):
  232. def __init__(self, config: DPTConfig):
  233. super().__init__()
  234. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  235. raise ValueError(
  236. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  237. f"heads {config.num_attention_heads}."
  238. )
  239. self.config = config
  240. self.num_attention_heads = config.num_attention_heads
  241. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  242. self.all_head_size = self.num_attention_heads * self.attention_head_size
  243. self.dropout_prob = config.attention_probs_dropout_prob
  244. self.scaling = self.attention_head_size**-0.5
  245. self.is_causal = False
  246. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  247. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  248. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  249. def forward(
  250. self,
  251. hidden_states: torch.Tensor,
  252. **kwargs: Unpack[TransformersKwargs],
  253. ) -> tuple[torch.Tensor, torch.Tensor]:
  254. batch_size = hidden_states.shape[0]
  255. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  256. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  257. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  258. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  259. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  260. self.config._attn_implementation, eager_attention_forward
  261. )
  262. context_layer, attention_probs = attention_interface(
  263. self,
  264. query_layer,
  265. key_layer,
  266. value_layer,
  267. None,
  268. is_causal=self.is_causal,
  269. scaling=self.scaling,
  270. dropout=0.0 if not self.training else self.dropout_prob,
  271. **kwargs,
  272. )
  273. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  274. context_layer = context_layer.reshape(new_context_layer_shape)
  275. return context_layer, attention_probs
  276. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViTConfig->DPTConfig, ViTSelfOutput->DPTViTSelfOutput
  277. class DPTViTSelfOutput(nn.Module):
  278. """
  279. The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
  280. layernorm applied before each block.
  281. """
  282. def __init__(self, config: DPTConfig):
  283. super().__init__()
  284. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  285. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  286. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  287. hidden_states = self.dense(hidden_states)
  288. hidden_states = self.dropout(hidden_states)
  289. return hidden_states
  290. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViTConfig->DPTConfig, ViTSelfAttention->DPTSelfAttention, ViTSelfOutput->DPTViTSelfOutput
  291. class DPTViTAttention(nn.Module):
  292. def __init__(self, config: DPTConfig):
  293. super().__init__()
  294. self.attention = DPTSelfAttention(config)
  295. self.output = DPTViTSelfOutput(config)
  296. def forward(
  297. self,
  298. hidden_states: torch.Tensor,
  299. **kwargs: Unpack[TransformersKwargs],
  300. ) -> torch.Tensor:
  301. self_attn_output, _ = self.attention(hidden_states, **kwargs)
  302. output = self.output(self_attn_output, hidden_states)
  303. return output
  304. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViTConfig->DPTConfig, ViTIntermediate->DPTViTIntermediate
  305. class DPTViTIntermediate(nn.Module):
  306. def __init__(self, config: DPTConfig):
  307. super().__init__()
  308. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  309. if isinstance(config.hidden_act, str):
  310. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  311. else:
  312. self.intermediate_act_fn = config.hidden_act
  313. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  314. hidden_states = self.dense(hidden_states)
  315. hidden_states = self.intermediate_act_fn(hidden_states)
  316. return hidden_states
  317. # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViTConfig->DPTConfig, ViTOutput->DPTViTOutput
  318. class DPTViTOutput(nn.Module):
  319. def __init__(self, config: DPTConfig):
  320. super().__init__()
  321. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  322. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  323. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  324. hidden_states = self.dense(hidden_states)
  325. hidden_states = self.dropout(hidden_states)
  326. hidden_states = hidden_states + input_tensor
  327. return hidden_states
  328. # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput, ViTLayer->DPTViTLayer
  329. class DPTViTLayer(GradientCheckpointingLayer):
  330. """This corresponds to the Block class in the timm implementation."""
  331. def __init__(self, config: DPTConfig):
  332. super().__init__()
  333. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  334. self.seq_len_dim = 1
  335. self.attention = DPTViTAttention(config)
  336. self.intermediate = DPTViTIntermediate(config)
  337. self.output = DPTViTOutput(config)
  338. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  339. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  340. def forward(
  341. self,
  342. hidden_states: torch.Tensor,
  343. **kwargs: Unpack[TransformersKwargs],
  344. ) -> torch.Tensor:
  345. hidden_states_norm = self.layernorm_before(hidden_states)
  346. attention_output = self.attention(hidden_states_norm, **kwargs)
  347. # first residual connection
  348. hidden_states = attention_output + hidden_states
  349. # in ViT, layernorm is also applied after self-attention
  350. layer_output = self.layernorm_after(hidden_states)
  351. layer_output = self.intermediate(layer_output)
  352. # second residual connection is done here
  353. layer_output = self.output(layer_output, hidden_states)
  354. return layer_output
  355. class DPTReassembleStage(nn.Module):
  356. """
  357. This class reassembles the hidden states of the backbone into image-like feature representations at various
  358. resolutions.
  359. This happens in 3 stages:
  360. 1. Map the N + 1 tokens to a set of N tokens, by taking into account the readout ([CLS]) token according to
  361. `config.readout_type`.
  362. 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
  363. 3. Resizing the spatial dimensions (height, width).
  364. Args:
  365. config (`[DPTConfig]`):
  366. Model configuration class defining the model architecture.
  367. """
  368. def __init__(self, config):
  369. super().__init__()
  370. self.config = config
  371. self.layers = nn.ModuleList()
  372. if config.is_hybrid:
  373. self._init_reassemble_dpt_hybrid(config)
  374. else:
  375. self._init_reassemble_dpt(config)
  376. self.neck_ignore_stages = config.neck_ignore_stages
  377. def _init_reassemble_dpt_hybrid(self, config):
  378. r""" "
  379. For DPT-Hybrid the first 2 reassemble layers are set to `nn.Identity()`, please check the official
  380. implementation: https://github.com/isl-org/DPT/blob/f43ef9e08d70a752195028a51be5e1aff227b913/dpt/vit.py#L438
  381. for more details.
  382. """
  383. for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
  384. if i <= 1:
  385. self.layers.append(nn.Identity())
  386. elif i > 1:
  387. self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
  388. if config.readout_type != "project":
  389. raise ValueError(f"Readout type {config.readout_type} is not supported for DPT-Hybrid.")
  390. # When using DPT-Hybrid the readout type is set to "project". The sanity check is done on the config file
  391. self.readout_projects = nn.ModuleList()
  392. hidden_size = _get_backbone_hidden_size(config)
  393. for i in range(len(config.neck_hidden_sizes)):
  394. if i <= 1:
  395. self.readout_projects.append(nn.Sequential(nn.Identity()))
  396. elif i > 1:
  397. self.readout_projects.append(
  398. nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
  399. )
  400. def _init_reassemble_dpt(self, config):
  401. for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
  402. self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
  403. if config.readout_type == "project":
  404. self.readout_projects = nn.ModuleList()
  405. hidden_size = _get_backbone_hidden_size(config)
  406. for _ in range(len(config.neck_hidden_sizes)):
  407. self.readout_projects.append(
  408. nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
  409. )
  410. def forward(self, hidden_states: list[torch.Tensor], patch_height=None, patch_width=None) -> list[torch.Tensor]:
  411. """
  412. Args:
  413. hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
  414. List of hidden states from the backbone.
  415. """
  416. out = []
  417. for i, hidden_state in enumerate(hidden_states):
  418. if i not in self.neck_ignore_stages:
  419. # reshape to (batch_size, num_channels, height, width)
  420. cls_token, hidden_state = hidden_state[:, 0], hidden_state[:, 1:]
  421. batch_size, sequence_length, num_channels = hidden_state.shape
  422. if patch_height is not None and patch_width is not None:
  423. hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
  424. else:
  425. size = torch_int(sequence_length**0.5)
  426. hidden_state = hidden_state.reshape(batch_size, size, size, num_channels)
  427. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  428. feature_shape = hidden_state.shape
  429. if self.config.readout_type == "project":
  430. # reshape to (batch_size, height*width, num_channels)
  431. hidden_state = hidden_state.flatten(2).permute((0, 2, 1))
  432. readout = cls_token.unsqueeze(1).expand_as(hidden_state)
  433. # concatenate the readout token to the hidden states and project
  434. hidden_state = self.readout_projects[i](torch.cat((hidden_state, readout), -1))
  435. # reshape back to (batch_size, num_channels, height, width)
  436. hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
  437. elif self.config.readout_type == "add":
  438. hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
  439. hidden_state = hidden_state.reshape(feature_shape)
  440. hidden_state = self.layers[i](hidden_state)
  441. out.append(hidden_state)
  442. return out
  443. def _get_backbone_hidden_size(config):
  444. if config.backbone_config is not None and hasattr(config.backbone_config, "hidden_size"):
  445. return config.backbone_config.hidden_size
  446. else:
  447. return config.hidden_size
  448. class DPTReassembleLayer(nn.Module):
  449. def __init__(self, config: DPTConfig, channels: int, factor: int):
  450. super().__init__()
  451. # projection
  452. hidden_size = _get_backbone_hidden_size(config)
  453. self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1)
  454. # up/down sampling depending on factor
  455. if factor > 1:
  456. self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
  457. elif factor == 1:
  458. self.resize = nn.Identity()
  459. elif factor < 1:
  460. # so should downsample
  461. self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
  462. def forward(self, hidden_state):
  463. hidden_state = self.projection(hidden_state)
  464. hidden_state = self.resize(hidden_state)
  465. return hidden_state
  466. class DPTFeatureFusionStage(nn.Module):
  467. def __init__(self, config: DPTConfig):
  468. super().__init__()
  469. self.layers = nn.ModuleList()
  470. for _ in range(len(config.neck_hidden_sizes)):
  471. self.layers.append(DPTFeatureFusionLayer(config))
  472. def forward(self, hidden_states):
  473. # reversing the hidden_states, we start from the last
  474. hidden_states = hidden_states[::-1]
  475. fused_hidden_states = []
  476. fused_hidden_state = None
  477. for hidden_state, layer in zip(hidden_states, self.layers):
  478. if fused_hidden_state is None:
  479. # first layer only uses the last hidden_state
  480. fused_hidden_state = layer(hidden_state)
  481. else:
  482. fused_hidden_state = layer(fused_hidden_state, hidden_state)
  483. fused_hidden_states.append(fused_hidden_state)
  484. return fused_hidden_states
  485. class DPTPreActResidualLayer(nn.Module):
  486. """
  487. ResidualConvUnit, pre-activate residual unit.
  488. Args:
  489. config (`[DPTConfig]`):
  490. Model configuration class defining the model architecture.
  491. """
  492. def __init__(self, config: DPTConfig):
  493. super().__init__()
  494. self.use_batch_norm = config.use_batch_norm_in_fusion_residual
  495. use_bias_in_fusion_residual = (
  496. config.use_bias_in_fusion_residual
  497. if config.use_bias_in_fusion_residual is not None
  498. else not self.use_batch_norm
  499. )
  500. self.activation1 = nn.ReLU()
  501. self.convolution1 = nn.Conv2d(
  502. config.fusion_hidden_size,
  503. config.fusion_hidden_size,
  504. kernel_size=3,
  505. stride=1,
  506. padding=1,
  507. bias=use_bias_in_fusion_residual,
  508. )
  509. self.activation2 = nn.ReLU()
  510. self.convolution2 = nn.Conv2d(
  511. config.fusion_hidden_size,
  512. config.fusion_hidden_size,
  513. kernel_size=3,
  514. stride=1,
  515. padding=1,
  516. bias=use_bias_in_fusion_residual,
  517. )
  518. if self.use_batch_norm:
  519. self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
  520. self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
  521. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  522. residual = hidden_state
  523. hidden_state = self.activation1(hidden_state)
  524. hidden_state = self.convolution1(hidden_state)
  525. if self.use_batch_norm:
  526. hidden_state = self.batch_norm1(hidden_state)
  527. hidden_state = self.activation2(hidden_state)
  528. hidden_state = self.convolution2(hidden_state)
  529. if self.use_batch_norm:
  530. hidden_state = self.batch_norm2(hidden_state)
  531. return hidden_state + residual
  532. class DPTFeatureFusionLayer(nn.Module):
  533. """Feature fusion layer, merges feature maps from different stages.
  534. Args:
  535. config (`[DPTConfig]`):
  536. Model configuration class defining the model architecture.
  537. align_corners (`bool`, *optional*, defaults to `True`):
  538. The align_corner setting for bilinear upsample.
  539. """
  540. def __init__(self, config: DPTConfig, align_corners: bool = True):
  541. super().__init__()
  542. self.align_corners = align_corners
  543. self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
  544. self.residual_layer1 = DPTPreActResidualLayer(config)
  545. self.residual_layer2 = DPTPreActResidualLayer(config)
  546. def forward(self, hidden_state: torch.Tensor, residual: torch.Tensor | None = None) -> torch.Tensor:
  547. if residual is not None:
  548. if hidden_state.shape != residual.shape:
  549. residual = nn.functional.interpolate(
  550. residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
  551. )
  552. hidden_state = hidden_state + self.residual_layer1(residual)
  553. hidden_state = self.residual_layer2(hidden_state)
  554. hidden_state = nn.functional.interpolate(
  555. hidden_state, scale_factor=2, mode="bilinear", align_corners=self.align_corners
  556. )
  557. hidden_state = self.projection(hidden_state)
  558. return hidden_state
  559. @auto_docstring
  560. class DPTPreTrainedModel(PreTrainedModel):
  561. config: DPTConfig
  562. base_model_prefix = "dpt"
  563. main_input_name = "pixel_values"
  564. input_modalities = ("image",)
  565. supports_gradient_checkpointing = True
  566. _supports_sdpa = True
  567. _supports_flash_attn = True
  568. _supports_flex_attn = True
  569. _supports_attention_backend = True
  570. _can_record_outputs = {
  571. "hidden_states": DPTViTLayer,
  572. "attentions": DPTSelfAttention,
  573. }
  574. @torch.no_grad()
  575. def _init_weights(self, module):
  576. """Initialize the weights"""
  577. super()._init_weights(module)
  578. if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)):
  579. init.zeros_(module.cls_token)
  580. init.zeros_(module.position_embeddings)
  581. class DPTViTEncoder(nn.Module):
  582. def __init__(self, config: DPTConfig):
  583. super().__init__()
  584. self.config = config
  585. self.layer = nn.ModuleList([DPTViTLayer(config) for _ in range(config.num_hidden_layers)])
  586. def forward(
  587. self, hidden_states: torch.Tensor, output_hidden_states: bool = False, **kwargs: Unpack[TransformersKwargs]
  588. ) -> BaseModelOutput:
  589. for layer_module in self.layer:
  590. hidden_states = layer_module(hidden_states)
  591. return BaseModelOutput(last_hidden_state=hidden_states)
  592. @auto_docstring
  593. class DPTModel(DPTPreTrainedModel):
  594. def __init__(self, config: DPTConfig, add_pooling_layer: bool = True):
  595. r"""
  596. add_pooling_layer (bool, *optional*, defaults to `True`):
  597. Whether to add a pooling layer
  598. """
  599. super().__init__(config)
  600. self.config = config
  601. # vit encoder
  602. if config.is_hybrid:
  603. self.embeddings = DPTViTHybridEmbeddings(config)
  604. else:
  605. self.embeddings = DPTViTEmbeddings(config)
  606. self.encoder = DPTViTEncoder(config)
  607. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  608. self.pooler = DPTViTPooler(config) if add_pooling_layer else None
  609. # Initialize weights and apply final processing
  610. self.post_init()
  611. def get_input_embeddings(self):
  612. if self.config.is_hybrid:
  613. return self.embeddings
  614. else:
  615. return self.embeddings.patch_embeddings
  616. @merge_with_config_defaults
  617. @capture_outputs(tie_last_hidden_states=False)
  618. @auto_docstring
  619. def forward(
  620. self,
  621. pixel_values: torch.FloatTensor,
  622. **kwargs: Unpack[TransformersKwargs],
  623. ) -> BaseModelOutputWithPoolingAndIntermediateActivations:
  624. embedding_output: BaseModelOutputWithIntermediateActivations = self.embeddings(pixel_values)
  625. embedding_last_hidden_states = embedding_output.last_hidden_states
  626. encoder_outputs: BaseModelOutput = self.encoder(embedding_last_hidden_states, **kwargs)
  627. sequence_output = encoder_outputs.last_hidden_state
  628. sequence_output = self.layernorm(sequence_output)
  629. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  630. return BaseModelOutputWithPoolingAndIntermediateActivations(
  631. last_hidden_state=sequence_output,
  632. pooler_output=pooled_output,
  633. intermediate_activations=embedding_output.intermediate_activations,
  634. )
  635. # Copied from transformers.models.vit.modeling_vit.ViTPooler with ViTConfig->DPTConfig, ViTPooler->DPTViTPooler
  636. class DPTViTPooler(nn.Module):
  637. def __init__(self, config: DPTConfig):
  638. super().__init__()
  639. self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
  640. self.activation = ACT2FN[config.pooler_act]
  641. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  642. # We "pool" the model by simply taking the hidden state corresponding
  643. # to the first token.
  644. first_token_tensor = hidden_states[:, 0]
  645. pooled_output = self.dense(first_token_tensor)
  646. pooled_output = self.activation(pooled_output)
  647. return pooled_output
  648. class DPTNeck(nn.Module):
  649. """
  650. DPTNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
  651. input and produces another list of tensors as output. For DPT, it includes 2 stages:
  652. * DPTReassembleStage
  653. * DPTFeatureFusionStage.
  654. Args:
  655. config (dict): config dict.
  656. """
  657. def __init__(self, config: DPTConfig):
  658. super().__init__()
  659. self.config = config
  660. # postprocessing: only required in case of a non-hierarchical backbone (e.g. ViT, BEiT)
  661. if config.backbone_config is not None and config.backbone_config.model_type == "swinv2":
  662. self.reassemble_stage = None
  663. else:
  664. self.reassemble_stage = DPTReassembleStage(config)
  665. self.convs = nn.ModuleList()
  666. for channel in config.neck_hidden_sizes:
  667. self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
  668. # fusion
  669. self.fusion_stage = DPTFeatureFusionStage(config)
  670. def forward(
  671. self,
  672. hidden_states: list[torch.Tensor],
  673. patch_height: int | None = None,
  674. patch_width: int | None = None,
  675. ) -> list[torch.Tensor]:
  676. """
  677. Args:
  678. hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
  679. List of hidden states from the backbone.
  680. """
  681. if not isinstance(hidden_states, (tuple, list)):
  682. raise TypeError("hidden_states should be a tuple or list of tensors")
  683. if len(hidden_states) != len(self.config.neck_hidden_sizes):
  684. raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
  685. # postprocess hidden states
  686. if self.reassemble_stage is not None:
  687. hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
  688. features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
  689. # fusion blocks
  690. output = self.fusion_stage(features)
  691. return output
  692. class DPTDepthEstimationHead(nn.Module):
  693. """
  694. Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
  695. the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
  696. supplementary material).
  697. """
  698. def __init__(self, config: DPTConfig):
  699. super().__init__()
  700. self.config = config
  701. self.projection = None
  702. if config.add_projection:
  703. self.projection = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  704. features = config.fusion_hidden_size
  705. self.head = nn.Sequential(
  706. nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
  707. nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
  708. nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
  709. nn.ReLU(),
  710. nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
  711. nn.ReLU(),
  712. )
  713. def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
  714. # use last features
  715. hidden_states = hidden_states[self.config.head_in_index]
  716. if self.projection is not None:
  717. hidden_states = self.projection(hidden_states)
  718. hidden_states = nn.ReLU()(hidden_states)
  719. predicted_depth = self.head(hidden_states)
  720. predicted_depth = predicted_depth.squeeze(dim=1)
  721. return predicted_depth
  722. @auto_docstring(
  723. custom_intro="""
  724. DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
  725. """
  726. )
  727. class DPTForDepthEstimation(DPTPreTrainedModel):
  728. def __init__(self, config):
  729. super().__init__(config)
  730. self.backbone = None
  731. if config.is_hybrid is False and config.backbone_config is not None:
  732. self.backbone = load_backbone(config)
  733. else:
  734. self.dpt = DPTModel(config, add_pooling_layer=False)
  735. # Neck
  736. self.neck = DPTNeck(config)
  737. # Depth estimation head
  738. self.head = DPTDepthEstimationHead(config)
  739. # Initialize weights and apply final processing
  740. self.post_init()
  741. @can_return_tuple
  742. @auto_docstring
  743. def forward(
  744. self,
  745. pixel_values: torch.FloatTensor,
  746. labels: torch.LongTensor | None = None,
  747. **kwargs: Unpack[TransformersKwargs],
  748. ) -> DepthEstimatorOutput:
  749. r"""
  750. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  751. Ground truth depth estimation maps for computing the loss.
  752. Examples:
  753. ```python
  754. >>> from transformers import AutoImageProcessor, DPTForDepthEstimation
  755. >>> import torch
  756. >>> import numpy as np
  757. >>> from PIL import Image
  758. >>> import httpx
  759. >>> from io import BytesIO
  760. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  761. >>> with httpx.stream("GET", url) as response:
  762. ... image = Image.open(BytesIO(response.read()))
  763. >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
  764. >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
  765. >>> # prepare image for the model
  766. >>> inputs = image_processor(images=image, return_tensors="pt")
  767. >>> with torch.no_grad():
  768. ... outputs = model(**inputs)
  769. >>> # interpolate to original size
  770. >>> post_processed_output = image_processor.post_process_depth_estimation(
  771. ... outputs,
  772. ... target_sizes=[(image.height, image.width)],
  773. ... )
  774. >>> # visualize the prediction
  775. >>> predicted_depth = post_processed_output[0]["predicted_depth"]
  776. >>> depth = predicted_depth * 255 / predicted_depth.max()
  777. >>> depth = depth.detach().cpu().numpy()
  778. >>> depth = Image.fromarray(depth.astype("uint8"))
  779. ```"""
  780. loss = None
  781. if labels is not None:
  782. raise NotImplementedError("Training is not implemented yet")
  783. # Internally the model always needs to output hidden states, we control the output
  784. # per user request on the final output
  785. user_requested_hidden_states = kwargs.get("output_hidden_states") or getattr(
  786. self.config, "output_hidden_states", False
  787. )
  788. kwargs["output_hidden_states"] = True
  789. if self.backbone is not None:
  790. outputs = self.backbone.forward_with_filtered_kwargs(pixel_values, **kwargs)
  791. hidden_states = outputs.feature_maps
  792. else:
  793. outputs = self.dpt(pixel_values, **kwargs)
  794. hidden_states = outputs.hidden_states
  795. # only keep certain features based on config.backbone_out_indices
  796. # note that the hidden_states also include the initial embeddings
  797. if not self.config.is_hybrid:
  798. hidden_states = [
  799. feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
  800. ]
  801. else:
  802. backbone_hidden_states = outputs.intermediate_activations
  803. backbone_hidden_states.extend(
  804. feature
  805. for idx, feature in enumerate(hidden_states[1:])
  806. if idx in self.config.backbone_out_indices[2:]
  807. )
  808. hidden_states = backbone_hidden_states
  809. patch_height, patch_width = None, None
  810. if self.config.backbone_config is not None and self.config.is_hybrid is False:
  811. _, _, height, width = pixel_values.shape
  812. patch_size = self.config.backbone_config.patch_size
  813. patch_height = height // patch_size
  814. patch_width = width // patch_size
  815. hidden_states = self.neck(hidden_states, patch_height, patch_width)
  816. predicted_depth = self.head(hidden_states)
  817. return DepthEstimatorOutput(
  818. loss=loss,
  819. predicted_depth=predicted_depth,
  820. hidden_states=outputs.hidden_states if user_requested_hidden_states else None,
  821. attentions=outputs.attentions,
  822. )
  823. class DPTSemanticSegmentationHead(nn.Module):
  824. def __init__(self, config: DPTConfig):
  825. super().__init__()
  826. self.config = config
  827. features = config.fusion_hidden_size
  828. self.head = nn.Sequential(
  829. nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
  830. nn.BatchNorm2d(features),
  831. nn.ReLU(),
  832. nn.Dropout(config.semantic_classifier_dropout),
  833. nn.Conv2d(features, config.num_labels, kernel_size=1),
  834. nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
  835. )
  836. def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
  837. # use last features
  838. hidden_states = hidden_states[self.config.head_in_index]
  839. logits = self.head(hidden_states)
  840. return logits
  841. class DPTAuxiliaryHead(nn.Module):
  842. def __init__(self, config: DPTConfig):
  843. super().__init__()
  844. features = config.fusion_hidden_size
  845. self.head = nn.Sequential(
  846. nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
  847. nn.BatchNorm2d(features),
  848. nn.ReLU(),
  849. nn.Dropout(0.1, False),
  850. nn.Conv2d(features, config.num_labels, kernel_size=1),
  851. )
  852. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  853. logits = self.head(hidden_states)
  854. return logits
  855. @auto_docstring
  856. class DPTForSemanticSegmentation(DPTPreTrainedModel):
  857. def __init__(self, config: DPTConfig):
  858. super().__init__(config)
  859. self.dpt = DPTModel(config, add_pooling_layer=False)
  860. # Neck
  861. self.neck = DPTNeck(config)
  862. # Segmentation head(s)
  863. self.head = DPTSemanticSegmentationHead(config)
  864. self.auxiliary_head = DPTAuxiliaryHead(config) if config.use_auxiliary_head else None
  865. # Initialize weights and apply final processing
  866. self.post_init()
  867. @can_return_tuple
  868. @auto_docstring
  869. def forward(
  870. self,
  871. pixel_values: torch.FloatTensor | None = None,
  872. labels: torch.LongTensor | None = None,
  873. **kwargs: Unpack[TransformersKwargs],
  874. ) -> SemanticSegmenterOutput:
  875. r"""
  876. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  877. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  878. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  879. Examples:
  880. ```python
  881. >>> from transformers import AutoImageProcessor, DPTForSemanticSegmentation
  882. >>> from PIL import Image
  883. >>> import httpx
  884. >>> from io import BytesIO
  885. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  886. >>> with httpx.stream("GET", url) as response:
  887. ... image = Image.open(BytesIO(response.read()))
  888. >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large-ade")
  889. >>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade")
  890. >>> inputs = image_processor(images=image, return_tensors="pt")
  891. >>> outputs = model(**inputs)
  892. >>> logits = outputs.logits
  893. ```"""
  894. if labels is not None and self.config.num_labels == 1:
  895. raise ValueError("The number of labels should be greater than one")
  896. # Internally the model always needs to output hidden states, we control the output
  897. # per user request on the final output
  898. user_requested_hidden_states = kwargs.get("output_hidden_states") or getattr(
  899. self.config, "output_hidden_states", False
  900. )
  901. kwargs["output_hidden_states"] = True
  902. outputs: BaseModelOutputWithPoolingAndIntermediateActivations = self.dpt(pixel_values, **kwargs)
  903. hidden_states = outputs.hidden_states
  904. # only keep certain features based on config.backbone_out_indices
  905. # note that the hidden_states also include the initial embeddings
  906. if not self.config.is_hybrid:
  907. hidden_states = [
  908. feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
  909. ]
  910. else:
  911. backbone_hidden_states = outputs.intermediate_activations
  912. backbone_hidden_states.extend(
  913. feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices[2:]
  914. )
  915. hidden_states = backbone_hidden_states
  916. hidden_states = self.neck(hidden_states=hidden_states)
  917. logits = self.head(hidden_states)
  918. auxiliary_logits = None
  919. if self.auxiliary_head is not None:
  920. auxiliary_logits = self.auxiliary_head(hidden_states[-1])
  921. loss = None
  922. if labels is not None:
  923. # upsample logits to the images' original size
  924. upsampled_logits = nn.functional.interpolate(
  925. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  926. )
  927. if auxiliary_logits is not None:
  928. upsampled_auxiliary_logits = nn.functional.interpolate(
  929. auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  930. )
  931. # compute weighted loss
  932. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  933. main_loss = loss_fct(upsampled_logits, labels)
  934. auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
  935. loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
  936. return SemanticSegmenterOutput(
  937. loss=loss,
  938. logits=logits,
  939. hidden_states=outputs.hidden_states if user_requested_hidden_states else None,
  940. attentions=outputs.attentions,
  941. )
  942. __all__ = ["DPTForDepthEstimation", "DPTForSemanticSegmentation", "DPTModel", "DPTPreTrainedModel"]