image.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from pathlib import Path
  19. from typing import Any
  20. import torch
  21. from torch.utils.dlpack import from_dlpack, to_dlpack
  22. import kornia.color
  23. from kornia.core import Device, Dtype, Tensor
  24. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
  25. from kornia.image.base import ChannelsOrder, ColorSpace, ImageLayout, ImageSize, PixelFormat
  26. from kornia.io.io import ImageLoadType, load_image, write_image
  27. from kornia.utils.image_print import image_to_string
  28. # placeholder for numpy
  29. np_ndarray = Any
  30. DLPack = Any
  31. class Image:
  32. r"""Class that holds an Image Tensor representation.
  33. .. note::
  34. Disclaimer: This class provides the minimum functionality for image manipulation. However, as soon
  35. as you start to experiment with advanced tensor manipulation, you might expect fancy
  36. polymorphic behaviours.
  37. .. warning::
  38. This API is experimental and might suffer changes in the future.
  39. Args:
  40. data: a torch tensor containing the image data.
  41. layout: a dataclass containing the image layout information.
  42. Examples:
  43. >>> # from a torch.tensor
  44. >>> data = torch.randint(0, 255, (3, 4, 5), dtype=torch.uint8) # CxHxW
  45. >>> pixel_format = PixelFormat(
  46. ... color_space=ColorSpace.RGB,
  47. ... bit_depth=8,
  48. ... )
  49. >>> layout = ImageLayout(
  50. ... image_size=ImageSize(4, 5),
  51. ... channels=3,
  52. ... channels_order=ChannelsOrder.CHANNELS_FIRST,
  53. ... )
  54. >>> img = Image(data, pixel_format, layout)
  55. >>> assert img.channels == 3
  56. >>> # from a numpy array (like opencv)
  57. >>> data = np.ones((4, 5, 3), dtype=np.uint8) # HxWxC
  58. >>> img = Image.from_numpy(data, color_space=ColorSpace.RGB)
  59. >>> assert img.channels == 3
  60. >>> assert img.width == 5
  61. >>> assert img.height == 4
  62. """
  63. def __init__(self, data: Tensor, pixel_format: PixelFormat, layout: ImageLayout) -> None:
  64. """Image constructor.
  65. Args:
  66. data: a torch tensor containing the image data.
  67. pixel_format: the pixel format of the image.
  68. layout: a dataclass containing the image layout information.
  69. """
  70. # TODO: move this to a function KORNIA_CHECK_IMAGE_LAYOUT
  71. if layout.channels_order == ChannelsOrder.CHANNELS_FIRST:
  72. shape = [str(layout.channels), str(layout.image_size.height), str(layout.image_size.width)]
  73. elif layout.channels_order == ChannelsOrder.CHANNELS_LAST:
  74. shape = [str(layout.image_size.height), str(layout.image_size.width), str(layout.channels)]
  75. else:
  76. raise NotImplementedError(f"Layout {layout.channels_order} not implemented.")
  77. KORNIA_CHECK_SHAPE(data, shape)
  78. KORNIA_CHECK(data.element_size() == pixel_format.bit_depth // 8, "Invalid bit depth.")
  79. self._data = data
  80. self._pixel_format = pixel_format
  81. self._layout = layout
  82. def __repr__(self) -> str:
  83. return f"Image data: {self.data}\nPixel Format: {self.pixel_format}\n Layout: {self.layout}"
  84. # TODO: explore use TensorWrapper
  85. def to(self, device: Device = None, dtype: Dtype = None) -> Image:
  86. """Move the image to the given device and dtype.
  87. Args:
  88. device: the device to move the image to.
  89. dtype: the data type to cast the image to.
  90. Returns:
  91. Image: the image moved to the given device and dtype.
  92. """
  93. if device is not None and isinstance(device, torch.dtype):
  94. dtype, device = device, None
  95. # put the data to the device and dtype
  96. self._data = self.data.to(device, dtype)
  97. return self
  98. # TODO: explore use TensorWrapper
  99. def clone(self) -> Image:
  100. """Return a copy of the image."""
  101. return Image(self.data.clone(), self.pixel_format, self.layout)
  102. @property
  103. def data(self) -> Tensor:
  104. """Return the underlying tensor data."""
  105. return self._data
  106. @property
  107. def shape(self) -> tuple[int, ...]:
  108. """Return the image shape."""
  109. return self.data.shape
  110. @property
  111. def dtype(self) -> torch.dtype:
  112. """Return the image data type."""
  113. return self.data.dtype
  114. @property
  115. def device(self) -> torch.device:
  116. """Return the image device."""
  117. return self.data.device
  118. @property
  119. def pixel_format(self) -> PixelFormat:
  120. """Return the pixel format."""
  121. return self._pixel_format
  122. @property
  123. def layout(self) -> ImageLayout:
  124. """Return the image layout."""
  125. return self._layout
  126. @property
  127. def channels(self) -> int:
  128. """Return the number channels of the image."""
  129. return self.layout.channels
  130. @property
  131. def image_size(self) -> ImageSize:
  132. """Return the image size."""
  133. return self.layout.image_size
  134. @property
  135. def height(self) -> int:
  136. """Return the image height (columns)."""
  137. return int(self.layout.image_size.height)
  138. @property
  139. def width(self) -> int:
  140. """Return the image width (rows)."""
  141. return int(self.layout.image_size.width)
  142. @property
  143. def channels_order(self) -> ChannelsOrder:
  144. """Return the channels order."""
  145. return self.layout.channels_order
  146. # TODO: figure out a better way map this function
  147. def float(self) -> Image:
  148. """Return the image as float."""
  149. self._data = self.data.float()
  150. return self
  151. def to_gray(self) -> Image:
  152. """Converts the image to grayscale."""
  153. src = self._pixel_format.color_space
  154. data = self._data
  155. if src == ColorSpace.GRAY:
  156. return self
  157. is_channels_last = self._layout.channels_order == ChannelsOrder.CHANNELS_LAST
  158. if is_channels_last:
  159. data = data.permute(0, 3, 1, 2) if data.ndim == 4 else data.permute(2, 0, 1)
  160. # Perform the color space conversion
  161. if src == ColorSpace.RGB:
  162. out = kornia.color.rgb_to_grayscale(data)
  163. elif src == ColorSpace.BGR:
  164. out = kornia.color.bgr_to_grayscale(data)
  165. else:
  166. raise ValueError(f"Unsupported source color space for to_gray(): {src}")
  167. if is_channels_last:
  168. if out.ndim == 4:
  169. out = out.permute(0, 2, 3, 1)
  170. elif out.ndim == 3:
  171. out = out.permute(1, 2, 0)
  172. else:
  173. raise ValueError(f"Unexpected shape after grayscale conversion: {out.shape}")
  174. new_pf = PixelFormat(color_space=ColorSpace.GRAY, bit_depth=self._pixel_format.bit_depth)
  175. new_layout = ImageLayout(self._layout.image_size, channels=1, channels_order=self._layout.channels_order)
  176. return Image(out, new_pf, new_layout)
  177. def to_rgb(self) -> Image:
  178. """Converts the image to RGB."""
  179. src = self._pixel_format.color_space
  180. data = self._data
  181. if src == ColorSpace.RGB:
  182. return self
  183. is_channels_last = self._layout.channels_order == ChannelsOrder.CHANNELS_LAST
  184. if is_channels_last:
  185. data = data.permute(0, 3, 1, 2) if data.ndim == 4 else data.permute(2, 0, 1)
  186. if src == ColorSpace.GRAY:
  187. out = kornia.color.grayscale_to_rgb(data)
  188. elif src == ColorSpace.BGR:
  189. out = data[:, [2, 1, 0], ...] if data.ndim == 4 else data[[2, 1, 0], ...]
  190. else:
  191. raise ValueError(f"Unsupported source color space for to_rgb(): {src}")
  192. if is_channels_last:
  193. out = out.permute(0, 2, 3, 1) if out.ndim == 4 else out.permute(1, 2, 0)
  194. new_pf = PixelFormat(color_space=ColorSpace.RGB, bit_depth=self._pixel_format.bit_depth)
  195. new_layout = ImageLayout(self._layout.image_size, channels=3, channels_order=self._layout.channels_order)
  196. return Image(out, new_pf, new_layout)
  197. def to_bgr(self) -> Image:
  198. """Converts the image to BGR."""
  199. src = self._pixel_format.color_space
  200. data = self._data
  201. if src == ColorSpace.BGR:
  202. return self
  203. is_channels_last = self._layout.channels_order == ChannelsOrder.CHANNELS_LAST
  204. if is_channels_last:
  205. data = data.permute(0, 3, 1, 2) if data.ndim == 4 else data.permute(2, 0, 1)
  206. if src == ColorSpace.GRAY:
  207. rgb_data = kornia.color.grayscale_to_rgb(data)
  208. out = rgb_data[:, [2, 1, 0], ...] if rgb_data.ndim == 4 else rgb_data[[2, 1, 0], ...]
  209. elif src == ColorSpace.RGB:
  210. out = data[:, [2, 1, 0], ...] if data.ndim == 4 else data[[2, 1, 0], ...]
  211. else:
  212. raise ValueError(f"Unsupported source color space for to_bgr(): {src}")
  213. if is_channels_last:
  214. out = out.permute(0, 2, 3, 1) if out.ndim == 4 else out.permute(1, 2, 0)
  215. new_pf = PixelFormat(color_space=ColorSpace.BGR, bit_depth=self._pixel_format.bit_depth)
  216. new_layout = ImageLayout(self._layout.image_size, channels=3, channels_order=self._layout.channels_order)
  217. return Image(out, new_pf, new_layout)
  218. @classmethod
  219. def from_numpy(
  220. cls,
  221. data: np_ndarray,
  222. color_space: ColorSpace = ColorSpace.RGB,
  223. channels_order: ChannelsOrder = ChannelsOrder.CHANNELS_LAST,
  224. ) -> Image:
  225. """Construct an image tensor from a numpy array.
  226. Args:
  227. data: a numpy array containing the image data.
  228. color_space: the color space of the image.
  229. pixel_format: the pixel format of the image.
  230. channels_order: what dimension the channels are in the image tensor.
  231. Example:
  232. >>> data = np.ones((4, 5, 3), dtype=np.uint8) # HxWxC
  233. >>> img = Image.from_numpy(data, color_space=ColorSpace.RGB)
  234. >>> assert img.channels == 3
  235. >>> assert img.width == 5
  236. >>> assert img.height == 4
  237. """
  238. if channels_order == ChannelsOrder.CHANNELS_LAST:
  239. image_size = ImageSize(height=data.shape[0], width=data.shape[1])
  240. channels = data.shape[2]
  241. elif channels_order == ChannelsOrder.CHANNELS_FIRST:
  242. image_size = ImageSize(height=data.shape[1], width=data.shape[2])
  243. channels = data.shape[0]
  244. else:
  245. raise ValueError("channels_order must be either `CHANNELS_LAST` or `CHANNELS_FIRST`")
  246. # create the pixel format based on the input data
  247. pixel_format = PixelFormat(color_space=color_space, bit_depth=data.itemsize * 8)
  248. # create the image layout based on the input data
  249. layout = ImageLayout(image_size=image_size, channels=channels, channels_order=channels_order)
  250. # create the image tensor
  251. return cls(torch.from_numpy(data), pixel_format, layout)
  252. def to_numpy(self) -> np_ndarray:
  253. """Return a numpy array in cpu from the image tensor."""
  254. return self.data.cpu().detach().numpy()
  255. @classmethod
  256. def from_dlpack(cls, data: DLPack) -> Image:
  257. """Construct an image tensor from a DLPack capsule.
  258. Args:
  259. data: a DLPack capsule from numpy, tvm or jax.
  260. Example:
  261. >>> x = np.ones((4, 5, 3))
  262. >>> img = Image.from_dlpack(x.__dlpack__())
  263. """
  264. _data: Tensor = from_dlpack(data)
  265. pixel_format = PixelFormat(color_space=ColorSpace.RGB, bit_depth=_data.element_size() * 8)
  266. # create the image layout based on the input data
  267. layout = ImageLayout(
  268. image_size=ImageSize(height=_data.shape[1], width=_data.shape[2]),
  269. channels=_data.shape[0],
  270. channels_order=ChannelsOrder.CHANNELS_FIRST,
  271. )
  272. return cls(_data, pixel_format, layout)
  273. def to_dlpack(self) -> DLPack:
  274. """Return a DLPack capsule from the image tensor."""
  275. return to_dlpack(self.data)
  276. @classmethod
  277. def from_file(cls, file_path: str | Path) -> Image:
  278. """Construct an image tensor from a file.
  279. Args:
  280. file_path: the path to the file to read the image from.
  281. """
  282. # TODO: allow user to specify the desired type and device
  283. data: Tensor = load_image(file_path, desired_type=ImageLoadType.RGB8, device="cpu")
  284. pixel_format = PixelFormat(color_space=ColorSpace.RGB, bit_depth=data.element_size() * 8)
  285. layout = ImageLayout(
  286. image_size=ImageSize(height=data.shape[1], width=data.shape[2]),
  287. channels=data.shape[0],
  288. channels_order=ChannelsOrder.CHANNELS_FIRST,
  289. )
  290. return cls(data, pixel_format, layout)
  291. def write(self, file_path: str | Path) -> None:
  292. """Write the image to a file.
  293. For now, only support writing to JPEG format.
  294. Args:
  295. file_path: the path to the file to write the image to.
  296. Example:
  297. >>> data = np.ones((4, 5, 3), dtype=np.uint8) # HxWxC
  298. >>> img = Image.from_numpy(data)
  299. >>> img.write("test.jpg")
  300. """
  301. data = self.data
  302. if self.channels_order == ChannelsOrder.CHANNELS_LAST:
  303. data = data.permute(2, 0, 1)
  304. write_image(file_path, data)
  305. def print(self, max_width: int = 256) -> None:
  306. """Print the image tensor to the console.
  307. Args:
  308. max_width: the maximum width of the image to print.
  309. .. code-block:: python
  310. img = Image.from_file("panda.png")
  311. img.print()
  312. .. image:: https://github.com/kornia/data/blob/main/print_image.png?raw=true
  313. """
  314. print(image_to_string(self.data, max_width))