| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527 |
- from enum import Enum
- from typing import Union
- import torch
- from ..extension import _load_library
- from ..utils import _log_api_usage_once
- def _has_image_ops():
- return False
- if _load_library("image"):
- def _has_image_ops(): # noqa: F811
- return True
- def _assert_has_image_ops():
- if not _has_image_ops():
- raise RuntimeError(
- "Couldn't load the image extension. "
- "If you built torchvision from source, make sure libjpeg and libpng were found. "
- "Set TORCHVISION_WARN_WHEN_EXTENSION_LOADING_FAILS=1 and retry to get more details."
- )
- class ImageReadMode(Enum):
- """Allow automatic conversion to RGB, RGBA, etc while decoding.
- .. note::
- You don't need to use this struct, you can just pass strings to all
- ``mode`` parameters, e.g. ``mode="RGB"``.
- The different available modes are the following.
- - UNCHANGED: loads the image as-is
- - RGB: converts to RGB
- - RGBA: converts to RGB with transparency (also aliased as RGB_ALPHA)
- - GRAY: converts to grayscale
- - GRAY_ALPHA: converts to grayscale with transparency
- .. note::
- Some decoders won't support all possible values, e.g. GRAY and
- GRAY_ALPHA are only supported for PNG and JPEG images.
- """
- UNCHANGED = 0
- GRAY = 1
- GRAY_ALPHA = 2
- RGB = 3
- RGB_ALPHA = 4
- RGBA = RGB_ALPHA # Alias for convenience
- def read_file(path: str) -> torch.Tensor:
- """
- Return the bytes contents of a file as a uint8 1D Tensor.
- Args:
- path (str or ``pathlib.Path``): the path to the file to be read
- Returns:
- data (Tensor)
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(read_file)
- _assert_has_image_ops()
- data = torch.ops.image.read_file(str(path))
- return data
- def write_file(filename: str, data: torch.Tensor) -> None:
- """
- Write the content of an uint8 1D tensor to a file.
- Args:
- filename (str or ``pathlib.Path``): the path to the file to be written
- data (Tensor): the contents to be written to the output file
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(write_file)
- _assert_has_image_ops()
- torch.ops.image.write_file(str(filename), data)
- def decode_png(
- input: torch.Tensor,
- mode: ImageReadMode = ImageReadMode.UNCHANGED,
- apply_exif_orientation: bool = False,
- ) -> torch.Tensor:
- """
- Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
- The values of the output tensor are in uint8 in [0, 255] for most cases. If
- the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
- (supported from torchvision ``0.21``). Since uint16 support is limited in
- pytorch, we recommend calling
- :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
- after this function to convert the decoded image into a uint8 or float
- tensor.
- Args:
- input (Tensor[1]): a one dimensional uint8 tensor containing
- the raw bytes of the PNG image.
- mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
- Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
- for available modes.
- apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
- Default: False.
- Returns:
- output (Tensor[image_channels, image_height, image_width])
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(decode_png)
- _assert_has_image_ops()
- if isinstance(mode, str):
- mode = ImageReadMode[mode.upper()]
- output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation)
- return output
- def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
- """
- Takes an input tensor in CHW layout and returns a buffer with the contents
- of its corresponding PNG file.
- Args:
- input (Tensor[channels, image_height, image_width]): int8 image tensor of
- ``c`` channels, where ``c`` must 3 or 1.
- compression_level (int): Compression factor for the resulting file, it must be a number
- between 0 and 9. Default: 6
- Returns:
- Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
- PNG file.
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(encode_png)
- _assert_has_image_ops()
- output = torch.ops.image.encode_png(input, compression_level)
- return output
- def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
- """
- Takes an input tensor in CHW layout (or HW in the case of grayscale images)
- and saves it in a PNG file.
- Args:
- input (Tensor[channels, image_height, image_width]): int8 image tensor of
- ``c`` channels, where ``c`` must be 1 or 3.
- filename (str or ``pathlib.Path``): Path to save the image.
- compression_level (int): Compression factor for the resulting file, it must be a number
- between 0 and 9. Default: 6
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(write_png)
- output = encode_png(input, compression_level)
- write_file(filename, output)
- def decode_jpeg(
- input: Union[torch.Tensor, list[torch.Tensor]],
- mode: ImageReadMode = ImageReadMode.UNCHANGED,
- device: Union[str, torch.device] = "cpu",
- apply_exif_orientation: bool = False,
- ) -> Union[torch.Tensor, list[torch.Tensor]]:
- """Decode JPEG image(s) into 3D RGB or grayscale Tensor(s), on CPU or CUDA.
- The values of the output tensor are uint8 between 0 and 255.
- .. note::
- When using a CUDA device, passing a list of tensors is more efficient than repeated individual calls to ``decode_jpeg``.
- When using CPU the performance is equivalent.
- The CUDA version of this function has explicitly been designed with thread-safety in mind.
- This function does not return partial results in case of an error.
- Args:
- input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing
- the raw bytes of the JPEG image. The tensor(s) must be on CPU,
- regardless of the ``device`` parameter.
- mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
- Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
- for available modes.
- device (str or torch.device): The device on which the decoded image will
- be stored. If a cuda device is specified, the image will be decoded
- with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
- supported for CUDA version >= 10.1
- .. betastatus:: device parameter
- .. warning::
- There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
- Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
- apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
- Default: False. Only implemented for JPEG format on CPU.
- Returns:
- output (Tensor[image_channels, image_height, image_width] or list[Tensor[image_channels, image_height, image_width]]):
- The values of the output tensor(s) are uint8 between 0 and 255.
- ``output.device`` will be set to the specified ``device``
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(decode_jpeg)
- _assert_has_image_ops()
- if isinstance(device, str):
- device = torch.device(device)
- if isinstance(mode, str):
- mode = ImageReadMode[mode.upper()]
- if isinstance(input, list):
- if len(input) == 0:
- raise ValueError("Input list must contain at least one element")
- if not all(isinstance(t, torch.Tensor) for t in input):
- raise ValueError("All elements of the input list must be tensors.")
- if not all(t.device.type == "cpu" for t in input):
- raise ValueError("Input list must contain tensors on CPU.")
- if device.type == "cuda":
- return torch.ops.image.decode_jpegs_cuda(input, mode.value, device)
- else:
- return [torch.ops.image.decode_jpeg(img, mode.value, apply_exif_orientation) for img in input]
- else: # input is tensor
- if input.device.type != "cpu":
- raise ValueError("Input tensor must be a CPU tensor")
- if device.type == "cuda":
- return torch.ops.image.decode_jpegs_cuda([input], mode.value, device)[0]
- else:
- return torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation)
- def encode_jpeg(
- input: Union[torch.Tensor, list[torch.Tensor]], quality: int = 75
- ) -> Union[torch.Tensor, list[torch.Tensor]]:
- """Encode RGB tensor(s) into raw encoded jpeg bytes, on CPU or CUDA.
- .. note::
- Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``.
- For CPU tensors the performance is equivalent.
- Args:
- input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]):
- (list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3
- quality (int): Quality of the resulting JPEG file(s). Must be a number between
- 1 and 100. Default: 75
- Returns:
- output (Tensor[1] or list[Tensor[1]]): A (list of) one dimensional uint8 tensor(s) that contain the raw bytes of the JPEG file.
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(encode_jpeg)
- _assert_has_image_ops()
- if quality < 1 or quality > 100:
- raise ValueError("Image quality should be a positive number between 1 and 100")
- if isinstance(input, list):
- if not input:
- raise ValueError("encode_jpeg requires at least one input tensor when a list is passed")
- if input[0].device.type == "cuda":
- return torch.ops.image.encode_jpegs_cuda(input, quality)
- else:
- return [torch.ops.image.encode_jpeg(image, quality) for image in input]
- else: # single input tensor
- if input.device.type == "cuda":
- return torch.ops.image.encode_jpegs_cuda([input], quality)[0]
- else:
- return torch.ops.image.encode_jpeg(input, quality)
- def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
- """
- Takes an input tensor in CHW layout and saves it in a JPEG file.
- Args:
- input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
- channels, where ``c`` must be 1 or 3.
- filename (str or ``pathlib.Path``): Path to save the image.
- quality (int): Quality of the resulting JPEG file, it must be a number
- between 1 and 100. Default: 75
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(write_jpeg)
- output = encode_jpeg(input, quality)
- assert isinstance(output, torch.Tensor) # Needed for torchscript
- write_file(filename, output)
- def decode_image(
- input: Union[torch.Tensor, str],
- mode: ImageReadMode = ImageReadMode.UNCHANGED,
- apply_exif_orientation: bool = False,
- ) -> torch.Tensor:
- """Decode an image into a uint8 tensor, from a path or from raw encoded bytes.
- Currently supported image formats are jpeg, png, gif and webp.
- The values of the output tensor are in uint8 in [0, 255] for most cases.
- If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
- (supported from torchvision ``0.21``). Since uint16 support is limited in
- pytorch, we recommend calling
- :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
- after this function to convert the decoded image into a uint8 or float
- tensor.
- .. note::
- ``decode_image()`` doesn't work yet on AVIF or HEIC images. For these
- formats, directly call :func:`~torchvision.io.decode_avif` or
- :func:`~torchvision.io.decode_heic`.
- Args:
- input (Tensor or str or ``pathlib.Path``): The image to decode. If a
- tensor is passed, it must be one dimensional uint8 tensor containing
- the raw bytes of the image. Otherwise, this must be a path to the image file.
- mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
- Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
- for available modes.
- apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
- Only applies to JPEG and PNG images. Default: False.
- Returns:
- output (Tensor[image_channels, image_height, image_width])
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(decode_image)
- _assert_has_image_ops()
- if not isinstance(input, torch.Tensor):
- input = read_file(str(input))
- if isinstance(mode, str):
- mode = ImageReadMode[mode.upper()]
- output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
- return output
- def read_image(
- path: str,
- mode: ImageReadMode = ImageReadMode.UNCHANGED,
- apply_exif_orientation: bool = False,
- ) -> torch.Tensor:
- """[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead."""
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(read_image)
- data = read_file(path)
- return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)
- def decode_gif(input: torch.Tensor) -> torch.Tensor:
- """
- Decode a GIF image into a 3 or 4 dimensional RGB Tensor.
- The values of the output tensor are uint8 between 0 and 255.
- The output tensor has shape ``(C, H, W)`` if there is only one image in the
- GIF, and ``(N, C, H, W)`` if there are ``N`` images.
- Args:
- input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
- the raw bytes of the GIF image.
- Returns:
- output (Tensor[image_channels, image_height, image_width] or Tensor[num_images, image_channels, image_height, image_width])
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(decode_gif)
- _assert_has_image_ops()
- return torch.ops.image.decode_gif(input)
- def decode_webp(
- input: torch.Tensor,
- mode: ImageReadMode = ImageReadMode.UNCHANGED,
- ) -> torch.Tensor:
- """
- Decode a WEBP image into a 3 dimensional RGB[A] Tensor.
- The values of the output tensor are uint8 between 0 and 255.
- Args:
- input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
- the raw bytes of the WEBP image.
- mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
- Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
- for available modes.
- Returns:
- Decoded image (Tensor[image_channels, image_height, image_width])
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(decode_webp)
- _assert_has_image_ops()
- if isinstance(mode, str):
- mode = ImageReadMode[mode.upper()]
- return torch.ops.image.decode_webp(input, mode.value)
- # TODO_AVIF_HEIC: Better support for torchscript. Scripting decode_avif of
- # decode_heic currently fails, mainly because of the logic
- # _load_extra_decoders_once() (using global variables, try/except statements,
- # etc.).
- # The ops (torch.ops.extra_decoders_ns.decode_*) are otherwise torchscript-able,
- # and users who need torchscript can always just wrap those.
- # TODO_AVIF_HEIC: decode_image() should work for those. The key technical issue
- # we have here is that the format detection logic of decode_image() is
- # implemented in torchvision, and torchvision has zero knowledge of
- # torchvision-extra-decoders, so we cannot call the AVIF/HEIC C++ decoders
- # (those in torchvision-extra-decoders) from there.
- # A trivial check that could be done within torchvision would be to check the
- # file extension, if a path was passed. We could also just implement the
- # AVIF/HEIC detection logic in Python as a fallback, if the file detection
- # didn't find any format. In any case: properly determining whether a file is
- # HEIC is far from trivial, and relying on libmagic would probably be best
- _EXTRA_DECODERS_ALREADY_LOADED = False
- def _load_extra_decoders_once():
- global _EXTRA_DECODERS_ALREADY_LOADED
- if _EXTRA_DECODERS_ALREADY_LOADED:
- return
- try:
- import torchvision_extra_decoders
- # torchvision-extra-decoders only supports linux for now. BUT, users on
- # e.g. MacOS can still install it: they will get the pure-python
- # 0.0.0.dev version:
- # https://pypi.org/project/torchvision-extra-decoders/0.0.0.dev0, which
- # is a dummy version that was created to reserve the namespace on PyPI.
- # We have to check that expose_extra_decoders() exists for those users,
- # so we can properly error on non-Linux archs.
- assert hasattr(torchvision_extra_decoders, "expose_extra_decoders")
- except (AssertionError, ImportError) as e:
- raise RuntimeError(
- "In order to enable the AVIF and HEIC decoding capabilities of "
- "torchvision, you need to `pip install torchvision-extra-decoders`. "
- "Just install the package, you don't need to update your code. "
- "This is only supported on Linux, and this feature is still in BETA stage. "
- "Please let us know of any issue: https://github.com/pytorch/vision/issues/new/choose. "
- "Note that `torchvision-extra-decoders` is released under the LGPL license. "
- ) from e
- # This will expose torch.ops.extra_decoders_ns.decode_avif and torch.ops.extra_decoders_ns.decode_heic
- torchvision_extra_decoders.expose_extra_decoders()
- _EXTRA_DECODERS_ALREADY_LOADED = True
- def decode_avif(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
- """Decode an AVIF image into a 3 dimensional RGB[A] Tensor.
- .. warning::
- In order to enable the AVIF decoding capabilities of torchvision, you
- first need to run ``pip install torchvision-extra-decoders``. Just
- install the package, you don't need to update your code. This is only
- supported on Linux, and this feature is still in BETA stage. Please let
- us know of any issue:
- https://github.com/pytorch/vision/issues/new/choose. Note that
- `torchvision-extra-decoders
- <https://github.com/meta-pytorch/torchvision-extra-decoders/>`_ is
- released under the LGPL license.
- The values of the output tensor are in uint8 in [0, 255] for most images. If
- the image has a bit-depth of more than 8, then the output tensor is uint16
- in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
- calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
- ``scale=True`` after this function to convert the decoded image into a uint8
- or float tensor.
- Args:
- input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
- the raw bytes of the AVIF image.
- mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
- Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
- for available modes.
- Returns:
- Decoded image (Tensor[image_channels, image_height, image_width])
- """
- _load_extra_decoders_once()
- if input.dtype != torch.uint8:
- raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}")
- return torch.ops.extra_decoders_ns.decode_avif(input, mode.value)
- def decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
- """Decode an HEIC image into a 3 dimensional RGB[A] Tensor.
- .. warning::
- In order to enable the HEIC decoding capabilities of torchvision, you
- first need to run ``pip install torchvision-extra-decoders``. Just
- install the package, you don't need to update your code. This is only
- supported on Linux, and this feature is still in BETA stage. Please let
- us know of any issue:
- https://github.com/pytorch/vision/issues/new/choose. Note that
- `torchvision-extra-decoders
- <https://github.com/meta-pytorch/torchvision-extra-decoders/>`_ is
- released under the LGPL license.
- The values of the output tensor are in uint8 in [0, 255] for most images. If
- the image has a bit-depth of more than 8, then the output tensor is uint16
- in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
- calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
- ``scale=True`` after this function to convert the decoded image into a uint8
- or float tensor.
- Args:
- input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
- the raw bytes of the HEIC image.
- mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
- Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
- for available modes.
- Returns:
- Decoded image (Tensor[image_channels, image_height, image_width])
- """
- _load_extra_decoders_once()
- if input.dtype != torch.uint8:
- raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}")
- return torch.ops.extra_decoders_ns.decode_heic(input, mode.value)
|