io.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from enum import Enum
  19. from pathlib import Path
  20. from typing import Any
  21. import kornia_rs
  22. import torch
  23. import kornia
  24. from kornia.core import Device, Tensor
  25. from kornia.core.check import KORNIA_CHECK
  26. from kornia.utils import image_to_tensor, tensor_to_image
  27. class ImageLoadType(Enum):
  28. r"""Enum to specify the desired image type."""
  29. UNCHANGED = 0
  30. GRAY8 = 1
  31. RGB8 = 2
  32. RGBA8 = 3
  33. GRAY32 = 4
  34. RGB32 = 5
  35. def _load_image_to_tensor(path_file: Path, device: Device) -> Tensor:
  36. """Read an image file and decode using the Kornia Rust backend.
  37. The decoded image is returned as numpy array with shape HxWxC.
  38. Args:
  39. path_file: Path to a valid image file.
  40. device: the device where you want to get your image placed.
  41. Return:
  42. Image tensor with shape :math:`(3,H,W)`.
  43. """
  44. # read image and return as `np.ndarray` with shape HxWxC
  45. if path_file.suffix.lower() in [".jpg", ".jpeg"]:
  46. img = kornia_rs.read_image_jpegturbo(str(path_file))
  47. else:
  48. img = kornia_rs.read_image_any(str(path_file))
  49. # convert the image to tensor with shape CxHxW
  50. img_t = image_to_tensor(img, keepdim=True)
  51. # move the tensor to the desired device,
  52. dev = device if isinstance(device, torch.device) or device is None else torch.device(device)
  53. return img_t.to(device=dev)
  54. def _to_float32(image: Tensor) -> Tensor:
  55. """Convert an image tensor to float32."""
  56. KORNIA_CHECK(image.dtype == torch.uint8)
  57. return image.float() / 255.0
  58. def _to_uint8(image: Tensor) -> Tensor:
  59. """Convert an image tensor to uint8."""
  60. KORNIA_CHECK(image.dtype == torch.float32)
  61. return image.mul(255.0).byte()
  62. def load_image(
  63. path_file: str | Path, desired_type: ImageLoadType = ImageLoadType.RGB32, device: Device = "cpu"
  64. ) -> Tensor:
  65. """Read an image file and decode using the Kornia Rust backend.
  66. Args:
  67. path_file: Path to a valid image file.
  68. desired_type: the desired image type, defined by color space and dtype.
  69. device: the device where you want to get your image placed.
  70. Return:
  71. Image tensor with shape :math:`(3,H,W)`.
  72. """
  73. if not isinstance(path_file, Path):
  74. path_file = Path(path_file)
  75. # read the image using the kornia_rs package
  76. image: Tensor = _load_image_to_tensor(path_file, device) # CxHxW
  77. if desired_type == ImageLoadType.UNCHANGED:
  78. return image
  79. elif desired_type == ImageLoadType.GRAY8:
  80. if image.shape[0] == 1 and image.dtype == torch.uint8:
  81. return image
  82. elif image.shape[0] == 3 and image.dtype == torch.uint8:
  83. gray8 = kornia.color.rgb_to_grayscale(image)
  84. return gray8
  85. elif image.shape[0] == 4 and image.dtype == torch.uint8:
  86. gray32 = kornia.color.rgb_to_grayscale(kornia.color.rgba_to_rgb(_to_float32(image)))
  87. return _to_uint8(gray32)
  88. elif desired_type == ImageLoadType.RGB8:
  89. if image.shape[0] == 3 and image.dtype == torch.uint8:
  90. return image
  91. elif image.shape[0] == 1 and image.dtype == torch.uint8:
  92. rgb8 = kornia.color.grayscale_to_rgb(image)
  93. return rgb8
  94. elif desired_type == ImageLoadType.RGBA8:
  95. if image.shape[0] == 3 and image.dtype == torch.uint8:
  96. rgba32 = kornia.color.rgb_to_rgba(_to_float32(image), 0.0)
  97. return _to_uint8(rgba32)
  98. elif desired_type == ImageLoadType.GRAY32:
  99. if image.shape[0] == 1 and image.dtype == torch.uint8:
  100. return _to_float32(image)
  101. elif image.shape[0] == 3 and image.dtype == torch.uint8:
  102. gray32 = kornia.color.rgb_to_grayscale(_to_float32(image))
  103. return gray32
  104. elif image.shape[0] == 4 and image.dtype == torch.uint8:
  105. gray32 = kornia.color.rgb_to_grayscale(kornia.color.rgba_to_rgb(_to_float32(image)))
  106. return gray32
  107. elif desired_type == ImageLoadType.RGB32:
  108. if image.shape[0] == 3 and image.dtype == torch.uint8:
  109. return _to_float32(image)
  110. elif image.shape[0] == 1 and image.dtype == torch.uint8:
  111. rgb32 = kornia.color.grayscale_to_rgb(_to_float32(image))
  112. return rgb32
  113. raise NotImplementedError(f"Unknown type: {desired_type}")
  114. def _write_uint8_image(path_file: Path, img_np: Any, quality: int) -> None:
  115. """Write uint8 image to file."""
  116. if path_file.suffix.lower() in [".jpg", ".jpeg"]:
  117. mode = "mono" if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1) else "rgb"
  118. kornia_rs.write_image_jpeg(str(path_file), img_np, mode=mode, quality=quality)
  119. elif path_file.suffix.lower() == ".png":
  120. mode = "mono" if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1) else "rgb"
  121. kornia_rs.write_image_png_u8(str(path_file), img_np, mode=mode)
  122. elif path_file.suffix.lower() == ".tiff":
  123. mode = "mono" if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1) else "rgb"
  124. kornia_rs.write_image_tiff_u8(str(path_file), img_np, mode=mode)
  125. else:
  126. raise NotImplementedError(f"Unsupported file extension: {path_file.suffix} for uint8 image")
  127. def _write_uint16_image(path_file: Path, img_np: Any) -> None:
  128. """Write uint16 image to file."""
  129. mode = "mono" if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1) else "rgb"
  130. if path_file.suffix.lower() == ".png":
  131. kornia_rs.write_image_png_u16(str(path_file), img_np, mode=mode)
  132. elif path_file.suffix.lower() == ".tiff":
  133. kornia_rs.write_image_tiff_u16(str(path_file), img_np, mode=mode)
  134. else:
  135. raise NotImplementedError(f"Unsupported file extension: {path_file.suffix} for uint16 image")
  136. def _write_float32_image(path_file: Path, img_np: Any) -> None:
  137. """Write float32 image to file."""
  138. mode = "mono" if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1) else "rgb"
  139. if path_file.suffix.lower() == ".tiff":
  140. kornia_rs.write_image_tiff_f32(str(path_file), img_np, mode=mode)
  141. else:
  142. raise NotImplementedError(f"Unsupported file extension: {path_file.suffix} for float32 image")
  143. def write_image(path_file: str | Path, image: Tensor, quality: int = 80) -> None:
  144. """Save an image file using the Kornia Rust backend.
  145. Args:
  146. path_file: Path to a valid image file.
  147. image: Image tensor with shape :math:`(3,H,W)`, `(1,H,W)` and `(H,W)`.
  148. quality: The quality of the JPEG encoding. If the file extension is .png or .tiff, the quality is ignored.
  149. Return:
  150. None.
  151. """
  152. if not isinstance(path_file, Path):
  153. path_file = Path(path_file)
  154. KORNIA_CHECK(
  155. path_file.suffix in [".jpg", ".jpeg", ".png", ".tiff"],
  156. f"Invalid file extension: {path_file}, only .jpg, .jpeg, .png and .tiff are supported.",
  157. )
  158. KORNIA_CHECK(image.dim() >= 2, f"Invalid image shape: {image.shape}. Must be at least 2D.")
  159. img_np = tensor_to_image(image, keepdim=True, force_contiguous=True) # HxWxC
  160. if img_np.ndim == 2:
  161. img_np = img_np[..., None] # ensures channel dimension
  162. if image.dtype == torch.uint8:
  163. _write_uint8_image(path_file, img_np, quality)
  164. elif image.dtype == torch.uint16:
  165. _write_uint16_image(path_file, img_np)
  166. elif image.dtype == torch.float32:
  167. _write_float32_image(path_file, img_np)
  168. else:
  169. raise NotImplementedError(f"Unsupported image dtype: {image.dtype}")