modeling_glpn.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. # Copyright 2022 KAIST and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch GLPN model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from ...activations import ACT2FN
  19. from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput
  20. from ...modeling_utils import PreTrainedModel
  21. from ...utils import auto_docstring, logging
  22. from .configuration_glpn import GLPNConfig
  23. logger = logging.get_logger(__name__)
  24. # Copied from transformers.models.beit.modeling_beit.drop_path
  25. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  26. """
  27. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  28. """
  29. if drop_prob == 0.0 or not training:
  30. return input
  31. keep_prob = 1 - drop_prob
  32. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  33. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  34. random_tensor.floor_() # binarize
  35. output = input.div(keep_prob) * random_tensor
  36. return output
  37. # Copied from transformers.models.segformer.modeling_segformer.SegformerDropPath
  38. class GLPNDropPath(nn.Module):
  39. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  40. def __init__(self, drop_prob: float | None = None) -> None:
  41. super().__init__()
  42. self.drop_prob = drop_prob
  43. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  44. return drop_path(hidden_states, self.drop_prob, self.training)
  45. def extra_repr(self) -> str:
  46. return f"p={self.drop_prob}"
  47. # Copied from transformers.models.segformer.modeling_segformer.SegformerOverlapPatchEmbeddings
  48. class GLPNOverlapPatchEmbeddings(nn.Module):
  49. """Construct the overlapping patch embeddings."""
  50. def __init__(self, patch_size, stride, num_channels, hidden_size):
  51. super().__init__()
  52. self.proj = nn.Conv2d(
  53. num_channels,
  54. hidden_size,
  55. kernel_size=patch_size,
  56. stride=stride,
  57. padding=patch_size // 2,
  58. )
  59. self.layer_norm = nn.LayerNorm(hidden_size)
  60. def forward(self, pixel_values):
  61. embeddings = self.proj(pixel_values)
  62. _, _, height, width = embeddings.shape
  63. # (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels)
  64. # this can be fed to a Transformer layer
  65. embeddings = embeddings.flatten(2).transpose(1, 2)
  66. embeddings = self.layer_norm(embeddings)
  67. return embeddings, height, width
  68. # Copied from transformers.models.segformer.modeling_segformer.SegformerEfficientSelfAttention
  69. class GLPNEfficientSelfAttention(nn.Module):
  70. """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
  71. paper](https://huggingface.co/papers/2102.12122)."""
  72. def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
  73. super().__init__()
  74. self.hidden_size = hidden_size
  75. self.num_attention_heads = num_attention_heads
  76. if self.hidden_size % self.num_attention_heads != 0:
  77. raise ValueError(
  78. f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
  79. f"heads ({self.num_attention_heads})"
  80. )
  81. self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
  82. self.all_head_size = self.num_attention_heads * self.attention_head_size
  83. self.query = nn.Linear(self.hidden_size, self.all_head_size)
  84. self.key = nn.Linear(self.hidden_size, self.all_head_size)
  85. self.value = nn.Linear(self.hidden_size, self.all_head_size)
  86. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  87. self.sr_ratio = sequence_reduction_ratio
  88. if sequence_reduction_ratio > 1:
  89. self.sr = nn.Conv2d(
  90. hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
  91. )
  92. self.layer_norm = nn.LayerNorm(hidden_size)
  93. def forward(
  94. self,
  95. hidden_states,
  96. height,
  97. width,
  98. output_attentions=False,
  99. ):
  100. input_shape = hidden_states.shape[:-1]
  101. hidden_shape = (*input_shape, -1, self.attention_head_size)
  102. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  103. if self.sr_ratio > 1:
  104. batch_size, seq_len, num_channels = hidden_states.shape
  105. # Reshape to (batch_size, num_channels, height, width)
  106. hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
  107. # Apply sequence reduction
  108. hidden_states = self.sr(hidden_states)
  109. # Reshape back to (batch_size, seq_len, num_channels)
  110. hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1)
  111. hidden_states = self.layer_norm(hidden_states)
  112. kv_shape = (*hidden_states.shape[:-1], -1, self.attention_head_size)
  113. key_layer = self.key(hidden_states).view(kv_shape).transpose(1, 2)
  114. value_layer = self.value(hidden_states).view(kv_shape).transpose(1, 2)
  115. # Take the dot product between "query" and "key" to get the raw attention scores.
  116. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  117. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  118. # Normalize the attention scores to probabilities.
  119. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  120. # This is actually dropping out entire tokens to attend to, which might
  121. # seem a bit unusual, but is taken from the original Transformer paper.
  122. attention_probs = self.dropout(attention_probs)
  123. context_layer = torch.matmul(attention_probs, value_layer)
  124. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  125. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  126. context_layer = context_layer.view(new_context_layer_shape)
  127. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  128. return outputs
  129. # Copied from transformers.models.segformer.modeling_segformer.SegformerSelfOutput
  130. class GLPNSelfOutput(nn.Module):
  131. def __init__(self, config, hidden_size):
  132. super().__init__()
  133. self.dense = nn.Linear(hidden_size, hidden_size)
  134. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  135. def forward(self, hidden_states, input_tensor):
  136. hidden_states = self.dense(hidden_states)
  137. hidden_states = self.dropout(hidden_states)
  138. return hidden_states
  139. # Copied from transformers.models.segformer.modeling_segformer.SegformerAttention with Segformer->GLPN
  140. class GLPNAttention(nn.Module):
  141. def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
  142. super().__init__()
  143. self.self = GLPNEfficientSelfAttention(
  144. config=config,
  145. hidden_size=hidden_size,
  146. num_attention_heads=num_attention_heads,
  147. sequence_reduction_ratio=sequence_reduction_ratio,
  148. )
  149. self.output = GLPNSelfOutput(config, hidden_size=hidden_size)
  150. def forward(self, hidden_states, height, width, output_attentions=False):
  151. self_outputs = self.self(hidden_states, height, width, output_attentions)
  152. attention_output = self.output(self_outputs[0], hidden_states)
  153. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  154. return outputs
  155. # Copied from transformers.models.segformer.modeling_segformer.SegformerDWConv
  156. class GLPNDWConv(nn.Module):
  157. def __init__(self, dim=768):
  158. super().__init__()
  159. self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
  160. def forward(self, hidden_states, height, width):
  161. batch_size, seq_len, num_channels = hidden_states.shape
  162. hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width)
  163. hidden_states = self.dwconv(hidden_states)
  164. hidden_states = hidden_states.flatten(2).transpose(1, 2)
  165. return hidden_states
  166. # Copied from transformers.models.segformer.modeling_segformer.SegformerMixFFN with Segformer->GLPN
  167. class GLPNMixFFN(nn.Module):
  168. def __init__(self, config, in_features, hidden_features=None, out_features=None):
  169. super().__init__()
  170. out_features = out_features or in_features
  171. self.dense1 = nn.Linear(in_features, hidden_features)
  172. self.dwconv = GLPNDWConv(hidden_features)
  173. if isinstance(config.hidden_act, str):
  174. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  175. else:
  176. self.intermediate_act_fn = config.hidden_act
  177. self.dense2 = nn.Linear(hidden_features, out_features)
  178. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  179. def forward(self, hidden_states, height, width):
  180. hidden_states = self.dense1(hidden_states)
  181. hidden_states = self.dwconv(hidden_states, height, width)
  182. hidden_states = self.intermediate_act_fn(hidden_states)
  183. hidden_states = self.dropout(hidden_states)
  184. hidden_states = self.dense2(hidden_states)
  185. hidden_states = self.dropout(hidden_states)
  186. return hidden_states
  187. # Copied from transformers.models.segformer.modeling_segformer.SegformerLayer with Segformer->GLPN
  188. class GLPNLayer(nn.Module):
  189. """This corresponds to the Block class in the original implementation."""
  190. def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio):
  191. super().__init__()
  192. self.layer_norm_1 = nn.LayerNorm(hidden_size)
  193. self.attention = GLPNAttention(
  194. config,
  195. hidden_size=hidden_size,
  196. num_attention_heads=num_attention_heads,
  197. sequence_reduction_ratio=sequence_reduction_ratio,
  198. )
  199. self.drop_path = GLPNDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  200. self.layer_norm_2 = nn.LayerNorm(hidden_size)
  201. mlp_hidden_size = int(hidden_size * mlp_ratio)
  202. self.mlp = GLPNMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)
  203. def forward(self, hidden_states, height, width, output_attentions=False):
  204. self_attention_outputs = self.attention(
  205. self.layer_norm_1(hidden_states), # in GLPN, layernorm is applied before self-attention
  206. height,
  207. width,
  208. output_attentions=output_attentions,
  209. )
  210. attention_output = self_attention_outputs[0]
  211. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  212. # first residual connection (with stochastic depth)
  213. attention_output = self.drop_path(attention_output)
  214. hidden_states = attention_output + hidden_states
  215. mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
  216. # second residual connection (with stochastic depth)
  217. mlp_output = self.drop_path(mlp_output)
  218. layer_output = mlp_output + hidden_states
  219. outputs = (layer_output,) + outputs
  220. return outputs
  221. class GLPNEncoder(nn.Module):
  222. def __init__(self, config):
  223. super().__init__()
  224. self.config = config
  225. # stochastic depth decay rule
  226. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  227. # patch embeddings
  228. embeddings = []
  229. for i in range(config.num_encoder_blocks):
  230. embeddings.append(
  231. GLPNOverlapPatchEmbeddings(
  232. patch_size=config.patch_sizes[i],
  233. stride=config.strides[i],
  234. num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
  235. hidden_size=config.hidden_sizes[i],
  236. )
  237. )
  238. self.patch_embeddings = nn.ModuleList(embeddings)
  239. # Transformer blocks
  240. blocks = []
  241. cur = 0
  242. for i in range(config.num_encoder_blocks):
  243. # each block consists of layers
  244. layers = []
  245. if i != 0:
  246. cur += config.depths[i - 1]
  247. for j in range(config.depths[i]):
  248. layers.append(
  249. GLPNLayer(
  250. config,
  251. hidden_size=config.hidden_sizes[i],
  252. num_attention_heads=config.num_attention_heads[i],
  253. drop_path=dpr[cur + j],
  254. sequence_reduction_ratio=config.sr_ratios[i],
  255. mlp_ratio=config.mlp_ratios[i],
  256. )
  257. )
  258. blocks.append(nn.ModuleList(layers))
  259. self.block = nn.ModuleList(blocks)
  260. # Layer norms
  261. self.layer_norm = nn.ModuleList(
  262. [nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)]
  263. )
  264. def forward(
  265. self,
  266. pixel_values,
  267. output_attentions=False,
  268. output_hidden_states=False,
  269. return_dict=True,
  270. ):
  271. all_hidden_states = () if output_hidden_states else None
  272. all_self_attentions = () if output_attentions else None
  273. batch_size = pixel_values.shape[0]
  274. hidden_states = pixel_values
  275. for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)):
  276. embedding_layer, block_layer, norm_layer = x
  277. # first, obtain patch embeddings
  278. hidden_states, height, width = embedding_layer(hidden_states)
  279. # second, send embeddings through blocks
  280. for i, blk in enumerate(block_layer):
  281. layer_outputs = blk(hidden_states, height, width, output_attentions)
  282. hidden_states = layer_outputs[0]
  283. if output_attentions:
  284. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  285. # third, apply layer norm
  286. hidden_states = norm_layer(hidden_states)
  287. # fourth, optionally reshape back to (batch_size, num_channels, height, width)
  288. hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
  289. if output_hidden_states:
  290. all_hidden_states = all_hidden_states + (hidden_states,)
  291. if not return_dict:
  292. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  293. return BaseModelOutput(
  294. last_hidden_state=hidden_states,
  295. hidden_states=all_hidden_states,
  296. attentions=all_self_attentions,
  297. )
  298. @auto_docstring
  299. class GLPNPreTrainedModel(PreTrainedModel):
  300. config: GLPNConfig
  301. base_model_prefix = "glpn"
  302. main_input_name = "pixel_values"
  303. input_modalities = ("image",)
  304. _no_split_modules = []
  305. @auto_docstring
  306. class GLPNModel(GLPNPreTrainedModel):
  307. # Copied from transformers.models.segformer.modeling_segformer.SegformerModel.__init__ with Segformer->GLPN
  308. def __init__(self, config):
  309. super().__init__(config)
  310. self.config = config
  311. # hierarchical Transformer encoder
  312. self.encoder = GLPNEncoder(config)
  313. # Initialize weights and apply final processing
  314. self.post_init()
  315. @auto_docstring
  316. # Copied from transformers.models.segformer.modeling_segformer.SegformerModel.forward
  317. def forward(
  318. self,
  319. pixel_values: torch.FloatTensor,
  320. output_attentions: bool | None = None,
  321. output_hidden_states: bool | None = None,
  322. return_dict: bool | None = None,
  323. **kwargs,
  324. ) -> tuple | BaseModelOutput:
  325. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  326. output_hidden_states = (
  327. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  328. )
  329. return_dict = return_dict if return_dict is not None else self.config.return_dict
  330. encoder_outputs = self.encoder(
  331. pixel_values,
  332. output_attentions=output_attentions,
  333. output_hidden_states=output_hidden_states,
  334. return_dict=return_dict,
  335. )
  336. sequence_output = encoder_outputs[0]
  337. if not return_dict:
  338. return (sequence_output,) + encoder_outputs[1:]
  339. return BaseModelOutput(
  340. last_hidden_state=sequence_output,
  341. hidden_states=encoder_outputs.hidden_states,
  342. attentions=encoder_outputs.attentions,
  343. )
  344. class GLPNSelectiveFeatureFusion(nn.Module):
  345. """
  346. Selective Feature Fusion module, as explained in the [paper](https://huggingface.co/papers/2201.07436) (section 3.4). This
  347. module adaptively selects and integrates local and global features by attaining an attention map for each feature.
  348. """
  349. def __init__(self, in_channel=64):
  350. super().__init__()
  351. self.convolutional_layer1 = nn.Sequential(
  352. nn.Conv2d(in_channels=int(in_channel * 2), out_channels=in_channel, kernel_size=3, stride=1, padding=1),
  353. nn.BatchNorm2d(in_channel),
  354. nn.ReLU(),
  355. )
  356. self.convolutional_layer2 = nn.Sequential(
  357. nn.Conv2d(in_channels=in_channel, out_channels=int(in_channel / 2), kernel_size=3, stride=1, padding=1),
  358. nn.BatchNorm2d(int(in_channel / 2)),
  359. nn.ReLU(),
  360. )
  361. self.convolutional_layer3 = nn.Conv2d(
  362. in_channels=int(in_channel / 2), out_channels=2, kernel_size=3, stride=1, padding=1
  363. )
  364. self.sigmoid = nn.Sigmoid()
  365. def forward(self, local_features, global_features):
  366. # concatenate features along the channel dimension
  367. features = torch.cat((local_features, global_features), dim=1)
  368. # pass through convolutional layers
  369. features = self.convolutional_layer1(features)
  370. features = self.convolutional_layer2(features)
  371. features = self.convolutional_layer3(features)
  372. # apply sigmoid to get two-channel attention map
  373. attn = self.sigmoid(features)
  374. # construct hybrid features by adding element-wise
  375. hybrid_features = local_features * attn[:, 0, :, :].unsqueeze(1) + global_features * attn[
  376. :, 1, :, :
  377. ].unsqueeze(1)
  378. return hybrid_features
  379. class GLPNDecoderStage(nn.Module):
  380. def __init__(self, in_channels, out_channels):
  381. super().__init__()
  382. should_skip = in_channels == out_channels
  383. self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1) if not should_skip else nn.Identity()
  384. self.fusion = GLPNSelectiveFeatureFusion(out_channels)
  385. self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
  386. def forward(self, hidden_state, residual=None):
  387. hidden_state = self.convolution(hidden_state)
  388. if residual is not None:
  389. hidden_state = self.fusion(hidden_state, residual)
  390. hidden_state = self.upsample(hidden_state)
  391. return hidden_state
  392. hidden_state = self.upsample(hidden_state)
  393. return hidden_state
  394. class GLPNDecoder(nn.Module):
  395. def __init__(self, config):
  396. super().__init__()
  397. # we use features from end -> start
  398. reserved_hidden_sizes = config.hidden_sizes[::-1]
  399. out_channels = config.decoder_hidden_size
  400. self.stages = nn.ModuleList(
  401. [GLPNDecoderStage(hidden_size, out_channels) for hidden_size in reserved_hidden_sizes]
  402. )
  403. # don't fuse in first stage
  404. self.stages[0].fusion = None
  405. self.final_upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
  406. def forward(self, hidden_states: list[torch.Tensor]) -> list[torch.Tensor]:
  407. stage_hidden_states = []
  408. stage_hidden_state = None
  409. for hidden_state, stage in zip(hidden_states[::-1], self.stages):
  410. stage_hidden_state = stage(hidden_state, stage_hidden_state)
  411. stage_hidden_states.append(stage_hidden_state)
  412. stage_hidden_states[-1] = self.final_upsample(stage_hidden_state)
  413. return stage_hidden_states
  414. class SiLogLoss(nn.Module):
  415. r"""
  416. Implements the Scale-invariant log scale loss [Eigen et al., 2014](https://huggingface.co/papers/1406.2283).
  417. $$L=\frac{1}{n} \sum_{i} d_{i}^{2}-\frac{1}{2 n^{2}}\left(\sum_{i} d_{i}^{2}\right)$$ where $d_{i}=\log y_{i}-\log
  418. y_{i}^{*}$.
  419. """
  420. def __init__(self, lambd=0.5):
  421. super().__init__()
  422. self.lambd = lambd
  423. def forward(self, pred, target):
  424. valid_mask = (target > 0).detach()
  425. diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
  426. loss = torch.sqrt(torch.pow(diff_log, 2).mean() - self.lambd * torch.pow(diff_log.mean(), 2))
  427. return loss
  428. class GLPNDepthEstimationHead(nn.Module):
  429. def __init__(self, config):
  430. super().__init__()
  431. self.config = config
  432. channels = config.decoder_hidden_size
  433. self.head = nn.Sequential(
  434. nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
  435. nn.ReLU(inplace=False),
  436. nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1),
  437. )
  438. def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
  439. # use last features of the decoder
  440. hidden_states = hidden_states[self.config.head_in_index]
  441. hidden_states = self.head(hidden_states)
  442. predicted_depth = torch.sigmoid(hidden_states) * self.config.max_depth
  443. predicted_depth = predicted_depth.squeeze(dim=1)
  444. return predicted_depth
  445. @auto_docstring(
  446. custom_intro="""
  447. GLPN Model transformer with a lightweight depth estimation head on top e.g. for KITTI, NYUv2.
  448. """
  449. )
  450. class GLPNForDepthEstimation(GLPNPreTrainedModel):
  451. def __init__(self, config):
  452. super().__init__(config)
  453. self.glpn = GLPNModel(config)
  454. self.decoder = GLPNDecoder(config)
  455. self.head = GLPNDepthEstimationHead(config)
  456. # Initialize weights and apply final processing
  457. self.post_init()
  458. @auto_docstring
  459. def forward(
  460. self,
  461. pixel_values: torch.FloatTensor,
  462. labels: torch.FloatTensor | None = None,
  463. output_attentions: bool | None = None,
  464. output_hidden_states: bool | None = None,
  465. return_dict: bool | None = None,
  466. **kwargs,
  467. ) -> tuple[torch.Tensor] | DepthEstimatorOutput:
  468. r"""
  469. labels (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*):
  470. Ground truth depth estimation maps for computing the loss.
  471. Examples:
  472. ```python
  473. >>> from transformers import AutoImageProcessor, GLPNForDepthEstimation
  474. >>> import torch
  475. >>> import numpy as np
  476. >>> from PIL import Image
  477. >>> import httpx
  478. >>> from io import BytesIO
  479. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  480. >>> with httpx.stream("GET", url) as response:
  481. ... image = Image.open(BytesIO(response.read()))
  482. >>> image_processor = AutoImageProcessor.from_pretrained("vinvino02/glpn-kitti")
  483. >>> model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-kitti")
  484. >>> # prepare image for the model
  485. >>> inputs = image_processor(images=image, return_tensors="pt")
  486. >>> with torch.no_grad():
  487. ... outputs = model(**inputs)
  488. >>> # interpolate to original size
  489. >>> post_processed_output = image_processor.post_process_depth_estimation(
  490. ... outputs,
  491. ... target_sizes=[(image.height, image.width)],
  492. ... )
  493. >>> # visualize the prediction
  494. >>> predicted_depth = post_processed_output[0]["predicted_depth"]
  495. >>> depth = predicted_depth * 255 / predicted_depth.max()
  496. >>> depth = depth.detach().cpu().numpy()
  497. >>> depth = Image.fromarray(depth.astype("uint8"))
  498. ```"""
  499. return_dict = return_dict if return_dict is not None else self.config.return_dict
  500. output_hidden_states = (
  501. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  502. )
  503. outputs = self.glpn(
  504. pixel_values,
  505. output_attentions=output_attentions,
  506. output_hidden_states=True, # we need the intermediate hidden states
  507. return_dict=return_dict,
  508. )
  509. hidden_states = outputs.hidden_states if return_dict else outputs[1]
  510. out = self.decoder(hidden_states)
  511. predicted_depth = self.head(out)
  512. loss = None
  513. if labels is not None:
  514. loss_fct = SiLogLoss()
  515. loss = loss_fct(predicted_depth, labels)
  516. if not return_dict:
  517. if output_hidden_states:
  518. output = (predicted_depth,) + outputs[1:]
  519. else:
  520. output = (predicted_depth,) + outputs[2:]
  521. return ((loss,) + output) if loss is not None else output
  522. return DepthEstimatorOutput(
  523. loss=loss,
  524. predicted_depth=predicted_depth,
  525. hidden_states=outputs.hidden_states if output_hidden_states else None,
  526. attentions=outputs.attentions,
  527. )
  528. __all__ = ["GLPNForDepthEstimation", "GLPNLayer", "GLPNModel", "GLPNPreTrainedModel"]