# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar import torch from torch import nn from torch.autograd import Function from torch.distributions import Bernoulli, RelaxedBernoulli from typing_extensions import Self from kornia.augmentation.base import _AugmentationBase from kornia.core import Module, Tensor T = TypeVar("T", bound="OperationBase") class OperationBase(Module): """Base class of differentiable augmentation operations. Args: operation: Kornia augmentation module. initial_magnitude: targeted magnitude parameter name and its initial magnitude value. The magnitude parameter name shall align with the attribute inside the random_generator in each augmentation. If None, the augmentation will be randomly applied according to the augmentation sampling range. temperature: temperature for RelaxedBernoulli distribution used during training. is_batch_operation: determine if to obtain the probability from `p` or `p_batch`. Set to True for most non-shape-persistent operations (e.g. cropping). """ def __init__( self, operation: _AugmentationBase, initial_magnitude: Optional[List[Tuple[str, Optional[float]]]] = None, temperature: float = 0.1, is_batch_operation: bool = False, magnitude_fn: Optional[Callable[[Tensor], Tensor]] = None, gradient_estimator: Optional[Type[Function]] = None, symmetric_megnitude: bool = False, ) -> None: super().__init__() if not isinstance(operation, _AugmentationBase): raise ValueError(f"Only Kornia augmentations supported. Got {operation}.") self.op = operation self._init_magnitude(initial_magnitude) # Avoid skipping the sampling in `__batch_prob_generator__` self.probability_range = (1e-7, 1 - 1e-7) self._is_batch_operation = is_batch_operation if is_batch_operation: self._probability = nn.Parameter(torch.empty(1).fill_(self.op.p_batch)) else: self._probability = nn.Parameter(torch.empty(1).fill_(self.op.p)) if temperature < 0: raise ValueError(f"Expect temperature value greater than 0. Got {temperature}.") self.register_buffer("temperature", torch.empty(1).fill_(temperature)) self.symmetric_megnitude = symmetric_megnitude self._magnitude_fn = self._init_magnitude_fn(magnitude_fn) self._gradient_estimator = gradient_estimator def _init_magnitude_fn(self, magnitude_fn: Optional[Callable[[Tensor], Tensor]]) -> Callable[[Tensor], Tensor]: def _identity(x: Tensor) -> Tensor: return x def _random_flip(fn: Callable[[Tensor], Tensor]) -> Callable[[Tensor], Tensor]: def f(x: Tensor) -> Tensor: flip = torch.rand((x.shape[0],), device=x.device) > 0.5 return fn(x) * flip return f if magnitude_fn is None: magnitude_fn = _identity if self.symmetric_megnitude: return _random_flip(magnitude_fn) return magnitude_fn def _init_magnitude(self, initial_magnitude: Optional[List[Tuple[str, Optional[float]]]]) -> None: if isinstance(initial_magnitude, (list, tuple)): if not all(isinstance(ini_mag, (list, tuple)) and len(ini_mag) == 2 for ini_mag in initial_magnitude): raise ValueError(f"`initial_magnitude` shall be a list of 2-element tuples. Got {initial_magnitude}") if len(initial_magnitude) != 1: raise NotImplementedError("Multi magnitudes operations are not yet supported.") if initial_magnitude is None: self._factor_name = None self._magnitude = None self.magnitude_range = None else: self._factor_name = initial_magnitude[0][0] if self.op._param_generator is not None: self.magnitude_range = getattr(self.op._param_generator, self._factor_name) else: raise ValueError(f"No valid magnitude `{self._factor_name}` found in `{self.op._param_generator}`.") self._magnitude = None if initial_magnitude[0][1] is not None: self._magnitude = nn.Parameter(torch.empty(1).fill_(initial_magnitude[0][1])) def _update_probability_gen(self, relaxation: bool) -> None: if relaxation: if self._is_batch_operation: self.op._p_batch_gen = RelaxedBernoulli(self.temperature, self.probability) else: self.op._p_gen = RelaxedBernoulli(self.temperature, self.probability) elif self._is_batch_operation: self.op._p_batch_gen = Bernoulli(self.probability) else: self.op._p_gen = Bernoulli(self.probability) def train(self, mode: bool = True) -> Self: self._update_probability_gen(relaxation=mode) return super().train(mode=mode) def eval(self) -> Self: return self.train(False) def forward_parameters(self, batch_shape: torch.Size, mag: Optional[Tensor] = None) -> Dict[str, Tensor]: if mag is None: mag = self.magnitude # Need to setup the sampler again for each update. # Otherwise, an error for updating the same graph twice will be thrown. self._update_probability_gen(relaxation=True) params = self.op.forward_parameters(batch_shape) if mag is not None: if self._factor_name is None: raise RuntimeError("No factor found in the params while `mag` is provided.") # For single factor operations, this is equivalent to `same_on_batch=True` params[self._factor_name] = params[self._factor_name].zero_() + mag if self._factor_name is not None: params[self._factor_name] = self._magnitude_fn(params[self._factor_name]) return params def forward(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None) -> Tensor: if params is None: params = self.forward_parameters(input.shape) batch_prob = params["batch_prob"][(...,) + ((None,) * (len(input.shape) - 1))].to(device=input.device) if self._gradient_estimator is not None: # skip the gradient computation if gradient estimator is provided. with torch.no_grad(): output = self.op(input, params=params) output = batch_prob * output + (1 - batch_prob) * input if self.magnitude is None: # If magnitude is None, make the grad w.r.t the input return self._gradient_estimator.apply(input, output) # If magnitude is not None, make the grad w.r.t the magnitude return self._gradient_estimator.apply(self.magnitude, output) return batch_prob * self.op(input, params=params) + (1 - batch_prob) * input @property def transform_matrix(self) -> Optional[Tensor]: if hasattr(self.op, "transform_matrix"): return self.op.transform_matrix return None @property def magnitude(self) -> Optional[Tensor]: if self._magnitude is None: return None mag = self._magnitude if self.magnitude_range is not None: return mag.clamp(*self.magnitude_range) return mag @property def probability(self) -> Tensor: p = self._probability.clamp(*self.probability_range) return p