modeling_vitpose.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch VitPose model."""
  15. from dataclasses import dataclass
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...backbone_utils import load_backbone
  20. from ...modeling_outputs import BackboneOutput
  21. from ...modeling_utils import PreTrainedModel
  22. from ...processing_utils import Unpack
  23. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  24. from ...utils.generic import can_return_tuple
  25. from .configuration_vitpose import VitPoseConfig
  26. logger = logging.get_logger(__name__)
  27. # General docstring
  28. @dataclass
  29. @auto_docstring(
  30. custom_intro="""
  31. Class for outputs of pose estimation models.
  32. """
  33. )
  34. class VitPoseEstimatorOutput(ModelOutput):
  35. r"""
  36. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  37. Loss is not supported at this moment. See https://github.com/ViTAE-Transformer/ViTPose/tree/main/mmpose/models/losses for further detail.
  38. heatmaps (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`):
  39. Heatmaps as predicted by the model.
  40. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  41. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  42. one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
  43. (also called feature maps) of the model at the output of each stage.
  44. """
  45. loss: torch.FloatTensor | None = None
  46. heatmaps: torch.FloatTensor | None = None
  47. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  48. attentions: tuple[torch.FloatTensor, ...] | None = None
  49. @auto_docstring
  50. class VitPosePreTrainedModel(PreTrainedModel):
  51. config: VitPoseConfig
  52. base_model_prefix = "vit"
  53. main_input_name = "pixel_values"
  54. input_modalities = ("image",)
  55. supports_gradient_checkpointing = True
  56. @torch.no_grad()
  57. def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm):
  58. """Initialize the weights"""
  59. if isinstance(module, (nn.Linear, nn.Conv2d)):
  60. init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  61. if module.bias is not None:
  62. init.zeros_(module.bias)
  63. elif isinstance(module, nn.LayerNorm):
  64. init.zeros_(module.bias)
  65. init.ones_(module.weight)
  66. def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"):
  67. """Flip the flipped heatmaps back to the original form.
  68. Args:
  69. output_flipped (`torch.tensor` of shape `(batch_size, num_keypoints, height, width)`):
  70. The output heatmaps obtained from the flipped images.
  71. flip_pairs (`torch.Tensor` of shape `(num_keypoints, 2)`):
  72. Pairs of keypoints which are mirrored (for example, left ear -- right ear).
  73. target_type (`str`, *optional*, defaults to `"gaussian-heatmap"`):
  74. Target type to use. Can be gaussian-heatmap or combined-target.
  75. gaussian-heatmap: Classification target with gaussian distribution.
  76. combined-target: The combination of classification target (response map) and regression target (offset map).
  77. Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
  78. Returns:
  79. torch.Tensor: heatmaps that flipped back to the original image
  80. """
  81. if target_type not in ["gaussian-heatmap", "combined-target"]:
  82. raise ValueError("target_type should be gaussian-heatmap or combined-target")
  83. if output_flipped.ndim != 4:
  84. raise ValueError("output_flipped should be [batch_size, num_keypoints, height, width]")
  85. batch_size, num_keypoints, height, width = output_flipped.shape
  86. channels = 1
  87. if target_type == "combined-target":
  88. channels = 3
  89. output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...]
  90. output_flipped = output_flipped.reshape(batch_size, -1, channels, height, width)
  91. output_flipped_back = output_flipped.clone()
  92. # Swap left-right parts
  93. for left, right in flip_pairs.tolist():
  94. output_flipped_back[:, left, ...] = output_flipped[:, right, ...]
  95. output_flipped_back[:, right, ...] = output_flipped[:, left, ...]
  96. output_flipped_back = output_flipped_back.reshape((batch_size, num_keypoints, height, width))
  97. # Flip horizontally
  98. output_flipped_back = output_flipped_back.flip(-1)
  99. return output_flipped_back
  100. class VitPoseSimpleDecoder(nn.Module):
  101. """
  102. Simple decoding head consisting of a ReLU activation, 4x upsampling and a 3x3 convolution, turning the
  103. feature maps into heatmaps.
  104. """
  105. def __init__(self, config: VitPoseConfig):
  106. super().__init__()
  107. self.activation = nn.ReLU()
  108. self.upsampling = nn.Upsample(scale_factor=config.scale_factor, mode="bilinear", align_corners=False)
  109. self.conv = nn.Conv2d(
  110. config.backbone_config.hidden_size, config.num_labels, kernel_size=3, stride=1, padding=1
  111. )
  112. def forward(self, hidden_state: torch.Tensor, flip_pairs: torch.Tensor | None = None) -> torch.Tensor:
  113. # Transform input: ReLU + upsample
  114. hidden_state = self.activation(hidden_state)
  115. hidden_state = self.upsampling(hidden_state)
  116. heatmaps = self.conv(hidden_state)
  117. if flip_pairs is not None:
  118. heatmaps = flip_back(heatmaps, flip_pairs)
  119. return heatmaps
  120. class VitPoseClassicDecoder(nn.Module):
  121. """
  122. Classic decoding head consisting of a 2 deconvolutional blocks, followed by a 1x1 convolution layer,
  123. turning the feature maps into heatmaps.
  124. """
  125. def __init__(self, config: VitPoseConfig):
  126. super().__init__()
  127. self.deconv1 = nn.ConvTranspose2d(
  128. config.backbone_config.hidden_size, 256, kernel_size=4, stride=2, padding=1, bias=False
  129. )
  130. self.batchnorm1 = nn.BatchNorm2d(256)
  131. self.relu1 = nn.ReLU()
  132. self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False)
  133. self.batchnorm2 = nn.BatchNorm2d(256)
  134. self.relu2 = nn.ReLU()
  135. self.conv = nn.Conv2d(256, config.num_labels, kernel_size=1, stride=1, padding=0)
  136. def forward(self, hidden_state: torch.Tensor, flip_pairs: torch.Tensor | None = None):
  137. hidden_state = self.deconv1(hidden_state)
  138. hidden_state = self.batchnorm1(hidden_state)
  139. hidden_state = self.relu1(hidden_state)
  140. hidden_state = self.deconv2(hidden_state)
  141. hidden_state = self.batchnorm2(hidden_state)
  142. hidden_state = self.relu2(hidden_state)
  143. heatmaps = self.conv(hidden_state)
  144. if flip_pairs is not None:
  145. heatmaps = flip_back(heatmaps, flip_pairs)
  146. return heatmaps
  147. @auto_docstring(
  148. custom_intro="""
  149. The VitPose model with a pose estimation head on top.
  150. """
  151. )
  152. class VitPoseForPoseEstimation(VitPosePreTrainedModel):
  153. def __init__(self, config: VitPoseConfig):
  154. super().__init__(config)
  155. self.backbone = load_backbone(config)
  156. # add backbone attributes
  157. if not hasattr(self.backbone.config, "hidden_size"):
  158. raise ValueError("The backbone should have a hidden_size attribute")
  159. if not hasattr(self.backbone.config, "image_size"):
  160. raise ValueError("The backbone should have an image_size attribute")
  161. if not hasattr(self.backbone.config, "patch_size"):
  162. raise ValueError("The backbone should have a patch_size attribute")
  163. self.head = VitPoseSimpleDecoder(config) if config.use_simple_decoder else VitPoseClassicDecoder(config)
  164. # Initialize weights and apply final processing
  165. self.post_init()
  166. @can_return_tuple
  167. @auto_docstring
  168. def forward(
  169. self,
  170. pixel_values: torch.Tensor,
  171. dataset_index: torch.Tensor | None = None,
  172. flip_pairs: torch.Tensor | None = None,
  173. labels: torch.Tensor | None = None,
  174. **kwargs: Unpack[TransformersKwargs],
  175. ) -> VitPoseEstimatorOutput:
  176. r"""
  177. dataset_index (`torch.Tensor` of shape `(batch_size,)`):
  178. Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
  179. This corresponds to the dataset index used during training, e.g. For the single dataset index 0 refers to the corresponding dataset. For the multiple datasets index 0 refers to dataset A (e.g. MPII) and index 1 refers to dataset B (e.g. CrowdPose).
  180. flip_pairs (`torch.tensor`, *optional*):
  181. Whether to mirror pairs of keypoints (for example, left ear -- right ear).
  182. Examples:
  183. ```python
  184. >>> from transformers import AutoImageProcessor, VitPoseForPoseEstimation
  185. >>> import torch
  186. >>> from PIL import Image
  187. >>> import httpx
  188. >>> from io import BytesIO
  189. >>> processor = AutoImageProcessor.from_pretrained("usyd-community/vitpose-base-simple")
  190. >>> model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
  191. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  192. >>> with httpx.stream("GET", url) as response:
  193. ... image = Image.open(BytesIO(response.read()))
  194. >>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
  195. >>> inputs = processor(image, boxes=boxes, return_tensors="pt")
  196. >>> with torch.no_grad():
  197. ... outputs = model(**inputs)
  198. >>> heatmaps = outputs.heatmaps
  199. ```"""
  200. loss = None
  201. if labels is not None:
  202. raise NotImplementedError("Training is not yet supported")
  203. outputs: BackboneOutput = self.backbone.forward_with_filtered_kwargs(
  204. pixel_values,
  205. dataset_index=dataset_index,
  206. **kwargs,
  207. )
  208. # Turn output hidden states in tensor of shape (batch_size, num_channels, height, width)
  209. sequence_output = outputs.feature_maps[-1]
  210. batch_size = sequence_output.shape[0]
  211. patch_height = self.config.backbone_config.image_size[0] // self.config.backbone_config.patch_size[0]
  212. patch_width = self.config.backbone_config.image_size[1] // self.config.backbone_config.patch_size[1]
  213. sequence_output = sequence_output.permute(0, 2, 1)
  214. sequence_output = sequence_output.reshape(batch_size, -1, patch_height, patch_width).contiguous()
  215. heatmaps = self.head(sequence_output, flip_pairs=flip_pairs)
  216. return VitPoseEstimatorOutput(
  217. loss=loss,
  218. heatmaps=heatmaps,
  219. hidden_states=outputs.hidden_states,
  220. attentions=outputs.attentions,
  221. )
  222. __all__ = ["VitPosePreTrainedModel", "VitPoseForPoseEstimation"]