| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/chmv2/modular_chmv2.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_chmv2.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2026 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- from torch import nn
- from ... import initialization as init
- from ...backbone_utils import load_backbone
- from ...modeling_outputs import DepthEstimatorOutput
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
- from .configuration_chmv2 import CHMv2Config
- def _get_backbone_hidden_size(config):
- if config.backbone_config is not None and hasattr(config.backbone_config, "hidden_size"):
- return config.backbone_config.hidden_size
- else:
- return config.hidden_size
- class CHMv2ReassembleLayer(nn.Module):
- def __init__(self, config: CHMv2Config, channels: int, factor: int):
- super().__init__()
- # projection
- hidden_size = _get_backbone_hidden_size(config)
- self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1)
- # up/down sampling depending on factor
- if factor > 1:
- self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
- elif factor == 1:
- self.resize = nn.Identity()
- elif factor < 1:
- # so should downsample
- self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
- def forward(self, hidden_state):
- hidden_state = self.projection(hidden_state)
- hidden_state = self.resize(hidden_state)
- return hidden_state
- class CHMv2ReassembleStage(nn.Module):
- """
- Reassemble stage that processes hidden states from the backbone into image-like feature
- representations at various resolutions.
- """
- def __init__(self, config: CHMv2Config):
- super().__init__()
- self.config = config
- self.readout_type = config.readout_type
- self.layers = nn.ModuleList()
- for out_channels, factor in zip(config.post_process_channels, config.reassemble_factors):
- self.layers.append(
- CHMv2ReassembleLayer(
- config=config,
- channels=out_channels,
- factor=factor,
- )
- )
- hidden_size = _get_backbone_hidden_size(config)
- if self.readout_type == "project":
- self.readout_projects = nn.ModuleList()
- for _ in range(len(self.layers)):
- self.readout_projects.append(nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), nn.GELU()))
- def forward(self, hidden_states: list[torch.Tensor], patch_height=None, patch_width=None) -> list[torch.Tensor]:
- out = []
- for layer_idx, hidden_state in enumerate(hidden_states):
- if isinstance(hidden_state, (tuple, list)) and len(hidden_state) == 2:
- hidden_state, cls_token = hidden_state[0], hidden_state[1]
- feature_shape = hidden_state.shape
- if self.readout_type == "project":
- hidden_state = hidden_state.flatten(2).transpose(1, 2)
- readout = cls_token.unsqueeze(1).expand_as(hidden_state)
- hidden_state = self.readout_projects[layer_idx](torch.cat((hidden_state, readout), -1))
- hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
- elif self.readout_type == "add":
- hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
- hidden_state = hidden_state.reshape(feature_shape)
- else:
- if hidden_state.dim() == 3:
- hidden_state = hidden_state[:, 1:]
- batch_size, _, num_channels = hidden_state.shape
- hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
- hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
- hidden_state = self.layers[layer_idx](hidden_state)
- out.append(hidden_state)
- return out
- class CHMv2PreActResidualLayer(nn.Module):
- """
- ResidualConvUnit, pre-activate residual unit.
- Args:
- config (`[CHMv2Config]`):
- Model configuration class defining the model architecture.
- """
- def __init__(self, config):
- super().__init__()
- self.activation1 = nn.ReLU()
- self.convolution1 = nn.Conv2d(
- config.fusion_hidden_size,
- config.fusion_hidden_size,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=True,
- )
- self.activation2 = nn.ReLU()
- self.convolution2 = nn.Conv2d(
- config.fusion_hidden_size,
- config.fusion_hidden_size,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=True,
- )
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- residual = hidden_state
- hidden_state = self.activation1(hidden_state)
- hidden_state = self.convolution1(hidden_state)
- hidden_state = self.activation2(hidden_state)
- hidden_state = self.convolution2(hidden_state)
- return hidden_state + residual
- class CHMv2FeatureFusionLayer(nn.Module):
- def __init__(self, config: CHMv2Config, is_first_layer: bool = False):
- super().__init__()
- self.is_first_layer = is_first_layer
- self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
- if not is_first_layer:
- self.residual_layer1 = CHMv2PreActResidualLayer(config)
- self.residual_layer2 = CHMv2PreActResidualLayer(config)
- def forward(self, hidden_state, residual=None, size=None):
- if residual is not None and not self.is_first_layer:
- if hidden_state.shape != residual.shape:
- _, _, height, width = hidden_state.shape
- residual = nn.functional.interpolate(
- residual, size=(height, width), mode="bilinear", align_corners=False
- )
- hidden_state = hidden_state + self.residual_layer1(residual)
- hidden_state = self.residual_layer2(hidden_state)
- modifier = {"scale_factor": 2} if size is None else {"size": size}
- hidden_state = nn.functional.interpolate(
- hidden_state,
- **modifier,
- mode="bilinear",
- align_corners=True,
- )
- hidden_state = self.projection(hidden_state)
- return hidden_state
- class CHMv2UpsampleConvHead(nn.Module):
- """
- Convolutional head with intermediate upsampling.
- Architecture: Conv3x3 -> 2x bilinear upsample -> Conv3x3 -> ReLU -> Conv1x1.
- """
- def __init__(self, features, number_output_channels, n_hidden_channels=128):
- super().__init__()
- self.head = nn.ModuleList(
- [
- nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
- nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
- nn.Conv2d(features // 2, n_hidden_channels, kernel_size=3, stride=1, padding=1),
- nn.ReLU(),
- nn.Conv2d(n_hidden_channels, number_output_channels, kernel_size=1, stride=1, padding=0),
- ]
- )
- def forward(self, hidden_states):
- for layer in self.head:
- hidden_states = layer(hidden_states)
- return hidden_states
- class CHMv2Head(nn.Module):
- """
- CHMv2 dense-prediction head adapted from DPT.
- Integrates reassemble, projection convs, feature fusion, and UpConv depth head.
- """
- def __init__(self, config: CHMv2Config):
- super().__init__()
- self.config = config
- self.reassemble_stage = CHMv2ReassembleStage(config)
- self.convs = nn.ModuleList()
- for channel in config.post_process_channels:
- self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
- self.fusion_layers = nn.ModuleList()
- for idx in range(len(config.post_process_channels)):
- self.fusion_layers.append(CHMv2FeatureFusionLayer(config, is_first_layer=(idx == 0)))
- self.conv_depth = CHMv2UpsampleConvHead(
- features=config.fusion_hidden_size,
- number_output_channels=config.number_output_channels,
- n_hidden_channels=config.head_hidden_size,
- )
- def forward_features(self, hidden_states: list[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor:
- hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
- features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
- features.reverse()
- fused_hidden_state = self.fusion_layers[0](features[0])
- for i in range(1, len(self.fusion_layers)):
- fused_hidden_state = self.fusion_layers[i](fused_hidden_state, features[i])
- return fused_hidden_state
- def forward(self, hidden_states: list[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor:
- out = self.forward_features(hidden_states, patch_height, patch_width)
- out = self.conv_depth(out)
- return out
- class CHMv2FeaturesToDepth(nn.Module):
- """Converts raw logits from the CHMv2 head into a depth map using depth bins."""
- def __init__(self, config: CHMv2Config):
- super().__init__()
- self.min_depth = config.min_depth
- self.max_depth = config.max_depth
- self.bins_strategy = config.bins_strategy
- self.norm_strategy = config.norm_strategy
- self._mixlog_max_clamp_value = 1e-4
- self._mixlog_eps_shift = 1e-8
- self._mixlog_eps = 1e-12
- def _create_mixlog_bins(self, n_bins: int, device: torch.device) -> torch.Tensor:
- """
- Creates mixed log bins interpolated between linear and log distributions.
- The max_depth is divided by 8.0 internally; this scaling is reversed in
- `_create_outputs_with_mixlog_norm` by multiplying by 8.0.
- """
- scaled_max_depth = self.max_depth / 8.0
- linear = torch.linspace(self.min_depth, scaled_max_depth, n_bins, device=device)
- log = torch.exp(
- torch.linspace(
- torch.log(torch.tensor(self.min_depth, device=device)),
- torch.log(torch.tensor(scaled_max_depth, device=device)),
- n_bins,
- device=device,
- )
- )
- interp_weight = torch.linspace(1.0, 0.0, n_bins, device=device)
- bins = interp_weight * log + (1.0 - interp_weight) * linear
- return bins
- def _create_outputs_with_mixlog_norm(self, input: torch.Tensor, bins: torch.Tensor) -> torch.Tensor:
- """Converts depth bin logits to depth values using mixlog normalization."""
- logits = torch.relu(input)
- min_per_sample = logits.amin(dim=1, keepdim=True)
- shift = (-min_per_sample).clamp_min(0.0).clamp_max(self._mixlog_max_clamp_value) + self._mixlog_eps_shift
- logits_pos = logits + shift
- denom = logits_pos.sum(dim=1, keepdim=True)
- denom = torch.nan_to_num(denom, nan=1.0, posinf=1.0, neginf=1.0).clamp_min(self._mixlog_eps)
- weights = logits_pos / denom
- bins_broadcast = bins.view(1, -1, 1, 1).clamp_min(self._mixlog_eps)
- output = (weights * bins_broadcast).sum(dim=1, keepdim=True).clamp_min(self._mixlog_eps)
- output = output * 8.0
- return output
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- n_bins = x.shape[1]
- if n_bins > 1:
- if self.bins_strategy == "linear":
- bins = torch.linspace(self.min_depth, self.max_depth, n_bins, device=x.device)
- elif self.bins_strategy == "log":
- bins = torch.linspace(
- torch.log(torch.tensor(self.min_depth)),
- torch.log(torch.tensor(self.max_depth)),
- n_bins,
- device=x.device,
- )
- bins = torch.exp(bins)
- else:
- bins = self._create_mixlog_bins(n_bins, x.device)
- if self.norm_strategy in ["linear", "softmax", "sigmoid"]:
- if self.norm_strategy == "linear":
- logit = torch.relu(x)
- eps = 0.1
- logit = logit + eps
- logit = logit / logit.sum(dim=1, keepdim=True)
- elif self.norm_strategy == "softmax":
- logit = torch.softmax(x, dim=1)
- else:
- logit = torch.sigmoid(x)
- logit = logit / logit.sum(dim=1, keepdim=True)
- output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
- else:
- output = self._create_outputs_with_mixlog_norm(x, bins)
- else:
- output = torch.relu(x) + self.min_depth
- return output
- @auto_docstring
- class CHMv2PreTrainedModel(PreTrainedModel):
- config: CHMv2Config
- base_model_prefix = "chmv2"
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- supports_gradient_checkpointing = True
- _supports_sdpa = True
- _supports_flash_attn = True
- _supports_flex_attn = True
- _supports_attention_backend = True
- def _init_weights(self, module) -> None:
- super()._init_weights(module)
- if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
- init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- init.zeros_(module.bias)
- @auto_docstring(
- custom_intro="""
- CHMv2 Model with a depth estimation head on top (consisting of convolutional layers) e.g. for canopy height
- estimation.
- """
- )
- class CHMv2ForDepthEstimation(CHMv2PreTrainedModel):
- def __init__(self, config: CHMv2Config):
- super().__init__(config)
- self.backbone = load_backbone(config)
- self.head = CHMv2Head(config)
- self.features_to_depth = CHMv2FeaturesToDepth(config)
- self.post_init()
- def get_input_embeddings(self):
- return self.backbone.get_input_embeddings()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> DepthEstimatorOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
- Ground truth depth estimation maps for computing the loss.
- """
- loss = None
- if labels is not None:
- raise NotImplementedError("Training is not implemented yet")
- _, _, height, width = pixel_values.shape
- patch_size = self.config.patch_size
- patch_height = height // patch_size
- patch_width = width // patch_size
- backbone_output = self.backbone(pixel_values, **kwargs)
- intermediate_features = list(zip(backbone_output.feature_maps, backbone_output.cls_tokens))
- head_output = self.head(intermediate_features, patch_height, patch_width)
- predicted_depth = self.features_to_depth(head_output)
- predicted_depth = predicted_depth.squeeze(dim=1)
- return DepthEstimatorOutput(
- loss=loss,
- predicted_depth=predicted_depth,
- hidden_states=backbone_output.hidden_states,
- attentions=backbone_output.attentions,
- )
- __all__ = ["CHMv2ForDepthEstimation", "CHMv2PreTrainedModel"]
|