processing_sam3.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. # Copyright 2025 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Processor class for SAM3.
  16. """
  17. from copy import deepcopy
  18. import numpy as np
  19. from ...image_utils import ImageInput
  20. from ...processing_utils import ProcessorMixin
  21. from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
  22. from ...utils import TensorType, auto_docstring, is_torch_available, logging
  23. from ...utils.import_utils import requires
  24. logger = logging.get_logger(__name__)
  25. if is_torch_available():
  26. import torch
  27. def box_cxcywh_to_xyxy(x):
  28. x_c, y_c, w, h = x.unbind(-1)
  29. b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
  30. return torch.stack(b, dim=-1)
  31. def box_cxcywh_to_xywh(x):
  32. x_c, y_c, w, h = x.unbind(-1)
  33. b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (w), (h)]
  34. return torch.stack(b, dim=-1)
  35. def box_xywh_to_xyxy(x):
  36. x, y, w, h = x.unbind(-1)
  37. b = [(x), (y), (x + w), (y + h)]
  38. return torch.stack(b, dim=-1)
  39. def box_xywh_to_cxcywh(x):
  40. x, y, w, h = x.unbind(-1)
  41. b = [(x + 0.5 * w), (y + 0.5 * h), (w), (h)]
  42. return torch.stack(b, dim=-1)
  43. def box_xyxy_to_xywh(x):
  44. x, y, X, Y = x.unbind(-1)
  45. b = [(x), (y), (X - x), (Y - y)]
  46. return torch.stack(b, dim=-1)
  47. def box_xyxy_to_cxcywh(x):
  48. x0, y0, x1, y1 = x.unbind(-1)
  49. b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
  50. return torch.stack(b, dim=-1)
  51. def box_area(boxes):
  52. """
  53. Batched version of box area. Boxes should be in [x0, y0, x1, y1] format.
  54. Inputs:
  55. - boxes: Tensor of shape (..., 4)
  56. Returns:
  57. - areas: Tensor of shape (...,)
  58. """
  59. x0, y0, x1, y1 = boxes.unbind(-1)
  60. return (x1 - x0) * (y1 - y0)
  61. @requires(backends=("torch",))
  62. @auto_docstring
  63. class Sam3Processor(ProcessorMixin):
  64. def __init__(
  65. self, image_processor, tokenizer, target_size: int | None = None, point_pad_value: int = -10, **kwargs
  66. ):
  67. r"""
  68. target_size (`int`, *optional*):
  69. The target size (target_size, target_size) to which the image will be resized.
  70. point_pad_value (`int`, *optional*, defaults to -10):
  71. The value used for padding input boxes.
  72. """
  73. super().__init__(image_processor, tokenizer, **kwargs)
  74. self.point_pad_value = point_pad_value
  75. self.target_size = target_size if target_size is not None else self.image_processor.size["height"]
  76. @auto_docstring
  77. def __call__(
  78. self,
  79. images: ImageInput | None = None,
  80. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
  81. segmentation_maps: ImageInput | None = None,
  82. input_boxes: list[list[list[float]]] | torch.Tensor | None = None,
  83. input_boxes_labels: list[list[list[int]]] | torch.Tensor | None = None,
  84. original_sizes: list[list[float]] | torch.Tensor | None = None,
  85. return_tensors: str | TensorType | None = None,
  86. **kwargs,
  87. ) -> BatchEncoding:
  88. r"""
  89. images (`ImageInput`, *optional*):
  90. The image(s) to process.
  91. text (`str`, `list[str]`, `list[list[str]]`, *optional*):
  92. The text to process.
  93. segmentation_maps (`ImageInput`, *optional*):
  94. The segmentation maps to process.
  95. input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*):
  96. The bounding boxes to process.
  97. input_boxes_labels (`list[list[int]]`, `torch.Tensor`, *optional*):
  98. The labels for the bounding boxes.
  99. original_sizes (`list[list[float]]`, `torch.Tensor`, *optional*):
  100. The original sizes of the images.
  101. Returns:
  102. A [`BatchEncoding`] with the following fields:
  103. - `pixel_values` (`torch.Tensor`): The processed image(s).
  104. - `original_sizes` (`list[list[float]]`): The original sizes of the images.
  105. - `labels` (`torch.Tensor`): The processed segmentation maps (if provided).
  106. - `input_boxes_labels` (`torch.Tensor`): The processed labels for the bounding boxes.
  107. - `input_boxes` (`torch.Tensor`): The processed bounding boxes.
  108. """
  109. encoding = None
  110. if images is not None:
  111. encoding = self.image_processor(
  112. images,
  113. segmentation_maps=segmentation_maps,
  114. return_tensors=return_tensors,
  115. **kwargs,
  116. )
  117. elif original_sizes is not None:
  118. if isinstance(original_sizes, torch.Tensor):
  119. original_sizes = original_sizes.cpu().tolist()
  120. encoding = BatchEncoding({"original_sizes": original_sizes}, tensor_type=return_tensors)
  121. elif input_boxes is not None:
  122. raise ValueError("Either images or original_sizes must be provided if input_boxes is not None")
  123. text = self._resolve_text_prompts(text, input_boxes)
  124. if text is not None:
  125. text_inputs = self.tokenizer(text, return_tensors=return_tensors, padding="max_length", max_length=32)
  126. if encoding is not None:
  127. encoding.update(text_inputs)
  128. else:
  129. encoding = text_inputs
  130. # Process input boxes if provided
  131. if input_boxes is not None:
  132. original_sizes = encoding["original_sizes"]
  133. # Validate and convert inputs to standardized format
  134. processed_boxes = self._validate_single_input(
  135. input_boxes,
  136. expected_depth=3,
  137. input_name="boxes",
  138. expected_format="[image level, box level, box coordinates]",
  139. expected_coord_size=4,
  140. )
  141. processed_boxes_labels = self._validate_single_input(
  142. input_boxes_labels,
  143. expected_depth=2,
  144. input_name="labels",
  145. expected_format="[image level, box level]",
  146. )
  147. # Get padding requirements for all inputs
  148. if processed_boxes is not None:
  149. boxes_max_dims = self._get_nested_dimensions(processed_boxes)[:2]
  150. if processed_boxes_labels is not None:
  151. boxes_labels_max_dims = self._get_nested_dimensions(processed_boxes_labels)[:2]
  152. # Ensure boxes and labels have consistent dimensions
  153. if processed_boxes is not None and processed_boxes_labels is not None:
  154. if boxes_max_dims != boxes_labels_max_dims:
  155. raise ValueError(
  156. "Input boxes and labels have inconsistent dimensions. Please ensure they have the same dimensions."
  157. )
  158. # Pad and normalize all inputs to final tensor format
  159. if processed_boxes is not None:
  160. padded_boxes = self._pad_nested_list(processed_boxes, boxes_max_dims + [4])
  161. final_boxes = torch.tensor(padded_boxes, dtype=torch.float32)
  162. self._normalize_tensor_coordinates(
  163. final_boxes, original_sizes, is_bounding_box=True, preserve_padding=True
  164. )
  165. final_boxes = box_xyxy_to_cxcywh(final_boxes)
  166. encoding.update({"input_boxes": final_boxes})
  167. if processed_boxes_labels is not None:
  168. padded_boxes_labels = self._pad_nested_list(processed_boxes_labels, boxes_labels_max_dims)
  169. final_boxes_labels = torch.tensor(padded_boxes_labels, dtype=torch.int64)
  170. encoding.update({"input_boxes_labels": final_boxes_labels})
  171. return encoding
  172. def _normalize_coordinates(self, coords: "torch.Tensor", original_size, is_bounding_box=False) -> "torch.Tensor":
  173. """
  174. Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.
  175. Args:
  176. target_size (`int`):
  177. The target size of the image.
  178. coords (`torch.Tensor`):
  179. The coordinates to be normalized.
  180. original_size (`tuple`):
  181. The original size of the image.
  182. is_bounding_box (`bool`, *optional*, defaults to `False`):
  183. Whether the coordinates are bounding boxes.
  184. """
  185. old_h, old_w = original_size
  186. coords = deepcopy(coords).float()
  187. if is_bounding_box:
  188. coords = coords.reshape(-1, 2, 2)
  189. coords[..., 0] = coords[..., 0] / old_w
  190. coords[..., 1] = coords[..., 1] / old_h
  191. if is_bounding_box:
  192. coords = coords.reshape(-1, 4)
  193. return coords
  194. def _convert_to_nested_list(self, data, expected_depth, current_depth=0):
  195. """
  196. Recursively convert various input formats (tensors, numpy arrays, lists) to nested lists.
  197. Preserves None values within lists.
  198. Args:
  199. data: Input data in any format (may be None or contain None values)
  200. expected_depth: Expected nesting depth
  201. current_depth: Current depth in recursion
  202. Returns:
  203. Nested list representation of the data (or None)
  204. """
  205. if data is None:
  206. return None
  207. # Convert tensor/numpy to list if we're at a leaf level or if it's a multi-dimensional array
  208. if isinstance(data, torch.Tensor): # PyTorch tensor
  209. if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small tensor
  210. return data.numpy().tolist()
  211. else:
  212. return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
  213. elif isinstance(data, np.ndarray): # NumPy array
  214. if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small array
  215. return data.tolist()
  216. else:
  217. return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
  218. elif isinstance(data, list):
  219. if current_depth == expected_depth:
  220. # We've reached the expected depth, return as is
  221. return data
  222. else:
  223. # Continue recursion, preserving None values
  224. return [
  225. self._convert_to_nested_list(item, expected_depth, current_depth + 1) if item is not None else None
  226. for item in data
  227. ]
  228. elif isinstance(data, (int, float)):
  229. return data
  230. else:
  231. raise ValueError(f"Unsupported data type: {type(data)}")
  232. def _resolve_text_prompts(self, text, input_boxes):
  233. """
  234. Resolve text prompts by setting defaults based on prompt types.
  235. """
  236. # If no text provided, infer default based on prompt type
  237. if text is None:
  238. return "visual" if input_boxes else None
  239. if not isinstance(text, (list, tuple)):
  240. return text
  241. # Validate list/tuple length matches both prompt types if provided
  242. text = list(text) # Convert to list to allow modification
  243. if input_boxes and len(text) != len(input_boxes):
  244. raise ValueError(
  245. f"The number of text prompts must match the number of input boxes. "
  246. f"Got {len(text)} text prompts and {len(input_boxes)} input boxes."
  247. )
  248. # Fill in None values with defaults based on corresponding prompt
  249. for i, text_value in enumerate(text):
  250. if text_value is None and input_boxes and input_boxes[i] is not None:
  251. text[i] = "visual"
  252. return text
  253. def _get_nested_dimensions(self, nested_list, max_dims=None):
  254. """
  255. Get the maximum dimensions at each level of nesting, skipping None values.
  256. Args:
  257. nested_list (`list`):
  258. Nested list structure (may contain None values).
  259. max_dims (`list`, *optional*):
  260. Current maximum dimensions (for recursion).
  261. Returns:
  262. `list`: A list of maximum dimensions for each nesting level.
  263. """
  264. if max_dims is None:
  265. max_dims = []
  266. if not isinstance(nested_list, list):
  267. return max_dims
  268. if len(max_dims) == 0:
  269. max_dims.append(len(nested_list))
  270. else:
  271. max_dims[0] = max(max_dims[0], len(nested_list))
  272. if len(nested_list) > 0:
  273. for item in nested_list:
  274. # Skip None values
  275. if item is None:
  276. continue
  277. if isinstance(item, list):
  278. sub_dims = self._get_nested_dimensions(item)
  279. # Merge sub_dims into max_dims
  280. for i, dim in enumerate(sub_dims):
  281. if i + 1 >= len(max_dims):
  282. max_dims.append(dim)
  283. else:
  284. max_dims[i + 1] = max(max_dims[i + 1], dim)
  285. return max_dims
  286. def _pad_nested_list(self, nested_list, target_dims, current_level=0, pad_value=None):
  287. """
  288. Recursively pad a nested list to match target dimensions. Replaces None values with padded structures.
  289. Args:
  290. nested_list (`list`):
  291. Nested list to pad (may contain None values).
  292. target_dims (`list`):
  293. Target dimensions for each level.
  294. current_level (`int`, *optional*, defaults to 0):
  295. Current nesting level.
  296. pad_value (`int`, *optional*):
  297. Value to use for padding.
  298. Returns:
  299. `list`: The padded nested list.
  300. """
  301. if pad_value is None:
  302. pad_value = self.point_pad_value
  303. if current_level >= len(target_dims):
  304. return nested_list
  305. # Ensure we have a list
  306. if not isinstance(nested_list, list):
  307. nested_list = [nested_list]
  308. # Pad current level
  309. current_size = len(nested_list)
  310. target_size = target_dims[current_level]
  311. # Pad with appropriate values
  312. if current_level == len(target_dims) - 1:
  313. # At the coordinate level, pad with pad_value
  314. nested_list.extend([pad_value] * (target_size - current_size))
  315. else:
  316. # At higher levels, pad with nested structures
  317. if current_size > 0:
  318. # Create appropriately sized template
  319. if current_level < len(target_dims) - 2:
  320. # For non-coordinate levels, create empty nested structure
  321. template_dims = target_dims[current_level + 1 :]
  322. template = self._create_empty_nested_structure(template_dims, pad_value)
  323. else:
  324. # For coordinate level, create list of pad_values
  325. template = [pad_value] * target_dims[current_level + 1]
  326. nested_list.extend([deepcopy(template) for _ in range(target_size - current_size)])
  327. else:
  328. # Create from scratch
  329. template_dims = target_dims[current_level + 1 :]
  330. template = self._create_empty_nested_structure(template_dims, pad_value)
  331. nested_list.extend([deepcopy(template) for _ in range(target_size)])
  332. # Recursively pad sublists, replacing None with padded structures
  333. if current_level < len(target_dims) - 1:
  334. for i in range(len(nested_list)):
  335. if nested_list[i] is None:
  336. # Replace None with fully padded structure
  337. template_dims = target_dims[current_level + 1 :]
  338. nested_list[i] = self._create_empty_nested_structure(template_dims, pad_value)
  339. elif isinstance(nested_list[i], list):
  340. nested_list[i] = self._pad_nested_list(nested_list[i], target_dims, current_level + 1, pad_value)
  341. return nested_list
  342. def _create_empty_nested_structure(self, dims, pad_value):
  343. """
  344. Create an empty nested structure with given dimensions filled with pad_value.
  345. Args:
  346. dims (`list`):
  347. The dimensions of the nested structure.
  348. pad_value (`int`):
  349. The value to fill the structure with.
  350. """
  351. if len(dims) == 1:
  352. return [pad_value] * dims[0]
  353. else:
  354. return [self._create_empty_nested_structure(dims[1:], pad_value) for _ in range(dims[0])]
  355. def _get_nesting_level(self, input_list):
  356. """
  357. Get the nesting level of a list structure, skipping None values.
  358. Args:
  359. input_list (`list`):
  360. The list to get the nesting level of.
  361. """
  362. if isinstance(input_list, list):
  363. if len(input_list) == 0:
  364. return 1
  365. # Find first non-None element to determine nesting level
  366. for item in input_list:
  367. if item is not None:
  368. return 1 + self._get_nesting_level(item)
  369. # All elements are None, treat as single level
  370. return 1
  371. elif isinstance(input_list, (np.ndarray, torch.Tensor)):
  372. # For arrays/tensors, the nesting level is the number of dimensions
  373. return len(input_list.shape)
  374. return 0
  375. def _validate_single_input(
  376. self,
  377. data: torch.Tensor | np.ndarray | list,
  378. expected_depth: int,
  379. input_name: str,
  380. expected_format: str,
  381. expected_coord_size: int | None = None,
  382. ) -> list:
  383. """
  384. Validate a single input by ensuring proper nesting and raising an error if the input is not valid.
  385. Args:
  386. data (`torch.Tensor`, `np.ndarray`, or `list`):
  387. Input data to process.
  388. expected_depth (`int`):
  389. Expected nesting depth.
  390. input_name (`str`):
  391. Name of the input for error messages.
  392. expected_format (`str`):
  393. The expected format of the input.
  394. expected_coord_size (`int`, *optional*):
  395. Expected coordinate size (4 for boxes, None for labels).
  396. .
  397. """
  398. if data is None:
  399. return None
  400. # Handle tensors and numpy arrays first
  401. if isinstance(data, (torch.Tensor, np.ndarray)):
  402. # For tensors/arrays, we can directly check the number of dimensions
  403. if data.ndim != expected_depth:
  404. raise ValueError(
  405. f"Input {input_name} must be a tensor/array with {expected_depth} dimensions. The expected nesting format is {expected_format}. Got {data.ndim} dimensions."
  406. )
  407. elif expected_coord_size is not None:
  408. if data.shape[-1] != expected_coord_size:
  409. raise ValueError(
  410. f"Input {input_name} must be a tensor/array with {expected_coord_size} as the last dimension, got {data.shape[-1]}."
  411. )
  412. return self._convert_to_nested_list(data, expected_depth)
  413. # Handle nested lists
  414. if isinstance(data, list):
  415. current_depth = self._get_nesting_level(data)
  416. if current_depth != expected_depth:
  417. raise ValueError(
  418. f"Input {input_name} must be a nested list with {expected_depth} levels. The expected nesting format is {expected_format}. Got {current_depth} levels."
  419. )
  420. return self._convert_to_nested_list(data, expected_depth)
  421. def _normalize_tensor_coordinates(self, tensor, original_sizes, is_bounding_box=False, preserve_padding=False):
  422. """
  423. Helper method to normalize coordinates in a tensor across multiple images.
  424. Args:
  425. tensor (`torch.Tensor`):
  426. Input tensor with coordinates.
  427. original_sizes (`list`):
  428. Original image sizes.
  429. is_bounding_box (`bool`, *optional*, defaults to `False`):
  430. Whether coordinates are bounding boxes.
  431. preserve_padding (`bool`, *optional*, defaults to `False`):
  432. Whether to preserve padding values (for boxes).
  433. """
  434. if preserve_padding:
  435. # For boxes: avoid normalizing pad values
  436. mask = tensor != self.point_pad_value
  437. coord_mask = mask.all(dim=-1, keepdim=True)
  438. for img_idx in range(len(original_sizes)):
  439. if img_idx < tensor.shape[0]:
  440. original_size = original_sizes[img_idx] if img_idx < len(original_sizes) else original_sizes[0]
  441. normalized_coords = self._normalize_coordinates(
  442. tensor[img_idx], original_size, is_bounding_box=is_bounding_box
  443. )
  444. if preserve_padding:
  445. # Only update non-padded values
  446. img_mask = coord_mask[img_idx]
  447. tensor[img_idx] = torch.where(
  448. img_mask.expand_as(tensor[img_idx]), normalized_coords, tensor[img_idx]
  449. )
  450. else:
  451. tensor[img_idx] = normalized_coords
  452. def post_process_semantic_segmentation(self, outputs, target_sizes=None, threshold=0.5):
  453. """
  454. Converts the output of [`Sam3Model`] into semantic segmentation maps.
  455. Args:
  456. outputs ([`Sam3ImageSegmentationOutput`]):
  457. Raw outputs of the model containing semantic_seg.
  458. target_sizes (`list[tuple]` of length `batch_size`, *optional*):
  459. List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
  460. predictions will not be resized.
  461. threshold (`float`, *optional*, defaults to 0.5):
  462. Threshold for binarizing the semantic segmentation masks.
  463. Returns:
  464. semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
  465. segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
  466. specified). Each entry is a binary mask (0 or 1).
  467. """
  468. return self.image_processor.post_process_semantic_segmentation(outputs, target_sizes, threshold)
  469. def post_process_object_detection(self, outputs, threshold=0.3, target_sizes=None):
  470. """
  471. Converts the raw output of [`Sam3Model`] into final bounding boxes in (top_left_x, top_left_y,
  472. bottom_right_x, bottom_right_y) format. This is a convenience wrapper around the image processor method.
  473. Args:
  474. outputs ([`Sam3ImageSegmentationOutput`]):
  475. Raw outputs of the model containing pred_boxes, pred_logits, and optionally presence_logits.
  476. threshold (`float`, *optional*, defaults to 0.3):
  477. Score threshold to keep object detection predictions.
  478. target_sizes (`list[tuple[int, int]]`, *optional*):
  479. List of tuples (`tuple[int, int]`) containing the target size `(height, width)` of each image in the
  480. batch. If unset, predictions will not be resized.
  481. Returns:
  482. `list[dict]`: A list of dictionaries, each dictionary containing the following keys:
  483. - **scores** (`torch.Tensor`): The confidence scores for each predicted box on the image.
  484. - **boxes** (`torch.Tensor`): Image bounding boxes in (top_left_x, top_left_y, bottom_right_x,
  485. bottom_right_y) format.
  486. Example:
  487. ```python
  488. >>> from transformers import AutoModel, AutoProcessor
  489. >>> from PIL import Image
  490. >>> import httpx
  491. >>> from io import BytesIO
  492. >>> model = AutoModel.from_pretrained("facebook/sam3-base")
  493. >>> processor = AutoProcessor.from_pretrained("facebook/sam3-base")
  494. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  495. >>> with httpx.stream("GET", url) as response:
  496. ... image = Image.open(BytesIO(response.read()))
  497. >>> inputs = processor(images=image, text="cat", return_tensors="pt")
  498. >>> outputs = model(**inputs)
  499. >>> # Post-process to get bounding boxes
  500. >>> results = processor.post_process_object_detection(outputs, threshold=0.3, target_sizes=[image.size[::-1]])
  501. >>> boxes = results[0]["boxes"]
  502. >>> scores = results[0]["scores"]
  503. ```
  504. """
  505. return self.image_processor.post_process_object_detection(outputs, threshold, target_sizes)
  506. def post_process_instance_segmentation(
  507. self,
  508. outputs,
  509. threshold=0.3,
  510. mask_threshold=0.5,
  511. target_sizes=None,
  512. ):
  513. """
  514. Converts the raw output of [`Sam3Model`] into instance segmentation predictions with bounding boxes and masks.
  515. This is a convenience wrapper around the image processor method.
  516. Args:
  517. outputs ([`Sam3ImageSegmentationOutput`]):
  518. Raw outputs of the model containing pred_boxes, pred_logits, pred_masks, and optionally
  519. presence_logits.
  520. threshold (`float`, *optional*, defaults to 0.3):
  521. Score threshold to keep instance predictions.
  522. mask_threshold (`float`, *optional*, defaults to 0.5):
  523. Threshold for binarizing the predicted masks.
  524. target_sizes (`list[tuple[int, int]]`, *optional*):
  525. List of tuples (`tuple[int, int]`) containing the target size `(height, width)` of each image in the
  526. batch. If unset, predictions will not be resized.
  527. Returns:
  528. `list[dict]`: A list of dictionaries, each dictionary containing the following keys:
  529. - **scores** (`torch.Tensor`): The confidence scores for each predicted instance on the image.
  530. - **boxes** (`torch.Tensor`): Image bounding boxes in (top_left_x, top_left_y, bottom_right_x,
  531. bottom_right_y) format.
  532. - **masks** (`torch.Tensor`): Binary segmentation masks for each instance, shape (num_instances,
  533. height, width).
  534. Example:
  535. ```python
  536. >>> from transformers import AutoModel, AutoProcessor
  537. >>> from PIL import Image
  538. >>> import httpx
  539. >>> from io import BytesIO
  540. >>> model = AutoModel.from_pretrained("facebook/sam3-base")
  541. >>> processor = AutoProcessor.from_pretrained("facebook/sam3-base")
  542. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  543. >>> with httpx.stream("GET", url) as response:
  544. ... image = Image.open(BytesIO(response.read()))
  545. >>> inputs = processor(images=image, text="cat", return_tensors="pt")
  546. >>> outputs = model(**inputs)
  547. >>> # Post-process to get instance segmentation
  548. >>> results = processor.post_process_instance_segmentation(
  549. ... outputs, threshold=0.3, target_sizes=[image.size[::-1]]
  550. ... )
  551. >>> masks = results[0]["masks"]
  552. >>> boxes = results[0]["boxes"]
  553. >>> scores = results[0]["scores"]
  554. ```
  555. """
  556. return self.image_processor.post_process_instance_segmentation(
  557. outputs, threshold, mask_threshold, target_sizes
  558. )
  559. __all__ = ["Sam3Processor"]