modeling_chmv2.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/chmv2/modular_chmv2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_chmv2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2026 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import torch
  21. from torch import nn
  22. from ... import initialization as init
  23. from ...backbone_utils import load_backbone
  24. from ...modeling_outputs import DepthEstimatorOutput
  25. from ...modeling_utils import PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  28. from .configuration_chmv2 import CHMv2Config
  29. def _get_backbone_hidden_size(config):
  30. if config.backbone_config is not None and hasattr(config.backbone_config, "hidden_size"):
  31. return config.backbone_config.hidden_size
  32. else:
  33. return config.hidden_size
  34. class CHMv2ReassembleLayer(nn.Module):
  35. def __init__(self, config: CHMv2Config, channels: int, factor: int):
  36. super().__init__()
  37. # projection
  38. hidden_size = _get_backbone_hidden_size(config)
  39. self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1)
  40. # up/down sampling depending on factor
  41. if factor > 1:
  42. self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
  43. elif factor == 1:
  44. self.resize = nn.Identity()
  45. elif factor < 1:
  46. # so should downsample
  47. self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
  48. def forward(self, hidden_state):
  49. hidden_state = self.projection(hidden_state)
  50. hidden_state = self.resize(hidden_state)
  51. return hidden_state
  52. class CHMv2ReassembleStage(nn.Module):
  53. """
  54. Reassemble stage that processes hidden states from the backbone into image-like feature
  55. representations at various resolutions.
  56. """
  57. def __init__(self, config: CHMv2Config):
  58. super().__init__()
  59. self.config = config
  60. self.readout_type = config.readout_type
  61. self.layers = nn.ModuleList()
  62. for out_channels, factor in zip(config.post_process_channels, config.reassemble_factors):
  63. self.layers.append(
  64. CHMv2ReassembleLayer(
  65. config=config,
  66. channels=out_channels,
  67. factor=factor,
  68. )
  69. )
  70. hidden_size = _get_backbone_hidden_size(config)
  71. if self.readout_type == "project":
  72. self.readout_projects = nn.ModuleList()
  73. for _ in range(len(self.layers)):
  74. self.readout_projects.append(nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), nn.GELU()))
  75. def forward(self, hidden_states: list[torch.Tensor], patch_height=None, patch_width=None) -> list[torch.Tensor]:
  76. out = []
  77. for layer_idx, hidden_state in enumerate(hidden_states):
  78. if isinstance(hidden_state, (tuple, list)) and len(hidden_state) == 2:
  79. hidden_state, cls_token = hidden_state[0], hidden_state[1]
  80. feature_shape = hidden_state.shape
  81. if self.readout_type == "project":
  82. hidden_state = hidden_state.flatten(2).transpose(1, 2)
  83. readout = cls_token.unsqueeze(1).expand_as(hidden_state)
  84. hidden_state = self.readout_projects[layer_idx](torch.cat((hidden_state, readout), -1))
  85. hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
  86. elif self.readout_type == "add":
  87. hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
  88. hidden_state = hidden_state.reshape(feature_shape)
  89. else:
  90. if hidden_state.dim() == 3:
  91. hidden_state = hidden_state[:, 1:]
  92. batch_size, _, num_channels = hidden_state.shape
  93. hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
  94. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  95. hidden_state = self.layers[layer_idx](hidden_state)
  96. out.append(hidden_state)
  97. return out
  98. class CHMv2PreActResidualLayer(nn.Module):
  99. """
  100. ResidualConvUnit, pre-activate residual unit.
  101. Args:
  102. config (`[CHMv2Config]`):
  103. Model configuration class defining the model architecture.
  104. """
  105. def __init__(self, config):
  106. super().__init__()
  107. self.activation1 = nn.ReLU()
  108. self.convolution1 = nn.Conv2d(
  109. config.fusion_hidden_size,
  110. config.fusion_hidden_size,
  111. kernel_size=3,
  112. stride=1,
  113. padding=1,
  114. bias=True,
  115. )
  116. self.activation2 = nn.ReLU()
  117. self.convolution2 = nn.Conv2d(
  118. config.fusion_hidden_size,
  119. config.fusion_hidden_size,
  120. kernel_size=3,
  121. stride=1,
  122. padding=1,
  123. bias=True,
  124. )
  125. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  126. residual = hidden_state
  127. hidden_state = self.activation1(hidden_state)
  128. hidden_state = self.convolution1(hidden_state)
  129. hidden_state = self.activation2(hidden_state)
  130. hidden_state = self.convolution2(hidden_state)
  131. return hidden_state + residual
  132. class CHMv2FeatureFusionLayer(nn.Module):
  133. def __init__(self, config: CHMv2Config, is_first_layer: bool = False):
  134. super().__init__()
  135. self.is_first_layer = is_first_layer
  136. self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
  137. if not is_first_layer:
  138. self.residual_layer1 = CHMv2PreActResidualLayer(config)
  139. self.residual_layer2 = CHMv2PreActResidualLayer(config)
  140. def forward(self, hidden_state, residual=None, size=None):
  141. if residual is not None and not self.is_first_layer:
  142. if hidden_state.shape != residual.shape:
  143. _, _, height, width = hidden_state.shape
  144. residual = nn.functional.interpolate(
  145. residual, size=(height, width), mode="bilinear", align_corners=False
  146. )
  147. hidden_state = hidden_state + self.residual_layer1(residual)
  148. hidden_state = self.residual_layer2(hidden_state)
  149. modifier = {"scale_factor": 2} if size is None else {"size": size}
  150. hidden_state = nn.functional.interpolate(
  151. hidden_state,
  152. **modifier,
  153. mode="bilinear",
  154. align_corners=True,
  155. )
  156. hidden_state = self.projection(hidden_state)
  157. return hidden_state
  158. class CHMv2UpsampleConvHead(nn.Module):
  159. """
  160. Convolutional head with intermediate upsampling.
  161. Architecture: Conv3x3 -> 2x bilinear upsample -> Conv3x3 -> ReLU -> Conv1x1.
  162. """
  163. def __init__(self, features, number_output_channels, n_hidden_channels=128):
  164. super().__init__()
  165. self.head = nn.ModuleList(
  166. [
  167. nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
  168. nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
  169. nn.Conv2d(features // 2, n_hidden_channels, kernel_size=3, stride=1, padding=1),
  170. nn.ReLU(),
  171. nn.Conv2d(n_hidden_channels, number_output_channels, kernel_size=1, stride=1, padding=0),
  172. ]
  173. )
  174. def forward(self, hidden_states):
  175. for layer in self.head:
  176. hidden_states = layer(hidden_states)
  177. return hidden_states
  178. class CHMv2Head(nn.Module):
  179. """
  180. CHMv2 dense-prediction head adapted from DPT.
  181. Integrates reassemble, projection convs, feature fusion, and UpConv depth head.
  182. """
  183. def __init__(self, config: CHMv2Config):
  184. super().__init__()
  185. self.config = config
  186. self.reassemble_stage = CHMv2ReassembleStage(config)
  187. self.convs = nn.ModuleList()
  188. for channel in config.post_process_channels:
  189. self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
  190. self.fusion_layers = nn.ModuleList()
  191. for idx in range(len(config.post_process_channels)):
  192. self.fusion_layers.append(CHMv2FeatureFusionLayer(config, is_first_layer=(idx == 0)))
  193. self.conv_depth = CHMv2UpsampleConvHead(
  194. features=config.fusion_hidden_size,
  195. number_output_channels=config.number_output_channels,
  196. n_hidden_channels=config.head_hidden_size,
  197. )
  198. def forward_features(self, hidden_states: list[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor:
  199. hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
  200. features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
  201. features.reverse()
  202. fused_hidden_state = self.fusion_layers[0](features[0])
  203. for i in range(1, len(self.fusion_layers)):
  204. fused_hidden_state = self.fusion_layers[i](fused_hidden_state, features[i])
  205. return fused_hidden_state
  206. def forward(self, hidden_states: list[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor:
  207. out = self.forward_features(hidden_states, patch_height, patch_width)
  208. out = self.conv_depth(out)
  209. return out
  210. class CHMv2FeaturesToDepth(nn.Module):
  211. """Converts raw logits from the CHMv2 head into a depth map using depth bins."""
  212. def __init__(self, config: CHMv2Config):
  213. super().__init__()
  214. self.min_depth = config.min_depth
  215. self.max_depth = config.max_depth
  216. self.bins_strategy = config.bins_strategy
  217. self.norm_strategy = config.norm_strategy
  218. self._mixlog_max_clamp_value = 1e-4
  219. self._mixlog_eps_shift = 1e-8
  220. self._mixlog_eps = 1e-12
  221. def _create_mixlog_bins(self, n_bins: int, device: torch.device) -> torch.Tensor:
  222. """
  223. Creates mixed log bins interpolated between linear and log distributions.
  224. The max_depth is divided by 8.0 internally; this scaling is reversed in
  225. `_create_outputs_with_mixlog_norm` by multiplying by 8.0.
  226. """
  227. scaled_max_depth = self.max_depth / 8.0
  228. linear = torch.linspace(self.min_depth, scaled_max_depth, n_bins, device=device)
  229. log = torch.exp(
  230. torch.linspace(
  231. torch.log(torch.tensor(self.min_depth, device=device)),
  232. torch.log(torch.tensor(scaled_max_depth, device=device)),
  233. n_bins,
  234. device=device,
  235. )
  236. )
  237. interp_weight = torch.linspace(1.0, 0.0, n_bins, device=device)
  238. bins = interp_weight * log + (1.0 - interp_weight) * linear
  239. return bins
  240. def _create_outputs_with_mixlog_norm(self, input: torch.Tensor, bins: torch.Tensor) -> torch.Tensor:
  241. """Converts depth bin logits to depth values using mixlog normalization."""
  242. logits = torch.relu(input)
  243. min_per_sample = logits.amin(dim=1, keepdim=True)
  244. shift = (-min_per_sample).clamp_min(0.0).clamp_max(self._mixlog_max_clamp_value) + self._mixlog_eps_shift
  245. logits_pos = logits + shift
  246. denom = logits_pos.sum(dim=1, keepdim=True)
  247. denom = torch.nan_to_num(denom, nan=1.0, posinf=1.0, neginf=1.0).clamp_min(self._mixlog_eps)
  248. weights = logits_pos / denom
  249. bins_broadcast = bins.view(1, -1, 1, 1).clamp_min(self._mixlog_eps)
  250. output = (weights * bins_broadcast).sum(dim=1, keepdim=True).clamp_min(self._mixlog_eps)
  251. output = output * 8.0
  252. return output
  253. def forward(self, x: torch.Tensor) -> torch.Tensor:
  254. n_bins = x.shape[1]
  255. if n_bins > 1:
  256. if self.bins_strategy == "linear":
  257. bins = torch.linspace(self.min_depth, self.max_depth, n_bins, device=x.device)
  258. elif self.bins_strategy == "log":
  259. bins = torch.linspace(
  260. torch.log(torch.tensor(self.min_depth)),
  261. torch.log(torch.tensor(self.max_depth)),
  262. n_bins,
  263. device=x.device,
  264. )
  265. bins = torch.exp(bins)
  266. else:
  267. bins = self._create_mixlog_bins(n_bins, x.device)
  268. if self.norm_strategy in ["linear", "softmax", "sigmoid"]:
  269. if self.norm_strategy == "linear":
  270. logit = torch.relu(x)
  271. eps = 0.1
  272. logit = logit + eps
  273. logit = logit / logit.sum(dim=1, keepdim=True)
  274. elif self.norm_strategy == "softmax":
  275. logit = torch.softmax(x, dim=1)
  276. else:
  277. logit = torch.sigmoid(x)
  278. logit = logit / logit.sum(dim=1, keepdim=True)
  279. output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
  280. else:
  281. output = self._create_outputs_with_mixlog_norm(x, bins)
  282. else:
  283. output = torch.relu(x) + self.min_depth
  284. return output
  285. @auto_docstring
  286. class CHMv2PreTrainedModel(PreTrainedModel):
  287. config: CHMv2Config
  288. base_model_prefix = "chmv2"
  289. main_input_name = "pixel_values"
  290. input_modalities = ("image",)
  291. supports_gradient_checkpointing = True
  292. _supports_sdpa = True
  293. _supports_flash_attn = True
  294. _supports_flex_attn = True
  295. _supports_attention_backend = True
  296. def _init_weights(self, module) -> None:
  297. super()._init_weights(module)
  298. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  299. init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  300. if module.bias is not None:
  301. init.zeros_(module.bias)
  302. @auto_docstring(
  303. custom_intro="""
  304. CHMv2 Model with a depth estimation head on top (consisting of convolutional layers) e.g. for canopy height
  305. estimation.
  306. """
  307. )
  308. class CHMv2ForDepthEstimation(CHMv2PreTrainedModel):
  309. def __init__(self, config: CHMv2Config):
  310. super().__init__(config)
  311. self.backbone = load_backbone(config)
  312. self.head = CHMv2Head(config)
  313. self.features_to_depth = CHMv2FeaturesToDepth(config)
  314. self.post_init()
  315. def get_input_embeddings(self):
  316. return self.backbone.get_input_embeddings()
  317. @can_return_tuple
  318. @auto_docstring
  319. def forward(
  320. self,
  321. pixel_values: torch.FloatTensor,
  322. labels: torch.LongTensor | None = None,
  323. **kwargs: Unpack[TransformersKwargs],
  324. ) -> DepthEstimatorOutput:
  325. r"""
  326. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  327. Ground truth depth estimation maps for computing the loss.
  328. """
  329. loss = None
  330. if labels is not None:
  331. raise NotImplementedError("Training is not implemented yet")
  332. _, _, height, width = pixel_values.shape
  333. patch_size = self.config.patch_size
  334. patch_height = height // patch_size
  335. patch_width = width // patch_size
  336. backbone_output = self.backbone(pixel_values, **kwargs)
  337. intermediate_features = list(zip(backbone_output.feature_maps, backbone_output.cls_tokens))
  338. head_output = self.head(intermediate_features, patch_height, patch_width)
  339. predicted_depth = self.features_to_depth(head_output)
  340. predicted_depth = predicted_depth.squeeze(dim=1)
  341. return DepthEstimatorOutput(
  342. loss=loss,
  343. predicted_depth=predicted_depth,
  344. hidden_states=backbone_output.hidden_states,
  345. attentions=backbone_output.attentions,
  346. )
  347. __all__ = ["CHMv2ForDepthEstimation", "CHMv2PreTrainedModel"]