utils.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. import collections
  2. import math
  3. import pathlib
  4. import warnings
  5. from itertools import repeat
  6. from types import FunctionType
  7. from typing import Any, BinaryIO, Optional, Union
  8. import numpy as np
  9. import torch
  10. from PIL import __version__ as PILLOW_VERSION_STRING, Image, ImageColor, ImageDraw, ImageFont
  11. __all__ = [
  12. "_Image_fromarray",
  13. "make_grid",
  14. "save_image",
  15. "draw_bounding_boxes",
  16. "draw_segmentation_masks",
  17. "draw_keypoints",
  18. "flow_to_image",
  19. ]
  20. @torch.no_grad()
  21. def make_grid(
  22. tensor: Union[torch.Tensor, list[torch.Tensor]],
  23. nrow: int = 8,
  24. padding: int = 2,
  25. normalize: bool = False,
  26. value_range: Optional[tuple[int, int]] = None,
  27. scale_each: bool = False,
  28. pad_value: float = 0.0,
  29. ) -> torch.Tensor:
  30. """
  31. Make a grid of images.
  32. Args:
  33. tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
  34. or a list of images all of the same size.
  35. nrow (int, optional): Number of images displayed in each row of the grid.
  36. The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
  37. padding (int, optional): amount of padding. Default: ``2``.
  38. normalize (bool, optional): If True, shift the image to the range (0, 1),
  39. by the min and max values specified by ``value_range``. Default: ``False``.
  40. value_range (tuple, optional): tuple (min, max) where min and max are numbers,
  41. then these numbers are used to normalize the image. By default, min and max
  42. are computed from the tensor.
  43. scale_each (bool, optional): If ``True``, scale each image in the batch of
  44. images separately rather than the (min, max) over all images. Default: ``False``.
  45. pad_value (float, optional): Value for the padded pixels. Default: ``0``.
  46. Returns:
  47. grid (Tensor): the tensor containing grid of images.
  48. """
  49. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  50. _log_api_usage_once(make_grid)
  51. if not torch.is_tensor(tensor):
  52. if isinstance(tensor, list):
  53. for t in tensor:
  54. if not torch.is_tensor(t):
  55. raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
  56. else:
  57. raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
  58. # if list of tensors, convert to a 4D mini-batch Tensor
  59. if isinstance(tensor, list):
  60. tensor = torch.stack(tensor, dim=0)
  61. if tensor.dim() == 2: # single image H x W
  62. tensor = tensor.unsqueeze(0)
  63. if tensor.dim() == 3: # single image
  64. if tensor.size(0) == 1: # if single-channel, convert to 3-channel
  65. tensor = torch.cat((tensor, tensor, tensor), 0)
  66. tensor = tensor.unsqueeze(0)
  67. if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
  68. tensor = torch.cat((tensor, tensor, tensor), 1)
  69. if normalize is True:
  70. tensor = tensor.clone() # avoid modifying tensor in-place
  71. if value_range is not None and not isinstance(value_range, tuple):
  72. raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
  73. def norm_ip(img, low, high):
  74. img.clamp_(min=low, max=high)
  75. img.sub_(low).div_(max(high - low, 1e-5))
  76. def norm_range(t, value_range):
  77. if value_range is not None:
  78. norm_ip(t, value_range[0], value_range[1])
  79. else:
  80. norm_ip(t, float(t.min()), float(t.max()))
  81. if scale_each is True:
  82. for t in tensor: # loop over mini-batch dimension
  83. norm_range(t, value_range)
  84. else:
  85. norm_range(tensor, value_range)
  86. if not isinstance(tensor, torch.Tensor):
  87. raise TypeError("tensor should be of type torch.Tensor")
  88. if tensor.size(0) == 1:
  89. return tensor.squeeze(0)
  90. # make the mini-batch of images into a grid
  91. nmaps = tensor.size(0)
  92. xmaps = min(nrow, nmaps)
  93. ymaps = int(math.ceil(float(nmaps) / xmaps))
  94. height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
  95. num_channels = tensor.size(1)
  96. grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
  97. k = 0
  98. for y in range(ymaps):
  99. for x in range(xmaps):
  100. if k >= nmaps:
  101. break
  102. # Tensor.copy_() is a valid method but seems to be missing from the stubs
  103. # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
  104. grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
  105. 2, x * width + padding, width - padding
  106. ).copy_(tensor[k])
  107. k = k + 1
  108. return grid
  109. class _ImageDrawTV(ImageDraw.ImageDraw):
  110. """
  111. A wrapper around PIL.ImageDraw to add functionalities for drawing rotated bounding boxes.
  112. """
  113. def oriented_rectangle(self, xy, fill=None, outline=None, width=1):
  114. self.dashed_line(((xy[0], xy[1]), (xy[2], xy[3])), width=width, fill=outline)
  115. for i in range(2, len(xy), 2):
  116. self.line(
  117. ((xy[i], xy[i + 1]), (xy[(i + 2) % len(xy)], xy[(i + 3) % len(xy)])),
  118. width=width,
  119. fill=outline,
  120. )
  121. self.polygon(xy, fill=fill, outline=None, width=0)
  122. def dashed_line(self, xy, fill=None, width=0, joint=None, dash_length=5, space_length=5):
  123. # Calculate the total length of the line
  124. total_length = 0
  125. for i in range(1, len(xy)):
  126. x1, y1 = xy[i - 1]
  127. x2, y2 = xy[i]
  128. total_length += ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5
  129. # Initialize the current position and the current dash
  130. current_position = 0
  131. current_dash = True
  132. # Iterate over the coordinates of the line
  133. for i in range(1, len(xy)):
  134. x1, y1 = xy[i - 1]
  135. x2, y2 = xy[i]
  136. # Calculate the length of this segment
  137. segment_length = ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5
  138. # While there are still dashes to draw on this segment
  139. while segment_length > 0:
  140. # Calculate the length of this dash
  141. dash_length_to_draw = min(segment_length, dash_length if current_dash else space_length)
  142. # Calculate the end point of this dash
  143. dx = x2 - x1
  144. dy = y2 - y1
  145. angle = math.atan2(dy, dx)
  146. end_x = x1 + math.cos(angle) * dash_length_to_draw
  147. end_y = y1 + math.sin(angle) * dash_length_to_draw
  148. # If this is a dash, draw it
  149. if current_dash:
  150. self.line([(x1, y1), (end_x, end_y)], fill, width, joint)
  151. # Update the current position and the current dash
  152. current_position += dash_length_to_draw
  153. segment_length -= dash_length_to_draw
  154. x1, y1 = end_x, end_y
  155. current_dash = not current_dash
  156. def _Image_fromarray(
  157. obj: np.ndarray,
  158. mode: str,
  159. ) -> Image.Image:
  160. """
  161. A wrapper around PIL.Image.fromarray to mitigate the deprecation of the
  162. mode paramter. See:
  163. https://pillow.readthedocs.io/en/stable/releasenotes/11.3.0.html#image-fromarray-mode-parameter
  164. """
  165. # This may throw if the version string is from an install that comes from a
  166. # non-stable or development version. We'll fall back to the old behavior in
  167. # such cases.
  168. try:
  169. PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION_STRING.split("."))
  170. except Exception:
  171. PILLOW_VERSION = None
  172. if PILLOW_VERSION is not None and PILLOW_VERSION >= (11, 3):
  173. # The actual PR that implements the deprecation has more context for why
  174. # it was done, and also points out some problems:
  175. #
  176. # https://github.com/python-pillow/Pillow/pull/9018
  177. #
  178. # Our use case falls into those problems. We actually rely on the old
  179. # behavior of Image.fromarray():
  180. #
  181. # new behavior: PIL will infer the image mode from the data passed
  182. # in. That is, the type and shape determines the mode.
  183. #
  184. # old behiavor: The mode will change how PIL reads the image,
  185. # regardless of the data. That is, it will make the
  186. # data work with the mode.
  187. #
  188. # Our uses of Image.fromarray() are effectively a "turn into PIL image
  189. # AND convert the kind" operation. In particular, in
  190. # functional.to_pil_image() and transforms.ToPILImage.
  191. #
  192. # However, Image.frombuffer() still performs this conversion. The code
  193. # below is lifted from the new implementation of Image.fromarray(). We
  194. # omit the code that infers the mode, and use the code that figures out
  195. # from the data passed in (obj) what the correct parameters are to
  196. # Image.frombuffer().
  197. #
  198. # Note that the alternate solution below does not work:
  199. #
  200. # img = Image.fromarray(obj)
  201. # img = img.convert(mode)
  202. #
  203. # The resulting image has very different actual pixel values than before.
  204. #
  205. # TODO: Issue #9151. Pillow has an open PR to restore the functionality
  206. # we rely on:
  207. #
  208. # https://github.com/python-pillow/Pillow/pull/9063
  209. #
  210. # When that is part of a release, we can revisit this hack below.
  211. arr = obj.__array_interface__
  212. shape = arr["shape"]
  213. ndim = len(shape)
  214. size = 1 if ndim == 1 else shape[1], shape[0]
  215. strides = arr.get("strides", None)
  216. contiguous_obj: Union[np.ndarray, bytes] = obj
  217. if strides is not None:
  218. # We require that the data is contiguous; if it is not, we need to
  219. # convert it into a contiguous format.
  220. if hasattr(obj, "tobytes"):
  221. contiguous_obj = obj.tobytes()
  222. elif hasattr(obj, "tostring"):
  223. contiguous_obj = obj.tostring()
  224. else:
  225. raise ValueError("Unable to convert obj into contiguous format")
  226. return Image.frombuffer(mode, size, contiguous_obj, "raw", mode, 0, 1)
  227. else:
  228. return Image.fromarray(obj, mode)
  229. @torch.no_grad()
  230. def save_image(
  231. tensor: Union[torch.Tensor, list[torch.Tensor]],
  232. fp: Union[str, pathlib.Path, BinaryIO],
  233. format: Optional[str] = None,
  234. **kwargs,
  235. ) -> None:
  236. """
  237. Save a given Tensor into an image file.
  238. Args:
  239. tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
  240. saves the tensor as a grid of images by calling ``make_grid``.
  241. fp (string or file object): A filename or a file object
  242. format(Optional): If omitted, the format to use is determined from the filename extension.
  243. If a file object was used instead of a filename, this parameter should always be used.
  244. **kwargs: Other arguments are documented in ``make_grid``.
  245. """
  246. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  247. _log_api_usage_once(save_image)
  248. grid = make_grid(tensor, **kwargs)
  249. # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
  250. ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
  251. im = Image.fromarray(ndarr)
  252. im.save(fp, format=format)
  253. @torch.no_grad()
  254. def draw_bounding_boxes(
  255. image: torch.Tensor,
  256. boxes: torch.Tensor,
  257. labels: Optional[list[str]] = None,
  258. colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None,
  259. fill: Optional[bool] = False,
  260. width: int = 1,
  261. font: Optional[str] = None,
  262. font_size: Optional[int] = None,
  263. label_colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None,
  264. label_background_colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None,
  265. fill_labels: bool = False,
  266. ) -> torch.Tensor:
  267. """
  268. Draws bounding boxes on given RGB image.
  269. The image values should be uint8 in [0, 255] or float in [0, 1].
  270. If fill is True, Resulting Tensor should be saved as PNG image.
  271. Args:
  272. image (Tensor): Tensor of shape (C, H, W) and dtype uint8 or float.
  273. boxes (Tensor): Tensor of size (N, 4) or (N, 8) containing bounding boxes.
  274. For (N, 4), the format is (xmin, ymin, xmax, ymax) and the boxes are absolute coordinates with respect to the image.
  275. In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`.
  276. For (N, 8), the format is (x1, y1, x2, y2, x3, y3, x4, y4) and the boxes are absolute coordinates with respect to the underlying
  277. object, so no need to verify the latter inequalities.
  278. labels (List[str]): List containing the labels of bounding boxes.
  279. colors (color or list of colors, optional): List containing the colors
  280. of the boxes or single color for all boxes. The color can be represented as
  281. PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
  282. By default, random colors are generated for boxes.
  283. fill (bool): If `True` fills the bounding box with specified color.
  284. width (int): Width of bounding box.
  285. font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
  286. also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
  287. `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
  288. font_size (int): The requested font size in points.
  289. label_colors (color or list of colors, optional): Colors for the label text. See the description of the
  290. `colors` argument for details. Defaults to the same colors used for the boxes, or to black if ``fill_labels`` is True.
  291. label_background_colors (color or list of colors, optional): Colors for the label text box fill. Defaults to the
  292. same colors used for the boxes. Ignored when ``fill_labels`` is False.
  293. fill_labels (bool): If `True` fills the label background with specified color (from the ``label_background_colors`` parameter,
  294. or from the ``colors`` parameter if not specified). Default: False.
  295. Returns:
  296. img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
  297. """
  298. import torchvision.transforms.v2.functional as F # noqa
  299. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  300. _log_api_usage_once(draw_bounding_boxes)
  301. if not isinstance(image, torch.Tensor):
  302. raise TypeError(f"Tensor expected, got {type(image)}")
  303. elif not (image.dtype == torch.uint8 or image.is_floating_point()):
  304. raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
  305. elif image.dim() != 3:
  306. raise ValueError("Pass individual images, not batches")
  307. elif image.size(0) not in {1, 3}:
  308. raise ValueError("Only grayscale and RGB images are supported")
  309. elif boxes.shape[-1] == 4 and ((boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any()):
  310. raise ValueError(
  311. "Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them"
  312. )
  313. num_boxes = boxes.shape[0]
  314. if num_boxes == 0:
  315. warnings.warn("boxes doesn't contain any box. No box was drawn")
  316. return image
  317. if labels is None:
  318. labels: Union[list[str], list[None]] = [None] * num_boxes # type: ignore[no-redef]
  319. elif len(labels) != num_boxes:
  320. raise ValueError(
  321. f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
  322. )
  323. colors = _parse_colors(colors, num_objects=num_boxes) # type: ignore[assignment]
  324. if label_colors or fill_labels:
  325. label_colors = _parse_colors(label_colors if label_colors else "black", num_objects=num_boxes) # type: ignore[assignment]
  326. else:
  327. label_colors = colors.copy() # type: ignore[assignment]
  328. if fill_labels and label_background_colors:
  329. label_background_colors = _parse_colors(label_background_colors, num_objects=num_boxes) # type: ignore[assignment]
  330. else:
  331. label_background_colors = colors.copy() # type: ignore[assignment]
  332. if font is None:
  333. if font_size is not None:
  334. warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
  335. txt_font = ImageFont.load_default()
  336. else:
  337. txt_font = ImageFont.truetype(font=font, size=font_size or 10)
  338. # Handle Grayscale images
  339. if image.size(0) == 1:
  340. image = torch.tile(image, (3, 1, 1))
  341. original_dtype = image.dtype
  342. if original_dtype.is_floating_point:
  343. image = F.to_dtype(image, dtype=torch.uint8, scale=True)
  344. img_to_draw = F.to_pil_image(image)
  345. img_boxes = boxes.to(torch.int64).tolist()
  346. if fill:
  347. draw = _ImageDrawTV(img_to_draw, "RGBA")
  348. else:
  349. draw = _ImageDrawTV(img_to_draw)
  350. for bbox, color, label, label_color, label_bg_color in zip(img_boxes, colors, labels, label_colors, label_background_colors): # type: ignore[arg-type]
  351. draw_method = draw.oriented_rectangle if len(bbox) > 4 else draw.rectangle
  352. fill_color = color + (100,) if fill else None
  353. draw_method(bbox, width=width, outline=color, fill=fill_color)
  354. if label is not None:
  355. box_margin = 1
  356. margin = width + box_margin
  357. if fill_labels:
  358. left, top, right, bottom = draw.textbbox((bbox[0] + margin, bbox[1] + margin), label, font=txt_font)
  359. draw.rectangle(
  360. (left - box_margin, top - box_margin, right + box_margin, bottom + box_margin), fill=label_bg_color # type: ignore[arg-type]
  361. )
  362. draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=label_color, font=txt_font) # type: ignore[arg-type]
  363. out = F.pil_to_tensor(img_to_draw)
  364. if original_dtype.is_floating_point:
  365. out = F.to_dtype(out, dtype=original_dtype, scale=True)
  366. return out
  367. @torch.no_grad()
  368. def draw_segmentation_masks(
  369. image: torch.Tensor,
  370. masks: torch.Tensor,
  371. alpha: float = 0.8,
  372. colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None,
  373. ) -> torch.Tensor:
  374. """
  375. Draws segmentation masks on given RGB image.
  376. The image values should be uint8 in [0, 255] or float in [0, 1].
  377. Args:
  378. image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
  379. masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
  380. alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
  381. 0 means full transparency, 1 means no transparency.
  382. colors (color or list of colors, optional): List containing the colors
  383. of the masks or single color for all masks. The color can be represented as
  384. PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
  385. By default, random colors are generated for each mask.
  386. Returns:
  387. img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
  388. """
  389. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  390. _log_api_usage_once(draw_segmentation_masks)
  391. if not isinstance(image, torch.Tensor):
  392. raise TypeError(f"The image must be a tensor, got {type(image)}")
  393. elif not (image.dtype == torch.uint8 or image.is_floating_point()):
  394. raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
  395. elif image.dim() != 3:
  396. raise ValueError("Pass individual images, not batches")
  397. elif image.size()[0] != 3:
  398. raise ValueError("Pass an RGB image. Other Image formats are not supported")
  399. if masks.ndim == 2:
  400. masks = masks[None, :, :]
  401. if masks.ndim != 3:
  402. raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
  403. if masks.dtype != torch.bool:
  404. raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
  405. if masks.shape[-2:] != image.shape[-2:]:
  406. raise ValueError("The image and the masks must have the same height and width")
  407. num_masks = masks.size()[0]
  408. overlapping_masks = masks.sum(dim=0) > 1
  409. if num_masks == 0:
  410. warnings.warn("masks doesn't contain any mask. No mask was drawn")
  411. return image
  412. original_dtype = image.dtype
  413. colors = [
  414. torch.tensor(color, dtype=original_dtype, device=image.device)
  415. for color in _parse_colors(colors, num_objects=num_masks, dtype=original_dtype)
  416. ]
  417. img_to_draw = image.detach().clone()
  418. # TODO: There might be a way to vectorize this
  419. for mask, color in zip(masks, colors):
  420. img_to_draw[:, mask] = color[:, None]
  421. img_to_draw[:, overlapping_masks] = 0
  422. out = image * (1 - alpha) + img_to_draw * alpha
  423. # Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype
  424. return out.to(original_dtype)
  425. @torch.no_grad()
  426. def draw_keypoints(
  427. image: torch.Tensor,
  428. keypoints: torch.Tensor,
  429. connectivity: Optional[list[tuple[int, int]]] = None,
  430. colors: Optional[Union[str, tuple[int, int, int]]] = None,
  431. radius: int = 2,
  432. width: int = 3,
  433. visibility: Optional[torch.Tensor] = None,
  434. ) -> torch.Tensor:
  435. """
  436. Draws Keypoints on given RGB image.
  437. The image values should be uint8 in [0, 255] or float in [0, 1].
  438. Keypoints can be drawn for multiple instances at a time.
  439. This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.
  440. Args:
  441. image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
  442. keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
  443. in the format [x, y].
  444. connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
  445. to be connected.
  446. If at least one of the two connected keypoints has a ``visibility`` of False,
  447. this specific connection is not drawn.
  448. Exclusions due to invisibility are computed per-instance.
  449. colors (str, Tuple): The color can be represented as
  450. PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
  451. radius (int): Integer denoting radius of keypoint.
  452. width (int): Integer denoting width of line connecting keypoints.
  453. visibility (Tensor): Tensor of shape (num_instances, K) specifying the visibility of the K
  454. keypoints for each of the N instances.
  455. True means that the respective keypoint is visible and should be drawn.
  456. False means invisible, so neither the point nor possible connections containing it are drawn.
  457. The input tensor will be cast to bool.
  458. Default ``None`` means that all the keypoints are visible.
  459. For more details, see :ref:`draw_keypoints_with_visibility`.
  460. Returns:
  461. img (Tensor[C, H, W]): Image Tensor with keypoints drawn.
  462. """
  463. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  464. _log_api_usage_once(draw_keypoints)
  465. # validate image
  466. if not isinstance(image, torch.Tensor):
  467. raise TypeError(f"The image must be a tensor, got {type(image)}")
  468. elif not (image.dtype == torch.uint8 or image.is_floating_point()):
  469. raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
  470. elif image.dim() != 3:
  471. raise ValueError("Pass individual images, not batches")
  472. elif image.size()[0] != 3:
  473. raise ValueError("Pass an RGB image. Other Image formats are not supported")
  474. # validate keypoints
  475. if keypoints.ndim != 3:
  476. raise ValueError("keypoints must be of shape (num_instances, K, 2)")
  477. # validate visibility
  478. if visibility is None: # set default
  479. visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool)
  480. if visibility.ndim == 3:
  481. # If visibility was passed as pred.split([2, 1], dim=-1), it will be of shape (num_instances, K, 1).
  482. # We make sure it is of shape (num_instances, K). This isn't documented, we're just being nice.
  483. visibility = visibility.squeeze(-1)
  484. if visibility.ndim != 2:
  485. raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}")
  486. if visibility.shape != keypoints.shape[:-1]:
  487. raise ValueError(
  488. "keypoints and visibility must have the same dimensionality for num_instances and K. "
  489. f"Got {visibility.shape=} and {keypoints.shape=}"
  490. )
  491. original_dtype = image.dtype
  492. if original_dtype.is_floating_point:
  493. from torchvision.transforms.v2.functional import to_dtype # noqa
  494. image = to_dtype(image, dtype=torch.uint8, scale=True)
  495. ndarr = image.permute(1, 2, 0).cpu().numpy()
  496. img_to_draw = Image.fromarray(ndarr)
  497. draw = ImageDraw.Draw(img_to_draw)
  498. img_kpts = keypoints.to(torch.int64).tolist()
  499. img_vis = visibility.cpu().bool().tolist()
  500. for kpt_inst, vis_inst in zip(img_kpts, img_vis):
  501. for kpt_coord, kp_vis in zip(kpt_inst, vis_inst):
  502. if not kp_vis:
  503. continue
  504. x1 = kpt_coord[0] - radius
  505. x2 = kpt_coord[0] + radius
  506. y1 = kpt_coord[1] - radius
  507. y2 = kpt_coord[1] + radius
  508. draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
  509. if connectivity:
  510. for connection in connectivity:
  511. if (not vis_inst[connection[0]]) or (not vis_inst[connection[1]]):
  512. continue
  513. start_pt_x = kpt_inst[connection[0]][0]
  514. start_pt_y = kpt_inst[connection[0]][1]
  515. end_pt_x = kpt_inst[connection[1]][0]
  516. end_pt_y = kpt_inst[connection[1]][1]
  517. draw.line(
  518. ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
  519. width=width,
  520. )
  521. out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
  522. if original_dtype.is_floating_point:
  523. out = to_dtype(out, dtype=original_dtype, scale=True)
  524. return out
  525. # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
  526. @torch.no_grad()
  527. def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
  528. """
  529. Converts a flow to an RGB image.
  530. Args:
  531. flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
  532. Returns:
  533. img (Tensor): Image Tensor of dtype uint8 where each color corresponds
  534. to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
  535. """
  536. if flow.dtype != torch.float:
  537. raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
  538. orig_shape = flow.shape
  539. if flow.ndim == 3:
  540. flow = flow[None] # Add batch dim
  541. if flow.ndim != 4 or flow.shape[1] != 2:
  542. raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
  543. max_norm = torch.sum(flow**2, dim=1).sqrt().max()
  544. epsilon = torch.finfo((flow).dtype).eps
  545. normalized_flow = flow / (max_norm + epsilon)
  546. img = _normalized_flow_to_image(normalized_flow)
  547. if len(orig_shape) == 3:
  548. img = img[0] # Remove batch dim
  549. return img
  550. @torch.no_grad()
  551. def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
  552. """
  553. Converts a batch of normalized flow to an RGB image.
  554. Args:
  555. normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
  556. Returns:
  557. img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
  558. """
  559. N, _, H, W = normalized_flow.shape
  560. device = normalized_flow.device
  561. flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
  562. colorwheel = _make_colorwheel().to(device) # shape [55x3]
  563. num_cols = colorwheel.shape[0]
  564. norm = torch.sum(normalized_flow**2, dim=1).sqrt()
  565. a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
  566. fk = (a + 1) / 2 * (num_cols - 1)
  567. k0 = torch.floor(fk).to(torch.long)
  568. k1 = k0 + 1
  569. k1[k1 == num_cols] = 0
  570. f = fk - k0
  571. for c in range(colorwheel.shape[1]):
  572. tmp = colorwheel[:, c]
  573. col0 = tmp[k0] / 255.0
  574. col1 = tmp[k1] / 255.0
  575. col = (1 - f) * col0 + f * col1
  576. col = 1 - norm * (1 - col)
  577. flow_image[:, c, :, :] = torch.floor(255 * col)
  578. return flow_image
  579. def _make_colorwheel() -> torch.Tensor:
  580. """
  581. Generates a color wheel for optical flow visualization as presented in:
  582. Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
  583. URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
  584. Returns:
  585. colorwheel (Tensor[55, 3]): Colorwheel Tensor.
  586. """
  587. RY = 15
  588. YG = 6
  589. GC = 4
  590. CB = 11
  591. BM = 13
  592. MR = 6
  593. ncols = RY + YG + GC + CB + BM + MR
  594. colorwheel = torch.zeros((ncols, 3))
  595. col = 0
  596. # RY
  597. colorwheel[0:RY, 0] = 255
  598. colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
  599. col = col + RY
  600. # YG
  601. colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
  602. colorwheel[col : col + YG, 1] = 255
  603. col = col + YG
  604. # GC
  605. colorwheel[col : col + GC, 1] = 255
  606. colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
  607. col = col + GC
  608. # CB
  609. colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
  610. colorwheel[col : col + CB, 2] = 255
  611. col = col + CB
  612. # BM
  613. colorwheel[col : col + BM, 2] = 255
  614. colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
  615. col = col + BM
  616. # MR
  617. colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
  618. colorwheel[col : col + MR, 0] = 255
  619. return colorwheel
  620. def _generate_color_palette(num_objects: int):
  621. palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
  622. return [tuple((i * palette) % 255) for i in range(num_objects)]
  623. def _parse_colors(
  624. colors: Union[None, str, tuple[int, int, int], list[Union[str, tuple[int, int, int]]]],
  625. *,
  626. num_objects: int,
  627. dtype: torch.dtype = torch.uint8,
  628. ) -> list[tuple[int, int, int]]:
  629. """
  630. Parses a specification of colors for a set of objects.
  631. Args:
  632. colors: A specification of colors for the objects. This can be one of the following:
  633. - None: to generate a color palette automatically.
  634. - A list of colors: where each color is either a string (specifying a named color) or an RGB tuple.
  635. - A string or an RGB tuple: to use the same color for all objects.
  636. If `colors` is a tuple, it should be a 3-tuple specifying the RGB values of the color.
  637. If `colors` is a list, it should have at least as many elements as the number of objects to color.
  638. num_objects (int): The number of objects to color.
  639. Returns:
  640. A list of 3-tuples, specifying the RGB values of the colors.
  641. Raises:
  642. ValueError: If the number of colors in the list is less than the number of objects to color.
  643. If `colors` is not a list, tuple, string or None.
  644. """
  645. if colors is None:
  646. colors = _generate_color_palette(num_objects)
  647. elif isinstance(colors, list):
  648. if len(colors) < num_objects:
  649. raise ValueError(
  650. f"Number of colors must be equal or larger than the number of objects, but got {len(colors)} < {num_objects}."
  651. )
  652. elif not isinstance(colors, (tuple, str)):
  653. raise ValueError(f"colors must be a tuple or a string, or a list thereof, but got {colors}.")
  654. elif isinstance(colors, tuple) and len(colors) != 3:
  655. raise ValueError(f"If passed as tuple, colors should be an RGB triplet, but got {colors}.")
  656. else: # colors specifies a single color for all objects
  657. colors = [colors] * num_objects
  658. colors = [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors]
  659. if dtype.is_floating_point: # [0, 255] -> [0, 1]
  660. colors = [tuple(v / 255 for v in color) for color in colors] # type: ignore[union-attr]
  661. return colors # type: ignore[return-value]
  662. def _log_api_usage_once(obj: Any) -> None:
  663. """
  664. Logs API usage(module and name) within an organization.
  665. In a large ecosystem, it's often useful to track the PyTorch and
  666. TorchVision APIs usage. This API provides the similar functionality to the
  667. logging module in the Python stdlib. It can be used for debugging purpose
  668. to log which methods are used and by default it is inactive, unless the user
  669. manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
  670. Please note it is triggered only once for the same API call within a process.
  671. It does not collect any data from open-source users since it is no-op by default.
  672. For more information, please refer to
  673. * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
  674. * Logging policy: https://github.com/pytorch/vision/issues/5052;
  675. Args:
  676. obj (class instance or method): an object to extract info from.
  677. """
  678. module = obj.__module__
  679. if not module.startswith("torchvision"):
  680. module = f"torchvision.internal.{module}"
  681. name = obj.__class__.__name__
  682. if isinstance(obj, FunctionType):
  683. name = obj.__name__
  684. torch._C._log_api_usage_once(f"{module}.{name}")
  685. def _make_ntuple(x: Any, n: int) -> tuple[Any, ...]:
  686. """
  687. Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
  688. Otherwise, we will make a tuple of length n, all with value of x.
  689. reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
  690. Args:
  691. x (Any): input value
  692. n (int): length of the resulting tuple
  693. """
  694. if isinstance(x, collections.abc.Iterable):
  695. return tuple(x)
  696. return tuple(repeat(x, n))