modular_yolos.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. from typing import Optional
  2. import numpy as np
  3. import torch
  4. from torch import nn
  5. from torchvision.transforms.v2 import functional as tvF
  6. from transformers.models.detr.image_processing_detr import DetrImageProcessor
  7. from transformers.models.detr.image_processing_pil_detr import DetrImageProcessorPil
  8. from ...image_transforms import center_to_corners_format
  9. from ...image_utils import PILImageResampling, SizeDict, get_image_size_for_max_height_width
  10. from ...utils import TensorType, logging, requires_backends
  11. logger = logging.get_logger(__name__)
  12. def get_size_with_aspect_ratio_yolos(
  13. image_size: tuple[int, int], size: int, max_size: int | None = None, mod_size: int = 16
  14. ) -> tuple[int, int]:
  15. """
  16. Computes the output image size given the input image size and the desired output size, while ensuring that both
  17. height and width are multiples of `mod_size`.
  18. This mirrors the YOLOS-specific behavior used in the torch/fast backends and is required so that all YOLOS
  19. image processing backends (PIL, torchvision, fast) produce identical output shapes.
  20. """
  21. height, width = image_size
  22. raw_size = None
  23. if max_size is not None:
  24. min_original_size = float(min((height, width)))
  25. max_original_size = float(max((height, width)))
  26. if max_original_size / min_original_size * size > max_size:
  27. raw_size = max_size * min_original_size / max_original_size
  28. size = int(round(raw_size))
  29. if width < height:
  30. ow = size
  31. if max_size is not None and raw_size is not None:
  32. oh = int(raw_size * height / width)
  33. else:
  34. oh = int(size * height / width)
  35. elif (height <= width and height == size) or (width <= height and width == size):
  36. oh, ow = height, width
  37. else:
  38. oh = size
  39. if max_size is not None and raw_size is not None:
  40. ow = int(raw_size * width / height)
  41. else:
  42. ow = int(size * width / height)
  43. if mod_size is not None:
  44. ow = ow - (ow % mod_size)
  45. oh = oh - (oh % mod_size)
  46. return (oh, ow)
  47. class YolosImageProcessor(DetrImageProcessor):
  48. def resize(
  49. self,
  50. image: torch.Tensor,
  51. size: SizeDict,
  52. resample: Optional["PILImageResampling | tvF.InterpolationMode | int"] = None,
  53. **kwargs,
  54. ) -> torch.Tensor:
  55. """
  56. Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
  57. int, smaller edge of the image will be matched to this number.
  58. Args:
  59. image (`torch.Tensor`):
  60. Image to resize.
  61. size (`SizeDict`):
  62. Size of the image's `(height, width)` dimensions after resizing. Available options are:
  63. - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
  64. Do NOT keep the aspect ratio.
  65. - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
  66. the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
  67. less or equal to `longest_edge`.
  68. - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
  69. aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
  70. `max_width`.
  71. resample (`PILImageResampling | tvF.InterpolationMode | int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
  72. Resampling filter to use if resizing the image.
  73. """
  74. if size.shortest_edge and size.longest_edge:
  75. # Resize the image so that the shortest edge or the longest edge is of the given size
  76. # while maintaining the aspect ratio of the original image.
  77. new_size = get_size_with_aspect_ratio_yolos(image.shape[-2:], size.shortest_edge, size.longest_edge)
  78. elif size.max_height and size.max_width:
  79. new_size = get_image_size_for_max_height_width(image.shape[-2:], size.max_height, size.max_width)
  80. elif size.height and size.width:
  81. new_size = (size.height, size.width)
  82. else:
  83. raise ValueError(
  84. f"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got {size}."
  85. )
  86. image = super().resize(
  87. image, size=SizeDict(height=new_size[0], width=new_size[1]), resample=resample, **kwargs
  88. )
  89. return image
  90. def post_process_object_detection(
  91. self, outputs, threshold: float = 0.5, target_sizes: TensorType | list[tuple] = None
  92. ):
  93. """
  94. Converts the raw output of [`YolosForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
  95. bottom_right_x, bottom_right_y) format. Only supports PyTorch.
  96. Args:
  97. outputs ([`YolosObjectDetectionOutput`]):
  98. Raw outputs of the model.
  99. threshold (`float`, *optional*):
  100. Score threshold to keep object detection predictions.
  101. target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
  102. Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
  103. `(height, width)` of each image in the batch. If unset, predictions will not be resized.
  104. Returns:
  105. `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
  106. in the batch as predicted by the model.
  107. """
  108. out_logits, out_bbox = outputs.logits, outputs.pred_boxes
  109. if target_sizes is not None:
  110. if len(out_logits) != len(target_sizes):
  111. raise ValueError(
  112. "Make sure that you pass in as many target sizes as the batch dimension of the logits"
  113. )
  114. prob = nn.functional.softmax(out_logits, -1)
  115. scores, labels = prob[..., :-1].max(-1)
  116. # Convert to [x0, y0, x1, y1] format
  117. boxes = center_to_corners_format(out_bbox)
  118. # Convert from relative [0, 1] to absolute [0, height] coordinates
  119. if target_sizes is not None:
  120. if isinstance(target_sizes, list):
  121. img_h = torch.Tensor([i[0] for i in target_sizes])
  122. img_w = torch.Tensor([i[1] for i in target_sizes])
  123. else:
  124. img_h, img_w = target_sizes.unbind(1)
  125. scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
  126. boxes = boxes * scale_fct[:, None, :]
  127. results = []
  128. for s, l, b in zip(scores, labels, boxes):
  129. score = s[s > threshold]
  130. label = l[s > threshold]
  131. box = b[s > threshold]
  132. results.append({"scores": score, "labels": label, "boxes": box})
  133. return results
  134. def post_process_instance_segmentation(self):
  135. raise NotImplementedError("Segmentation post-processing is not implemented for Deformable DETR yet.")
  136. def post_process_semantic_segmentation(self):
  137. raise NotImplementedError("Semantic segmentation post-processing is not implemented for Deformable DETR yet.")
  138. def post_process_panoptic_segmentation(self):
  139. raise NotImplementedError("Panoptic segmentation post-processing is not implemented for Deformable DETR yet.")
  140. class YolosImageProcessorPil(DetrImageProcessorPil):
  141. def resize(
  142. self,
  143. image: np.ndarray,
  144. size: SizeDict,
  145. resample: Optional["PILImageResampling"] = None,
  146. **kwargs,
  147. ) -> np.ndarray:
  148. """
  149. Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
  150. int, smaller edge of the image will be matched to this number.
  151. Args:
  152. image (`np.ndarray`):
  153. Image to resize.
  154. size (`SizeDict`):
  155. Size of the image's `(height, width)` dimensions after resizing. Available options are:
  156. - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
  157. Do NOT keep the aspect ratio.
  158. - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
  159. the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
  160. less or equal to `longest_edge`.
  161. - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
  162. aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
  163. `max_width`.
  164. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
  165. Resampling filter to use if resizing the image.
  166. """
  167. resample = resample if resample is not None else self.resample
  168. if size.shortest_edge and size.longest_edge:
  169. # Resize the image so that the shortest edge or the longest edge is of the given size
  170. # while maintaining the aspect ratio of the original image.
  171. new_size = get_size_with_aspect_ratio_yolos(
  172. image.shape[-2:],
  173. size.shortest_edge,
  174. size.longest_edge or size.shortest_edge,
  175. )
  176. elif size.max_height and size.max_width:
  177. new_size = get_image_size_for_max_height_width(image.shape[-2:], size.max_height, size.max_width)
  178. elif size.height and size.width:
  179. new_size = (size.height, size.width)
  180. else:
  181. raise ValueError(
  182. f"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got {size}."
  183. )
  184. image = super().resize(
  185. image,
  186. size=SizeDict(height=new_size[0], width=new_size[1]),
  187. resample=resample,
  188. **kwargs,
  189. )
  190. return image
  191. def post_process_object_detection(
  192. self, outputs, threshold: float = 0.5, target_sizes: TensorType | list[tuple] = None
  193. ):
  194. """
  195. Converts the raw output of [`YolosForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
  196. bottom_right_x, bottom_right_y) format. Only supports PyTorch.
  197. Args:
  198. outputs ([`YolosObjectDetectionOutput`]):
  199. Raw outputs of the model.
  200. threshold (`float`, *optional*):
  201. Score threshold to keep object detection predictions.
  202. target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
  203. Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
  204. `(height, width)` of each image in the batch. If unset, predictions will not be resized.
  205. Returns:
  206. `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
  207. in the batch as predicted by the model.
  208. """
  209. requires_backends(self, ["torch"])
  210. out_logits, out_bbox = outputs.logits, outputs.pred_boxes
  211. if target_sizes is not None:
  212. if len(out_logits) != len(target_sizes):
  213. raise ValueError(
  214. "Make sure that you pass in as many target sizes as the batch dimension of the logits"
  215. )
  216. prob = nn.functional.softmax(out_logits, -1)
  217. scores, labels = prob[..., :-1].max(-1)
  218. # Convert to [x0, y0, x1, y1] format
  219. boxes = center_to_corners_format(out_bbox)
  220. # Convert from relative [0, 1] to absolute [0, height] coordinates
  221. if target_sizes is not None:
  222. if isinstance(target_sizes, list):
  223. img_h = torch.Tensor([i[0] for i in target_sizes])
  224. img_w = torch.Tensor([i[1] for i in target_sizes])
  225. else:
  226. img_h, img_w = target_sizes.unbind(1)
  227. scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
  228. boxes = boxes * scale_fct[:, None, :]
  229. results = []
  230. for s, l, b in zip(scores, labels, boxes):
  231. score = s[s > threshold]
  232. label = l[s > threshold]
  233. box = b[s > threshold]
  234. results.append({"scores": score, "labels": label, "boxes": box})
  235. return results
  236. def post_process_instance_segmentation(self):
  237. raise NotImplementedError("Segmentation post-processing is not implemented for Deformable DETR yet.")
  238. def post_process_semantic_segmentation(self):
  239. raise NotImplementedError("Semantic segmentation post-processing is not implemented for Deformable DETR yet.")
  240. def post_process_panoptic_segmentation(self):
  241. raise NotImplementedError("Panoptic segmentation post-processing is not implemented for Deformable DETR yet.")
  242. __all__ = ["YolosImageProcessor", "YolosImageProcessorPil"]