ada.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import warnings
  18. from typing import Any, Dict, List, Optional, Tuple, Union, cast
  19. import torch
  20. from ...core import Device, Tensor # noqa: TID252
  21. from .. import ( # noqa: TID252
  22. AugmentationSequential,
  23. ColorJitter,
  24. ImageSequential,
  25. RandomAffine,
  26. RandomErasing,
  27. RandomGaussianNoise,
  28. RandomHorizontalFlip,
  29. RandomRotation90,
  30. )
  31. from ..base import _AugmentationBase # noqa: TID252
  32. from ..container.params import ParamItem # noqa: TID252
  33. _data_keys_type = List[str]
  34. _inputs_type = Union[Tensor, Dict[str, Tensor]]
  35. class AdaptiveDiscriminatorAugmentation(AugmentationSequential):
  36. r"""Implementation of Adaptive Discriminator Augmentation for GANs training as introduced in :cite:`Karras2020ada`.
  37. adjust a global probability p over all augmentations list to select a subset of images to augment
  38. based on an exponential moving average of the Discriminator's accuracy labeling real samples.
  39. Args:
  40. *args: a list of kornia augmentation modules, set to a default list if not specified.
  41. initial_p: initial global probability `p` for applying the augmentations
  42. adjustment_speed: float
  43. step size for updating the global probability `p`
  44. max_p: maximum allowed value for `p`
  45. target_real_acc: target Discriminator accuracy to guide `p` adjustments
  46. ema_lambda: EMA smoothing factor. The real accuracy EMA is what's used to determine the `p` update
  47. update_every: `p` update frequency (in steps)
  48. erasing_scale: scale range used for `RandomErasing` if default augmentations are used
  49. erasing_ratio: aspect ratio range used for `RandomErasing` if default augmentations are used
  50. erasing_fill_value: fill value used in `RandomErasing`
  51. same_on_batch: apply the same transformation across the batch
  52. data_keys: input types to apply augmentations on
  53. **kwargs: Additional keyword arguments passed to `AugmentationSequential`
  54. Examples:
  55. >>> from kornia.augmentation.presets.ada import AdaptiveDiscriminatorAugmentation
  56. >>> original = torch.randn(2, 3, 16, 16)
  57. >>> ada = AdaptiveDiscriminatorAugmentation()
  58. >>> augmented = ada(original)
  59. This example demonstrates using default augmentations with AdaptiveDiscriminatorAugmentation in a GAN training loop.
  60. >>> import kornia.augmentation as K
  61. >>> from kornia.augmentation.presets.ada import AdaptiveDiscriminatorAugmentation
  62. >>> originals = torch.randn(2, 3, 5, 6)
  63. >>> aug_list = [
  64. ... K.RandomRotation90(times=(0, 3), p=1),
  65. ... K.RandomAffine(degrees=10, translate=(.1, .1), scale=(.9, 1.1), p=1),
  66. ... K.ColorJitter(brightness=.2, contrast=.2, saturation=.2, hue=.1, p=1),
  67. ... ]
  68. >>> ada = AdaptiveDiscriminatorAugmentation(*aug_list)
  69. >>> augmented = ada(original)
  70. This example demonstrates using custom augmentations with AdaptiveDiscriminatorAugmentation.
  71. """
  72. def __init__(
  73. self,
  74. *args: Union[_AugmentationBase, ImageSequential],
  75. initial_p: float = 1e-5,
  76. adjustment_speed: float = 1e-2,
  77. max_p: float = 0.8,
  78. target_real_acc: float = 0.85,
  79. ema_lambda: float = 0.99,
  80. update_every: int = 5,
  81. erasing_scale: Union[Tensor, Tuple[float, float]] = (0.02, 0.33),
  82. erasing_ratio: Union[Tensor, Tuple[float, float]] = (0.3, 3.3),
  83. erasing_fill_value: float = 0.0,
  84. data_keys: Optional[_data_keys_type] = None,
  85. same_on_batch: Optional[bool] = False,
  86. **kwargs: Any,
  87. ) -> None:
  88. if not args:
  89. args = self.default_ada_transfroms(erasing_scale, erasing_ratio, erasing_fill_value)
  90. super().__init__(
  91. *args,
  92. data_keys=data_keys
  93. if data_keys is not None
  94. else [
  95. "input",
  96. ],
  97. same_on_batch=same_on_batch,
  98. **kwargs,
  99. )
  100. if adjustment_speed <= 0:
  101. raise ValueError(f"Invalid `adjustment_speed` ({adjustment_speed}) — must be greater than 0")
  102. if not 0 <= target_real_acc <= 1:
  103. raise ValueError(f"Invalid `target_real_acc` ({target_real_acc}) — must be in [0, 1]")
  104. if not 0 <= ema_lambda <= 1:
  105. raise ValueError(f"Invalid `ema_lambda` ({ema_lambda}) — must be in [0, 1]")
  106. if update_every < 1:
  107. raise ValueError(f"Invalid `update_every` ({update_every}) — must be at least 1")
  108. if not 0 <= max_p <= 1:
  109. raise ValueError(f"Invalid `max_p` ({max_p}) — must be in [0, 1]")
  110. if not 0 <= initial_p <= 1:
  111. raise ValueError(f"Invalid `initial_p` ({initial_p}) — must be in [0, 1]")
  112. if initial_p > max_p:
  113. warnings.warn(
  114. f"`initial_p` ({initial_p}) is greater than `max_p` ({max_p}), resetting `initial_p` to `max_p`",
  115. stacklevel=2,
  116. )
  117. initial_p = max_p
  118. self.p = initial_p
  119. self.adjustment_speed = adjustment_speed
  120. self.max_p = max_p
  121. self.target_real_acc = target_real_acc
  122. self.ema_lambda = ema_lambda
  123. self.update_every = update_every
  124. self.real_acc_ema: float = 0.5
  125. self._num_calls = 0 # -update_every # to avoid updating in the first `update_every` steps
  126. def default_ada_transfroms(
  127. self, scale: Union[Tensor, Tuple[float, float]], ratio: Union[Tensor, Tuple[float, float]], value: float
  128. ) -> Tuple[Union[_AugmentationBase, ImageSequential], ...]:
  129. # if changed in the future, please change the expected transforms list in test_presets.py
  130. return (
  131. RandomHorizontalFlip(p=1),
  132. RandomRotation90(times=(0, 3), p=1.0),
  133. RandomErasing(scale=scale, ratio=ratio, value=value, p=0.9),
  134. RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1), p=1.0),
  135. ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=1.0),
  136. RandomGaussianNoise(std=0.1, p=1.0),
  137. )
  138. def update(self, real_acc: float) -> None:
  139. r"""Updates internal params `p` once every `update_every` calls based on discriminator accuracy.
  140. the update is based on an exponential moving average of `real_acc`
  141. `p` is updated by adding or subtracting `adjustment_speed` from it and clamp it at [0, `max_p`]
  142. Args:
  143. real_acc: the Discriminator's accuracy labeling real samples.
  144. """
  145. self._num_calls += 1
  146. if self._num_calls < self.update_every:
  147. return
  148. self._num_calls = 0
  149. self.real_acc_ema = self.ema_lambda * self.real_acc_ema + (1 - self.ema_lambda) * real_acc
  150. if self.real_acc_ema < self.target_real_acc:
  151. self.p = max(0, self.p - self.adjustment_speed)
  152. else:
  153. self.p = min(self.p + self.adjustment_speed, self.max_p)
  154. def _get_inputs_metadata(self, inputs: _inputs_type, data_keys: _data_keys_type) -> Tuple[int, Device]:
  155. if isinstance(inputs, dict):
  156. key = data_keys[0]
  157. batch_size = inputs[key].size(0)
  158. device = inputs[key].device
  159. else:
  160. batch_size = inputs.size(0)
  161. device = inputs.device
  162. return batch_size, device
  163. def _sample_inputs(self, inputs: _inputs_type, data_keys: _data_keys_type, p_tensor: Tensor) -> _inputs_type:
  164. if isinstance(inputs, dict):
  165. return {key: inputs[key][p_tensor] for key in data_keys}
  166. else:
  167. return inputs[p_tensor]
  168. def _merge_inputs(
  169. self,
  170. original: _inputs_type,
  171. augmented: _inputs_type,
  172. p_tensor: Tensor,
  173. ) -> _inputs_type:
  174. merged: _inputs_type
  175. if isinstance(original, dict) and isinstance(augmented, dict):
  176. merged = {}
  177. for key in original.keys():
  178. merged_tensor = original[key].clone()
  179. merged_tensor[p_tensor] = augmented[key]
  180. merged[key] = merged_tensor
  181. elif isinstance(original, Tensor) and isinstance(augmented, Tensor):
  182. merged = original.clone()
  183. merged[p_tensor] = augmented
  184. else:
  185. raise TypeError(
  186. f"original inputs and augmented inputs aren't of the same type "
  187. f"(type({type(original)}), type({type(augmented)}))"
  188. )
  189. return merged
  190. def forward( # type: ignore[override]
  191. self,
  192. inputs: _inputs_type,
  193. params: Optional[List[ParamItem]] = None,
  194. data_keys: Optional[_data_keys_type] = None,
  195. real_acc: Optional[float] = None,
  196. ) -> _inputs_type:
  197. r"""Apply augmentations to a subset of input tensors with global probability `p`.
  198. This method applies the augmentation pipeline to a subset of input samples, randomly selected
  199. via a Bernoulli distribution with probability `p`
  200. if `real_acc` is provided, the internal probability `p` is updated via the `update` method.
  201. Non-augmented samples retain their original values, and the output matches the input structure.
  202. `real_acc` is the Discriminator's accuracy on real images; for example,
  203. `(real_logits > 0).float().mean().item()` if using logits andn assuming real labels are positive.
  204. """
  205. if real_acc is not None:
  206. self.update(real_acc)
  207. if self.p == 0:
  208. return inputs
  209. if data_keys is None:
  210. data_keys = (
  211. [k.name for k in self.data_keys]
  212. if self.data_keys is not None
  213. else [
  214. "input",
  215. ]
  216. )
  217. batch_size, device = self._get_inputs_metadata(inputs, data_keys=data_keys)
  218. p_tensor = torch.bernoulli(torch.full((batch_size,), self.p, dtype=torch.float32, device=device)).bool()
  219. if not p_tensor.any():
  220. return inputs
  221. selected_inputs: _inputs_type = self._sample_inputs(inputs, data_keys=data_keys, p_tensor=p_tensor)
  222. augmented_inputs = cast(
  223. _inputs_type,
  224. super().forward(
  225. selected_inputs, # type: ignore[arg-type]
  226. params=params,
  227. data_keys=data_keys,
  228. ),
  229. )
  230. return self._merge_inputs(inputs, augmented_inputs, p_tensor)