image.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. from enum import Enum
  2. from typing import Union
  3. import torch
  4. from ..extension import _load_library
  5. from ..utils import _log_api_usage_once
  6. def _has_image_ops():
  7. return False
  8. if _load_library("image"):
  9. def _has_image_ops(): # noqa: F811
  10. return True
  11. def _assert_has_image_ops():
  12. if not _has_image_ops():
  13. raise RuntimeError(
  14. "Couldn't load the image extension. "
  15. "If you built torchvision from source, make sure libjpeg and libpng were found. "
  16. "Set TORCHVISION_WARN_WHEN_EXTENSION_LOADING_FAILS=1 and retry to get more details."
  17. )
  18. class ImageReadMode(Enum):
  19. """Allow automatic conversion to RGB, RGBA, etc while decoding.
  20. .. note::
  21. You don't need to use this struct, you can just pass strings to all
  22. ``mode`` parameters, e.g. ``mode="RGB"``.
  23. The different available modes are the following.
  24. - UNCHANGED: loads the image as-is
  25. - RGB: converts to RGB
  26. - RGBA: converts to RGB with transparency (also aliased as RGB_ALPHA)
  27. - GRAY: converts to grayscale
  28. - GRAY_ALPHA: converts to grayscale with transparency
  29. .. note::
  30. Some decoders won't support all possible values, e.g. GRAY and
  31. GRAY_ALPHA are only supported for PNG and JPEG images.
  32. """
  33. UNCHANGED = 0
  34. GRAY = 1
  35. GRAY_ALPHA = 2
  36. RGB = 3
  37. RGB_ALPHA = 4
  38. RGBA = RGB_ALPHA # Alias for convenience
  39. def read_file(path: str) -> torch.Tensor:
  40. """
  41. Return the bytes contents of a file as a uint8 1D Tensor.
  42. Args:
  43. path (str or ``pathlib.Path``): the path to the file to be read
  44. Returns:
  45. data (Tensor)
  46. """
  47. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  48. _log_api_usage_once(read_file)
  49. _assert_has_image_ops()
  50. data = torch.ops.image.read_file(str(path))
  51. return data
  52. def write_file(filename: str, data: torch.Tensor) -> None:
  53. """
  54. Write the content of an uint8 1D tensor to a file.
  55. Args:
  56. filename (str or ``pathlib.Path``): the path to the file to be written
  57. data (Tensor): the contents to be written to the output file
  58. """
  59. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  60. _log_api_usage_once(write_file)
  61. _assert_has_image_ops()
  62. torch.ops.image.write_file(str(filename), data)
  63. def decode_png(
  64. input: torch.Tensor,
  65. mode: ImageReadMode = ImageReadMode.UNCHANGED,
  66. apply_exif_orientation: bool = False,
  67. ) -> torch.Tensor:
  68. """
  69. Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
  70. The values of the output tensor are in uint8 in [0, 255] for most cases. If
  71. the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
  72. (supported from torchvision ``0.21``). Since uint16 support is limited in
  73. pytorch, we recommend calling
  74. :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
  75. after this function to convert the decoded image into a uint8 or float
  76. tensor.
  77. Args:
  78. input (Tensor[1]): a one dimensional uint8 tensor containing
  79. the raw bytes of the PNG image.
  80. mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
  81. Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
  82. for available modes.
  83. apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
  84. Default: False.
  85. Returns:
  86. output (Tensor[image_channels, image_height, image_width])
  87. """
  88. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  89. _log_api_usage_once(decode_png)
  90. _assert_has_image_ops()
  91. if isinstance(mode, str):
  92. mode = ImageReadMode[mode.upper()]
  93. output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation)
  94. return output
  95. def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
  96. """
  97. Takes an input tensor in CHW layout and returns a buffer with the contents
  98. of its corresponding PNG file.
  99. Args:
  100. input (Tensor[channels, image_height, image_width]): int8 image tensor of
  101. ``c`` channels, where ``c`` must 3 or 1.
  102. compression_level (int): Compression factor for the resulting file, it must be a number
  103. between 0 and 9. Default: 6
  104. Returns:
  105. Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
  106. PNG file.
  107. """
  108. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  109. _log_api_usage_once(encode_png)
  110. _assert_has_image_ops()
  111. output = torch.ops.image.encode_png(input, compression_level)
  112. return output
  113. def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
  114. """
  115. Takes an input tensor in CHW layout (or HW in the case of grayscale images)
  116. and saves it in a PNG file.
  117. Args:
  118. input (Tensor[channels, image_height, image_width]): int8 image tensor of
  119. ``c`` channels, where ``c`` must be 1 or 3.
  120. filename (str or ``pathlib.Path``): Path to save the image.
  121. compression_level (int): Compression factor for the resulting file, it must be a number
  122. between 0 and 9. Default: 6
  123. """
  124. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  125. _log_api_usage_once(write_png)
  126. output = encode_png(input, compression_level)
  127. write_file(filename, output)
  128. def decode_jpeg(
  129. input: Union[torch.Tensor, list[torch.Tensor]],
  130. mode: ImageReadMode = ImageReadMode.UNCHANGED,
  131. device: Union[str, torch.device] = "cpu",
  132. apply_exif_orientation: bool = False,
  133. ) -> Union[torch.Tensor, list[torch.Tensor]]:
  134. """Decode JPEG image(s) into 3D RGB or grayscale Tensor(s), on CPU or CUDA.
  135. The values of the output tensor are uint8 between 0 and 255.
  136. .. note::
  137. When using a CUDA device, passing a list of tensors is more efficient than repeated individual calls to ``decode_jpeg``.
  138. When using CPU the performance is equivalent.
  139. The CUDA version of this function has explicitly been designed with thread-safety in mind.
  140. This function does not return partial results in case of an error.
  141. Args:
  142. input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing
  143. the raw bytes of the JPEG image. The tensor(s) must be on CPU,
  144. regardless of the ``device`` parameter.
  145. mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
  146. Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
  147. for available modes.
  148. device (str or torch.device): The device on which the decoded image will
  149. be stored. If a cuda device is specified, the image will be decoded
  150. with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
  151. supported for CUDA version >= 10.1
  152. .. betastatus:: device parameter
  153. .. warning::
  154. There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
  155. Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
  156. apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
  157. Default: False. Only implemented for JPEG format on CPU.
  158. Returns:
  159. output (Tensor[image_channels, image_height, image_width] or list[Tensor[image_channels, image_height, image_width]]):
  160. The values of the output tensor(s) are uint8 between 0 and 255.
  161. ``output.device`` will be set to the specified ``device``
  162. """
  163. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  164. _log_api_usage_once(decode_jpeg)
  165. _assert_has_image_ops()
  166. if isinstance(device, str):
  167. device = torch.device(device)
  168. if isinstance(mode, str):
  169. mode = ImageReadMode[mode.upper()]
  170. if isinstance(input, list):
  171. if len(input) == 0:
  172. raise ValueError("Input list must contain at least one element")
  173. if not all(isinstance(t, torch.Tensor) for t in input):
  174. raise ValueError("All elements of the input list must be tensors.")
  175. if not all(t.device.type == "cpu" for t in input):
  176. raise ValueError("Input list must contain tensors on CPU.")
  177. if device.type == "cuda":
  178. return torch.ops.image.decode_jpegs_cuda(input, mode.value, device)
  179. else:
  180. return [torch.ops.image.decode_jpeg(img, mode.value, apply_exif_orientation) for img in input]
  181. else: # input is tensor
  182. if input.device.type != "cpu":
  183. raise ValueError("Input tensor must be a CPU tensor")
  184. if device.type == "cuda":
  185. return torch.ops.image.decode_jpegs_cuda([input], mode.value, device)[0]
  186. else:
  187. return torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation)
  188. def encode_jpeg(
  189. input: Union[torch.Tensor, list[torch.Tensor]], quality: int = 75
  190. ) -> Union[torch.Tensor, list[torch.Tensor]]:
  191. """Encode RGB tensor(s) into raw encoded jpeg bytes, on CPU or CUDA.
  192. .. note::
  193. Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``.
  194. For CPU tensors the performance is equivalent.
  195. Args:
  196. input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]):
  197. (list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3
  198. quality (int): Quality of the resulting JPEG file(s). Must be a number between
  199. 1 and 100. Default: 75
  200. Returns:
  201. output (Tensor[1] or list[Tensor[1]]): A (list of) one dimensional uint8 tensor(s) that contain the raw bytes of the JPEG file.
  202. """
  203. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  204. _log_api_usage_once(encode_jpeg)
  205. _assert_has_image_ops()
  206. if quality < 1 or quality > 100:
  207. raise ValueError("Image quality should be a positive number between 1 and 100")
  208. if isinstance(input, list):
  209. if not input:
  210. raise ValueError("encode_jpeg requires at least one input tensor when a list is passed")
  211. if input[0].device.type == "cuda":
  212. return torch.ops.image.encode_jpegs_cuda(input, quality)
  213. else:
  214. return [torch.ops.image.encode_jpeg(image, quality) for image in input]
  215. else: # single input tensor
  216. if input.device.type == "cuda":
  217. return torch.ops.image.encode_jpegs_cuda([input], quality)[0]
  218. else:
  219. return torch.ops.image.encode_jpeg(input, quality)
  220. def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
  221. """
  222. Takes an input tensor in CHW layout and saves it in a JPEG file.
  223. Args:
  224. input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
  225. channels, where ``c`` must be 1 or 3.
  226. filename (str or ``pathlib.Path``): Path to save the image.
  227. quality (int): Quality of the resulting JPEG file, it must be a number
  228. between 1 and 100. Default: 75
  229. """
  230. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  231. _log_api_usage_once(write_jpeg)
  232. output = encode_jpeg(input, quality)
  233. assert isinstance(output, torch.Tensor) # Needed for torchscript
  234. write_file(filename, output)
  235. def decode_image(
  236. input: Union[torch.Tensor, str],
  237. mode: ImageReadMode = ImageReadMode.UNCHANGED,
  238. apply_exif_orientation: bool = False,
  239. ) -> torch.Tensor:
  240. """Decode an image into a uint8 tensor, from a path or from raw encoded bytes.
  241. Currently supported image formats are jpeg, png, gif and webp.
  242. The values of the output tensor are in uint8 in [0, 255] for most cases.
  243. If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
  244. (supported from torchvision ``0.21``). Since uint16 support is limited in
  245. pytorch, we recommend calling
  246. :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
  247. after this function to convert the decoded image into a uint8 or float
  248. tensor.
  249. .. note::
  250. ``decode_image()`` doesn't work yet on AVIF or HEIC images. For these
  251. formats, directly call :func:`~torchvision.io.decode_avif` or
  252. :func:`~torchvision.io.decode_heic`.
  253. Args:
  254. input (Tensor or str or ``pathlib.Path``): The image to decode. If a
  255. tensor is passed, it must be one dimensional uint8 tensor containing
  256. the raw bytes of the image. Otherwise, this must be a path to the image file.
  257. mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
  258. Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
  259. for available modes.
  260. apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
  261. Only applies to JPEG and PNG images. Default: False.
  262. Returns:
  263. output (Tensor[image_channels, image_height, image_width])
  264. """
  265. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  266. _log_api_usage_once(decode_image)
  267. _assert_has_image_ops()
  268. if not isinstance(input, torch.Tensor):
  269. input = read_file(str(input))
  270. if isinstance(mode, str):
  271. mode = ImageReadMode[mode.upper()]
  272. output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
  273. return output
  274. def read_image(
  275. path: str,
  276. mode: ImageReadMode = ImageReadMode.UNCHANGED,
  277. apply_exif_orientation: bool = False,
  278. ) -> torch.Tensor:
  279. """[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead."""
  280. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  281. _log_api_usage_once(read_image)
  282. data = read_file(path)
  283. return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)
  284. def decode_gif(input: torch.Tensor) -> torch.Tensor:
  285. """
  286. Decode a GIF image into a 3 or 4 dimensional RGB Tensor.
  287. The values of the output tensor are uint8 between 0 and 255.
  288. The output tensor has shape ``(C, H, W)`` if there is only one image in the
  289. GIF, and ``(N, C, H, W)`` if there are ``N`` images.
  290. Args:
  291. input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
  292. the raw bytes of the GIF image.
  293. Returns:
  294. output (Tensor[image_channels, image_height, image_width] or Tensor[num_images, image_channels, image_height, image_width])
  295. """
  296. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  297. _log_api_usage_once(decode_gif)
  298. _assert_has_image_ops()
  299. return torch.ops.image.decode_gif(input)
  300. def decode_webp(
  301. input: torch.Tensor,
  302. mode: ImageReadMode = ImageReadMode.UNCHANGED,
  303. ) -> torch.Tensor:
  304. """
  305. Decode a WEBP image into a 3 dimensional RGB[A] Tensor.
  306. The values of the output tensor are uint8 between 0 and 255.
  307. Args:
  308. input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
  309. the raw bytes of the WEBP image.
  310. mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
  311. Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
  312. for available modes.
  313. Returns:
  314. Decoded image (Tensor[image_channels, image_height, image_width])
  315. """
  316. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  317. _log_api_usage_once(decode_webp)
  318. _assert_has_image_ops()
  319. if isinstance(mode, str):
  320. mode = ImageReadMode[mode.upper()]
  321. return torch.ops.image.decode_webp(input, mode.value)
  322. # TODO_AVIF_HEIC: Better support for torchscript. Scripting decode_avif of
  323. # decode_heic currently fails, mainly because of the logic
  324. # _load_extra_decoders_once() (using global variables, try/except statements,
  325. # etc.).
  326. # The ops (torch.ops.extra_decoders_ns.decode_*) are otherwise torchscript-able,
  327. # and users who need torchscript can always just wrap those.
  328. # TODO_AVIF_HEIC: decode_image() should work for those. The key technical issue
  329. # we have here is that the format detection logic of decode_image() is
  330. # implemented in torchvision, and torchvision has zero knowledge of
  331. # torchvision-extra-decoders, so we cannot call the AVIF/HEIC C++ decoders
  332. # (those in torchvision-extra-decoders) from there.
  333. # A trivial check that could be done within torchvision would be to check the
  334. # file extension, if a path was passed. We could also just implement the
  335. # AVIF/HEIC detection logic in Python as a fallback, if the file detection
  336. # didn't find any format. In any case: properly determining whether a file is
  337. # HEIC is far from trivial, and relying on libmagic would probably be best
  338. _EXTRA_DECODERS_ALREADY_LOADED = False
  339. def _load_extra_decoders_once():
  340. global _EXTRA_DECODERS_ALREADY_LOADED
  341. if _EXTRA_DECODERS_ALREADY_LOADED:
  342. return
  343. try:
  344. import torchvision_extra_decoders
  345. # torchvision-extra-decoders only supports linux for now. BUT, users on
  346. # e.g. MacOS can still install it: they will get the pure-python
  347. # 0.0.0.dev version:
  348. # https://pypi.org/project/torchvision-extra-decoders/0.0.0.dev0, which
  349. # is a dummy version that was created to reserve the namespace on PyPI.
  350. # We have to check that expose_extra_decoders() exists for those users,
  351. # so we can properly error on non-Linux archs.
  352. assert hasattr(torchvision_extra_decoders, "expose_extra_decoders")
  353. except (AssertionError, ImportError) as e:
  354. raise RuntimeError(
  355. "In order to enable the AVIF and HEIC decoding capabilities of "
  356. "torchvision, you need to `pip install torchvision-extra-decoders`. "
  357. "Just install the package, you don't need to update your code. "
  358. "This is only supported on Linux, and this feature is still in BETA stage. "
  359. "Please let us know of any issue: https://github.com/pytorch/vision/issues/new/choose. "
  360. "Note that `torchvision-extra-decoders` is released under the LGPL license. "
  361. ) from e
  362. # This will expose torch.ops.extra_decoders_ns.decode_avif and torch.ops.extra_decoders_ns.decode_heic
  363. torchvision_extra_decoders.expose_extra_decoders()
  364. _EXTRA_DECODERS_ALREADY_LOADED = True
  365. def decode_avif(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
  366. """Decode an AVIF image into a 3 dimensional RGB[A] Tensor.
  367. .. warning::
  368. In order to enable the AVIF decoding capabilities of torchvision, you
  369. first need to run ``pip install torchvision-extra-decoders``. Just
  370. install the package, you don't need to update your code. This is only
  371. supported on Linux, and this feature is still in BETA stage. Please let
  372. us know of any issue:
  373. https://github.com/pytorch/vision/issues/new/choose. Note that
  374. `torchvision-extra-decoders
  375. <https://github.com/meta-pytorch/torchvision-extra-decoders/>`_ is
  376. released under the LGPL license.
  377. The values of the output tensor are in uint8 in [0, 255] for most images. If
  378. the image has a bit-depth of more than 8, then the output tensor is uint16
  379. in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
  380. calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
  381. ``scale=True`` after this function to convert the decoded image into a uint8
  382. or float tensor.
  383. Args:
  384. input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
  385. the raw bytes of the AVIF image.
  386. mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
  387. Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
  388. for available modes.
  389. Returns:
  390. Decoded image (Tensor[image_channels, image_height, image_width])
  391. """
  392. _load_extra_decoders_once()
  393. if input.dtype != torch.uint8:
  394. raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}")
  395. return torch.ops.extra_decoders_ns.decode_avif(input, mode.value)
  396. def decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
  397. """Decode an HEIC image into a 3 dimensional RGB[A] Tensor.
  398. .. warning::
  399. In order to enable the HEIC decoding capabilities of torchvision, you
  400. first need to run ``pip install torchvision-extra-decoders``. Just
  401. install the package, you don't need to update your code. This is only
  402. supported on Linux, and this feature is still in BETA stage. Please let
  403. us know of any issue:
  404. https://github.com/pytorch/vision/issues/new/choose. Note that
  405. `torchvision-extra-decoders
  406. <https://github.com/meta-pytorch/torchvision-extra-decoders/>`_ is
  407. released under the LGPL license.
  408. The values of the output tensor are in uint8 in [0, 255] for most images. If
  409. the image has a bit-depth of more than 8, then the output tensor is uint16
  410. in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
  411. calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
  412. ``scale=True`` after this function to convert the decoded image into a uint8
  413. or float tensor.
  414. Args:
  415. input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
  416. the raw bytes of the HEIC image.
  417. mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
  418. Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
  419. for available modes.
  420. Returns:
  421. Decoded image (Tensor[image_channels, image_height, image_width])
  422. """
  423. _load_extra_decoders_once()
  424. if input.dtype != torch.uint8:
  425. raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}")
  426. return torch.ops.extra_decoders_ns.decode_heic(input, mode.value)