| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- from pathlib import Path
- from typing import Any
- import torch
- from torch.utils.dlpack import from_dlpack, to_dlpack
- import kornia.color
- from kornia.core import Device, Dtype, Tensor
- from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
- from kornia.image.base import ChannelsOrder, ColorSpace, ImageLayout, ImageSize, PixelFormat
- from kornia.io.io import ImageLoadType, load_image, write_image
- from kornia.utils.image_print import image_to_string
- # placeholder for numpy
- np_ndarray = Any
- DLPack = Any
- class Image:
- r"""Class that holds an Image Tensor representation.
- .. note::
- Disclaimer: This class provides the minimum functionality for image manipulation. However, as soon
- as you start to experiment with advanced tensor manipulation, you might expect fancy
- polymorphic behaviours.
- .. warning::
- This API is experimental and might suffer changes in the future.
- Args:
- data: a torch tensor containing the image data.
- layout: a dataclass containing the image layout information.
- Examples:
- >>> # from a torch.tensor
- >>> data = torch.randint(0, 255, (3, 4, 5), dtype=torch.uint8) # CxHxW
- >>> pixel_format = PixelFormat(
- ... color_space=ColorSpace.RGB,
- ... bit_depth=8,
- ... )
- >>> layout = ImageLayout(
- ... image_size=ImageSize(4, 5),
- ... channels=3,
- ... channels_order=ChannelsOrder.CHANNELS_FIRST,
- ... )
- >>> img = Image(data, pixel_format, layout)
- >>> assert img.channels == 3
- >>> # from a numpy array (like opencv)
- >>> data = np.ones((4, 5, 3), dtype=np.uint8) # HxWxC
- >>> img = Image.from_numpy(data, color_space=ColorSpace.RGB)
- >>> assert img.channels == 3
- >>> assert img.width == 5
- >>> assert img.height == 4
- """
- def __init__(self, data: Tensor, pixel_format: PixelFormat, layout: ImageLayout) -> None:
- """Image constructor.
- Args:
- data: a torch tensor containing the image data.
- pixel_format: the pixel format of the image.
- layout: a dataclass containing the image layout information.
- """
- # TODO: move this to a function KORNIA_CHECK_IMAGE_LAYOUT
- if layout.channels_order == ChannelsOrder.CHANNELS_FIRST:
- shape = [str(layout.channels), str(layout.image_size.height), str(layout.image_size.width)]
- elif layout.channels_order == ChannelsOrder.CHANNELS_LAST:
- shape = [str(layout.image_size.height), str(layout.image_size.width), str(layout.channels)]
- else:
- raise NotImplementedError(f"Layout {layout.channels_order} not implemented.")
- KORNIA_CHECK_SHAPE(data, shape)
- KORNIA_CHECK(data.element_size() == pixel_format.bit_depth // 8, "Invalid bit depth.")
- self._data = data
- self._pixel_format = pixel_format
- self._layout = layout
- def __repr__(self) -> str:
- return f"Image data: {self.data}\nPixel Format: {self.pixel_format}\n Layout: {self.layout}"
- # TODO: explore use TensorWrapper
- def to(self, device: Device = None, dtype: Dtype = None) -> Image:
- """Move the image to the given device and dtype.
- Args:
- device: the device to move the image to.
- dtype: the data type to cast the image to.
- Returns:
- Image: the image moved to the given device and dtype.
- """
- if device is not None and isinstance(device, torch.dtype):
- dtype, device = device, None
- # put the data to the device and dtype
- self._data = self.data.to(device, dtype)
- return self
- # TODO: explore use TensorWrapper
- def clone(self) -> Image:
- """Return a copy of the image."""
- return Image(self.data.clone(), self.pixel_format, self.layout)
- @property
- def data(self) -> Tensor:
- """Return the underlying tensor data."""
- return self._data
- @property
- def shape(self) -> tuple[int, ...]:
- """Return the image shape."""
- return self.data.shape
- @property
- def dtype(self) -> torch.dtype:
- """Return the image data type."""
- return self.data.dtype
- @property
- def device(self) -> torch.device:
- """Return the image device."""
- return self.data.device
- @property
- def pixel_format(self) -> PixelFormat:
- """Return the pixel format."""
- return self._pixel_format
- @property
- def layout(self) -> ImageLayout:
- """Return the image layout."""
- return self._layout
- @property
- def channels(self) -> int:
- """Return the number channels of the image."""
- return self.layout.channels
- @property
- def image_size(self) -> ImageSize:
- """Return the image size."""
- return self.layout.image_size
- @property
- def height(self) -> int:
- """Return the image height (columns)."""
- return int(self.layout.image_size.height)
- @property
- def width(self) -> int:
- """Return the image width (rows)."""
- return int(self.layout.image_size.width)
- @property
- def channels_order(self) -> ChannelsOrder:
- """Return the channels order."""
- return self.layout.channels_order
- # TODO: figure out a better way map this function
- def float(self) -> Image:
- """Return the image as float."""
- self._data = self.data.float()
- return self
- def to_gray(self) -> Image:
- """Converts the image to grayscale."""
- src = self._pixel_format.color_space
- data = self._data
- if src == ColorSpace.GRAY:
- return self
- is_channels_last = self._layout.channels_order == ChannelsOrder.CHANNELS_LAST
- if is_channels_last:
- data = data.permute(0, 3, 1, 2) if data.ndim == 4 else data.permute(2, 0, 1)
- # Perform the color space conversion
- if src == ColorSpace.RGB:
- out = kornia.color.rgb_to_grayscale(data)
- elif src == ColorSpace.BGR:
- out = kornia.color.bgr_to_grayscale(data)
- else:
- raise ValueError(f"Unsupported source color space for to_gray(): {src}")
- if is_channels_last:
- if out.ndim == 4:
- out = out.permute(0, 2, 3, 1)
- elif out.ndim == 3:
- out = out.permute(1, 2, 0)
- else:
- raise ValueError(f"Unexpected shape after grayscale conversion: {out.shape}")
- new_pf = PixelFormat(color_space=ColorSpace.GRAY, bit_depth=self._pixel_format.bit_depth)
- new_layout = ImageLayout(self._layout.image_size, channels=1, channels_order=self._layout.channels_order)
- return Image(out, new_pf, new_layout)
- def to_rgb(self) -> Image:
- """Converts the image to RGB."""
- src = self._pixel_format.color_space
- data = self._data
- if src == ColorSpace.RGB:
- return self
- is_channels_last = self._layout.channels_order == ChannelsOrder.CHANNELS_LAST
- if is_channels_last:
- data = data.permute(0, 3, 1, 2) if data.ndim == 4 else data.permute(2, 0, 1)
- if src == ColorSpace.GRAY:
- out = kornia.color.grayscale_to_rgb(data)
- elif src == ColorSpace.BGR:
- out = data[:, [2, 1, 0], ...] if data.ndim == 4 else data[[2, 1, 0], ...]
- else:
- raise ValueError(f"Unsupported source color space for to_rgb(): {src}")
- if is_channels_last:
- out = out.permute(0, 2, 3, 1) if out.ndim == 4 else out.permute(1, 2, 0)
- new_pf = PixelFormat(color_space=ColorSpace.RGB, bit_depth=self._pixel_format.bit_depth)
- new_layout = ImageLayout(self._layout.image_size, channels=3, channels_order=self._layout.channels_order)
- return Image(out, new_pf, new_layout)
- def to_bgr(self) -> Image:
- """Converts the image to BGR."""
- src = self._pixel_format.color_space
- data = self._data
- if src == ColorSpace.BGR:
- return self
- is_channels_last = self._layout.channels_order == ChannelsOrder.CHANNELS_LAST
- if is_channels_last:
- data = data.permute(0, 3, 1, 2) if data.ndim == 4 else data.permute(2, 0, 1)
- if src == ColorSpace.GRAY:
- rgb_data = kornia.color.grayscale_to_rgb(data)
- out = rgb_data[:, [2, 1, 0], ...] if rgb_data.ndim == 4 else rgb_data[[2, 1, 0], ...]
- elif src == ColorSpace.RGB:
- out = data[:, [2, 1, 0], ...] if data.ndim == 4 else data[[2, 1, 0], ...]
- else:
- raise ValueError(f"Unsupported source color space for to_bgr(): {src}")
- if is_channels_last:
- out = out.permute(0, 2, 3, 1) if out.ndim == 4 else out.permute(1, 2, 0)
- new_pf = PixelFormat(color_space=ColorSpace.BGR, bit_depth=self._pixel_format.bit_depth)
- new_layout = ImageLayout(self._layout.image_size, channels=3, channels_order=self._layout.channels_order)
- return Image(out, new_pf, new_layout)
- @classmethod
- def from_numpy(
- cls,
- data: np_ndarray,
- color_space: ColorSpace = ColorSpace.RGB,
- channels_order: ChannelsOrder = ChannelsOrder.CHANNELS_LAST,
- ) -> Image:
- """Construct an image tensor from a numpy array.
- Args:
- data: a numpy array containing the image data.
- color_space: the color space of the image.
- pixel_format: the pixel format of the image.
- channels_order: what dimension the channels are in the image tensor.
- Example:
- >>> data = np.ones((4, 5, 3), dtype=np.uint8) # HxWxC
- >>> img = Image.from_numpy(data, color_space=ColorSpace.RGB)
- >>> assert img.channels == 3
- >>> assert img.width == 5
- >>> assert img.height == 4
- """
- if channels_order == ChannelsOrder.CHANNELS_LAST:
- image_size = ImageSize(height=data.shape[0], width=data.shape[1])
- channels = data.shape[2]
- elif channels_order == ChannelsOrder.CHANNELS_FIRST:
- image_size = ImageSize(height=data.shape[1], width=data.shape[2])
- channels = data.shape[0]
- else:
- raise ValueError("channels_order must be either `CHANNELS_LAST` or `CHANNELS_FIRST`")
- # create the pixel format based on the input data
- pixel_format = PixelFormat(color_space=color_space, bit_depth=data.itemsize * 8)
- # create the image layout based on the input data
- layout = ImageLayout(image_size=image_size, channels=channels, channels_order=channels_order)
- # create the image tensor
- return cls(torch.from_numpy(data), pixel_format, layout)
- def to_numpy(self) -> np_ndarray:
- """Return a numpy array in cpu from the image tensor."""
- return self.data.cpu().detach().numpy()
- @classmethod
- def from_dlpack(cls, data: DLPack) -> Image:
- """Construct an image tensor from a DLPack capsule.
- Args:
- data: a DLPack capsule from numpy, tvm or jax.
- Example:
- >>> x = np.ones((4, 5, 3))
- >>> img = Image.from_dlpack(x.__dlpack__())
- """
- _data: Tensor = from_dlpack(data)
- pixel_format = PixelFormat(color_space=ColorSpace.RGB, bit_depth=_data.element_size() * 8)
- # create the image layout based on the input data
- layout = ImageLayout(
- image_size=ImageSize(height=_data.shape[1], width=_data.shape[2]),
- channels=_data.shape[0],
- channels_order=ChannelsOrder.CHANNELS_FIRST,
- )
- return cls(_data, pixel_format, layout)
- def to_dlpack(self) -> DLPack:
- """Return a DLPack capsule from the image tensor."""
- return to_dlpack(self.data)
- @classmethod
- def from_file(cls, file_path: str | Path) -> Image:
- """Construct an image tensor from a file.
- Args:
- file_path: the path to the file to read the image from.
- """
- # TODO: allow user to specify the desired type and device
- data: Tensor = load_image(file_path, desired_type=ImageLoadType.RGB8, device="cpu")
- pixel_format = PixelFormat(color_space=ColorSpace.RGB, bit_depth=data.element_size() * 8)
- layout = ImageLayout(
- image_size=ImageSize(height=data.shape[1], width=data.shape[2]),
- channels=data.shape[0],
- channels_order=ChannelsOrder.CHANNELS_FIRST,
- )
- return cls(data, pixel_format, layout)
- def write(self, file_path: str | Path) -> None:
- """Write the image to a file.
- For now, only support writing to JPEG format.
- Args:
- file_path: the path to the file to write the image to.
- Example:
- >>> data = np.ones((4, 5, 3), dtype=np.uint8) # HxWxC
- >>> img = Image.from_numpy(data)
- >>> img.write("test.jpg")
- """
- data = self.data
- if self.channels_order == ChannelsOrder.CHANNELS_LAST:
- data = data.permute(2, 0, 1)
- write_image(file_path, data)
- def print(self, max_width: int = 256) -> None:
- """Print the image tensor to the console.
- Args:
- max_width: the maximum width of the image to print.
- .. code-block:: python
- img = Image.from_file("panda.png")
- img.print()
- .. image:: https://github.com/kornia/data/blob/main/print_image.png?raw=true
- """
- print(image_to_string(self.data, max_width))
|