modeling_zoedepth.py 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349
  1. # Copyright 2024 Intel Labs and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ZoeDepth model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...backbone_utils import load_backbone
  22. from ...modeling_outputs import DepthEstimatorOutput
  23. from ...modeling_utils import PreTrainedModel
  24. from ...utils import ModelOutput, auto_docstring, logging
  25. from .configuration_zoedepth import ZoeDepthConfig
  26. logger = logging.get_logger(__name__)
  27. @dataclass
  28. @auto_docstring(
  29. custom_intro="""
  30. Extension of `DepthEstimatorOutput` to include domain logits (ZoeDepth specific).
  31. """
  32. )
  33. class ZoeDepthDepthEstimatorOutput(ModelOutput):
  34. r"""
  35. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  36. Classification (or regression if config.num_labels==1) loss.
  37. domain_logits (`torch.FloatTensor` of shape `(batch_size, num_domains)`):
  38. Logits for each domain (e.g. NYU and KITTI) in case multiple metric heads are used.
  39. """
  40. loss: torch.FloatTensor | None = None
  41. predicted_depth: torch.FloatTensor | None = None
  42. domain_logits: torch.FloatTensor | None = None
  43. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  44. attentions: tuple[torch.FloatTensor, ...] | None = None
  45. class ZoeDepthReassembleStage(nn.Module):
  46. """
  47. This class reassembles the hidden states of the backbone into image-like feature representations at various
  48. resolutions.
  49. This happens in 3 stages:
  50. 1. Map the N + 1 tokens to a set of N tokens, by taking into account the readout ([CLS]) token according to
  51. `config.readout_type`.
  52. 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
  53. 3. Resizing the spatial dimensions (height, width).
  54. Args:
  55. config (`[ZoeDepthConfig]`):
  56. Model configuration class defining the model architecture.
  57. """
  58. def __init__(self, config):
  59. super().__init__()
  60. self.readout_type = config.readout_type
  61. self.layers = nn.ModuleList()
  62. for neck_hidden_size, factor in zip(config.neck_hidden_sizes, config.reassemble_factors):
  63. self.layers.append(ZoeDepthReassembleLayer(config, channels=neck_hidden_size, factor=factor))
  64. if config.readout_type == "project":
  65. self.readout_projects = nn.ModuleList()
  66. hidden_size = config.backbone_hidden_size
  67. for _ in config.neck_hidden_sizes:
  68. self.readout_projects.append(
  69. nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
  70. )
  71. def forward(self, hidden_states: list[torch.Tensor], patch_height, patch_width) -> list[torch.Tensor]:
  72. """
  73. Args:
  74. hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
  75. List of hidden states from the backbone.
  76. """
  77. batch_size = hidden_states[0].shape[0]
  78. # stack along batch dimension
  79. # shape (batch_size*num_stages, sequence_length + 1, hidden_size)
  80. hidden_states = torch.cat(hidden_states, dim=0)
  81. cls_token, hidden_states = hidden_states[:, 0], hidden_states[:, 1:]
  82. # reshape hidden_states to (batch_size*num_stages, num_channels, height, width)
  83. total_batch_size, sequence_length, num_channels = hidden_states.shape
  84. hidden_states = hidden_states.reshape(total_batch_size, patch_height, patch_width, num_channels)
  85. hidden_states = hidden_states.permute(0, 3, 1, 2).contiguous()
  86. if self.readout_type == "project":
  87. # reshape to (batch_size*num_stages, height*width, num_channels)
  88. hidden_states = hidden_states.flatten(2).permute((0, 2, 1))
  89. readout = cls_token.unsqueeze(dim=1).expand_as(hidden_states)
  90. # concatenate the readout token to the hidden states
  91. # to get (batch_size*num_stages, height*width, 2*num_channels)
  92. hidden_states = torch.cat((hidden_states, readout), -1)
  93. elif self.readout_type == "add":
  94. hidden_states = hidden_states + cls_token.unsqueeze(-1)
  95. out = []
  96. for stage_idx, hidden_state in enumerate(hidden_states.split(batch_size, dim=0)):
  97. if self.readout_type == "project":
  98. hidden_state = self.readout_projects[stage_idx](hidden_state)
  99. # reshape back to (batch_size, num_channels, height, width)
  100. hidden_state = hidden_state.permute(0, 2, 1).reshape(batch_size, -1, patch_height, patch_width)
  101. hidden_state = self.layers[stage_idx](hidden_state)
  102. out.append(hidden_state)
  103. return out
  104. class ZoeDepthReassembleLayer(nn.Module):
  105. def __init__(self, config, channels, factor):
  106. super().__init__()
  107. # projection
  108. hidden_size = config.backbone_hidden_size
  109. self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1)
  110. # up/down sampling depending on factor
  111. if factor > 1:
  112. self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
  113. elif factor == 1:
  114. self.resize = nn.Identity()
  115. elif factor < 1:
  116. # so should downsample
  117. self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
  118. # Copied from transformers.models.dpt.modeling_dpt.DPTReassembleLayer.forward with DPT->ZoeDepth
  119. def forward(self, hidden_state):
  120. hidden_state = self.projection(hidden_state)
  121. hidden_state = self.resize(hidden_state)
  122. return hidden_state
  123. # Copied from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage with DPT->ZoeDepth
  124. class ZoeDepthFeatureFusionStage(nn.Module):
  125. def __init__(self, config: ZoeDepthConfig):
  126. super().__init__()
  127. self.layers = nn.ModuleList()
  128. for _ in range(len(config.neck_hidden_sizes)):
  129. self.layers.append(ZoeDepthFeatureFusionLayer(config))
  130. def forward(self, hidden_states):
  131. # reversing the hidden_states, we start from the last
  132. hidden_states = hidden_states[::-1]
  133. fused_hidden_states = []
  134. fused_hidden_state = None
  135. for hidden_state, layer in zip(hidden_states, self.layers):
  136. if fused_hidden_state is None:
  137. # first layer only uses the last hidden_state
  138. fused_hidden_state = layer(hidden_state)
  139. else:
  140. fused_hidden_state = layer(fused_hidden_state, hidden_state)
  141. fused_hidden_states.append(fused_hidden_state)
  142. return fused_hidden_states
  143. # Copied from transformers.models.dpt.modeling_dpt.DPTPreActResidualLayer with DPT->ZoeDepth
  144. class ZoeDepthPreActResidualLayer(nn.Module):
  145. """
  146. ResidualConvUnit, pre-activate residual unit.
  147. Args:
  148. config (`[ZoeDepthConfig]`):
  149. Model configuration class defining the model architecture.
  150. """
  151. # Ignore copy
  152. def __init__(self, config):
  153. super().__init__()
  154. self.use_batch_norm = config.use_batch_norm_in_fusion_residual
  155. use_bias_in_fusion_residual = (
  156. config.use_bias_in_fusion_residual
  157. if config.use_bias_in_fusion_residual is not None
  158. else not self.use_batch_norm
  159. )
  160. self.activation1 = nn.ReLU()
  161. self.convolution1 = nn.Conv2d(
  162. config.fusion_hidden_size,
  163. config.fusion_hidden_size,
  164. kernel_size=3,
  165. stride=1,
  166. padding=1,
  167. bias=use_bias_in_fusion_residual,
  168. )
  169. self.activation2 = nn.ReLU()
  170. self.convolution2 = nn.Conv2d(
  171. config.fusion_hidden_size,
  172. config.fusion_hidden_size,
  173. kernel_size=3,
  174. stride=1,
  175. padding=1,
  176. bias=use_bias_in_fusion_residual,
  177. )
  178. if self.use_batch_norm:
  179. self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size, eps=config.batch_norm_eps)
  180. self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size, eps=config.batch_norm_eps)
  181. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  182. residual = hidden_state
  183. hidden_state = self.activation1(hidden_state)
  184. hidden_state = self.convolution1(hidden_state)
  185. if self.use_batch_norm:
  186. hidden_state = self.batch_norm1(hidden_state)
  187. hidden_state = self.activation2(hidden_state)
  188. hidden_state = self.convolution2(hidden_state)
  189. if self.use_batch_norm:
  190. hidden_state = self.batch_norm2(hidden_state)
  191. return hidden_state + residual
  192. # Copied from transformers.models.dpt.modeling_dpt.DPTFeatureFusionLayer with DPT->ZoeDepth
  193. class ZoeDepthFeatureFusionLayer(nn.Module):
  194. """Feature fusion layer, merges feature maps from different stages.
  195. Args:
  196. config (`[ZoeDepthConfig]`):
  197. Model configuration class defining the model architecture.
  198. align_corners (`bool`, *optional*, defaults to `True`):
  199. The align_corner setting for bilinear upsample.
  200. """
  201. def __init__(self, config: ZoeDepthConfig, align_corners: bool = True):
  202. super().__init__()
  203. self.align_corners = align_corners
  204. self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
  205. self.residual_layer1 = ZoeDepthPreActResidualLayer(config)
  206. self.residual_layer2 = ZoeDepthPreActResidualLayer(config)
  207. def forward(self, hidden_state: torch.Tensor, residual: torch.Tensor | None = None) -> torch.Tensor:
  208. if residual is not None:
  209. if hidden_state.shape != residual.shape:
  210. residual = nn.functional.interpolate(
  211. residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
  212. )
  213. hidden_state = hidden_state + self.residual_layer1(residual)
  214. hidden_state = self.residual_layer2(hidden_state)
  215. hidden_state = nn.functional.interpolate(
  216. hidden_state, scale_factor=2, mode="bilinear", align_corners=self.align_corners
  217. )
  218. hidden_state = self.projection(hidden_state)
  219. return hidden_state
  220. class ZoeDepthNeck(nn.Module):
  221. """
  222. ZoeDepthNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
  223. input and produces another list of tensors as output. For ZoeDepth, it includes 2 stages:
  224. * ZoeDepthReassembleStage
  225. * ZoeDepthFeatureFusionStage.
  226. Args:
  227. config (dict): config dict.
  228. """
  229. # Copied from transformers.models.dpt.modeling_dpt.DPTNeck.__init__ with DPT->ZoeDepth
  230. def __init__(self, config: ZoeDepthConfig):
  231. super().__init__()
  232. self.config = config
  233. # postprocessing: only required in case of a non-hierarchical backbone (e.g. ViT, BEiT)
  234. if config.backbone_config is not None and config.backbone_config.model_type == "swinv2":
  235. self.reassemble_stage = None
  236. else:
  237. self.reassemble_stage = ZoeDepthReassembleStage(config)
  238. self.convs = nn.ModuleList()
  239. for channel in config.neck_hidden_sizes:
  240. self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
  241. # fusion
  242. self.fusion_stage = ZoeDepthFeatureFusionStage(config)
  243. def forward(self, hidden_states: list[torch.Tensor], patch_height, patch_width) -> list[torch.Tensor]:
  244. """
  245. Args:
  246. hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
  247. List of hidden states from the backbone.
  248. """
  249. if not isinstance(hidden_states, (tuple, list)):
  250. raise TypeError("hidden_states should be a tuple or list of tensors")
  251. if len(hidden_states) != len(self.config.neck_hidden_sizes):
  252. raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
  253. # postprocess hidden states
  254. if self.reassemble_stage is not None:
  255. hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
  256. features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
  257. # fusion blocks
  258. output = self.fusion_stage(features)
  259. return output, features[-1]
  260. class ZoeDepthRelativeDepthEstimationHead(nn.Module):
  261. """
  262. Relative depth estimation head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
  263. the predictions to the input resolution after the first convolutional layer (details can be found in DPT's paper's
  264. supplementary material).
  265. """
  266. def __init__(self, config):
  267. super().__init__()
  268. self.head_in_index = config.head_in_index
  269. self.projection = None
  270. if config.add_projection:
  271. self.projection = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  272. features = config.fusion_hidden_size
  273. self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1)
  274. self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
  275. self.conv2 = nn.Conv2d(features // 2, config.num_relative_features, kernel_size=3, stride=1, padding=1)
  276. self.conv3 = nn.Conv2d(config.num_relative_features, 1, kernel_size=1, stride=1, padding=0)
  277. def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
  278. # use last features
  279. hidden_states = hidden_states[self.head_in_index]
  280. if self.projection is not None:
  281. hidden_states = self.projection(hidden_states)
  282. hidden_states = nn.ReLU()(hidden_states)
  283. hidden_states = self.conv1(hidden_states)
  284. hidden_states = self.upsample(hidden_states)
  285. hidden_states = self.conv2(hidden_states)
  286. hidden_states = nn.ReLU()(hidden_states)
  287. # we need the features here (after second conv + ReLu)
  288. features = hidden_states
  289. hidden_states = self.conv3(hidden_states)
  290. hidden_states = nn.ReLU()(hidden_states)
  291. predicted_depth = hidden_states.squeeze(dim=1)
  292. return predicted_depth, features
  293. def log_binom(n, k, eps=1e-7):
  294. """log(nCk) using stirling approximation"""
  295. n = n + eps
  296. k = k + eps
  297. return n * torch.log(n) - k * torch.log(k) - (n - k) * torch.log(n - k + eps)
  298. class LogBinomialSoftmax(nn.Module):
  299. def __init__(self, n_classes=256, act=torch.softmax):
  300. """Compute log binomial distribution for n_classes
  301. Args:
  302. n_classes (`int`, *optional*, defaults to 256):
  303. Number of output classes.
  304. act (`torch.nn.Module`, *optional*, defaults to `torch.softmax`):
  305. Activation function to apply to the output.
  306. """
  307. super().__init__()
  308. self.k = n_classes
  309. self.act = act
  310. self.register_buffer("k_idx", torch.arange(0, n_classes).view(1, -1, 1, 1), persistent=False)
  311. self.register_buffer("k_minus_1", torch.tensor([self.k - 1]).view(1, -1, 1, 1), persistent=False)
  312. def forward(self, probabilities, temperature=1.0, eps=1e-4):
  313. """Compute the log binomial distribution for probabilities.
  314. Args:
  315. probabilities (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
  316. Tensor containing probabilities of each class.
  317. temperature (`float` or `torch.Tensor` of shape `(batch_size, num_channels, height, width)`, *optional*, defaults to 1):
  318. Temperature of distribution.
  319. eps (`float`, *optional*, defaults to 1e-4):
  320. Small number for numerical stability.
  321. Returns:
  322. `torch.Tensor` of shape `(batch_size, num_channels, height, width)`:
  323. Log binomial distribution logbinomial(p;t).
  324. """
  325. if probabilities.ndim == 3:
  326. probabilities = probabilities.unsqueeze(1) # make it (batch_size, num_channels, height, width)
  327. one_minus_probabilities = torch.clamp(1 - probabilities, eps, 1)
  328. probabilities = torch.clamp(probabilities, eps, 1)
  329. y = (
  330. log_binom(self.k_minus_1, self.k_idx)
  331. + self.k_idx * torch.log(probabilities)
  332. + (self.k_minus_1 - self.k_idx) * torch.log(one_minus_probabilities)
  333. )
  334. return self.act(y / temperature, dim=1)
  335. class ZoeDepthConditionalLogBinomialSoftmax(nn.Module):
  336. def __init__(
  337. self,
  338. config,
  339. in_features,
  340. condition_dim,
  341. n_classes=256,
  342. bottleneck_factor=2,
  343. ):
  344. """Per-pixel MLP followed by a Conditional Log Binomial softmax.
  345. Args:
  346. in_features (`int`):
  347. Number of input channels in the main feature.
  348. condition_dim (`int`):
  349. Number of input channels in the condition feature.
  350. n_classes (`int`, *optional*, defaults to 256):
  351. Number of classes.
  352. bottleneck_factor (`int`, *optional*, defaults to 2):
  353. Hidden dim factor.
  354. """
  355. super().__init__()
  356. bottleneck = (in_features + condition_dim) // bottleneck_factor
  357. self.mlp = nn.Sequential(
  358. nn.Conv2d(in_features + condition_dim, bottleneck, kernel_size=1, stride=1, padding=0),
  359. nn.GELU(),
  360. # 2 for probabilities linear norm, 2 for temperature linear norm
  361. nn.Conv2d(bottleneck, 2 + 2, kernel_size=1, stride=1, padding=0),
  362. nn.Softplus(),
  363. )
  364. self.p_eps = 1e-4
  365. self.max_temp = config.max_temp
  366. self.min_temp = config.min_temp
  367. self.log_binomial_transform = LogBinomialSoftmax(n_classes, act=torch.softmax)
  368. def forward(self, main_feature, condition_feature):
  369. """
  370. Args:
  371. main_feature (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
  372. Main feature.
  373. condition_feature (torch.Tensor of shape `(batch_size, num_channels, height, width)`):
  374. Condition feature.
  375. Returns:
  376. `torch.Tensor`:
  377. Output log binomial distribution
  378. """
  379. probabilities_and_temperature = self.mlp(torch.concat((main_feature, condition_feature), dim=1))
  380. probabilities, temperature = (
  381. probabilities_and_temperature[:, :2, ...],
  382. probabilities_and_temperature[:, 2:, ...],
  383. )
  384. probabilities = probabilities + self.p_eps
  385. probabilities = probabilities[:, 0, ...] / (probabilities[:, 0, ...] + probabilities[:, 1, ...])
  386. temperature = temperature + self.p_eps
  387. temperature = temperature[:, 0, ...] / (temperature[:, 0, ...] + temperature[:, 1, ...])
  388. temperature = temperature.unsqueeze(1)
  389. temperature = (self.max_temp - self.min_temp) * temperature + self.min_temp
  390. return self.log_binomial_transform(probabilities, temperature)
  391. class ZoeDepthSeedBinRegressor(nn.Module):
  392. def __init__(self, config, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
  393. """Bin center regressor network.
  394. Can be "normed" or "unnormed". If "normed", bin centers are bounded on the (min_depth, max_depth) interval.
  395. Args:
  396. config (`int`):
  397. Model configuration.
  398. n_bins (`int`, *optional*, defaults to 16):
  399. Number of bin centers.
  400. mlp_dim (`int`, *optional*, defaults to 256):
  401. Hidden dimension.
  402. min_depth (`float`, *optional*, defaults to 1e-3):
  403. Min depth value.
  404. max_depth (`float`, *optional*, defaults to 10):
  405. Max depth value.
  406. """
  407. super().__init__()
  408. self.in_features = config.bottleneck_features
  409. self.bin_centers_type = config.bin_centers_type
  410. self.min_depth = min_depth
  411. self.max_depth = max_depth
  412. self.conv1 = nn.Conv2d(self.in_features, mlp_dim, 1, 1, 0)
  413. self.act1 = nn.ReLU(inplace=True)
  414. self.conv2 = nn.Conv2d(mlp_dim, n_bins, 1, 1, 0)
  415. self.act2 = nn.ReLU(inplace=True) if self.bin_centers_type == "normed" else nn.Softplus()
  416. def forward(self, x):
  417. """
  418. Returns tensor of bin_width vectors (centers). One vector b for every pixel
  419. """
  420. x = self.conv1(x)
  421. x = self.act1(x)
  422. x = self.conv2(x)
  423. bin_centers = self.act2(x)
  424. if self.bin_centers_type == "normed":
  425. bin_centers = bin_centers + 1e-3
  426. bin_widths_normed = bin_centers / bin_centers.sum(dim=1, keepdim=True)
  427. # shape (batch_size, num_channels, height, width)
  428. bin_widths = (self.max_depth - self.min_depth) * bin_widths_normed
  429. # pad has the form (left, right, top, bottom, front, back)
  430. bin_widths = nn.functional.pad(bin_widths, (0, 0, 0, 0, 1, 0), mode="constant", value=self.min_depth)
  431. # shape (batch_size, num_channels, height, width)
  432. bin_edges = torch.cumsum(bin_widths, dim=1)
  433. bin_centers = 0.5 * (bin_edges[:, :-1, ...] + bin_edges[:, 1:, ...])
  434. return bin_widths_normed, bin_centers
  435. else:
  436. return bin_centers, bin_centers
  437. @torch.jit.script
  438. def inv_attractor(dx, alpha: float = 300, gamma: int = 2):
  439. """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center
  440. This is the default one according to the accompanying paper.
  441. Args:
  442. dx (`torch.Tensor`):
  443. The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
  444. alpha (`float`, *optional*, defaults to 300):
  445. Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction.
  446. gamma (`int`, *optional*, defaults to 2):
  447. Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected.
  448. Lower gamma = farther reach.
  449. Returns:
  450. torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc
  451. """
  452. return dx.div(1 + alpha * dx.pow(gamma))
  453. class ZoeDepthAttractorLayer(nn.Module):
  454. def __init__(
  455. self,
  456. config,
  457. n_bins,
  458. n_attractors=16,
  459. min_depth=1e-3,
  460. max_depth=10,
  461. memory_efficient=False,
  462. ):
  463. """
  464. Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth)
  465. """
  466. super().__init__()
  467. self.alpha = config.attractor_alpha
  468. self.gemma = config.attractor_gamma
  469. self.kind = config.attractor_kind
  470. self.n_attractors = n_attractors
  471. self.n_bins = n_bins
  472. self.min_depth = min_depth
  473. self.max_depth = max_depth
  474. self.memory_efficient = memory_efficient
  475. # MLP to predict attractor points
  476. in_features = mlp_dim = config.bin_embedding_dim
  477. self.conv1 = nn.Conv2d(in_features, mlp_dim, 1, 1, 0)
  478. self.act1 = nn.ReLU(inplace=True)
  479. self.conv2 = nn.Conv2d(mlp_dim, n_attractors * 2, 1, 1, 0) # x2 for linear norm
  480. self.act2 = nn.ReLU(inplace=True)
  481. def forward(self, x, prev_bin, prev_bin_embedding=None, interpolate=True):
  482. """
  483. The forward pass of the attractor layer. This layer predicts the new bin centers based on the previous bin centers
  484. and the attractor points (the latter are predicted by the MLP).
  485. Args:
  486. x (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
  487. Feature block.
  488. prev_bin (`torch.Tensor` of shape `(batch_size, prev_number_of_bins, height, width)`):
  489. Previous bin centers normed.
  490. prev_bin_embedding (`torch.Tensor`, *optional*):
  491. Optional previous bin embeddings.
  492. interpolate (`bool`, *optional*, defaults to `True`):
  493. Whether to interpolate the previous bin embeddings to the size of the input features.
  494. Returns:
  495. `tuple[`torch.Tensor`, `torch.Tensor`]:
  496. New bin centers normed and scaled.
  497. """
  498. if prev_bin_embedding is not None:
  499. if interpolate:
  500. prev_bin_embedding = nn.functional.interpolate(
  501. prev_bin_embedding, x.shape[-2:], mode="bilinear", align_corners=True
  502. )
  503. x = x + prev_bin_embedding
  504. x = self.conv1(x)
  505. x = self.act1(x)
  506. x = self.conv2(x)
  507. attractors = self.act2(x)
  508. attractors = attractors + 1e-3
  509. batch_size, _, height, width = attractors.shape
  510. attractors = attractors.view(batch_size, self.n_attractors, 2, height, width)
  511. # batch_size, num_attractors, 2, height, width
  512. # note: original repo had a bug here: https://github.com/isl-org/ZoeDepth/blame/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/layers/attractor.py#L105C9-L106C50
  513. # we include the bug to maintain compatibility with the weights
  514. attractors_normed = attractors[:, :, 0, ...] # batch_size, batch_size*num_attractors, height, width
  515. bin_centers = nn.functional.interpolate(prev_bin, (height, width), mode="bilinear", align_corners=True)
  516. # note: only attractor_type = "exp" is supported here, since no checkpoints were released with other attractor types
  517. if not self.memory_efficient:
  518. func = {"mean": torch.mean, "sum": torch.sum}[self.kind]
  519. # shape (batch_size, num_bins, height, width)
  520. delta_c = func(inv_attractor(attractors_normed.unsqueeze(2) - bin_centers.unsqueeze(1)), dim=1)
  521. else:
  522. delta_c = torch.zeros_like(bin_centers, device=bin_centers.device)
  523. for i in range(self.n_attractors):
  524. # shape (batch_size, num_bins, height, width)
  525. delta_c += inv_attractor(attractors_normed[:, i, ...].unsqueeze(1) - bin_centers)
  526. if self.kind == "mean":
  527. delta_c = delta_c / self.n_attractors
  528. bin_new_centers = bin_centers + delta_c
  529. bin_centers = (self.max_depth - self.min_depth) * bin_new_centers + self.min_depth
  530. bin_centers, _ = torch.sort(bin_centers, dim=1)
  531. bin_centers = torch.clip(bin_centers, self.min_depth, self.max_depth)
  532. return bin_new_centers, bin_centers
  533. class ZoeDepthAttractorLayerUnnormed(nn.Module):
  534. def __init__(
  535. self,
  536. config,
  537. n_bins,
  538. n_attractors=16,
  539. min_depth=1e-3,
  540. max_depth=10,
  541. memory_efficient=True,
  542. ):
  543. """
  544. Attractor layer for bin centers. Bin centers are unbounded
  545. """
  546. super().__init__()
  547. self.n_attractors = n_attractors
  548. self.n_bins = n_bins
  549. self.min_depth = min_depth
  550. self.max_depth = max_depth
  551. self.alpha = config.attractor_alpha
  552. self.gamma = config.attractor_alpha
  553. self.kind = config.attractor_kind
  554. self.memory_efficient = memory_efficient
  555. in_features = mlp_dim = config.bin_embedding_dim
  556. self.conv1 = nn.Conv2d(in_features, mlp_dim, 1, 1, 0)
  557. self.act1 = nn.ReLU(inplace=True)
  558. self.conv2 = nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0)
  559. self.act2 = nn.Softplus()
  560. def forward(self, x, prev_bin, prev_bin_embedding=None, interpolate=True):
  561. """
  562. The forward pass of the attractor layer. This layer predicts the new bin centers based on the previous bin centers
  563. and the attractor points (the latter are predicted by the MLP).
  564. Args:
  565. x (`torch.Tensor` of shape (batch_size, num_channels, height, width)`):
  566. Feature block.
  567. prev_bin (`torch.Tensor` of shape (batch_size, prev_num_bins, height, width)`):
  568. Previous bin centers normed.
  569. prev_bin_embedding (`torch.Tensor`, *optional*):
  570. Optional previous bin embeddings.
  571. interpolate (`bool`, *optional*, defaults to `True`):
  572. Whether to interpolate the previous bin embeddings to the size of the input features.
  573. Returns:
  574. `tuple[`torch.Tensor`, `torch.Tensor`]:
  575. New bin centers unbounded. Two outputs just to keep the API consistent with the normed version.
  576. """
  577. if prev_bin_embedding is not None:
  578. if interpolate:
  579. prev_bin_embedding = nn.functional.interpolate(
  580. prev_bin_embedding, x.shape[-2:], mode="bilinear", align_corners=True
  581. )
  582. x = x + prev_bin_embedding
  583. x = self.conv1(x)
  584. x = self.act1(x)
  585. x = self.conv2(x)
  586. attractors = self.act2(x)
  587. height, width = attractors.shape[-2:]
  588. bin_centers = nn.functional.interpolate(prev_bin, (height, width), mode="bilinear", align_corners=True)
  589. if not self.memory_efficient:
  590. func = {"mean": torch.mean, "sum": torch.sum}[self.kind]
  591. # shape batch_size, num_bins, height, width
  592. delta_c = func(inv_attractor(attractors.unsqueeze(2) - bin_centers.unsqueeze(1)), dim=1)
  593. else:
  594. delta_c = torch.zeros_like(bin_centers, device=bin_centers.device)
  595. for i in range(self.n_attractors):
  596. # shape batch_size, num_bins, height, width
  597. delta_c += inv_attractor(attractors[:, i, ...].unsqueeze(1) - bin_centers)
  598. if self.kind == "mean":
  599. delta_c = delta_c / self.n_attractors
  600. bin_new_centers = bin_centers + delta_c
  601. bin_centers = bin_new_centers
  602. return bin_new_centers, bin_centers
  603. class ZoeDepthProjector(nn.Module):
  604. def __init__(self, in_features, out_features, mlp_dim=128):
  605. """Projector MLP.
  606. Args:
  607. in_features (`int`):
  608. Number of input channels.
  609. out_features (`int`):
  610. Number of output channels.
  611. mlp_dim (`int`, *optional*, defaults to 128):
  612. Hidden dimension.
  613. """
  614. super().__init__()
  615. self.conv1 = nn.Conv2d(in_features, mlp_dim, 1, 1, 0)
  616. self.act = nn.ReLU(inplace=True)
  617. self.conv2 = nn.Conv2d(mlp_dim, out_features, 1, 1, 0)
  618. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  619. hidden_state = self.conv1(hidden_state)
  620. hidden_state = self.act(hidden_state)
  621. hidden_state = self.conv2(hidden_state)
  622. return hidden_state
  623. # Copied from transformers.models.grounding_dino.modeling_grounding_dino.GroundingDinoMultiheadAttention with GroundingDino->ZoeDepth
  624. class ZoeDepthMultiheadAttention(nn.Module):
  625. """Equivalent implementation of nn.MultiheadAttention with `batch_first=True`."""
  626. # Ignore copy
  627. def __init__(self, hidden_size, num_attention_heads, dropout):
  628. super().__init__()
  629. if hidden_size % num_attention_heads != 0:
  630. raise ValueError(
  631. f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
  632. f"heads ({num_attention_heads})"
  633. )
  634. self.num_attention_heads = num_attention_heads
  635. self.attention_head_size = int(hidden_size / num_attention_heads)
  636. self.all_head_size = self.num_attention_heads * self.attention_head_size
  637. self.query = nn.Linear(hidden_size, self.all_head_size)
  638. self.key = nn.Linear(hidden_size, self.all_head_size)
  639. self.value = nn.Linear(hidden_size, self.all_head_size)
  640. self.out_proj = nn.Linear(hidden_size, hidden_size)
  641. self.dropout = nn.Dropout(dropout)
  642. def forward(
  643. self,
  644. queries: torch.Tensor,
  645. keys: torch.Tensor,
  646. values: torch.Tensor,
  647. attention_mask: torch.FloatTensor | None = None,
  648. output_attentions: bool | None = False,
  649. ) -> tuple[torch.Tensor]:
  650. batch_size, seq_length, _ = queries.shape
  651. query_layer = (
  652. self.query(queries)
  653. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  654. .transpose(1, 2)
  655. )
  656. key_layer = (
  657. self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  658. )
  659. value_layer = (
  660. self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  661. )
  662. # Take the dot product between "query" and "key" to get the raw attention scores.
  663. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  664. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  665. if attention_mask is not None:
  666. # Apply the attention mask is (precomputed for all layers in ZoeDepthModel forward() function)
  667. attention_scores = attention_scores + attention_mask
  668. # Normalize the attention scores to probabilities.
  669. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  670. # This is actually dropping out entire tokens to attend to, which might
  671. # seem a bit unusual, but is taken from the original Transformer paper.
  672. attention_probs = self.dropout(attention_probs)
  673. context_layer = torch.matmul(attention_probs, value_layer)
  674. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  675. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  676. context_layer = context_layer.view(new_context_layer_shape)
  677. context_layer = self.out_proj(context_layer)
  678. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  679. return outputs
  680. class ZoeDepthTransformerEncoderLayer(nn.Module):
  681. def __init__(self, config, dropout=0.1, activation="relu"):
  682. super().__init__()
  683. hidden_size = config.patch_transformer_hidden_size
  684. intermediate_size = config.patch_transformer_intermediate_size
  685. num_attention_heads = config.patch_transformer_num_attention_heads
  686. self.self_attn = ZoeDepthMultiheadAttention(hidden_size, num_attention_heads, dropout=dropout)
  687. self.linear1 = nn.Linear(hidden_size, intermediate_size)
  688. self.dropout = nn.Dropout(dropout)
  689. self.linear2 = nn.Linear(intermediate_size, hidden_size)
  690. self.norm1 = nn.LayerNorm(hidden_size)
  691. self.norm2 = nn.LayerNorm(hidden_size)
  692. self.dropout1 = nn.Dropout(dropout)
  693. self.dropout2 = nn.Dropout(dropout)
  694. self.activation = ACT2FN[activation]
  695. def forward(
  696. self,
  697. src,
  698. src_mask: torch.Tensor | None = None,
  699. ):
  700. queries = keys = src
  701. src2 = self.self_attn(queries=queries, keys=keys, values=src, attention_mask=src_mask)[0]
  702. src = src + self.dropout1(src2)
  703. src = self.norm1(src)
  704. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  705. src = src + self.dropout2(src2)
  706. src = self.norm2(src)
  707. return src
  708. class ZoeDepthPatchTransformerEncoder(nn.Module):
  709. def __init__(self, config):
  710. """ViT-like transformer block
  711. Args:
  712. config (`ZoeDepthConfig`):
  713. Model configuration class defining the model architecture.
  714. """
  715. super().__init__()
  716. in_channels = config.bottleneck_features
  717. self.transformer_encoder = nn.ModuleList(
  718. [ZoeDepthTransformerEncoderLayer(config) for _ in range(config.num_patch_transformer_layers)]
  719. )
  720. self.embedding_convPxP = nn.Conv2d(
  721. in_channels, config.patch_transformer_hidden_size, kernel_size=1, stride=1, padding=0
  722. )
  723. def positional_encoding_1d(self, batch_size, sequence_length, embedding_dim, device="cpu", dtype=torch.float32):
  724. """Generate positional encodings
  725. Args:
  726. sequence_length (int): Sequence length
  727. embedding_dim (int): Embedding dimension
  728. Returns:
  729. torch.Tensor: Positional encodings.
  730. """
  731. position = torch.arange(0, sequence_length, dtype=dtype, device=device).unsqueeze(1)
  732. index = torch.arange(0, embedding_dim, 2, dtype=dtype, device=device).unsqueeze(0)
  733. div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim))
  734. pos_encoding = position * div_term
  735. pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1)
  736. pos_encoding = pos_encoding.unsqueeze(dim=0).repeat(batch_size, 1, 1)
  737. return pos_encoding
  738. def forward(self, x):
  739. """Forward pass
  740. Args:
  741. x (torch.Tensor - NCHW): Input feature tensor
  742. Returns:
  743. torch.Tensor - Transformer output embeddings of shape (batch_size, sequence_length, embedding_dim)
  744. """
  745. embeddings = self.embedding_convPxP(x).flatten(2) # shape (batch_size, num_channels, sequence_length)
  746. # add an extra special CLS token at the start for global accumulation
  747. embeddings = nn.functional.pad(embeddings, (1, 0))
  748. embeddings = embeddings.permute(0, 2, 1)
  749. batch_size, sequence_length, embedding_dim = embeddings.shape
  750. embeddings = embeddings + self.positional_encoding_1d(
  751. batch_size, sequence_length, embedding_dim, device=embeddings.device, dtype=embeddings.dtype
  752. )
  753. for i in range(4):
  754. embeddings = self.transformer_encoder[i](embeddings)
  755. return embeddings
  756. class ZoeDepthMLPClassifier(nn.Module):
  757. def __init__(self, in_features, out_features) -> None:
  758. super().__init__()
  759. hidden_features = in_features
  760. self.linear1 = nn.Linear(in_features, hidden_features)
  761. self.activation = nn.ReLU()
  762. self.linear2 = nn.Linear(hidden_features, out_features)
  763. def forward(self, hidden_state):
  764. hidden_state = self.linear1(hidden_state)
  765. hidden_state = self.activation(hidden_state)
  766. domain_logits = self.linear2(hidden_state)
  767. return domain_logits
  768. class ZoeDepthMultipleMetricDepthEstimationHeads(nn.Module):
  769. """
  770. Multiple metric depth estimation heads. A MLP classifier is used to route between 2 different heads.
  771. """
  772. def __init__(self, config):
  773. super().__init__()
  774. bin_embedding_dim = config.bin_embedding_dim
  775. n_attractors = config.num_attractors
  776. self.bin_configurations = config.bin_configurations
  777. self.bin_centers_type = config.bin_centers_type
  778. # Bottleneck convolution
  779. bottleneck_features = config.bottleneck_features
  780. self.conv2 = nn.Conv2d(bottleneck_features, bottleneck_features, kernel_size=1, stride=1, padding=0)
  781. # Transformer classifier on the bottleneck
  782. self.patch_transformer = ZoeDepthPatchTransformerEncoder(config)
  783. # MLP classifier
  784. self.mlp_classifier = ZoeDepthMLPClassifier(in_features=128, out_features=2)
  785. # Regressor and attractor
  786. if self.bin_centers_type == "normed":
  787. Attractor = ZoeDepthAttractorLayer
  788. elif self.bin_centers_type == "softplus":
  789. Attractor = ZoeDepthAttractorLayerUnnormed
  790. # We have bins for each bin configuration
  791. # Create a map (ModuleDict) of 'name' -> seed_bin_regressor
  792. self.seed_bin_regressors = nn.ModuleDict(
  793. {
  794. conf["name"]: ZoeDepthSeedBinRegressor(
  795. config,
  796. n_bins=conf["n_bins"],
  797. mlp_dim=bin_embedding_dim // 2,
  798. min_depth=conf["min_depth"],
  799. max_depth=conf["max_depth"],
  800. )
  801. for conf in config.bin_configurations
  802. }
  803. )
  804. self.seed_projector = ZoeDepthProjector(
  805. in_features=bottleneck_features, out_features=bin_embedding_dim, mlp_dim=bin_embedding_dim // 2
  806. )
  807. self.projectors = nn.ModuleList(
  808. [
  809. ZoeDepthProjector(
  810. in_features=config.fusion_hidden_size,
  811. out_features=bin_embedding_dim,
  812. mlp_dim=bin_embedding_dim // 2,
  813. )
  814. for _ in range(4)
  815. ]
  816. )
  817. # Create a map (ModuleDict) of 'name' -> attractors (ModuleList)
  818. self.attractors = nn.ModuleDict(
  819. {
  820. configuration["name"]: nn.ModuleList(
  821. [
  822. Attractor(
  823. config,
  824. n_bins=n_attractors[i],
  825. min_depth=configuration["min_depth"],
  826. max_depth=configuration["max_depth"],
  827. )
  828. for i in range(len(n_attractors))
  829. ]
  830. )
  831. for configuration in config.bin_configurations
  832. }
  833. )
  834. last_in = config.num_relative_features
  835. # conditional log binomial for each bin configuration
  836. self.conditional_log_binomial = nn.ModuleDict(
  837. {
  838. configuration["name"]: ZoeDepthConditionalLogBinomialSoftmax(
  839. config,
  840. last_in,
  841. bin_embedding_dim,
  842. configuration["n_bins"],
  843. bottleneck_factor=4,
  844. )
  845. for configuration in config.bin_configurations
  846. }
  847. )
  848. def forward(self, outconv_activation, bottleneck, feature_blocks, relative_depth):
  849. x = self.conv2(bottleneck)
  850. # Predict which path to take
  851. # Embedding is of shape (batch_size, hidden_size)
  852. embedding = self.patch_transformer(x)[:, 0, :]
  853. # MLP classifier to get logits of shape (batch_size, 2)
  854. domain_logits = self.mlp_classifier(embedding)
  855. domain_vote = torch.softmax(domain_logits.sum(dim=0, keepdim=True), dim=-1)
  856. # Get the path
  857. names = [configuration["name"] for configuration in self.bin_configurations]
  858. bin_configurations_name = names[torch.argmax(domain_vote, dim=-1).squeeze().item()]
  859. try:
  860. conf = [config for config in self.bin_configurations if config["name"] == bin_configurations_name][0]
  861. except IndexError:
  862. raise ValueError(f"bin_configurations_name {bin_configurations_name} not found in bin_configurationss")
  863. min_depth = conf["min_depth"]
  864. max_depth = conf["max_depth"]
  865. seed_bin_regressor = self.seed_bin_regressors[bin_configurations_name]
  866. _, seed_bin_centers = seed_bin_regressor(x)
  867. if self.bin_centers_type in ["normed", "hybrid2"]:
  868. prev_bin = (seed_bin_centers - min_depth) / (max_depth - min_depth)
  869. else:
  870. prev_bin = seed_bin_centers
  871. prev_bin_embedding = self.seed_projector(x)
  872. attractors = self.attractors[bin_configurations_name]
  873. for projector, attractor, feature in zip(self.projectors, attractors, feature_blocks):
  874. bin_embedding = projector(feature)
  875. bin, bin_centers = attractor(bin_embedding, prev_bin, prev_bin_embedding, interpolate=True)
  876. prev_bin = bin
  877. prev_bin_embedding = bin_embedding
  878. last = outconv_activation
  879. bin_centers = nn.functional.interpolate(bin_centers, last.shape[-2:], mode="bilinear", align_corners=True)
  880. bin_embedding = nn.functional.interpolate(bin_embedding, last.shape[-2:], mode="bilinear", align_corners=True)
  881. conditional_log_binomial = self.conditional_log_binomial[bin_configurations_name]
  882. x = conditional_log_binomial(last, bin_embedding)
  883. # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
  884. out = torch.sum(x * bin_centers, dim=1, keepdim=True)
  885. return out, domain_logits
  886. class ZoeDepthMetricDepthEstimationHead(nn.Module):
  887. def __init__(self, config):
  888. super().__init__()
  889. bin_configuration = config.bin_configurations[0]
  890. n_bins = bin_configuration["n_bins"]
  891. min_depth = bin_configuration["min_depth"]
  892. max_depth = bin_configuration["max_depth"]
  893. bin_embedding_dim = config.bin_embedding_dim
  894. n_attractors = config.num_attractors
  895. bin_centers_type = config.bin_centers_type
  896. self.min_depth = min_depth
  897. self.max_depth = max_depth
  898. self.bin_centers_type = bin_centers_type
  899. # Bottleneck convolution
  900. bottleneck_features = config.bottleneck_features
  901. self.conv2 = nn.Conv2d(bottleneck_features, bottleneck_features, kernel_size=1, stride=1, padding=0)
  902. # Regressor and attractor
  903. if self.bin_centers_type == "normed":
  904. Attractor = ZoeDepthAttractorLayer
  905. elif self.bin_centers_type == "softplus":
  906. Attractor = ZoeDepthAttractorLayerUnnormed
  907. self.seed_bin_regressor = ZoeDepthSeedBinRegressor(
  908. config, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth
  909. )
  910. self.seed_projector = ZoeDepthProjector(in_features=bottleneck_features, out_features=bin_embedding_dim)
  911. self.projectors = nn.ModuleList(
  912. [
  913. ZoeDepthProjector(in_features=config.fusion_hidden_size, out_features=bin_embedding_dim)
  914. for _ in range(4)
  915. ]
  916. )
  917. self.attractors = nn.ModuleList(
  918. [
  919. Attractor(
  920. config,
  921. n_bins=n_bins,
  922. n_attractors=n_attractors[i],
  923. min_depth=min_depth,
  924. max_depth=max_depth,
  925. )
  926. for i in range(4)
  927. ]
  928. )
  929. last_in = config.num_relative_features + 1 # +1 for relative depth
  930. # use log binomial instead of softmax
  931. self.conditional_log_binomial = ZoeDepthConditionalLogBinomialSoftmax(
  932. config,
  933. last_in,
  934. bin_embedding_dim,
  935. n_classes=n_bins,
  936. )
  937. def forward(self, outconv_activation, bottleneck, feature_blocks, relative_depth):
  938. x = self.conv2(bottleneck)
  939. _, seed_bin_centers = self.seed_bin_regressor(x)
  940. if self.bin_centers_type in ["normed", "hybrid2"]:
  941. prev_bin = (seed_bin_centers - self.min_depth) / (self.max_depth - self.min_depth)
  942. else:
  943. prev_bin = seed_bin_centers
  944. prev_bin_embedding = self.seed_projector(x)
  945. # unroll this loop for better performance
  946. for projector, attractor, feature in zip(self.projectors, self.attractors, feature_blocks):
  947. bin_embedding = projector(feature)
  948. bin, bin_centers = attractor(bin_embedding, prev_bin, prev_bin_embedding, interpolate=True)
  949. prev_bin = bin.clone()
  950. prev_bin_embedding = bin_embedding.clone()
  951. last = outconv_activation
  952. # concatenative relative depth with last. First interpolate relative depth to last size
  953. relative_conditioning = relative_depth.unsqueeze(1)
  954. relative_conditioning = nn.functional.interpolate(
  955. relative_conditioning, size=last.shape[2:], mode="bilinear", align_corners=True
  956. )
  957. last = torch.cat([last, relative_conditioning], dim=1)
  958. bin_embedding = nn.functional.interpolate(bin_embedding, last.shape[-2:], mode="bilinear", align_corners=True)
  959. x = self.conditional_log_binomial(last, bin_embedding)
  960. # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
  961. bin_centers = nn.functional.interpolate(bin_centers, x.shape[-2:], mode="bilinear", align_corners=True)
  962. out = torch.sum(x * bin_centers, dim=1, keepdim=True)
  963. return out, None
  964. # Modified from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->ZoeDepth,dpt->zoedepth
  965. # avoiding sdpa and flash_attn_2 support, it's done int the backend
  966. @auto_docstring
  967. class ZoeDepthPreTrainedModel(PreTrainedModel):
  968. config: ZoeDepthConfig
  969. base_model_prefix = "zoedepth"
  970. main_input_name = "pixel_values"
  971. input_modalities = ("image",)
  972. supports_gradient_checkpointing = True
  973. def _init_weights(self, module):
  974. super()._init_weights(module)
  975. if isinstance(module, LogBinomialSoftmax):
  976. init.copy_(module.k_idx, torch.arange(0, module.k).view(1, -1, 1, 1))
  977. init.copy_(module.k_minus_1, torch.tensor([module.k - 1]).view(1, -1, 1, 1))
  978. @auto_docstring(
  979. custom_intro="""
  980. ZoeDepth model with one or multiple metric depth estimation head(s) on top.
  981. """
  982. )
  983. class ZoeDepthForDepthEstimation(ZoeDepthPreTrainedModel):
  984. def __init__(self, config):
  985. super().__init__(config)
  986. self.backbone = load_backbone(config)
  987. if hasattr(self.backbone.config, "hidden_size") and hasattr(self.backbone.config, "patch_size"):
  988. config.backbone_hidden_size = self.backbone.config.hidden_size
  989. self.patch_size = self.backbone.config.patch_size
  990. else:
  991. raise ValueError(
  992. "ZoeDepth assumes the backbone's config to have `hidden_size` and `patch_size` attributes"
  993. )
  994. self.neck = ZoeDepthNeck(config)
  995. self.relative_head = ZoeDepthRelativeDepthEstimationHead(config)
  996. self.metric_head = (
  997. ZoeDepthMultipleMetricDepthEstimationHeads(config)
  998. if len(config.bin_configurations) > 1
  999. else ZoeDepthMetricDepthEstimationHead(config)
  1000. )
  1001. # Initialize weights and apply final processing
  1002. self.post_init()
  1003. @auto_docstring
  1004. def forward(
  1005. self,
  1006. pixel_values: torch.FloatTensor,
  1007. labels: torch.LongTensor | None = None,
  1008. output_attentions: bool | None = None,
  1009. output_hidden_states: bool | None = None,
  1010. return_dict: bool | None = None,
  1011. **kwargs,
  1012. ) -> tuple[torch.Tensor] | DepthEstimatorOutput:
  1013. r"""
  1014. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  1015. Ground truth depth estimation maps for computing the loss.
  1016. Examples:
  1017. ```python
  1018. >>> from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
  1019. >>> import torch
  1020. >>> import numpy as np
  1021. >>> from PIL import Image
  1022. >>> import httpx
  1023. >>> from io import BytesIO
  1024. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1025. >>> with httpx.stream("GET", url) as response:
  1026. ... image = Image.open(BytesIO(response.read()))
  1027. >>> image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
  1028. >>> model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti")
  1029. >>> # prepare image for the model
  1030. >>> inputs = image_processor(images=image, return_tensors="pt")
  1031. >>> with torch.no_grad():
  1032. ... outputs = model(**inputs)
  1033. >>> # interpolate to original size
  1034. >>> post_processed_output = image_processor.post_process_depth_estimation(
  1035. ... outputs,
  1036. ... source_sizes=[(image.height, image.width)],
  1037. ... )
  1038. >>> # visualize the prediction
  1039. >>> predicted_depth = post_processed_output[0]["predicted_depth"]
  1040. >>> depth = predicted_depth * 255 / predicted_depth.max()
  1041. >>> depth = depth.detach().cpu().numpy()
  1042. >>> depth = Image.fromarray(depth.astype("uint8"))
  1043. ```"""
  1044. loss = None
  1045. if labels is not None:
  1046. raise NotImplementedError("Training is not implemented yet")
  1047. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1048. output_hidden_states = (
  1049. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1050. )
  1051. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1052. outputs = self.backbone.forward_with_filtered_kwargs(
  1053. pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
  1054. )
  1055. hidden_states = outputs.feature_maps
  1056. _, _, height, width = pixel_values.shape
  1057. patch_size = self.patch_size
  1058. patch_height = height // patch_size
  1059. patch_width = width // patch_size
  1060. hidden_states, features = self.neck(hidden_states, patch_height, patch_width)
  1061. out = [features] + hidden_states
  1062. relative_depth, features = self.relative_head(hidden_states)
  1063. out = [features] + out
  1064. metric_depth, domain_logits = self.metric_head(
  1065. outconv_activation=out[0], bottleneck=out[1], feature_blocks=out[2:], relative_depth=relative_depth
  1066. )
  1067. metric_depth = metric_depth.squeeze(dim=1)
  1068. if not return_dict:
  1069. if domain_logits is not None:
  1070. output = (metric_depth, domain_logits) + outputs[1:]
  1071. else:
  1072. output = (metric_depth,) + outputs[1:]
  1073. return ((loss,) + output) if loss is not None else output
  1074. return ZoeDepthDepthEstimatorOutput(
  1075. loss=loss,
  1076. predicted_depth=metric_depth,
  1077. domain_logits=domain_logits,
  1078. hidden_states=outputs.hidden_states,
  1079. attentions=outputs.attentions,
  1080. )
  1081. __all__ = ["ZoeDepthForDepthEstimation", "ZoeDepthPreTrainedModel"]