| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278 |
- # Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch VitPose model."""
- from dataclasses import dataclass
- import torch
- from torch import nn
- from ... import initialization as init
- from ...backbone_utils import load_backbone
- from ...modeling_outputs import BackboneOutput
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
- from ...utils.generic import can_return_tuple
- from .configuration_vitpose import VitPoseConfig
- logger = logging.get_logger(__name__)
- # General docstring
- @dataclass
- @auto_docstring(
- custom_intro="""
- Class for outputs of pose estimation models.
- """
- )
- class VitPoseEstimatorOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Loss is not supported at this moment. See https://github.com/ViTAE-Transformer/ViTPose/tree/main/mmpose/models/losses for further detail.
- heatmaps (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`):
- Heatmaps as predicted by the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
- (also called feature maps) of the model at the output of each stage.
- """
- loss: torch.FloatTensor | None = None
- heatmaps: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @auto_docstring
- class VitPosePreTrainedModel(PreTrainedModel):
- config: VitPoseConfig
- base_model_prefix = "vit"
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- supports_gradient_checkpointing = True
- @torch.no_grad()
- def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm):
- """Initialize the weights"""
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, nn.LayerNorm):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"):
- """Flip the flipped heatmaps back to the original form.
- Args:
- output_flipped (`torch.tensor` of shape `(batch_size, num_keypoints, height, width)`):
- The output heatmaps obtained from the flipped images.
- flip_pairs (`torch.Tensor` of shape `(num_keypoints, 2)`):
- Pairs of keypoints which are mirrored (for example, left ear -- right ear).
- target_type (`str`, *optional*, defaults to `"gaussian-heatmap"`):
- Target type to use. Can be gaussian-heatmap or combined-target.
- gaussian-heatmap: Classification target with gaussian distribution.
- combined-target: The combination of classification target (response map) and regression target (offset map).
- Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
- Returns:
- torch.Tensor: heatmaps that flipped back to the original image
- """
- if target_type not in ["gaussian-heatmap", "combined-target"]:
- raise ValueError("target_type should be gaussian-heatmap or combined-target")
- if output_flipped.ndim != 4:
- raise ValueError("output_flipped should be [batch_size, num_keypoints, height, width]")
- batch_size, num_keypoints, height, width = output_flipped.shape
- channels = 1
- if target_type == "combined-target":
- channels = 3
- output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...]
- output_flipped = output_flipped.reshape(batch_size, -1, channels, height, width)
- output_flipped_back = output_flipped.clone()
- # Swap left-right parts
- for left, right in flip_pairs.tolist():
- output_flipped_back[:, left, ...] = output_flipped[:, right, ...]
- output_flipped_back[:, right, ...] = output_flipped[:, left, ...]
- output_flipped_back = output_flipped_back.reshape((batch_size, num_keypoints, height, width))
- # Flip horizontally
- output_flipped_back = output_flipped_back.flip(-1)
- return output_flipped_back
- class VitPoseSimpleDecoder(nn.Module):
- """
- Simple decoding head consisting of a ReLU activation, 4x upsampling and a 3x3 convolution, turning the
- feature maps into heatmaps.
- """
- def __init__(self, config: VitPoseConfig):
- super().__init__()
- self.activation = nn.ReLU()
- self.upsampling = nn.Upsample(scale_factor=config.scale_factor, mode="bilinear", align_corners=False)
- self.conv = nn.Conv2d(
- config.backbone_config.hidden_size, config.num_labels, kernel_size=3, stride=1, padding=1
- )
- def forward(self, hidden_state: torch.Tensor, flip_pairs: torch.Tensor | None = None) -> torch.Tensor:
- # Transform input: ReLU + upsample
- hidden_state = self.activation(hidden_state)
- hidden_state = self.upsampling(hidden_state)
- heatmaps = self.conv(hidden_state)
- if flip_pairs is not None:
- heatmaps = flip_back(heatmaps, flip_pairs)
- return heatmaps
- class VitPoseClassicDecoder(nn.Module):
- """
- Classic decoding head consisting of a 2 deconvolutional blocks, followed by a 1x1 convolution layer,
- turning the feature maps into heatmaps.
- """
- def __init__(self, config: VitPoseConfig):
- super().__init__()
- self.deconv1 = nn.ConvTranspose2d(
- config.backbone_config.hidden_size, 256, kernel_size=4, stride=2, padding=1, bias=False
- )
- self.batchnorm1 = nn.BatchNorm2d(256)
- self.relu1 = nn.ReLU()
- self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False)
- self.batchnorm2 = nn.BatchNorm2d(256)
- self.relu2 = nn.ReLU()
- self.conv = nn.Conv2d(256, config.num_labels, kernel_size=1, stride=1, padding=0)
- def forward(self, hidden_state: torch.Tensor, flip_pairs: torch.Tensor | None = None):
- hidden_state = self.deconv1(hidden_state)
- hidden_state = self.batchnorm1(hidden_state)
- hidden_state = self.relu1(hidden_state)
- hidden_state = self.deconv2(hidden_state)
- hidden_state = self.batchnorm2(hidden_state)
- hidden_state = self.relu2(hidden_state)
- heatmaps = self.conv(hidden_state)
- if flip_pairs is not None:
- heatmaps = flip_back(heatmaps, flip_pairs)
- return heatmaps
- @auto_docstring(
- custom_intro="""
- The VitPose model with a pose estimation head on top.
- """
- )
- class VitPoseForPoseEstimation(VitPosePreTrainedModel):
- def __init__(self, config: VitPoseConfig):
- super().__init__(config)
- self.backbone = load_backbone(config)
- # add backbone attributes
- if not hasattr(self.backbone.config, "hidden_size"):
- raise ValueError("The backbone should have a hidden_size attribute")
- if not hasattr(self.backbone.config, "image_size"):
- raise ValueError("The backbone should have an image_size attribute")
- if not hasattr(self.backbone.config, "patch_size"):
- raise ValueError("The backbone should have a patch_size attribute")
- self.head = VitPoseSimpleDecoder(config) if config.use_simple_decoder else VitPoseClassicDecoder(config)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.Tensor,
- dataset_index: torch.Tensor | None = None,
- flip_pairs: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> VitPoseEstimatorOutput:
- r"""
- dataset_index (`torch.Tensor` of shape `(batch_size,)`):
- Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
- This corresponds to the dataset index used during training, e.g. For the single dataset index 0 refers to the corresponding dataset. For the multiple datasets index 0 refers to dataset A (e.g. MPII) and index 1 refers to dataset B (e.g. CrowdPose).
- flip_pairs (`torch.tensor`, *optional*):
- Whether to mirror pairs of keypoints (for example, left ear -- right ear).
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, VitPoseForPoseEstimation
- >>> import torch
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> processor = AutoImageProcessor.from_pretrained("usyd-community/vitpose-base-simple")
- >>> model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
- >>> inputs = processor(image, boxes=boxes, return_tensors="pt")
- >>> with torch.no_grad():
- ... outputs = model(**inputs)
- >>> heatmaps = outputs.heatmaps
- ```"""
- loss = None
- if labels is not None:
- raise NotImplementedError("Training is not yet supported")
- outputs: BackboneOutput = self.backbone.forward_with_filtered_kwargs(
- pixel_values,
- dataset_index=dataset_index,
- **kwargs,
- )
- # Turn output hidden states in tensor of shape (batch_size, num_channels, height, width)
- sequence_output = outputs.feature_maps[-1]
- batch_size = sequence_output.shape[0]
- patch_height = self.config.backbone_config.image_size[0] // self.config.backbone_config.patch_size[0]
- patch_width = self.config.backbone_config.image_size[1] // self.config.backbone_config.patch_size[1]
- sequence_output = sequence_output.permute(0, 2, 1)
- sequence_output = sequence_output.reshape(batch_size, -1, patch_height, patch_width).contiguous()
- heatmaps = self.head(sequence_output, flip_pairs=flip_pairs)
- return VitPoseEstimatorOutput(
- loss=loss,
- heatmaps=heatmaps,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = ["VitPosePreTrainedModel", "VitPoseForPoseEstimation"]
|