image.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from functools import wraps
  18. from typing import Any, Callable, List, Optional
  19. import torch
  20. from torch import nn
  21. from torch.nn import functional as F
  22. from kornia.core import Tensor
  23. def image_to_tensor(image: Any, keepdim: bool = True) -> Tensor:
  24. """Convert a numpy image to a PyTorch 4d tensor image.
  25. Args:
  26. image: image of the form :math:`(H, W, C)`, :math:`(H, W)` or
  27. :math:`(B, H, W, C)`.
  28. keepdim: If ``False`` unsqueeze the input image to match the shape
  29. :math:`(B, H, W, C)`.
  30. Returns:
  31. tensor of the form :math:`(B, C, H, W)` if keepdim is ``False``,
  32. :math:`(C, H, W)` otherwise.
  33. Example:
  34. >>> img = np.ones((3, 3))
  35. >>> image_to_tensor(img).shape
  36. torch.Size([1, 3, 3])
  37. >>> img = np.ones((4, 4, 1))
  38. >>> image_to_tensor(img).shape
  39. torch.Size([1, 4, 4])
  40. >>> img = np.ones((4, 4, 3))
  41. >>> image_to_tensor(img, keepdim=False).shape
  42. torch.Size([1, 3, 4, 4])
  43. """
  44. if len(image.shape) > 4 or len(image.shape) < 2:
  45. raise ValueError("Input size must be a two, three or four dimensional array")
  46. input_shape = image.shape
  47. tensor: Tensor = torch.from_numpy(image)
  48. if len(input_shape) == 2:
  49. # (H, W) -> (1, H, W)
  50. tensor = tensor.unsqueeze(0)
  51. elif len(input_shape) == 3:
  52. # (H, W, C) -> (C, H, W)
  53. tensor = tensor.permute(2, 0, 1)
  54. elif len(input_shape) == 4:
  55. # (B, H, W, C) -> (B, C, H, W)
  56. tensor = tensor.permute(0, 3, 1, 2)
  57. keepdim = True # no need to unsqueeze
  58. else:
  59. raise ValueError(f"Cannot process image with shape {input_shape}")
  60. return tensor.unsqueeze(0) if not keepdim else tensor
  61. def image_list_to_tensor(images: List[Any]) -> Tensor:
  62. """Convert a list of numpy images to a PyTorch 4d tensor image.
  63. Args:
  64. images: list of images, each of the form :math:`(H, W, C)`.
  65. Image shapes must be consistent
  66. Returns:
  67. tensor of the form :math:`(B, C, H, W)`.
  68. Example:
  69. >>> imgs = [np.ones((4, 4, 1)), np.zeros((4, 4, 1))]
  70. >>> image_list_to_tensor(imgs).shape
  71. torch.Size([2, 1, 4, 4])
  72. """
  73. if not images:
  74. raise ValueError("Input list of images is empty")
  75. images_t = []
  76. for img in images:
  77. if not torch.is_tensor(img):
  78. img = torch.as_tensor(img)
  79. images_t.append(img)
  80. shape = images_t[0].shape
  81. if len(shape) != 3:
  82. raise ValueError("Each image must have shape (H, W, C)")
  83. if any(img.shape != shape for img in images_t):
  84. raise ValueError("All images must have the same shape")
  85. # Stack into (N, H, W, C) then permute to (N, C, H, W)
  86. return torch.stack(images_t, dim=0).permute(0, 3, 1, 2)
  87. def _to_bchw(tensor: Tensor) -> Tensor:
  88. """Convert a PyTorch tensor image to BCHW format.
  89. Args:
  90. tensor (torch.Tensor): image of the form :math:`(*, H, W)`.
  91. Returns:
  92. input tensor of the form :math:`(B, C, H, W)`.
  93. """
  94. if not isinstance(tensor, Tensor):
  95. raise TypeError(f"Input type is not a Tensor. Got {type(tensor)}")
  96. if len(tensor.shape) < 2:
  97. raise ValueError(f"Input size must be a two, three or four dimensional tensor. Got {tensor.shape}")
  98. if len(tensor.shape) == 2:
  99. tensor = tensor.unsqueeze(0)
  100. if len(tensor.shape) == 3:
  101. tensor = tensor.unsqueeze(0)
  102. if len(tensor.shape) > 4:
  103. tensor = tensor.view(-1, tensor.shape[-3], tensor.shape[-2], tensor.shape[-1])
  104. return tensor
  105. def _to_bcdhw(tensor: Tensor) -> Tensor:
  106. """Convert a PyTorch tensor image to BCDHW format.
  107. Args:
  108. tensor (torch.Tensor): image of the form :math:`(*, D, H, W)`.
  109. Returns:
  110. input tensor of the form :math:`(B, C, D, H, W)`.
  111. """
  112. if not isinstance(tensor, Tensor):
  113. raise TypeError(f"Input type is not a Tensor. Got {type(tensor)}")
  114. if len(tensor.shape) < 3:
  115. raise ValueError(f"Input size must be a three, four or five dimensional tensor. Got {tensor.shape}")
  116. if len(tensor.shape) == 3:
  117. tensor = tensor.unsqueeze(0)
  118. if len(tensor.shape) == 4:
  119. tensor = tensor.unsqueeze(0)
  120. if len(tensor.shape) > 5:
  121. tensor = tensor.view(-1, tensor.shape[-4], tensor.shape[-3], tensor.shape[-2], tensor.shape[-1])
  122. return tensor
  123. def tensor_to_image(tensor: Tensor, keepdim: bool = False, force_contiguous: bool = False) -> Any:
  124. """Convert a PyTorch tensor image to a numpy image.
  125. In case the tensor is in the GPU, it will be copied back to CPU.
  126. Args:
  127. tensor: image of the form :math:`(H, W)`, :math:`(C, H, W)` or
  128. :math:`(B, C, H, W)`.
  129. keepdim: If ``False`` squeeze the input image to match the shape
  130. :math:`(H, W, C)` or :math:`(H, W)`.
  131. force_contiguous: If ``True`` call `contiguous` to the tensor before
  132. Returns:
  133. image of the form :math:`(H, W)`, :math:`(H, W, C)` or :math:`(B, H, W, C)`.
  134. Example:
  135. >>> img = torch.ones(1, 3, 3)
  136. >>> tensor_to_image(img).shape
  137. (3, 3)
  138. >>> img = torch.ones(3, 4, 4)
  139. >>> tensor_to_image(img).shape
  140. (4, 4, 3)
  141. """
  142. if not isinstance(tensor, Tensor):
  143. raise TypeError(f"Input type is not a Tensor. Got {type(tensor)}")
  144. if len(tensor.shape) > 4 or len(tensor.shape) < 2:
  145. raise ValueError("Input size must be a two, three or four dimensional tensor")
  146. input_shape = tensor.shape
  147. image = tensor.cpu().detach()
  148. if len(input_shape) == 2:
  149. # (H, W) -> (H, W)
  150. pass
  151. elif len(input_shape) == 3:
  152. # (C, H, W) -> (H, W, C)
  153. if input_shape[0] == 1:
  154. # Grayscale for proper plt.imshow needs to be (H,W)
  155. image = image.squeeze()
  156. else:
  157. image = image.permute(1, 2, 0)
  158. elif len(input_shape) == 4:
  159. # (B, C, H, W) -> (B, H, W, C)
  160. image = image.permute(0, 2, 3, 1)
  161. if input_shape[0] == 1 and not keepdim:
  162. image = image.squeeze(0)
  163. if input_shape[1] == 1:
  164. image = image.squeeze(-1)
  165. else:
  166. raise ValueError(f"Cannot process tensor with shape {input_shape}")
  167. # make sure the image is contiguous
  168. if force_contiguous:
  169. image = image.contiguous()
  170. return image.numpy()
  171. class ImageToTensor(nn.Module):
  172. """Converts a numpy image to a PyTorch 4d tensor image.
  173. Args:
  174. keepdim: If ``False`` unsqueeze the input image to match the shape :math:`(B, H, W, C)`.
  175. """
  176. def __init__(self, keepdim: bool = False) -> None:
  177. super().__init__()
  178. self.keepdim = keepdim
  179. def forward(self, x: Any) -> Tensor:
  180. return image_to_tensor(x, keepdim=self.keepdim)
  181. def make_grid(tensor: Tensor, n_row: Optional[int] = None, padding: int = 2) -> Tensor:
  182. """Convert a batched tensor to one image with padding in between.
  183. Args:
  184. tensor: A batched tensor of shape (B, C, H, W).
  185. n_row: Number of images displayed in each row of the grid.
  186. padding: The amount of padding to add between images.
  187. Returns:
  188. Tensor: The combined image grid.
  189. """
  190. if not isinstance(tensor, torch.Tensor):
  191. raise TypeError("Input tensor must be a PyTorch tensor.")
  192. B, C, H, W = tensor.shape
  193. if n_row is None:
  194. n_row = int(torch.sqrt(torch.tensor(B, dtype=torch.float32)).ceil().item())
  195. n_col = (B + n_row - 1) // n_row
  196. padded_H = H + padding
  197. padded_W = W + padding
  198. # pad each image on right and bottom with `padding` zeros
  199. tensor_padded = F.pad(tensor, (0, padding, 0, padding))
  200. total = n_row * n_col
  201. if total > B:
  202. pad_tiles = torch.zeros((total - B, C, padded_H, padded_W), dtype=tensor.dtype, device=tensor.device)
  203. tensor_padded = torch.cat((tensor_padded, pad_tiles), dim=0)
  204. # ensure contiguous memory layout before reshaping / permuting
  205. tensor_padded = tensor_padded.contiguous()
  206. # reshape into (n_row, n_col, C, padded_H, padded_W)
  207. grid = tensor_padded.view(n_row, n_col, C, padded_H, padded_W)
  208. # permute to (C, n_row, padded_H, n_col, padded_W) then collapse
  209. grid = grid.permute(2, 0, 3, 1, 4).contiguous()
  210. combined = grid.view(C, n_row * padded_H, n_col * padded_W)
  211. # crop trailing right/bottom padding to match original
  212. combined_H = n_row * padded_H - padding
  213. combined_W = n_col * padded_W - padding
  214. combined = combined[:, :combined_H, :combined_W]
  215. return combined
  216. def perform_keep_shape_image(f: Callable[..., Tensor]) -> Callable[..., Tensor]:
  217. """Apply `f` to an image of arbitrary leading dimensions `(*, C, H, W)`.
  218. It works by first viewing the image as `(B, C, H, W)`, applying the function and re-viewing the image as original
  219. shape.
  220. """
  221. @wraps(f)
  222. def _wrapper(input: Tensor, *args: Any, **kwargs: Any) -> Tensor:
  223. if not isinstance(input, Tensor):
  224. raise TypeError(f"Input input type is not a Tensor. Got {type(input)}")
  225. if input.shape.numel() == 0:
  226. raise ValueError("Invalid input tensor, it is empty.")
  227. input_shape = input.shape
  228. input = _to_bchw(input) # view input as (B, C, H, W)
  229. output = f(input, *args, **kwargs)
  230. if len(input_shape) == 3:
  231. output = output[0]
  232. if len(input_shape) == 2:
  233. output = output[0, 0]
  234. if len(input_shape) > 4:
  235. output = output.view(*(input_shape[:-3] + output.shape[-3:]))
  236. return output
  237. return _wrapper
  238. def perform_keep_shape_video(f: Callable[..., Tensor]) -> Callable[..., Tensor]:
  239. """Apply `f` to an image of arbitrary leading dimensions `(*, C, D, H, W)`.
  240. It works by first viewing the image as `(B, C, D, H, W)`, applying the function and re-viewing the image as original
  241. shape.
  242. """
  243. @wraps(f)
  244. def _wrapper(input: Tensor, *args: Any, **kwargs: Any) -> Tensor:
  245. if not isinstance(input, Tensor):
  246. raise TypeError(f"Input input type is not a Tensor. Got {type(input)}")
  247. if input.numel() == 0:
  248. raise ValueError("Invalid input tensor, it is empty.")
  249. input_shape = input.shape
  250. input = _to_bcdhw(input) # view input as (B, C, D, H, W)
  251. output = f(input, *args, **kwargs)
  252. if len(input_shape) == 4:
  253. output = output[0]
  254. if len(input_shape) == 3:
  255. output = output[0, 0]
  256. if len(input_shape) > 5:
  257. output = output.view(*(input_shape[:-4] + output.shape[-4:]))
  258. return output
  259. return _wrapper