transforms.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from typing import Dict
  2. import numpy as np
  3. import torch
  4. import kornia.augmentation as K
  5. from kornia.geometry.transform import warp_perspective
  6. # Adapted from Kornia
  7. class GeometricSequential:
  8. def __init__(self, *transforms, align_corners=True) -> None:
  9. self.transforms = transforms
  10. self.align_corners = align_corners
  11. def __call__(self, x, mode="bilinear"):
  12. b, c, h, w = x.shape
  13. M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
  14. for t in self.transforms:
  15. if np.random.rand() < t.p:
  16. M = M.matmul(
  17. t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None)
  18. )
  19. return (
  20. warp_perspective(
  21. x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
  22. ),
  23. M,
  24. )
  25. def apply_transform(self, x, M, mode="bilinear"):
  26. b, c, h, w = x.shape
  27. return warp_perspective(
  28. x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
  29. )
  30. class RandomPerspective(K.RandomPerspective):
  31. def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
  32. distortion_scale = torch.as_tensor(
  33. self.distortion_scale, device=self._device, dtype=self._dtype
  34. )
  35. return self.random_perspective_generator(
  36. batch_shape[0],
  37. batch_shape[-2],
  38. batch_shape[-1],
  39. distortion_scale,
  40. self.same_on_batch,
  41. self.device,
  42. self.dtype,
  43. )
  44. def random_perspective_generator(
  45. self,
  46. batch_size: int,
  47. height: int,
  48. width: int,
  49. distortion_scale: torch.Tensor,
  50. same_on_batch: bool = False,
  51. device: torch.device = torch.device("cpu"),
  52. dtype: torch.dtype = torch.float32,
  53. ) -> Dict[str, torch.Tensor]:
  54. r"""Get parameters for ``perspective`` for a random perspective transform.
  55. Args:
  56. batch_size (int): the tensor batch size.
  57. height (int) : height of the image.
  58. width (int): width of the image.
  59. distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
  60. same_on_batch (bool): apply the same transformation across the batch. Default: False.
  61. device (torch.device): the device on which the random numbers will be generated. Default: cpu.
  62. dtype (torch.dtype): the data type of the generated random numbers. Default: float32.
  63. Returns:
  64. params Dict[str, torch.Tensor]: parameters to be passed for transformation.
  65. - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
  66. - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).
  67. Note:
  68. The generated random numbers are not reproducible across different devices and dtypes.
  69. """
  70. if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
  71. raise AssertionError(
  72. f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
  73. )
  74. if not (
  75. type(height) is int and height > 0 and type(width) is int and width > 0
  76. ):
  77. raise AssertionError(
  78. f"'height' and 'width' must be integers. Got {height}, {width}."
  79. )
  80. start_points: torch.Tensor = torch.tensor(
  81. [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
  82. device=distortion_scale.device,
  83. dtype=distortion_scale.dtype,
  84. ).expand(batch_size, -1, -1)
  85. # generate random offset not larger than half of the image
  86. fx = distortion_scale * width / 2
  87. fy = distortion_scale * height / 2
  88. factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
  89. offset = (torch.rand_like(start_points) - 0.5) * 2
  90. end_points = start_points + factor * offset
  91. return dict(start_points=start_points, end_points=end_points)
  92. class RandomErasing:
  93. def __init__(self, p = 0., scale = 0.) -> None:
  94. self.p = p
  95. self.scale = scale
  96. self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p)
  97. def __call__(self, image, depth):
  98. if self.p > 0:
  99. image = self.random_eraser(image)
  100. depth = self.random_eraser(depth, params=self.random_eraser._params)
  101. return image, depth