modular_sam3.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from ...image_utils import (
  16. IMAGENET_STANDARD_MEAN,
  17. IMAGENET_STANDARD_STD,
  18. )
  19. from ..sam2.image_processing_sam2 import Sam2ImageProcessor
  20. def _scale_boxes(boxes, target_sizes):
  21. """
  22. Scale batch of bounding boxes to the target sizes.
  23. Args:
  24. boxes (`torch.Tensor` of shape `(batch_size, num_boxes, 4)`):
  25. Bounding boxes to scale. Each box is expected to be in (x1, y1, x2, y2) format.
  26. target_sizes (`list[tuple[int, int]]` or `torch.Tensor` of shape `(batch_size, 2)`):
  27. Target sizes to scale the boxes to. Each target size is expected to be in (height, width) format.
  28. Returns:
  29. `torch.Tensor` of shape `(batch_size, num_boxes, 4)`: Scaled bounding boxes.
  30. """
  31. if isinstance(target_sizes, (list, tuple)):
  32. image_height = torch.tensor([i[0] for i in target_sizes])
  33. image_width = torch.tensor([i[1] for i in target_sizes])
  34. elif isinstance(target_sizes, torch.Tensor):
  35. image_height, image_width = target_sizes.unbind(1)
  36. else:
  37. raise TypeError("`target_sizes` must be a list, tuple or torch.Tensor")
  38. scale_factor = torch.stack([image_width, image_height, image_width, image_height], dim=1)
  39. scale_factor = scale_factor.unsqueeze(1).to(boxes.device)
  40. boxes = boxes * scale_factor
  41. return boxes
  42. class Sam3ImageProcessor(Sam2ImageProcessor):
  43. image_mean = IMAGENET_STANDARD_MEAN
  44. image_std = IMAGENET_STANDARD_STD
  45. size = {"height": 1008, "width": 1008}
  46. mask_size = {"height": 288, "width": 288}
  47. def post_process_semantic_segmentation(
  48. self, outputs, target_sizes: list[tuple] | None = None, threshold: float = 0.5
  49. ):
  50. """
  51. Converts the output of [`Sam3Model`] into semantic segmentation maps.
  52. Args:
  53. outputs ([`Sam3ImageSegmentationOutput`]):
  54. Raw outputs of the model containing semantic_seg.
  55. target_sizes (`list[tuple]` of length `batch_size`, *optional*):
  56. List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
  57. predictions will not be resized.
  58. threshold (`float`, *optional*, defaults to 0.5):
  59. Threshold for binarizing the semantic segmentation masks.
  60. Returns:
  61. semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
  62. segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
  63. specified). Each entry is a binary mask (0 or 1).
  64. """
  65. # Get semantic segmentation output
  66. # semantic_seg has shape (batch_size, 1, height, width)
  67. semantic_logits = outputs.semantic_seg
  68. if semantic_logits is None:
  69. raise ValueError(
  70. "Semantic segmentation output is not available in the model outputs. "
  71. "Make sure the model was run with semantic segmentation enabled."
  72. )
  73. # Apply sigmoid to convert logits to probabilities
  74. semantic_probs = semantic_logits.sigmoid()
  75. # Resize and binarize semantic segmentation maps
  76. if target_sizes is not None:
  77. if len(semantic_logits) != len(target_sizes):
  78. raise ValueError(
  79. "Make sure that you pass in as many target sizes as the batch dimension of the logits"
  80. )
  81. semantic_segmentation = []
  82. for idx in range(len(semantic_logits)):
  83. resized_probs = torch.nn.functional.interpolate(
  84. semantic_probs[idx].unsqueeze(dim=0),
  85. size=target_sizes[idx],
  86. mode="bilinear",
  87. align_corners=False,
  88. )
  89. # Binarize: values > threshold become 1, otherwise 0
  90. semantic_map = (resized_probs[0, 0] > threshold).to(torch.long)
  91. semantic_segmentation.append(semantic_map)
  92. else:
  93. # Binarize without resizing
  94. semantic_segmentation = (semantic_probs[:, 0] > threshold).to(torch.long)
  95. semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
  96. return semantic_segmentation
  97. def post_process_object_detection(self, outputs, threshold: float = 0.3, target_sizes: list[tuple] | None = None):
  98. """
  99. Converts the raw output of [`Sam3Model`] into final bounding boxes in (top_left_x, top_left_y,
  100. bottom_right_x, bottom_right_y) format.
  101. Args:
  102. outputs ([`Sam3ImageSegmentationOutput`]):
  103. Raw outputs of the model containing pred_boxes, pred_logits, and optionally presence_logits.
  104. threshold (`float`, *optional*, defaults to 0.3):
  105. Score threshold to keep object detection predictions.
  106. target_sizes (`list[tuple[int, int]]`, *optional*):
  107. List of tuples (`tuple[int, int]`) containing the target size `(height, width)` of each image in the
  108. batch. If unset, predictions will not be resized.
  109. Returns:
  110. `list[dict]`: A list of dictionaries, each dictionary containing the following keys:
  111. - **scores** (`torch.Tensor`): The confidence scores for each predicted box on the image.
  112. - **boxes** (`torch.Tensor`): Image bounding boxes in (top_left_x, top_left_y, bottom_right_x,
  113. bottom_right_y) format.
  114. """
  115. pred_logits = outputs.pred_logits # (batch_size, num_queries)
  116. pred_boxes = outputs.pred_boxes # (batch_size, num_queries, 4) in xyxy format
  117. presence_logits = outputs.presence_logits # (batch_size, 1) or None
  118. batch_size = pred_logits.shape[0]
  119. if target_sizes is not None and len(target_sizes) != batch_size:
  120. raise ValueError("Make sure that you pass in as many target sizes as images")
  121. # Compute scores: combine pred_logits with presence_logits if available
  122. batch_scores = pred_logits.sigmoid()
  123. if presence_logits is not None:
  124. presence_scores = presence_logits.sigmoid() # (batch_size, 1)
  125. batch_scores = batch_scores * presence_scores # Broadcast multiplication
  126. # Boxes are already in xyxy format from the model
  127. batch_boxes = pred_boxes
  128. # Convert from relative [0, 1] to absolute [0, height/width] coordinates
  129. if target_sizes is not None:
  130. batch_boxes = _scale_boxes(batch_boxes, target_sizes)
  131. results = []
  132. for scores, boxes in zip(batch_scores, batch_boxes):
  133. keep = scores > threshold
  134. scores = scores[keep]
  135. boxes = boxes[keep]
  136. results.append({"scores": scores, "boxes": boxes})
  137. return results
  138. def post_process_instance_segmentation(
  139. self,
  140. outputs,
  141. threshold: float = 0.3,
  142. mask_threshold: float = 0.5,
  143. target_sizes: list[tuple] | None = None,
  144. ):
  145. """
  146. Converts the raw output of [`Sam3Model`] into instance segmentation predictions with bounding boxes and masks.
  147. Args:
  148. outputs ([`Sam3ImageSegmentationOutput`]):
  149. Raw outputs of the model containing pred_boxes, pred_logits, pred_masks, and optionally
  150. presence_logits.
  151. threshold (`float`, *optional*, defaults to 0.3):
  152. Score threshold to keep instance predictions.
  153. mask_threshold (`float`, *optional*, defaults to 0.5):
  154. Threshold for binarizing the predicted masks.
  155. target_sizes (`list[tuple[int, int]]`, *optional*):
  156. List of tuples (`tuple[int, int]`) containing the target size `(height, width)` of each image in the
  157. batch. If unset, predictions will not be resized.
  158. Returns:
  159. `list[dict]`: A list of dictionaries, each dictionary containing the following keys:
  160. - **scores** (`torch.Tensor`): The confidence scores for each predicted instance on the image.
  161. - **boxes** (`torch.Tensor`): Image bounding boxes in (top_left_x, top_left_y, bottom_right_x,
  162. bottom_right_y) format.
  163. - **masks** (`torch.Tensor`): Binary segmentation masks for each instance, shape (num_instances,
  164. height, width).
  165. """
  166. pred_logits = outputs.pred_logits # (batch_size, num_queries)
  167. pred_boxes = outputs.pred_boxes # (batch_size, num_queries, 4) in xyxy format
  168. pred_masks = outputs.pred_masks # (batch_size, num_queries, height, width)
  169. presence_logits = outputs.presence_logits # (batch_size, 1) or None
  170. batch_size = pred_logits.shape[0]
  171. if target_sizes is not None and len(target_sizes) != batch_size:
  172. raise ValueError("Make sure that you pass in as many target sizes as images")
  173. # Compute scores: combine pred_logits with presence_logits if available
  174. batch_scores = pred_logits.sigmoid()
  175. if presence_logits is not None:
  176. presence_scores = presence_logits.sigmoid() # (batch_size, 1)
  177. batch_scores = batch_scores * presence_scores # Broadcast multiplication
  178. # Apply sigmoid to mask logits
  179. batch_masks = pred_masks.sigmoid()
  180. # Boxes are already in xyxy format from the model
  181. batch_boxes = pred_boxes
  182. # Scale boxes to target sizes if provided
  183. if target_sizes is not None:
  184. batch_boxes = _scale_boxes(batch_boxes, target_sizes)
  185. results = []
  186. for idx, (scores, boxes, masks) in enumerate(zip(batch_scores, batch_boxes, batch_masks)):
  187. # Filter by score threshold
  188. keep = scores > threshold
  189. scores = scores[keep]
  190. boxes = boxes[keep]
  191. masks = masks[keep] # (num_keep, height, width)
  192. # Resize masks to target size if provided
  193. if target_sizes is not None:
  194. target_size = target_sizes[idx]
  195. if len(masks) > 0:
  196. masks = torch.nn.functional.interpolate(
  197. masks.unsqueeze(0), # (1, num_keep, height, width)
  198. size=target_size,
  199. mode="bilinear",
  200. align_corners=False,
  201. ).squeeze(0) # (num_keep, target_height, target_width)
  202. # Binarize masks
  203. masks = (masks > mask_threshold).to(torch.long)
  204. results.append({"scores": scores, "boxes": boxes, "masks": masks})
  205. return results
  206. __all__ = ["Sam3ImageProcessor"]