| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- from enum import Enum
- from pathlib import Path
- from typing import Any
- import kornia_rs
- import torch
- import kornia
- from kornia.core import Device, Tensor
- from kornia.core.check import KORNIA_CHECK
- from kornia.utils import image_to_tensor, tensor_to_image
- class ImageLoadType(Enum):
- r"""Enum to specify the desired image type."""
- UNCHANGED = 0
- GRAY8 = 1
- RGB8 = 2
- RGBA8 = 3
- GRAY32 = 4
- RGB32 = 5
- def _load_image_to_tensor(path_file: Path, device: Device) -> Tensor:
- """Read an image file and decode using the Kornia Rust backend.
- The decoded image is returned as numpy array with shape HxWxC.
- Args:
- path_file: Path to a valid image file.
- device: the device where you want to get your image placed.
- Return:
- Image tensor with shape :math:`(3,H,W)`.
- """
- # read image and return as `np.ndarray` with shape HxWxC
- if path_file.suffix.lower() in [".jpg", ".jpeg"]:
- img = kornia_rs.read_image_jpegturbo(str(path_file))
- else:
- img = kornia_rs.read_image_any(str(path_file))
- # convert the image to tensor with shape CxHxW
- img_t = image_to_tensor(img, keepdim=True)
- # move the tensor to the desired device,
- dev = device if isinstance(device, torch.device) or device is None else torch.device(device)
- return img_t.to(device=dev)
- def _to_float32(image: Tensor) -> Tensor:
- """Convert an image tensor to float32."""
- KORNIA_CHECK(image.dtype == torch.uint8)
- return image.float() / 255.0
- def _to_uint8(image: Tensor) -> Tensor:
- """Convert an image tensor to uint8."""
- KORNIA_CHECK(image.dtype == torch.float32)
- return image.mul(255.0).byte()
- def load_image(
- path_file: str | Path, desired_type: ImageLoadType = ImageLoadType.RGB32, device: Device = "cpu"
- ) -> Tensor:
- """Read an image file and decode using the Kornia Rust backend.
- Args:
- path_file: Path to a valid image file.
- desired_type: the desired image type, defined by color space and dtype.
- device: the device where you want to get your image placed.
- Return:
- Image tensor with shape :math:`(3,H,W)`.
- """
- if not isinstance(path_file, Path):
- path_file = Path(path_file)
- # read the image using the kornia_rs package
- image: Tensor = _load_image_to_tensor(path_file, device) # CxHxW
- if desired_type == ImageLoadType.UNCHANGED:
- return image
- elif desired_type == ImageLoadType.GRAY8:
- if image.shape[0] == 1 and image.dtype == torch.uint8:
- return image
- elif image.shape[0] == 3 and image.dtype == torch.uint8:
- gray8 = kornia.color.rgb_to_grayscale(image)
- return gray8
- elif image.shape[0] == 4 and image.dtype == torch.uint8:
- gray32 = kornia.color.rgb_to_grayscale(kornia.color.rgba_to_rgb(_to_float32(image)))
- return _to_uint8(gray32)
- elif desired_type == ImageLoadType.RGB8:
- if image.shape[0] == 3 and image.dtype == torch.uint8:
- return image
- elif image.shape[0] == 1 and image.dtype == torch.uint8:
- rgb8 = kornia.color.grayscale_to_rgb(image)
- return rgb8
- elif desired_type == ImageLoadType.RGBA8:
- if image.shape[0] == 3 and image.dtype == torch.uint8:
- rgba32 = kornia.color.rgb_to_rgba(_to_float32(image), 0.0)
- return _to_uint8(rgba32)
- elif desired_type == ImageLoadType.GRAY32:
- if image.shape[0] == 1 and image.dtype == torch.uint8:
- return _to_float32(image)
- elif image.shape[0] == 3 and image.dtype == torch.uint8:
- gray32 = kornia.color.rgb_to_grayscale(_to_float32(image))
- return gray32
- elif image.shape[0] == 4 and image.dtype == torch.uint8:
- gray32 = kornia.color.rgb_to_grayscale(kornia.color.rgba_to_rgb(_to_float32(image)))
- return gray32
- elif desired_type == ImageLoadType.RGB32:
- if image.shape[0] == 3 and image.dtype == torch.uint8:
- return _to_float32(image)
- elif image.shape[0] == 1 and image.dtype == torch.uint8:
- rgb32 = kornia.color.grayscale_to_rgb(_to_float32(image))
- return rgb32
- raise NotImplementedError(f"Unknown type: {desired_type}")
- def _write_uint8_image(path_file: Path, img_np: Any, quality: int) -> None:
- """Write uint8 image to file."""
- if path_file.suffix.lower() in [".jpg", ".jpeg"]:
- mode = "mono" if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1) else "rgb"
- kornia_rs.write_image_jpeg(str(path_file), img_np, mode=mode, quality=quality)
- elif path_file.suffix.lower() == ".png":
- mode = "mono" if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1) else "rgb"
- kornia_rs.write_image_png_u8(str(path_file), img_np, mode=mode)
- elif path_file.suffix.lower() == ".tiff":
- mode = "mono" if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1) else "rgb"
- kornia_rs.write_image_tiff_u8(str(path_file), img_np, mode=mode)
- else:
- raise NotImplementedError(f"Unsupported file extension: {path_file.suffix} for uint8 image")
- def _write_uint16_image(path_file: Path, img_np: Any) -> None:
- """Write uint16 image to file."""
- mode = "mono" if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1) else "rgb"
- if path_file.suffix.lower() == ".png":
- kornia_rs.write_image_png_u16(str(path_file), img_np, mode=mode)
- elif path_file.suffix.lower() == ".tiff":
- kornia_rs.write_image_tiff_u16(str(path_file), img_np, mode=mode)
- else:
- raise NotImplementedError(f"Unsupported file extension: {path_file.suffix} for uint16 image")
- def _write_float32_image(path_file: Path, img_np: Any) -> None:
- """Write float32 image to file."""
- mode = "mono" if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1) else "rgb"
- if path_file.suffix.lower() == ".tiff":
- kornia_rs.write_image_tiff_f32(str(path_file), img_np, mode=mode)
- else:
- raise NotImplementedError(f"Unsupported file extension: {path_file.suffix} for float32 image")
- def write_image(path_file: str | Path, image: Tensor, quality: int = 80) -> None:
- """Save an image file using the Kornia Rust backend.
- Args:
- path_file: Path to a valid image file.
- image: Image tensor with shape :math:`(3,H,W)`, `(1,H,W)` and `(H,W)`.
- quality: The quality of the JPEG encoding. If the file extension is .png or .tiff, the quality is ignored.
- Return:
- None.
- """
- if not isinstance(path_file, Path):
- path_file = Path(path_file)
- KORNIA_CHECK(
- path_file.suffix in [".jpg", ".jpeg", ".png", ".tiff"],
- f"Invalid file extension: {path_file}, only .jpg, .jpeg, .png and .tiff are supported.",
- )
- KORNIA_CHECK(image.dim() >= 2, f"Invalid image shape: {image.shape}. Must be at least 2D.")
- img_np = tensor_to_image(image, keepdim=True, force_contiguous=True) # HxWxC
- if img_np.ndim == 2:
- img_np = img_np[..., None] # ensures channel dimension
- if image.dtype == torch.uint8:
- _write_uint8_image(path_file, img_np, quality)
- elif image.dtype == torch.uint16:
- _write_uint16_image(path_file, img_np)
- elif image.dtype == torch.float32:
- _write_float32_image(path_file, img_np)
- else:
- raise NotImplementedError(f"Unsupported image dtype: {image.dtype}")
|