crop2d.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Optional, Tuple, Union
  18. import torch
  19. from kornia.constants import Resample
  20. from kornia.core import Module, Tensor, as_tensor, pad, tensor
  21. from kornia.core.check import KORNIA_CHECK_SHAPE
  22. from kornia.geometry.bbox import infer_bbox_shape, validate_bbox
  23. from .affwarp import resize
  24. from .imgwarp import get_perspective_transform, warp_affine
  25. __all__ = [
  26. "CenterCrop2D",
  27. "center_crop",
  28. "crop_and_resize",
  29. "crop_by_boxes",
  30. "crop_by_indices",
  31. "crop_by_transform_mat",
  32. ]
  33. def crop_and_resize(
  34. input_tensor: Tensor,
  35. boxes: Tensor,
  36. size: Tuple[int, int],
  37. mode: str = "bilinear",
  38. padding_mode: str = "zeros",
  39. align_corners: bool = True,
  40. ) -> Tensor:
  41. r"""Extract crops from 2D images (4D tensor) and resize given a bounding box.
  42. Args:
  43. input_tensor: the 2D image tensor with shape (B, C, H, W).
  44. boxes : a tensor containing the coordinates of the bounding boxes to be extracted.
  45. The tensor must have the shape of Bx4x2, where each box is defined in the following (clockwise)
  46. order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in the x, y order.
  47. The coordinates would compose a rectangle with a shape of (N1, N2).
  48. size: a tuple with the height and width that will be
  49. used to resize the extracted patches.
  50. mode: interpolation mode to calculate output values
  51. ``'bilinear'`` | ``'nearest'``.
  52. padding_mode: padding mode for outside grid values
  53. ``'zeros'`` | ``'border'`` | 'reflection'.
  54. align_corners: mode for grid_generation.
  55. Returns:
  56. Tensor: tensor containing the patches with shape BxCxN1xN2.
  57. Example:
  58. >>> input = torch.tensor([[[
  59. ... [1., 2., 3., 4.],
  60. ... [5., 6., 7., 8.],
  61. ... [9., 10., 11., 12.],
  62. ... [13., 14., 15., 16.],
  63. ... ]]])
  64. >>> boxes = torch.tensor([[
  65. ... [1., 1.],
  66. ... [2., 1.],
  67. ... [2., 2.],
  68. ... [1., 2.],
  69. ... ]]) # 1x4x2
  70. >>> crop_and_resize(input, boxes, (2, 2), mode='nearest', align_corners=True)
  71. tensor([[[[ 6., 7.],
  72. [10., 11.]]]])
  73. """
  74. if not isinstance(input_tensor, Tensor):
  75. raise TypeError(f"Input tensor type is not a Tensor. Got {type(input_tensor)}")
  76. if not isinstance(boxes, Tensor):
  77. raise TypeError(f"Input boxes type is not a Tensor. Got {type(boxes)}")
  78. if not isinstance(size, (tuple, list)) and len(size) == 2:
  79. raise ValueError(f"Input size must be a tuple/list of length 2. Got {size}")
  80. if len(input_tensor.shape) != 4:
  81. raise AssertionError(f"Only tensor with shape (B, C, H, W) supported. Got {input_tensor.shape}.")
  82. # unpack input data
  83. dst_h, dst_w = size
  84. # [x, y] origin
  85. # top-left, top-right, bottom-right, bottom-left
  86. points_src = boxes.to(input_tensor)
  87. # [x, y] destination
  88. # top-left, top-right, bottom-right, bottom-left
  89. points_dst = tensor(
  90. [[[0, 0], [dst_w - 1, 0], [dst_w - 1, dst_h - 1], [0, dst_h - 1]]],
  91. device=input_tensor.device,
  92. dtype=input_tensor.dtype,
  93. ).expand(points_src.shape[0], -1, -1)
  94. return crop_by_boxes(input_tensor, points_src, points_dst, mode, padding_mode, align_corners)
  95. def center_crop(
  96. input_tensor: Tensor,
  97. size: Tuple[int, int],
  98. mode: str = "bilinear",
  99. padding_mode: str = "zeros",
  100. align_corners: bool = True,
  101. ) -> Tensor:
  102. r"""Crop the 2D images (4D tensor) from the center.
  103. Args:
  104. input_tensor: the 2D image tensor with shape (B, C, H, W).
  105. size: a tuple with the expected height and width
  106. of the output patch.
  107. mode: interpolation mode to calculate output values
  108. ``'bilinear'`` | ``'nearest'``.
  109. padding_mode: padding mode for outside grid values
  110. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  111. align_corners: mode for grid_generation.
  112. Returns:
  113. the output tensor with patches.
  114. Examples:
  115. >>> input = torch.tensor([[[
  116. ... [1., 2., 3., 4.],
  117. ... [5., 6., 7., 8.],
  118. ... [9., 10., 11., 12.],
  119. ... [13., 14., 15., 16.],
  120. ... ]]])
  121. >>> center_crop(input, (2, 4), mode='nearest', align_corners=True)
  122. tensor([[[[ 5., 6., 7., 8.],
  123. [ 9., 10., 11., 12.]]]])
  124. """
  125. if not isinstance(input_tensor, Tensor):
  126. raise TypeError(f"Input tensor type is not a Tensor. Got {type(input_tensor)}")
  127. if not isinstance(size, (tuple, list)) and len(size) == 2:
  128. raise ValueError(f"Input size must be a tuple/list of length 2. Got {size}")
  129. if len(input_tensor.shape) != 4:
  130. raise AssertionError(f"Only tensor with shape (B, C, H, W) supported. Got {input_tensor.shape}.")
  131. # unpack input sizes
  132. dst_h, dst_w = size
  133. src_h, src_w = input_tensor.shape[-2:]
  134. # compute start/end offsets
  135. dst_h_half: float = dst_h / 2
  136. dst_w_half: float = dst_w / 2
  137. src_h_half: float = src_h / 2
  138. src_w_half: float = src_w / 2
  139. start_x: float = src_w_half - dst_w_half
  140. start_y: float = src_h_half - dst_h_half
  141. end_x: float = start_x + dst_w - 1
  142. end_y: float = start_y + dst_h - 1
  143. # [y, x] origin
  144. # top-left, top-right, bottom-right, bottom-left
  145. points_src: Tensor = tensor(
  146. [[[start_x, start_y], [end_x, start_y], [end_x, end_y], [start_x, end_y]]],
  147. device=input_tensor.device,
  148. dtype=input_tensor.dtype,
  149. )
  150. # [y, x] destination
  151. # top-left, top-right, bottom-right, bottom-left
  152. points_dst: Tensor = tensor(
  153. [[[0, 0], [dst_w - 1, 0], [dst_w - 1, dst_h - 1], [0, dst_h - 1]]],
  154. device=input_tensor.device,
  155. dtype=input_tensor.dtype,
  156. ).expand(points_src.shape[0], -1, -1)
  157. return crop_by_boxes(input_tensor, points_src, points_dst, mode, padding_mode, align_corners)
  158. def crop_by_boxes(
  159. input_tensor: Tensor,
  160. src_box: Tensor,
  161. dst_box: Tensor,
  162. mode: str = "bilinear",
  163. padding_mode: str = "zeros",
  164. align_corners: bool = True,
  165. validate_boxes: bool = True,
  166. ) -> Tensor:
  167. """Perform crop transform on 2D images (4D tensor) given two bounding boxes.
  168. Given an input tensor, this function selected the interested areas by the provided bounding boxes (src_box).
  169. Then the selected areas would be fitted into the targeted bounding boxes (dst_box) by a perspective transformation.
  170. So far, the ragged tensor is not supported by PyTorch right now. This function hereby requires the bounding boxes
  171. in a batch must be rectangles with same width and height.
  172. Args:
  173. input_tensor: the 2D image tensor with shape (B, C, H, W).
  174. src_box: a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
  175. to be extracted. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
  176. order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
  177. dst_box: a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
  178. to be placed. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
  179. order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
  180. mode: interpolation mode to calculate output values
  181. ``'bilinear'`` | ``'nearest'``.
  182. padding_mode: padding mode for outside grid values
  183. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  184. align_corners: mode for grid_generation.
  185. validate_boxes: flag to perform validation on boxes.
  186. Returns:
  187. Tensor: the output tensor with patches.
  188. Examples:
  189. >>> input = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4))
  190. >>> src_box = torch.tensor([[
  191. ... [1., 1.],
  192. ... [2., 1.],
  193. ... [2., 2.],
  194. ... [1., 2.],
  195. ... ]]) # 1x4x2
  196. >>> dst_box = torch.tensor([[
  197. ... [0., 0.],
  198. ... [1., 0.],
  199. ... [1., 1.],
  200. ... [0., 1.],
  201. ... ]]) # 1x4x2
  202. >>> crop_by_boxes(input, src_box, dst_box, align_corners=True)
  203. tensor([[[[ 5.0000, 6.0000],
  204. [ 9.0000, 10.0000]]]])
  205. Note:
  206. If the src_box is smaller than dst_box, the following error will be thrown.
  207. RuntimeError: solve_cpu: For batch 0: U(2,2) is zero, singular U.
  208. """
  209. if validate_boxes:
  210. validate_bbox(src_box)
  211. validate_bbox(dst_box)
  212. if len(input_tensor.shape) != 4:
  213. raise AssertionError(f"Only tensor with shape (B, C, H, W) supported. Got {input_tensor.shape}.")
  214. # compute transformation between points and warp
  215. # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long'
  216. dst_trans_src: Tensor = get_perspective_transform(src_box.to(input_tensor), dst_box.to(input_tensor))
  217. bbox: Tuple[Tensor, Tensor] = infer_bbox_shape(dst_box)
  218. if not ((bbox[0] == bbox[0][0]).all() and (bbox[1] == bbox[1][0]).all()):
  219. raise AssertionError(
  220. f"Cropping height, width and depth must be exact same in a batch. Got height {bbox[0]} and width {bbox[1]}."
  221. )
  222. h_out: int = int(bbox[0][0].item())
  223. w_out: int = int(bbox[1][0].item())
  224. return crop_by_transform_mat(
  225. input_tensor, dst_trans_src, (h_out, w_out), mode=mode, padding_mode=padding_mode, align_corners=align_corners
  226. )
  227. def crop_by_transform_mat(
  228. input_tensor: Tensor,
  229. transform: Tensor,
  230. out_size: Tuple[int, int],
  231. mode: str = "bilinear",
  232. padding_mode: str = "zeros",
  233. align_corners: bool = True,
  234. ) -> Tensor:
  235. """Perform crop transform on 2D images (4D tensor) given a perspective transformation matrix.
  236. Args:
  237. input_tensor: the 2D image tensor with shape (B, C, H, W).
  238. transform: a perspective transformation matrix with shape (B, 3, 3).
  239. out_size: size of the output image (height, width).
  240. mode: interpolation mode to calculate output values
  241. ``'bilinear'`` | ``'nearest'``.
  242. padding_mode (str): padding mode for outside grid values
  243. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  244. align_corners: mode for grid_generation.
  245. Returns:
  246. the output tensor with patches.
  247. """
  248. # simulate broadcasting
  249. dst_trans_src = as_tensor(
  250. transform.expand(input_tensor.shape[0], -1, -1), device=input_tensor.device, dtype=input_tensor.dtype
  251. )
  252. patches: Tensor = warp_affine(
  253. input_tensor,
  254. dst_trans_src[:, :2, :],
  255. out_size,
  256. mode=mode,
  257. padding_mode=padding_mode,
  258. align_corners=align_corners,
  259. )
  260. return patches
  261. def crop_by_indices(
  262. input_tensor: Tensor,
  263. src_box: Tensor,
  264. size: Optional[Tuple[int, int]] = None,
  265. interpolation: str = "bilinear",
  266. align_corners: Optional[bool] = None,
  267. antialias: bool = False,
  268. shape_compensation: str = "resize",
  269. ) -> Tensor:
  270. """Crop tensors with naive indices.
  271. Args:
  272. input_tensor: the 2D image tensor with shape (B, C, H, W).
  273. src_box: a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
  274. to be extracted. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
  275. order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
  276. size: output size. An auto resize or pad will be performed according to ``shape_compensation``
  277. if the cropped slice sizes are not exactly align `size`.
  278. If None, will auto-infer from src_box.
  279. interpolation: algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` |
  280. 'bicubic' | 'trilinear' | 'area'.
  281. align_corners: interpolation flag.
  282. antialias: if True, then image will be filtered with Gaussian before downscaling.
  283. No effect for upscaling.
  284. shape_compensation: if the cropped slice sizes are not exactly align `size`, the image can either be padded
  285. or resized.
  286. """
  287. KORNIA_CHECK_SHAPE(input_tensor, ["B", "C", "H", "W"])
  288. KORNIA_CHECK_SHAPE(src_box, ["B", "4", "2"])
  289. B, C, _, _ = input_tensor.shape
  290. src = as_tensor(src_box, device=input_tensor.device, dtype=torch.long)
  291. x1 = src[:, 0, 0]
  292. x2 = src[:, 1, 0] + 1
  293. y1 = src[:, 0, 1]
  294. y2 = src[:, 3, 1] + 1
  295. if (
  296. len(x1.unique(sorted=False))
  297. == len(x2.unique(sorted=False))
  298. == len(y1.unique(sorted=False))
  299. == len(y2.unique(sorted=False))
  300. == 1
  301. ):
  302. out = input_tensor[..., int(y1[0]) : int(y2[0]), int(x1[0]) : int(x2[0])]
  303. if size is not None and out.shape[-2:] != size:
  304. return resize(
  305. out, size, interpolation=interpolation, align_corners=align_corners, side="short", antialias=antialias
  306. )
  307. if size is None:
  308. h, w = infer_bbox_shape(src)
  309. size = h.unique(sorted=False), w.unique(sorted=False)
  310. out = torch.empty(B, C, *size, device=input_tensor.device, dtype=input_tensor.dtype)
  311. # Find out the cropped shapes that need to be resized.
  312. for i, _ in enumerate(out):
  313. _out = input_tensor[i : i + 1, :, int(y1[i]) : int(y2[i]), int(x1[i]) : int(x2[i])]
  314. if _out.shape[-2:] != size:
  315. if shape_compensation == "resize":
  316. out[i] = resize(
  317. _out,
  318. size,
  319. interpolation=interpolation,
  320. align_corners=align_corners,
  321. side="short",
  322. antialias=antialias,
  323. )
  324. else:
  325. out[i] = pad(_out, [0, size[1] - _out.shape[-1], 0, size[0] - _out.shape[-2]])
  326. else:
  327. out[i] = _out
  328. return out
  329. class CenterCrop2D(Module):
  330. """Center crop the input tensor.
  331. Args:
  332. size: Size (h, w) in pixels of the resized region or just one side.
  333. align_corners: interpolation flag.
  334. resample: Resampling mode.
  335. cropping_mode: Cropping mode, "resample" or "slice".
  336. Note:
  337. For JIT, the cropping mode must be "resample".
  338. """
  339. def __init__(
  340. self,
  341. size: Union[int, Tuple[int, int]],
  342. align_corners: bool = True,
  343. resample: Union[str, int, Resample] = Resample.BILINEAR.name,
  344. cropping_mode: str = "slice",
  345. ) -> None:
  346. super().__init__()
  347. if isinstance(size, tuple):
  348. self.size = (size[0], size[1])
  349. elif isinstance(size, int):
  350. self.size = (size, size)
  351. else:
  352. raise Exception(f"Invalid size type. Expected (int, tuple(int, int). Got: {type(size)}.")
  353. dst_h, dst_w = self.size
  354. points_dst = torch.tensor([[[0, 0], [dst_w - 1, 0], [dst_w - 1, dst_h - 1], [0, dst_h - 1]]], dtype=torch.long)
  355. self.register_buffer("points_src", points_dst.clone())
  356. self.register_buffer("points_dst", points_dst)
  357. self.flags = {
  358. "resample": Resample.get(resample),
  359. "cropping_mode": cropping_mode,
  360. "align_corners": align_corners,
  361. "size": self.size,
  362. "padding_mode": "zeros",
  363. }
  364. def forward(self, input: Tensor) -> Tensor:
  365. batch_size = input.shape[0]
  366. dst_h, dst_w = self.size
  367. src_h, src_w = input.shape[-2:]
  368. dst_h_half, dst_w_half = dst_h / 2, dst_w / 2
  369. src_h_half, src_w_half = src_h / 2, src_w / 2
  370. start_x, start_y = int(src_w_half - dst_w_half), int(src_h_half - dst_h_half)
  371. end_x, end_y = start_x + dst_w - 1, start_y + dst_h - 1
  372. # [y, x] origin
  373. # top-left, top-right, bottom-right, bottom-left
  374. self.points_src[0, 0, 0] = start_x
  375. self.points_src[0, 0, 1] = start_y
  376. self.points_src[0, 1, 0] = end_x
  377. self.points_src[0, 1, 1] = start_y
  378. self.points_src[0, 2, 0] = end_x
  379. self.points_src[0, 2, 1] = end_y
  380. self.points_src[0, 3, 0] = start_x
  381. self.points_src[0, 3, 1] = end_y
  382. if self.flags["cropping_mode"] == "resample": # uses bilinear interpolation to crop
  383. transform = get_perspective_transform(
  384. self.points_src.expand(batch_size, -1, -1).to(input),
  385. self.points_dst.expand(batch_size, -1, -1).to(input),
  386. )
  387. transform = transform.expand(batch_size, -1, -1)
  388. return crop_by_transform_mat(
  389. input,
  390. transform[:, :2, :],
  391. self.size,
  392. self.flags["resample"].name.lower(), # type:ignore
  393. "zeros",
  394. self.flags["align_corners"], # type:ignore
  395. )
  396. if self.flags["cropping_mode"] == "slice": # uses advanced slicing to crop
  397. return crop_by_indices(input, self.points_src.expand(batch_size, -1, -1).to(input), self.flags["size"]) # type:ignore
  398. raise NotImplementedError(f"Not supported type: {self.flags['cropping_mode']}.")