modular_chmv2.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. # Copyright 2026 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """CHMv2 model — Canopy Height Model v2, adapted from DPT."""
  15. from typing import Literal
  16. import torch
  17. from huggingface_hub.dataclasses import strict
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...backbone_utils import consolidate_backbone_kwargs_to_config, load_backbone
  21. from ...configuration_utils import PreTrainedConfig
  22. from ...modeling_outputs import DepthEstimatorOutput
  23. from ...modeling_utils import PreTrainedModel
  24. from ...processing_utils import ImagesKwargs, Unpack
  25. from ...utils import TensorType, TransformersKwargs, auto_docstring, can_return_tuple, requires_backends
  26. from ..auto import AutoConfig
  27. from ..depth_anything.modeling_depth_anything import (
  28. DepthAnythingPreActResidualLayer,
  29. )
  30. from ..dpt.image_processing_dpt import DPTImageProcessor
  31. from ..dpt.modeling_dpt import DPTReassembleLayer, _get_backbone_hidden_size
  32. @auto_docstring(checkpoint="facebook/dinov3-vitl16-chmv2-dpt-head")
  33. @strict
  34. class CHMv2Config(PreTrainedConfig):
  35. r"""
  36. backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*):
  37. The configuration of the backbone model. Only DINOv3ViTConfig is currently supported.
  38. patch_size (`int`, *optional*, defaults to 16):
  39. The patch size used by the backbone vision transformer.
  40. reassemble_factors (`list[float]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
  41. The up/downsampling factors of the reassemble layers.
  42. post_process_channels (`list[int]`, *optional*, defaults to `[128, 256, 512, 1024]`):
  43. The output channel sizes of the reassemble stage for each backbone feature level.
  44. fusion_hidden_size (`int`, *optional*, defaults to 256):
  45. The number of channels before fusion.
  46. head_hidden_size (`int`, *optional*, defaults to 128):
  47. The number of channels in the hidden layer of the depth estimation head.
  48. number_output_channels (`int`, *optional*, defaults to 256):
  49. Number of output channels for the CHMv2 head (number of depth bins).
  50. readout_type (`str`, *optional*, defaults to `"project"`):
  51. Type of readout operation for the CLS token. One of `["ignore", "add", "project"]`.
  52. min_depth (`float`, *optional*, defaults to 0.001):
  53. The minimum depth value for depth bin calculation.
  54. max_depth (`float`, *optional*, defaults to 96.0):
  55. The maximum depth value for depth bin calculation.
  56. bins_strategy (`str`, *optional*, defaults to `"chmv2_mixlog"`):
  57. The strategy for depth bins distribution. One of `["linear", "log", "chmv2_mixlog"]`.
  58. norm_strategy (`str`, *optional*, defaults to `"chmv2_mixlog"`):
  59. The normalization strategy for depth prediction. One of `["linear", "softmax", "sigmoid", "chmv2_mixlog"]`.
  60. ```python
  61. >>> from transformers import CHMv2Config, CHMv2ForDepthEstimation
  62. >>> configuration = CHMv2Config()
  63. >>> model = CHMv2ForDepthEstimation(configuration)
  64. >>> configuration = model.config
  65. ```
  66. """
  67. model_type = "chmv2"
  68. sub_configs = {"backbone_config": AutoConfig}
  69. backbone_config: dict | PreTrainedConfig | None = None
  70. patch_size: int = 16
  71. initializer_range: float = 0.02
  72. reassemble_factors: list[float | int] | None = None
  73. post_process_channels: list[int] | None = None
  74. fusion_hidden_size: int = 256
  75. head_hidden_size: int = 128
  76. number_output_channels: int = 256
  77. readout_type: str = "project"
  78. min_depth: float = 0.001
  79. max_depth: float = 96.0
  80. bins_strategy: Literal["linear", "log", "chmv2_mixlog"] = "chmv2_mixlog"
  81. norm_strategy: Literal["linear", "softmax", "sigmoid", "chmv2_mixlog"] = "chmv2_mixlog"
  82. def __post_init__(self, **kwargs):
  83. if self.reassemble_factors is None:
  84. self.reassemble_factors = [4, 2, 1, 0.5]
  85. if self.post_process_channels is None:
  86. self.post_process_channels = [128, 256, 512, 1024]
  87. default_config_kwargs = {
  88. "image_size": 416,
  89. "hidden_size": 1024,
  90. "intermediate_size": 4096,
  91. "num_attention_heads": 16,
  92. "num_hidden_layers": 24,
  93. "num_register_tokens": 4,
  94. "key_bias": True,
  95. "out_indices": [6, 12, 18, 24],
  96. "reshape_hidden_states": True,
  97. "apply_layernorm": True,
  98. "layer_norm_eps": 1e-6,
  99. "return_class_token": True,
  100. }
  101. self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config(
  102. backbone_config=self.backbone_config,
  103. default_config_type="dinov3_vit",
  104. default_config_kwargs=default_config_kwargs,
  105. **kwargs,
  106. )
  107. super().__post_init__(**kwargs)
  108. class CHMv2ImageProcessorKwargs(ImagesKwargs, total=False):
  109. r"""
  110. ensure_multiple_of (`int`, *optional*, defaults to 1):
  111. If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overridden
  112. by `ensure_multiple_of` in `preprocess`.
  113. keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
  114. If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
  115. be overridden by `keep_aspect_ratio` in `preprocess`.
  116. do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
  117. Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
  118. is used for background, and background itself is not included in all classes of a dataset (e.g.
  119. ADE20k). The background label will be replaced by 255.
  120. """
  121. ensure_multiple_of: int
  122. size_divisor: int
  123. keep_aspect_ratio: bool
  124. do_reduce_labels: bool
  125. class CHMv2ImageProcessor(DPTImageProcessor):
  126. do_resize = False
  127. do_pad = True
  128. size_divisor = 16
  129. ensure_multiple_of = 16
  130. keep_aspect_ratio = True
  131. image_mean = [0.420, 0.411, 0.296]
  132. image_std = [0.213, 0.156, 0.143]
  133. valid_kwargs = CHMv2ImageProcessorKwargs
  134. def post_process_depth_estimation(
  135. self,
  136. outputs: "DepthEstimatorOutput",
  137. target_sizes: TensorType | list[tuple[int, int]] | None | None = None,
  138. ) -> list[dict[str, TensorType]]:
  139. """
  140. Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
  141. Only supports PyTorch.
  142. Args:
  143. outputs ([`DepthEstimatorOutput`]):
  144. Raw outputs of the model.
  145. target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
  146. Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
  147. (height, width) of each image in the batch. If left to None, predictions will not be resized.
  148. Returns:
  149. `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
  150. predictions.
  151. """
  152. requires_backends(self, "torch")
  153. predicted_depth = outputs.predicted_depth
  154. if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
  155. raise ValueError(
  156. "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
  157. )
  158. results = []
  159. target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
  160. for depth, target_size in zip(predicted_depth, target_sizes):
  161. if target_size is not None:
  162. depth = torch.nn.functional.interpolate(
  163. depth[None, None, ...], size=target_size, mode="bilinear", align_corners=True
  164. ).squeeze()
  165. results.append({"predicted_depth": depth})
  166. return results
  167. class CHMv2ReassembleLayer(DPTReassembleLayer):
  168. pass
  169. class CHMv2ReassembleStage(nn.Module):
  170. """
  171. Reassemble stage that processes hidden states from the backbone into image-like feature
  172. representations at various resolutions.
  173. """
  174. def __init__(self, config: CHMv2Config):
  175. super().__init__()
  176. self.config = config
  177. self.readout_type = config.readout_type
  178. self.layers = nn.ModuleList()
  179. for out_channels, factor in zip(config.post_process_channels, config.reassemble_factors):
  180. self.layers.append(
  181. CHMv2ReassembleLayer(
  182. config=config,
  183. channels=out_channels,
  184. factor=factor,
  185. )
  186. )
  187. hidden_size = _get_backbone_hidden_size(config)
  188. if self.readout_type == "project":
  189. self.readout_projects = nn.ModuleList()
  190. for _ in range(len(self.layers)):
  191. self.readout_projects.append(nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), nn.GELU()))
  192. def forward(self, hidden_states: list[torch.Tensor], patch_height=None, patch_width=None) -> list[torch.Tensor]:
  193. out = []
  194. for layer_idx, hidden_state in enumerate(hidden_states):
  195. if isinstance(hidden_state, (tuple, list)) and len(hidden_state) == 2:
  196. hidden_state, cls_token = hidden_state[0], hidden_state[1]
  197. feature_shape = hidden_state.shape
  198. if self.readout_type == "project":
  199. hidden_state = hidden_state.flatten(2).transpose(1, 2)
  200. readout = cls_token.unsqueeze(1).expand_as(hidden_state)
  201. hidden_state = self.readout_projects[layer_idx](torch.cat((hidden_state, readout), -1))
  202. hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
  203. elif self.readout_type == "add":
  204. hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
  205. hidden_state = hidden_state.reshape(feature_shape)
  206. else:
  207. if hidden_state.dim() == 3:
  208. hidden_state = hidden_state[:, 1:]
  209. batch_size, _, num_channels = hidden_state.shape
  210. hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
  211. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  212. hidden_state = self.layers[layer_idx](hidden_state)
  213. out.append(hidden_state)
  214. return out
  215. class CHMv2PreActResidualLayer(DepthAnythingPreActResidualLayer):
  216. pass
  217. class CHMv2FeatureFusionLayer(nn.Module):
  218. def __init__(self, config: CHMv2Config, is_first_layer: bool = False):
  219. super().__init__()
  220. self.is_first_layer = is_first_layer
  221. self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
  222. if not is_first_layer:
  223. self.residual_layer1 = CHMv2PreActResidualLayer(config)
  224. self.residual_layer2 = CHMv2PreActResidualLayer(config)
  225. def forward(self, hidden_state, residual=None, size=None):
  226. if residual is not None and not self.is_first_layer:
  227. if hidden_state.shape != residual.shape:
  228. _, _, height, width = hidden_state.shape
  229. residual = nn.functional.interpolate(
  230. residual, size=(height, width), mode="bilinear", align_corners=False
  231. )
  232. hidden_state = hidden_state + self.residual_layer1(residual)
  233. hidden_state = self.residual_layer2(hidden_state)
  234. modifier = {"scale_factor": 2} if size is None else {"size": size}
  235. hidden_state = nn.functional.interpolate(
  236. hidden_state,
  237. **modifier,
  238. mode="bilinear",
  239. align_corners=True,
  240. )
  241. hidden_state = self.projection(hidden_state)
  242. return hidden_state
  243. class CHMv2UpsampleConvHead(nn.Module):
  244. """
  245. Convolutional head with intermediate upsampling.
  246. Architecture: Conv3x3 -> 2x bilinear upsample -> Conv3x3 -> ReLU -> Conv1x1.
  247. """
  248. def __init__(self, features, number_output_channels, n_hidden_channels=128):
  249. super().__init__()
  250. self.head = nn.ModuleList(
  251. [
  252. nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
  253. nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
  254. nn.Conv2d(features // 2, n_hidden_channels, kernel_size=3, stride=1, padding=1),
  255. nn.ReLU(),
  256. nn.Conv2d(n_hidden_channels, number_output_channels, kernel_size=1, stride=1, padding=0),
  257. ]
  258. )
  259. def forward(self, hidden_states):
  260. for layer in self.head:
  261. hidden_states = layer(hidden_states)
  262. return hidden_states
  263. class CHMv2Head(nn.Module):
  264. """
  265. CHMv2 dense-prediction head adapted from DPT.
  266. Integrates reassemble, projection convs, feature fusion, and UpConv depth head.
  267. """
  268. def __init__(self, config: CHMv2Config):
  269. super().__init__()
  270. self.config = config
  271. self.reassemble_stage = CHMv2ReassembleStage(config)
  272. self.convs = nn.ModuleList()
  273. for channel in config.post_process_channels:
  274. self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
  275. self.fusion_layers = nn.ModuleList()
  276. for idx in range(len(config.post_process_channels)):
  277. self.fusion_layers.append(CHMv2FeatureFusionLayer(config, is_first_layer=(idx == 0)))
  278. self.conv_depth = CHMv2UpsampleConvHead(
  279. features=config.fusion_hidden_size,
  280. number_output_channels=config.number_output_channels,
  281. n_hidden_channels=config.head_hidden_size,
  282. )
  283. def forward_features(self, hidden_states: list[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor:
  284. hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
  285. features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
  286. features.reverse()
  287. fused_hidden_state = self.fusion_layers[0](features[0])
  288. for i in range(1, len(self.fusion_layers)):
  289. fused_hidden_state = self.fusion_layers[i](fused_hidden_state, features[i])
  290. return fused_hidden_state
  291. def forward(self, hidden_states: list[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor:
  292. out = self.forward_features(hidden_states, patch_height, patch_width)
  293. out = self.conv_depth(out)
  294. return out
  295. class CHMv2FeaturesToDepth(nn.Module):
  296. """Converts raw logits from the CHMv2 head into a depth map using depth bins."""
  297. def __init__(self, config: CHMv2Config):
  298. super().__init__()
  299. self.min_depth = config.min_depth
  300. self.max_depth = config.max_depth
  301. self.bins_strategy = config.bins_strategy
  302. self.norm_strategy = config.norm_strategy
  303. self._mixlog_max_clamp_value = 1e-4
  304. self._mixlog_eps_shift = 1e-8
  305. self._mixlog_eps = 1e-12
  306. def _create_mixlog_bins(self, n_bins: int, device: torch.device) -> torch.Tensor:
  307. """
  308. Creates mixed log bins interpolated between linear and log distributions.
  309. The max_depth is divided by 8.0 internally; this scaling is reversed in
  310. `_create_outputs_with_mixlog_norm` by multiplying by 8.0.
  311. """
  312. scaled_max_depth = self.max_depth / 8.0
  313. linear = torch.linspace(self.min_depth, scaled_max_depth, n_bins, device=device)
  314. log = torch.exp(
  315. torch.linspace(
  316. torch.log(torch.tensor(self.min_depth, device=device)),
  317. torch.log(torch.tensor(scaled_max_depth, device=device)),
  318. n_bins,
  319. device=device,
  320. )
  321. )
  322. interp_weight = torch.linspace(1.0, 0.0, n_bins, device=device)
  323. bins = interp_weight * log + (1.0 - interp_weight) * linear
  324. return bins
  325. def _create_outputs_with_mixlog_norm(self, input: torch.Tensor, bins: torch.Tensor) -> torch.Tensor:
  326. """Converts depth bin logits to depth values using mixlog normalization."""
  327. logits = torch.relu(input)
  328. min_per_sample = logits.amin(dim=1, keepdim=True)
  329. shift = (-min_per_sample).clamp_min(0.0).clamp_max(self._mixlog_max_clamp_value) + self._mixlog_eps_shift
  330. logits_pos = logits + shift
  331. denom = logits_pos.sum(dim=1, keepdim=True)
  332. denom = torch.nan_to_num(denom, nan=1.0, posinf=1.0, neginf=1.0).clamp_min(self._mixlog_eps)
  333. weights = logits_pos / denom
  334. bins_broadcast = bins.view(1, -1, 1, 1).clamp_min(self._mixlog_eps)
  335. output = (weights * bins_broadcast).sum(dim=1, keepdim=True).clamp_min(self._mixlog_eps)
  336. output = output * 8.0
  337. return output
  338. def forward(self, x: torch.Tensor) -> torch.Tensor:
  339. n_bins = x.shape[1]
  340. if n_bins > 1:
  341. if self.bins_strategy == "linear":
  342. bins = torch.linspace(self.min_depth, self.max_depth, n_bins, device=x.device)
  343. elif self.bins_strategy == "log":
  344. bins = torch.linspace(
  345. torch.log(torch.tensor(self.min_depth)),
  346. torch.log(torch.tensor(self.max_depth)),
  347. n_bins,
  348. device=x.device,
  349. )
  350. bins = torch.exp(bins)
  351. else:
  352. bins = self._create_mixlog_bins(n_bins, x.device)
  353. if self.norm_strategy in ["linear", "softmax", "sigmoid"]:
  354. if self.norm_strategy == "linear":
  355. logit = torch.relu(x)
  356. eps = 0.1
  357. logit = logit + eps
  358. logit = logit / logit.sum(dim=1, keepdim=True)
  359. elif self.norm_strategy == "softmax":
  360. logit = torch.softmax(x, dim=1)
  361. else:
  362. logit = torch.sigmoid(x)
  363. logit = logit / logit.sum(dim=1, keepdim=True)
  364. output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
  365. else:
  366. output = self._create_outputs_with_mixlog_norm(x, bins)
  367. else:
  368. output = torch.relu(x) + self.min_depth
  369. return output
  370. @auto_docstring
  371. class CHMv2PreTrainedModel(PreTrainedModel):
  372. config: CHMv2Config
  373. base_model_prefix = "chmv2"
  374. main_input_name = "pixel_values"
  375. input_modalities = ("image",)
  376. supports_gradient_checkpointing = True
  377. _supports_sdpa = True
  378. _supports_flash_attn = True
  379. _supports_flex_attn = True
  380. _supports_attention_backend = True
  381. def _init_weights(self, module) -> None:
  382. super()._init_weights(module)
  383. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  384. init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  385. if module.bias is not None:
  386. init.zeros_(module.bias)
  387. @auto_docstring(
  388. custom_intro="""
  389. CHMv2 Model with a depth estimation head on top (consisting of convolutional layers) e.g. for canopy height
  390. estimation.
  391. """
  392. )
  393. class CHMv2ForDepthEstimation(CHMv2PreTrainedModel):
  394. def __init__(self, config: CHMv2Config):
  395. super().__init__(config)
  396. self.backbone = load_backbone(config)
  397. self.head = CHMv2Head(config)
  398. self.features_to_depth = CHMv2FeaturesToDepth(config)
  399. self.post_init()
  400. def get_input_embeddings(self):
  401. return self.backbone.get_input_embeddings()
  402. @can_return_tuple
  403. @auto_docstring
  404. def forward(
  405. self,
  406. pixel_values: torch.FloatTensor,
  407. labels: torch.LongTensor | None = None,
  408. **kwargs: Unpack[TransformersKwargs],
  409. ) -> DepthEstimatorOutput:
  410. r"""
  411. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  412. Ground truth depth estimation maps for computing the loss.
  413. """
  414. loss = None
  415. if labels is not None:
  416. raise NotImplementedError("Training is not implemented yet")
  417. _, _, height, width = pixel_values.shape
  418. patch_size = self.config.patch_size
  419. patch_height = height // patch_size
  420. patch_width = width // patch_size
  421. backbone_output = self.backbone(pixel_values, **kwargs)
  422. intermediate_features = list(zip(backbone_output.feature_maps, backbone_output.cls_tokens))
  423. head_output = self.head(intermediate_features, patch_height, patch_width)
  424. predicted_depth = self.features_to_depth(head_output)
  425. predicted_depth = predicted_depth.squeeze(dim=1)
  426. return DepthEstimatorOutput(
  427. loss=loss,
  428. predicted_depth=predicted_depth,
  429. hidden_states=backbone_output.hidden_states,
  430. attentions=backbone_output.attentions,
  431. )
  432. __all__ = [
  433. "CHMv2Config",
  434. "CHMv2ImageProcessor",
  435. "CHMv2ForDepthEstimation",
  436. "CHMv2PreTrainedModel",
  437. ]