base.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar
  18. import torch
  19. from torch import nn
  20. from torch.autograd import Function
  21. from torch.distributions import Bernoulli, RelaxedBernoulli
  22. from typing_extensions import Self
  23. from kornia.augmentation.base import _AugmentationBase
  24. from kornia.core import Module, Tensor
  25. T = TypeVar("T", bound="OperationBase")
  26. class OperationBase(Module):
  27. """Base class of differentiable augmentation operations.
  28. Args:
  29. operation: Kornia augmentation module.
  30. initial_magnitude: targeted magnitude parameter name and its initial magnitude value.
  31. The magnitude parameter name shall align with the attribute inside the random_generator
  32. in each augmentation. If None, the augmentation will be randomly applied according to
  33. the augmentation sampling range.
  34. temperature: temperature for RelaxedBernoulli distribution used during training.
  35. is_batch_operation: determine if to obtain the probability from `p` or `p_batch`.
  36. Set to True for most non-shape-persistent operations (e.g. cropping).
  37. """
  38. def __init__(
  39. self,
  40. operation: _AugmentationBase,
  41. initial_magnitude: Optional[List[Tuple[str, Optional[float]]]] = None,
  42. temperature: float = 0.1,
  43. is_batch_operation: bool = False,
  44. magnitude_fn: Optional[Callable[[Tensor], Tensor]] = None,
  45. gradient_estimator: Optional[Type[Function]] = None,
  46. symmetric_megnitude: bool = False,
  47. ) -> None:
  48. super().__init__()
  49. if not isinstance(operation, _AugmentationBase):
  50. raise ValueError(f"Only Kornia augmentations supported. Got {operation}.")
  51. self.op = operation
  52. self._init_magnitude(initial_magnitude)
  53. # Avoid skipping the sampling in `__batch_prob_generator__`
  54. self.probability_range = (1e-7, 1 - 1e-7)
  55. self._is_batch_operation = is_batch_operation
  56. if is_batch_operation:
  57. self._probability = nn.Parameter(torch.empty(1).fill_(self.op.p_batch))
  58. else:
  59. self._probability = nn.Parameter(torch.empty(1).fill_(self.op.p))
  60. if temperature < 0:
  61. raise ValueError(f"Expect temperature value greater than 0. Got {temperature}.")
  62. self.register_buffer("temperature", torch.empty(1).fill_(temperature))
  63. self.symmetric_megnitude = symmetric_megnitude
  64. self._magnitude_fn = self._init_magnitude_fn(magnitude_fn)
  65. self._gradient_estimator = gradient_estimator
  66. def _init_magnitude_fn(self, magnitude_fn: Optional[Callable[[Tensor], Tensor]]) -> Callable[[Tensor], Tensor]:
  67. def _identity(x: Tensor) -> Tensor:
  68. return x
  69. def _random_flip(fn: Callable[[Tensor], Tensor]) -> Callable[[Tensor], Tensor]:
  70. def f(x: Tensor) -> Tensor:
  71. flip = torch.rand((x.shape[0],), device=x.device) > 0.5
  72. return fn(x) * flip
  73. return f
  74. if magnitude_fn is None:
  75. magnitude_fn = _identity
  76. if self.symmetric_megnitude:
  77. return _random_flip(magnitude_fn)
  78. return magnitude_fn
  79. def _init_magnitude(self, initial_magnitude: Optional[List[Tuple[str, Optional[float]]]]) -> None:
  80. if isinstance(initial_magnitude, (list, tuple)):
  81. if not all(isinstance(ini_mag, (list, tuple)) and len(ini_mag) == 2 for ini_mag in initial_magnitude):
  82. raise ValueError(f"`initial_magnitude` shall be a list of 2-element tuples. Got {initial_magnitude}")
  83. if len(initial_magnitude) != 1:
  84. raise NotImplementedError("Multi magnitudes operations are not yet supported.")
  85. if initial_magnitude is None:
  86. self._factor_name = None
  87. self._magnitude = None
  88. self.magnitude_range = None
  89. else:
  90. self._factor_name = initial_magnitude[0][0]
  91. if self.op._param_generator is not None:
  92. self.magnitude_range = getattr(self.op._param_generator, self._factor_name)
  93. else:
  94. raise ValueError(f"No valid magnitude `{self._factor_name}` found in `{self.op._param_generator}`.")
  95. self._magnitude = None
  96. if initial_magnitude[0][1] is not None:
  97. self._magnitude = nn.Parameter(torch.empty(1).fill_(initial_magnitude[0][1]))
  98. def _update_probability_gen(self, relaxation: bool) -> None:
  99. if relaxation:
  100. if self._is_batch_operation:
  101. self.op._p_batch_gen = RelaxedBernoulli(self.temperature, self.probability)
  102. else:
  103. self.op._p_gen = RelaxedBernoulli(self.temperature, self.probability)
  104. elif self._is_batch_operation:
  105. self.op._p_batch_gen = Bernoulli(self.probability)
  106. else:
  107. self.op._p_gen = Bernoulli(self.probability)
  108. def train(self, mode: bool = True) -> Self:
  109. self._update_probability_gen(relaxation=mode)
  110. return super().train(mode=mode)
  111. def eval(self) -> Self:
  112. return self.train(False)
  113. def forward_parameters(self, batch_shape: torch.Size, mag: Optional[Tensor] = None) -> Dict[str, Tensor]:
  114. if mag is None:
  115. mag = self.magnitude
  116. # Need to setup the sampler again for each update.
  117. # Otherwise, an error for updating the same graph twice will be thrown.
  118. self._update_probability_gen(relaxation=True)
  119. params = self.op.forward_parameters(batch_shape)
  120. if mag is not None:
  121. if self._factor_name is None:
  122. raise RuntimeError("No factor found in the params while `mag` is provided.")
  123. # For single factor operations, this is equivalent to `same_on_batch=True`
  124. params[self._factor_name] = params[self._factor_name].zero_() + mag
  125. if self._factor_name is not None:
  126. params[self._factor_name] = self._magnitude_fn(params[self._factor_name])
  127. return params
  128. def forward(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None) -> Tensor:
  129. if params is None:
  130. params = self.forward_parameters(input.shape)
  131. batch_prob = params["batch_prob"][(...,) + ((None,) * (len(input.shape) - 1))].to(device=input.device)
  132. if self._gradient_estimator is not None:
  133. # skip the gradient computation if gradient estimator is provided.
  134. with torch.no_grad():
  135. output = self.op(input, params=params)
  136. output = batch_prob * output + (1 - batch_prob) * input
  137. if self.magnitude is None:
  138. # If magnitude is None, make the grad w.r.t the input
  139. return self._gradient_estimator.apply(input, output)
  140. # If magnitude is not None, make the grad w.r.t the magnitude
  141. return self._gradient_estimator.apply(self.magnitude, output)
  142. return batch_prob * self.op(input, params=params) + (1 - batch_prob) * input
  143. @property
  144. def transform_matrix(self) -> Optional[Tensor]:
  145. if hasattr(self.op, "transform_matrix"):
  146. return self.op.transform_matrix
  147. return None
  148. @property
  149. def magnitude(self) -> Optional[Tensor]:
  150. if self._magnitude is None:
  151. return None
  152. mag = self._magnitude
  153. if self.magnitude_range is not None:
  154. return mag.clamp(*self.magnitude_range)
  155. return mag
  156. @property
  157. def probability(self) -> Tensor:
  158. p = self._probability.clamp(*self.probability_range)
  159. return p