affwarp.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import warnings
  18. from typing import Optional, Tuple, Union
  19. import torch
  20. from kornia.core import ImageModule as Module
  21. from kornia.core import Tensor, ones, ones_like, zeros
  22. from kornia.filters import gaussian_blur2d
  23. from kornia.utils import _extract_device_dtype
  24. from kornia.utils.image import perform_keep_shape_image
  25. from kornia.utils.misc import eye_like
  26. from .imgwarp import get_affine_matrix2d, get_projective_transform, get_rotation_matrix2d, warp_affine, warp_affine3d
  27. __all__ = [
  28. "Affine",
  29. "Rescale",
  30. "Resize",
  31. "Rotate",
  32. "Scale",
  33. "Shear",
  34. "Translate",
  35. "affine",
  36. "affine3d",
  37. "rescale",
  38. "resize",
  39. "resize_to_be_divisible",
  40. "rotate",
  41. "rotate3d",
  42. "scale",
  43. "shear",
  44. "translate",
  45. ]
  46. # utilities to compute affine matrices
  47. def _compute_tensor_center(tensor: Tensor) -> Tensor:
  48. """Compute the center of tensor plane for (H, W), (C, H, W) and (B, C, H, W)."""
  49. if not 2 <= len(tensor.shape) <= 4:
  50. raise AssertionError(f"Must be a 3D tensor as HW, CHW and BCHW. Got {tensor.shape}.")
  51. height, width = tensor.shape[-2:]
  52. center_x: float = float(width - 1) / 2
  53. center_y: float = float(height - 1) / 2
  54. center: Tensor = torch.tensor([center_x, center_y], device=tensor.device, dtype=tensor.dtype)
  55. return center
  56. def _compute_tensor_center3d(tensor: Tensor) -> Tensor:
  57. """Compute the center of tensor plane for (D, H, W), (C, D, H, W) and (B, C, D, H, W)."""
  58. if not 3 <= len(tensor.shape) <= 5:
  59. raise AssertionError(f"Must be a 3D tensor as DHW, CDHW and BCDHW. Got {tensor.shape}.")
  60. depth, height, width = tensor.shape[-3:]
  61. center_x: float = float(width - 1) / 2
  62. center_y: float = float(height - 1) / 2
  63. center_z: float = float(depth - 1) / 2
  64. center: Tensor = torch.tensor([center_x, center_y, center_z], device=tensor.device, dtype=tensor.dtype)
  65. return center
  66. def _compute_rotation_matrix(angle: Tensor, center: Tensor) -> Tensor:
  67. """Compute a pure affine rotation matrix."""
  68. scale: Tensor = ones_like(center)
  69. matrix: Tensor = get_rotation_matrix2d(center, angle, scale)
  70. return matrix
  71. def _compute_rotation_matrix3d(yaw: Tensor, pitch: Tensor, roll: Tensor, center: Tensor) -> Tensor:
  72. """Compute a pure affine rotation matrix."""
  73. if len(yaw.shape) == len(pitch.shape) == len(roll.shape) == 0:
  74. yaw = yaw.unsqueeze(dim=0)
  75. pitch = pitch.unsqueeze(dim=0)
  76. roll = roll.unsqueeze(dim=0)
  77. if len(yaw.shape) == len(pitch.shape) == len(roll.shape) == 1:
  78. yaw = yaw.unsqueeze(dim=1)
  79. pitch = pitch.unsqueeze(dim=1)
  80. roll = roll.unsqueeze(dim=1)
  81. if not (len(yaw.shape) == len(pitch.shape) == len(roll.shape) == 2):
  82. raise AssertionError(f"Expected yaw, pitch, roll to be (B, 1). Got {yaw.shape}, {pitch.shape}, {roll.shape}.")
  83. angles: Tensor = torch.cat([yaw, pitch, roll], dim=1)
  84. scales: Tensor = ones_like(yaw)
  85. matrix: Tensor = get_projective_transform(center, angles, scales)
  86. return matrix
  87. def _compute_translation_matrix(translation: Tensor) -> Tensor:
  88. """Compute affine matrix for translation."""
  89. matrix: Tensor = eye_like(3, translation, shared_memory=False)
  90. dx, dy = torch.chunk(translation, chunks=2, dim=-1)
  91. matrix[..., 0, 2:3] += dx
  92. matrix[..., 1, 2:3] += dy
  93. return matrix
  94. def _compute_scaling_matrix(scale: Tensor, center: Tensor) -> Tensor:
  95. """Compute affine matrix for scaling."""
  96. angle: Tensor = zeros(scale.shape[:1], device=scale.device, dtype=scale.dtype)
  97. matrix: Tensor = get_rotation_matrix2d(center, angle, scale)
  98. return matrix
  99. def _compute_shear_matrix(shear: Tensor) -> Tensor:
  100. """Compute affine matrix for shearing."""
  101. matrix: Tensor = eye_like(3, shear, shared_memory=False)
  102. shx, shy = torch.chunk(shear, chunks=2, dim=-1)
  103. matrix[..., 0, 1:2] += shx
  104. matrix[..., 1, 0:1] += shy
  105. return matrix
  106. # based on:
  107. # https://github.com/anibali/tvl/blob/master/src/tvl/transforms.py#L166
  108. def affine(
  109. tensor: Tensor,
  110. matrix: Tensor,
  111. mode: str = "bilinear",
  112. padding_mode: str = "zeros",
  113. align_corners: bool = True,
  114. ) -> Tensor:
  115. r"""Apply an affine transformation to the image.
  116. .. image:: _static/img/warp_affine.png
  117. Args:
  118. tensor: The image tensor to be warped in shapes of
  119. :math:`(H, W)`, :math:`(D, H, W)` and :math:`(B, C, H, W)`.
  120. matrix: The 2x3 affine transformation matrix.
  121. mode: interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'``.
  122. padding_mode: padding mode for outside grid values
  123. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  124. align_corners: interpolation flag.
  125. Returns:
  126. The warped image with the same shape as the input.
  127. Example:
  128. >>> img = torch.rand(1, 2, 3, 5)
  129. >>> aff = torch.eye(2, 3)[None]
  130. >>> out = affine(img, aff)
  131. >>> print(out.shape)
  132. torch.Size([1, 2, 3, 5])
  133. """
  134. # warping needs data in the shape of BCHW
  135. is_unbatched: bool = tensor.ndimension() == 3
  136. if is_unbatched:
  137. tensor = torch.unsqueeze(tensor, dim=0)
  138. # we enforce broadcasting since by default grid_sample it does not
  139. # give support for that
  140. matrix = matrix.expand(tensor.shape[0], -1, -1)
  141. # warp the input tensor
  142. height: int = tensor.shape[-2]
  143. width: int = tensor.shape[-1]
  144. warped: Tensor = warp_affine(tensor, matrix, (height, width), mode, padding_mode, align_corners)
  145. # return in the original shape
  146. if is_unbatched:
  147. warped = torch.squeeze(warped, dim=0)
  148. return warped
  149. def affine3d(
  150. tensor: Tensor,
  151. matrix: Tensor,
  152. mode: str = "bilinear",
  153. padding_mode: str = "zeros",
  154. align_corners: bool = False,
  155. ) -> Tensor:
  156. r"""Apply an affine transformation to the 3d volume.
  157. Args:
  158. tensor: The image tensor to be warped in shapes of
  159. :math:`(D, H, W)`, :math:`(C, D, H, W)` and :math:`(B, C, D, H, W)`.
  160. matrix: The affine transformation matrix with shape :math:`(B, 3, 4)`.
  161. mode: interpolation mode to calculate output values
  162. ``'bilinear'`` | ``'nearest'``.
  163. padding_mode: padding mode for outside grid values
  164. `` 'zeros'`` | ``'border'`` | ``'reflection'``.
  165. align_corners: interpolation flag.
  166. Returns:
  167. The warped image.
  168. Example:
  169. >>> img = torch.rand(1, 2, 4, 3, 5)
  170. >>> aff = torch.eye(3, 4)[None]
  171. >>> out = affine3d(img, aff)
  172. >>> print(out.shape)
  173. torch.Size([1, 2, 4, 3, 5])
  174. """
  175. # warping needs data in the shape of BCDHW
  176. is_unbatched: bool = tensor.ndimension() == 4
  177. if is_unbatched:
  178. tensor = torch.unsqueeze(tensor, dim=0)
  179. # we enforce broadcasting since by default grid_sample it does not
  180. # give support for that
  181. matrix = matrix.expand(tensor.shape[0], -1, -1)
  182. # warp the input tensor
  183. depth: int = tensor.shape[-3]
  184. height: int = tensor.shape[-2]
  185. width: int = tensor.shape[-1]
  186. warped: Tensor = warp_affine3d(tensor, matrix, (depth, height, width), mode, padding_mode, align_corners)
  187. # return in the original shape
  188. if is_unbatched:
  189. warped = torch.squeeze(warped, dim=0)
  190. return warped
  191. # based on:
  192. # https://github.com/anibali/tvl/blob/master/src/tvl/transforms.py#L185
  193. def rotate(
  194. tensor: Tensor,
  195. angle: Tensor,
  196. center: Union[None, Tensor] = None,
  197. mode: str = "bilinear",
  198. padding_mode: str = "zeros",
  199. align_corners: bool = True,
  200. ) -> Tensor:
  201. r"""Rotate the tensor anti-clockwise about the center.
  202. .. image:: _static/img/rotate.png
  203. Args:
  204. tensor: The image tensor to be warped in shapes of :math:`(B, C, H, W)`.
  205. angle: The angle through which to rotate. The tensor
  206. must have a shape of (B), where B is batch size.
  207. center: The center through which to rotate. The tensor
  208. must have a shape of (B, 2), where B is batch size and last
  209. dimension contains cx and cy.
  210. mode: interpolation mode to calculate output values
  211. ``'bilinear'`` | ``'nearest'``.
  212. padding_mode: padding mode for outside grid values
  213. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  214. align_corners: interpolation flag.
  215. Returns:
  216. The rotated tensor with shape as input.
  217. .. note::
  218. See a working example `here <https://kornia.github.io/tutorials/nbs/rotate_affine.html>`__.
  219. Example:
  220. >>> img = torch.rand(1, 3, 4, 4)
  221. >>> angle = torch.tensor([90.])
  222. >>> out = rotate(img, angle)
  223. >>> print(out.shape)
  224. torch.Size([1, 3, 4, 4])
  225. """
  226. if not isinstance(tensor, Tensor):
  227. raise TypeError(f"Input tensor type is not a Tensor. Got {type(tensor)}")
  228. if not isinstance(angle, Tensor):
  229. raise TypeError(f"Input angle type is not a Tensor. Got {type(angle)}")
  230. if center is not None and not isinstance(center, Tensor):
  231. raise TypeError(f"Input center type is not a Tensor. Got {type(center)}")
  232. if len(tensor.shape) not in (3, 4):
  233. raise ValueError(f"Invalid tensor shape, we expect CxHxW or BxCxHxW. Got: {tensor.shape}")
  234. # compute the rotation center
  235. if center is None:
  236. center = _compute_tensor_center(tensor)
  237. # compute the rotation matrix
  238. # TODO: add broadcasting to get_rotation_matrix2d for center
  239. angle = angle.expand(tensor.shape[0])
  240. center = center.expand(tensor.shape[0], -1)
  241. rotation_matrix: Tensor = _compute_rotation_matrix(angle, center)
  242. # warp using the affine transform
  243. return affine(tensor, rotation_matrix[..., :2, :3], mode, padding_mode, align_corners)
  244. def rotate3d(
  245. tensor: Tensor,
  246. yaw: Tensor,
  247. pitch: Tensor,
  248. roll: Tensor,
  249. center: Union[None, Tensor] = None,
  250. mode: str = "bilinear",
  251. padding_mode: str = "zeros",
  252. align_corners: bool = False,
  253. ) -> Tensor:
  254. r"""Rotate 3D the tensor anti-clockwise about the centre.
  255. Args:
  256. tensor: The image tensor to be warped in shapes of :math:`(B, C, D, H, W)`.
  257. yaw: The yaw angle through which to rotate. The tensor
  258. must have a shape of (B), where B is batch size.
  259. pitch: The pitch angle through which to rotate. The tensor
  260. must have a shape of (B), where B is batch size.
  261. roll: The roll angle through which to rotate. The tensor
  262. must have a shape of (B), where B is batch size.
  263. center: The center through which to rotate. The tensor
  264. must have a shape of (B, 2), where B is batch size and last
  265. dimension contains cx and cy.
  266. mode: interpolation mode to calculate output values
  267. ``'bilinear'`` | ``'nearest'``.
  268. padding_mode: padding mode for outside grid values
  269. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  270. align_corners: interpolation flag.
  271. Returns:
  272. Tensor: The rotated tensor with shape as input.
  273. """
  274. if not isinstance(tensor, Tensor):
  275. raise TypeError(f"Input tensor type is not a Tensor. Got {type(tensor)}")
  276. if not isinstance(yaw, Tensor):
  277. raise TypeError(f"yaw is not a Tensor. Got {type(yaw)}")
  278. if not isinstance(pitch, Tensor):
  279. raise TypeError(f"pitch is not a Tensor. Got {type(pitch)}")
  280. if not isinstance(roll, Tensor):
  281. raise TypeError(f"roll is not a Tensor. Got {type(roll)}")
  282. if center is not None and not isinstance(center, Tensor):
  283. raise TypeError(f"Input center type is not a Tensor. Got {type(center)}")
  284. if len(tensor.shape) not in (4, 5):
  285. raise ValueError(f"Invalid tensor shape, we expect CxDxHxW or BxCxDxHxW. Got: {tensor.shape}")
  286. # compute the rotation center
  287. if center is None:
  288. center = _compute_tensor_center3d(tensor)
  289. # compute the rotation matrix
  290. # TODO: add broadcasting to get_rotation_matrix2d for center
  291. yaw = yaw.expand(tensor.shape[0])
  292. pitch = pitch.expand(tensor.shape[0])
  293. roll = roll.expand(tensor.shape[0])
  294. center = center.expand(tensor.shape[0], -1)
  295. rotation_matrix: Tensor = _compute_rotation_matrix3d(yaw, pitch, roll, center)
  296. # warp using the affine transform
  297. return affine3d(tensor, rotation_matrix[..., :3, :4], mode, padding_mode, align_corners)
  298. def translate(
  299. tensor: Tensor,
  300. translation: Tensor,
  301. mode: str = "bilinear",
  302. padding_mode: str = "zeros",
  303. align_corners: bool = True,
  304. ) -> Tensor:
  305. r"""Translate the tensor in pixel units.
  306. .. image:: _static/img/translate.png
  307. Args:
  308. tensor: The image tensor to be warped in shapes of :math:`(B, C, H, W)`.
  309. translation: tensor containing the amount of pixels to
  310. translate in the x and y direction. The tensor must have a shape of
  311. (B, 2), where B is batch size, last dimension contains dx dy.
  312. mode: interpolation mode to calculate output values
  313. ``'bilinear'`` | ``'nearest'``.
  314. padding_mode: padding mode for outside grid values
  315. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  316. align_corners: interpolation flag.
  317. Returns:
  318. The translated tensor with shape as input.
  319. Example:
  320. >>> img = torch.rand(1, 3, 4, 4)
  321. >>> translation = torch.tensor([[1., 0.]])
  322. >>> out = translate(img, translation)
  323. >>> print(out.shape)
  324. torch.Size([1, 3, 4, 4])
  325. """
  326. if not isinstance(tensor, Tensor):
  327. raise TypeError(f"Input tensor type is not a Tensor. Got {type(tensor)}")
  328. if not isinstance(translation, Tensor):
  329. raise TypeError(f"Input translation type is not a Tensor. Got {type(translation)}")
  330. if len(tensor.shape) not in (3, 4):
  331. raise ValueError(f"Invalid tensor shape, we expect CxHxW or BxCxHxW. Got: {tensor.shape}")
  332. # compute the translation matrix
  333. translation_matrix: Tensor = _compute_translation_matrix(translation)
  334. # warp using the affine transform
  335. return affine(tensor, translation_matrix[..., :2, :3], mode, padding_mode, align_corners)
  336. def scale(
  337. tensor: Tensor,
  338. scale_factor: Tensor,
  339. center: Union[None, Tensor] = None,
  340. mode: str = "bilinear",
  341. padding_mode: str = "zeros",
  342. align_corners: bool = True,
  343. ) -> Tensor:
  344. r"""Scale the tensor by a factor.
  345. .. image:: _static/img/scale.png
  346. Args:
  347. tensor: The image tensor to be warped in shapes of :math:`(B, C, H, W)`.
  348. scale_factor: The scale factor apply. The tensor
  349. must have a shape of (B) or (B, 2), where B is batch size.
  350. If (B), isotropic scaling will perform.
  351. If (B, 2), x-y-direction specific scaling will perform.
  352. center: The center through which to scale. The tensor
  353. must have a shape of (B, 2), where B is batch size and last
  354. dimension contains cx and cy.
  355. mode: interpolation mode to calculate output values
  356. ``'bilinear'`` | ``'nearest'``.
  357. padding_mode: padding mode for outside grid values
  358. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  359. align_corners: interpolation flag.
  360. Returns:
  361. The scaled tensor with the same shape as the input.
  362. Example:
  363. >>> img = torch.rand(1, 3, 4, 4)
  364. >>> scale_factor = torch.tensor([[2., 2.]])
  365. >>> out = scale(img, scale_factor)
  366. >>> print(out.shape)
  367. torch.Size([1, 3, 4, 4])
  368. """
  369. if not isinstance(tensor, Tensor):
  370. raise TypeError(f"Input tensor type is not a Tensor. Got {type(tensor)}")
  371. if not isinstance(scale_factor, Tensor):
  372. raise TypeError(f"Input scale_factor type is not a Tensor. Got {type(scale_factor)}")
  373. if len(scale_factor.shape) == 1:
  374. # convert isotropic scaling to x-y direction
  375. scale_factor = scale_factor.repeat(1, 2)
  376. # compute the tensor center
  377. if center is None:
  378. center = _compute_tensor_center(tensor)
  379. # compute the rotation matrix
  380. # TODO: add broadcasting to get_rotation_matrix2d for center
  381. center = center.expand(tensor.shape[0], -1)
  382. scale_factor = scale_factor.expand(tensor.shape[0], 2)
  383. scaling_matrix: Tensor = _compute_scaling_matrix(scale_factor, center)
  384. # warp using the affine transform
  385. return affine(tensor, scaling_matrix[..., :2, :3], mode, padding_mode, align_corners)
  386. def shear(
  387. tensor: Tensor,
  388. shear: Tensor,
  389. mode: str = "bilinear",
  390. padding_mode: str = "zeros",
  391. align_corners: bool = False,
  392. ) -> Tensor:
  393. r"""Shear the tensor.
  394. .. image:: _static/img/shear.png
  395. Args:
  396. tensor: The image tensor to be skewed with shape of :math:`(B, C, H, W)`.
  397. shear: tensor containing the angle to shear
  398. in the x and y direction. The tensor must have a shape of
  399. (B, 2), where B is batch size, last dimension contains shx shy.
  400. mode: interpolation mode to calculate output values
  401. ``'bilinear'`` | ``'nearest'``.
  402. padding_mode: padding mode for outside grid values
  403. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  404. align_corners: interpolation flag.
  405. Returns:
  406. The skewed tensor with shape same as the input.
  407. Example:
  408. >>> img = torch.rand(1, 3, 4, 4)
  409. >>> shear_factor = torch.tensor([[0.5, 0.0]])
  410. >>> out = shear(img, shear_factor)
  411. >>> print(out.shape)
  412. torch.Size([1, 3, 4, 4])
  413. """
  414. if not isinstance(tensor, Tensor):
  415. raise TypeError(f"Input tensor type is not a Tensor. Got {type(tensor)}")
  416. if not isinstance(shear, Tensor):
  417. raise TypeError(f"Input shear type is not a Tensor. Got {type(shear)}")
  418. if len(tensor.shape) not in (3, 4):
  419. raise ValueError(f"Invalid tensor shape, we expect CxHxW or BxCxHxW. Got: {tensor.shape}")
  420. # compute the translation matrix
  421. shear_matrix: Tensor = _compute_shear_matrix(shear)
  422. # warp using the affine transform
  423. return affine(tensor, shear_matrix[..., :2, :3], mode, padding_mode, align_corners)
  424. def _side_to_image_size(side_size: int, aspect_ratio: float, side: str = "short") -> Tuple[int, int]:
  425. if side not in ("short", "long", "vert", "horz"):
  426. raise ValueError(f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'")
  427. if side == "vert":
  428. return side_size, int(side_size * aspect_ratio)
  429. if side == "horz":
  430. return int(side_size / aspect_ratio), side_size
  431. if (side == "short") ^ (aspect_ratio < 1.0):
  432. return side_size, int(side_size * aspect_ratio)
  433. return int(side_size / aspect_ratio), side_size
  434. @perform_keep_shape_image
  435. def resize(
  436. input: Tensor,
  437. size: Union[int, Tuple[int, int]],
  438. interpolation: str = "bilinear",
  439. align_corners: Optional[bool] = None,
  440. side: str = "short",
  441. antialias: bool = False,
  442. ) -> Tensor:
  443. r"""Resize the input Tensor to the given size.
  444. .. image:: _static/img/resize.png
  445. Args:
  446. input: The image tensor to be skewed with shape of :math:`(..., H, W)`.
  447. `...` means there can be any number of dimensions.
  448. size: Desired output size. If size is a sequence like (h, w),
  449. output size will be matched to this. If size is an int, smaller edge of the image will
  450. be matched to this number. i.e, if height > width, then image will be rescaled
  451. to (size * height / width, size)
  452. interpolation: algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` |
  453. 'bicubic' | 'trilinear' | 'area'.
  454. align_corners: interpolation flag.
  455. side: Corresponding side if ``size`` is an integer. Can be one of ``'short'``, ``'long'``, ``'vert'``,
  456. or ``'horz'``.
  457. antialias: if True, then image will be filtered with Gaussian before downscaling.
  458. No effect for upscaling.
  459. Returns:
  460. The resized tensor with the shape as the specified size.
  461. Example:
  462. >>> img = torch.rand(1, 3, 4, 4)
  463. >>> out = resize(img, (6, 8))
  464. >>> print(out.shape)
  465. torch.Size([1, 3, 6, 8])
  466. """
  467. if not isinstance(input, Tensor):
  468. raise TypeError(f"Input tensor type is not a Tensor. Got {type(input)}")
  469. if len(input.shape) < 2:
  470. raise ValueError(f"Input tensor must have at least two dimensions. Got {len(input.shape)}")
  471. input_size = h, w = input.shape[-2:]
  472. if isinstance(size, int):
  473. if torch.onnx.is_in_onnx_export():
  474. warnings.warn(
  475. "Please pass the size with a tuple when exporting to ONNX to correct the tracing.", stacklevel=1
  476. )
  477. aspect_ratio = w / h
  478. size = _side_to_image_size(size, aspect_ratio, side)
  479. # Skip this dangerous if-else when converting to ONNX.
  480. if not torch.onnx.is_in_onnx_export():
  481. if size == input_size:
  482. return input
  483. factors = (h / size[0], w / size[1])
  484. # We do bluring only for downscaling
  485. antialias = antialias and (max(factors) > 1)
  486. if antialias:
  487. # First, we have to determine sigma
  488. # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
  489. sigmas = (max((factors[0] - 1.0) / 2.0, 0.001), max((factors[1] - 1.0) / 2.0, 0.001))
  490. # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
  491. # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
  492. # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
  493. ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
  494. # Make sure it is odd
  495. if (ks[0] % 2) == 0:
  496. ks = ks[0] + 1, ks[1]
  497. if (ks[1] % 2) == 0:
  498. ks = ks[0], ks[1] + 1
  499. input = gaussian_blur2d(input, ks, sigmas)
  500. output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
  501. return output
  502. def resize_to_be_divisible(
  503. input: Tensor,
  504. divisible_factor: int,
  505. interpolation: str = "bilinear",
  506. align_corners: Optional[bool] = None,
  507. side: str = "short",
  508. antialias: bool = False,
  509. ) -> Tensor:
  510. """Resize the input tensor to be divisible by a certain factor.
  511. Args:
  512. input (Tensor): Input tensor to be resized.
  513. divisible_factor (int): The factor to which the image should be divisible.
  514. interpolation (str, optional): Interpolation flag. Defaults to "bilinear".
  515. align_corners (Optional[bool], optional):
  516. whether to align the corners of the input and output. Defaults to None.
  517. side (str, optional): Side to resize. Defaults to "short".
  518. antialias (bool, optional):
  519. If True, then image will be filtered with Gaussian before downscaling. Defaults to False.
  520. Returns:
  521. Tensor: The resized tensor.
  522. """
  523. if isinstance(input, Tensor) and len(input.shape) == 4:
  524. height, width = input.shape[2], input.shape[3]
  525. if isinstance(input, Tensor) and len(input.shape) == 3:
  526. height, width = input.shape[1], input.shape[2]
  527. height = round(height / divisible_factor) * divisible_factor
  528. width = round(width / divisible_factor) * divisible_factor
  529. return resize(input, (height, width), interpolation, align_corners, side, antialias)
  530. def rescale(
  531. input: Tensor,
  532. factor: Union[float, Tuple[float, float]],
  533. interpolation: str = "bilinear",
  534. align_corners: Optional[bool] = None,
  535. antialias: bool = False,
  536. ) -> Tensor:
  537. r"""Rescale the input Tensor with the given factor.
  538. .. image:: _static/img/rescale.png
  539. Args:
  540. input: The image tensor to be scale with shape of :math:`(B, C, H, W)`.
  541. factor: Desired scaling factor in each direction. If scalar, the value is used
  542. for both the x- and y-direction.
  543. interpolation: algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` |
  544. ``'bicubic'`` | ``'trilinear'`` | ``'area'``.
  545. align_corners: interpolation flag.
  546. side: Corresponding side if ``size`` is an integer. Can be one of ``'short'``, ``'long'``, ``'vert'``,
  547. or ``'horz'``.
  548. antialias: if True, then image will be filtered with Gaussian before downscaling.
  549. No effect for upscaling.
  550. Returns:
  551. The rescaled tensor with the shape as the specified size.
  552. Example:
  553. >>> img = torch.rand(1, 3, 4, 4)
  554. >>> out = rescale(img, (2, 3))
  555. >>> print(out.shape)
  556. torch.Size([1, 3, 8, 12])
  557. """
  558. if isinstance(factor, float):
  559. factor_vert = factor_horz = factor
  560. else:
  561. factor_vert, factor_horz = factor
  562. height, width = input.size()[-2:]
  563. size = (int(height * factor_vert), int(width * factor_horz))
  564. return resize(input, size, interpolation=interpolation, align_corners=align_corners, antialias=antialias)
  565. class Resize(Module):
  566. r"""Resize the input Tensor to the given size.
  567. Args:
  568. size: Desired output size. If size is a sequence like (h, w),
  569. output size will be matched to this. If size is an int, smaller edge of the image will
  570. be matched to this number. i.e, if height > width, then image will be rescaled
  571. to (size * height / width, size)
  572. interpolation: algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` |
  573. 'bicubic' | 'trilinear' | 'area'.
  574. align_corners: interpolation flag.
  575. side: Corresponding side if ``size`` is an integer. Can be one of ``'short'``, ``'long'``, ``'vert'``,
  576. or ``'horz'``.
  577. antialias: if True, then image will be filtered with Gaussian before downscaling.
  578. No effect for upscaling.
  579. Returns:
  580. The resized tensor with the shape of the given size.
  581. Example:
  582. >>> img = torch.rand(1, 3, 4, 4)
  583. >>> out = Resize((6, 8))(img)
  584. >>> print(out.shape)
  585. torch.Size([1, 3, 6, 8])
  586. .. raw:: html
  587. <gradio-app src="kornia/kornia-resize-antialias"></gradio-app>
  588. """
  589. def __init__(
  590. self,
  591. size: Union[int, Tuple[int, int]],
  592. interpolation: str = "bilinear",
  593. align_corners: Optional[bool] = None,
  594. side: str = "short",
  595. antialias: bool = False,
  596. ) -> None:
  597. super().__init__()
  598. self.size: Union[int, Tuple[int, int]] = size
  599. self.interpolation: str = interpolation
  600. self.align_corners: Optional[bool] = align_corners
  601. self.side: str = side
  602. self.antialias: bool = antialias
  603. def forward(self, input: Tensor) -> Tensor:
  604. return resize(
  605. input,
  606. self.size,
  607. self.interpolation,
  608. align_corners=self.align_corners,
  609. side=self.side,
  610. antialias=self.antialias,
  611. )
  612. class Affine(Module):
  613. r"""Apply multiple elementary affine transforms simultaneously.
  614. Args:
  615. angle: Angle in degrees for counter-clockwise rotation around the center. The tensor
  616. must have a shape of (B), where B is the batch size.
  617. translation: Amount of pixels for translation in x- and y-direction. The tensor must
  618. have a shape of (B, 2), where B is the batch size and the last dimension contains dx and dy.
  619. scale_factor: Factor for scaling. The tensor must have a shape of (B), where B is the
  620. batch size.
  621. shear: Angles in degrees for shearing in x- and y-direction around the center. The
  622. tensor must have a shape of (B, 2), where B is the batch size and the last dimension contains sx and sy.
  623. center: Transformation center in pixels. The tensor must have a shape of (B, 2), where
  624. B is the batch size and the last dimension contains cx and cy. Defaults to the center of image to be
  625. transformed.
  626. mode: interpolation mode to calculate output values
  627. ``'bilinear'`` | ``'nearest'``.
  628. padding_mode: padding mode for outside grid values
  629. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  630. align_corners: interpolation flag.
  631. Raises:
  632. RuntimeError: If not one of ``angle``, ``translation``, ``scale_factor``, or ``shear`` is set.
  633. Returns:
  634. The transformed tensor with same shape as input.
  635. Example:
  636. >>> img = torch.rand(1, 2, 3, 5)
  637. >>> angle = 90. * torch.rand(1)
  638. >>> out = Affine(angle)(img)
  639. >>> print(out.shape)
  640. torch.Size([1, 2, 3, 5])
  641. """
  642. def __init__(
  643. self,
  644. angle: Optional[Tensor] = None,
  645. translation: Optional[Tensor] = None,
  646. scale_factor: Optional[Tensor] = None,
  647. shear: Optional[Tensor] = None,
  648. center: Optional[Tensor] = None,
  649. mode: str = "bilinear",
  650. padding_mode: str = "zeros",
  651. align_corners: bool = True,
  652. ) -> None:
  653. batch_sizes = [arg.size()[0] for arg in (angle, translation, scale_factor, shear) if arg is not None]
  654. if not batch_sizes:
  655. msg = (
  656. "Affine was created without any affine parameter. At least one of angle, translation, scale_factor, or "
  657. "shear has to be set."
  658. )
  659. raise RuntimeError(msg)
  660. batch_size = batch_sizes[0]
  661. if not all(other == batch_size for other in batch_sizes[1:]):
  662. raise RuntimeError(f"The batch sizes of the affine parameters mismatch: {batch_sizes}")
  663. self._batch_size = batch_size
  664. super().__init__()
  665. device, dtype = _extract_device_dtype([angle, translation, scale_factor])
  666. if angle is None:
  667. angle = zeros(batch_size, device=device, dtype=dtype)
  668. self.angle = angle
  669. if translation is None:
  670. translation = zeros(batch_size, 2, device=device, dtype=dtype)
  671. self.translation = translation
  672. if scale_factor is None:
  673. scale_factor = ones(batch_size, 2, device=device, dtype=dtype)
  674. self.scale_factor = scale_factor
  675. self.shear = shear
  676. self.center = center
  677. self.mode = mode
  678. self.padding_mode = padding_mode
  679. self.align_corners = align_corners
  680. def forward(self, input: Tensor) -> Tensor:
  681. if self.shear is None:
  682. sx = sy = None
  683. else:
  684. sx, sy = self.shear[..., 0], self.shear[..., 1]
  685. if self.center is None:
  686. center = _compute_tensor_center(input).expand(input.size()[0], -1)
  687. else:
  688. center = self.center
  689. matrix = get_affine_matrix2d(self.translation, center, self.scale_factor, -self.angle, sx=sx, sy=sy)
  690. return affine(input, matrix[..., :2, :3], self.mode, self.padding_mode, self.align_corners)
  691. class Rescale(Module):
  692. r"""Rescale the input Tensor with the given factor.
  693. Args:
  694. factor: Desired scaling factor in each direction. If scalar, the value is used
  695. for both the x- and y-direction.
  696. interpolation: algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` |
  697. ``'bicubic'`` | ``'trilinear'`` | ``'area'``.
  698. align_corners: interpolation flag.
  699. side: Corresponding side if ``size`` is an integer. Can be one of ``'short'``, ``'long'``, ``'vert'``,
  700. or ``'horz'``.
  701. antialias: if True, then image will be filtered with Gaussian before downscaling.
  702. No effect for upscaling.
  703. Returns:
  704. The rescaled tensor with the shape according to the given factor.
  705. Example:
  706. >>> img = torch.rand(1, 3, 4, 4)
  707. >>> out = Rescale((2, 3))(img)
  708. >>> print(out.shape)
  709. torch.Size([1, 3, 8, 12])
  710. """
  711. def __init__(
  712. self,
  713. factor: Union[float, Tuple[float, float]],
  714. interpolation: str = "bilinear",
  715. align_corners: bool = True,
  716. antialias: bool = False,
  717. ) -> None:
  718. super().__init__()
  719. self.factor: Union[float, Tuple[float, float]] = factor
  720. self.interpolation: str = interpolation
  721. self.align_corners: Optional[bool] = align_corners
  722. self.antialias: bool = antialias
  723. def forward(self, input: Tensor) -> Tensor:
  724. return rescale(
  725. input, self.factor, self.interpolation, align_corners=self.align_corners, antialias=self.antialias
  726. )
  727. class Rotate(Module):
  728. r"""Rotate the tensor anti-clockwise about the centre.
  729. Args:
  730. angle: The angle through which to rotate. The tensor
  731. must have a shape of (B), where B is batch size.
  732. center: The center through which to rotate. The tensor
  733. must have a shape of (B, 2), where B is batch size and last
  734. dimension contains cx and cy.
  735. mode: interpolation mode to calculate output values
  736. ``'bilinear'`` | ``'nearest'``.
  737. padding_mode: padding mode for outside grid values
  738. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  739. align_corners: interpolation flag.
  740. Returns:
  741. The rotated tensor with the same shape as the input.
  742. Example:
  743. >>> img = torch.rand(1, 3, 4, 4)
  744. >>> angle = torch.tensor([90.])
  745. >>> out = Rotate(angle)(img)
  746. >>> print(out.shape)
  747. torch.Size([1, 3, 4, 4])
  748. """
  749. def __init__(
  750. self,
  751. angle: Tensor,
  752. center: Union[None, Tensor] = None,
  753. mode: str = "bilinear",
  754. padding_mode: str = "zeros",
  755. align_corners: bool = True,
  756. ) -> None:
  757. super().__init__()
  758. self.angle: Tensor = angle
  759. self.center: Union[None, Tensor] = center
  760. self.mode: str = mode
  761. self.padding_mode: str = padding_mode
  762. self.align_corners: bool = align_corners
  763. def forward(self, input: Tensor) -> Tensor:
  764. return rotate(input, self.angle, self.center, self.mode, self.padding_mode, self.align_corners)
  765. class Translate(Module):
  766. r"""Translate the tensor in pixel units.
  767. Args:
  768. translation: tensor containing the amount of pixels to
  769. translate in the x and y direction. The tensor must have a shape of
  770. (B, 2), where B is batch size, last dimension contains dx dy.
  771. mode: interpolation mode to calculate output values
  772. ``'bilinear'`` | ``'nearest'``.
  773. padding_mode: padding mode for outside grid values
  774. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  775. align_corners: interpolation flag.
  776. Returns:
  777. The translated tensor with the same shape as the input.
  778. Example:
  779. >>> img = torch.rand(1, 3, 4, 4)
  780. >>> translation = torch.tensor([[1., 0.]])
  781. >>> out = Translate(translation)(img)
  782. >>> print(out.shape)
  783. torch.Size([1, 3, 4, 4])
  784. """
  785. def __init__(
  786. self, translation: Tensor, mode: str = "bilinear", padding_mode: str = "zeros", align_corners: bool = True
  787. ) -> None:
  788. super().__init__()
  789. self.translation: Tensor = translation
  790. self.mode: str = mode
  791. self.padding_mode: str = padding_mode
  792. self.align_corners: bool = align_corners
  793. def forward(self, input: Tensor) -> Tensor:
  794. return translate(input, self.translation, self.mode, self.padding_mode, self.align_corners)
  795. class Scale(Module):
  796. r"""Scale the tensor by a factor.
  797. Args:
  798. scale_factor: The scale factor apply. The tensor
  799. must have a shape of (B) or (B, 2), where B is batch size.
  800. If (B), isotropic scaling will perform.
  801. If (B, 2), x-y-direction specific scaling will perform.
  802. center: The center through which to scale. The tensor
  803. must have a shape of (B, 2), where B is batch size and last
  804. dimension contains cx and cy.
  805. mode: interpolation mode to calculate output values
  806. ``'bilinear'`` | ``'nearest'``.
  807. padding_mode: padding mode for outside grid values
  808. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  809. align_corners: interpolation flag.
  810. Returns:
  811. The scaled tensor with the same shape as the input.
  812. Example:
  813. >>> img = torch.rand(1, 3, 4, 4)
  814. >>> scale_factor = torch.tensor([[2., 2.]])
  815. >>> out = Scale(scale_factor)(img)
  816. >>> print(out.shape)
  817. torch.Size([1, 3, 4, 4])
  818. """
  819. def __init__(
  820. self,
  821. scale_factor: Tensor,
  822. center: Union[None, Tensor] = None,
  823. mode: str = "bilinear",
  824. padding_mode: str = "zeros",
  825. align_corners: bool = True,
  826. ) -> None:
  827. super().__init__()
  828. self.scale_factor: Tensor = scale_factor
  829. self.center: Union[None, Tensor] = center
  830. self.mode: str = mode
  831. self.padding_mode: str = padding_mode
  832. self.align_corners: bool = align_corners
  833. def forward(self, input: Tensor) -> Tensor:
  834. return scale(input, self.scale_factor, self.center, self.mode, self.padding_mode, self.align_corners)
  835. class Shear(Module):
  836. r"""Shear the tensor.
  837. Args:
  838. shear: tensor containing the angle to shear
  839. in the x and y direction. The tensor must have a shape of
  840. (B, 2), where B is batch size, last dimension contains shx shy.
  841. mode: interpolation mode to calculate output values
  842. ``'bilinear'`` | ``'nearest'``.
  843. padding_mode: padding mode for outside grid values
  844. ``'zeros'`` | ``'border'`` | ``'reflection'``.
  845. align_corners: interpolation flag.
  846. Returns:
  847. The skewed tensor with the same shape as the input.
  848. Example:
  849. >>> img = torch.rand(1, 3, 4, 4)
  850. >>> shear_factor = torch.tensor([[0.5, 0.0]])
  851. >>> out = Shear(shear_factor)(img)
  852. >>> print(out.shape)
  853. torch.Size([1, 3, 4, 4])
  854. """
  855. def __init__(
  856. self, shear: Tensor, mode: str = "bilinear", padding_mode: str = "zeros", align_corners: bool = True
  857. ) -> None:
  858. super().__init__()
  859. self.shear: Tensor = shear
  860. self.mode: str = mode
  861. self.padding_mode: str = padding_mode
  862. self.align_corners: bool = align_corners
  863. def forward(self, input: Tensor) -> Tensor:
  864. return shear(input, self.shear, self.mode, self.padding_mode, self.align_corners)