autoaugment.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Iterator, List, Optional, Tuple, Union
  18. from torch.distributions import Categorical
  19. from kornia.augmentation.auto.base import SUBPOLICY_CONFIG, PolicyAugmentBase
  20. from kornia.augmentation.auto.operations.policy import PolicySequential
  21. from kornia.augmentation.container.params import ParamItem
  22. from kornia.core import Module, tensor
  23. from . import ops
  24. imagenet_policy: List[SUBPOLICY_CONFIG] = [
  25. [("posterize", 0.4, 8), ("rotate", 0.6, 9)],
  26. [("solarize", 0.6, 5), ("auto_contrast", 0.6, None)],
  27. [("equalize", 0.8, None), ("equalize", 0.6, None)],
  28. [("posterize", 0.6, 7), ("posterize", 0.6, 6)],
  29. [("equalize", 0.4, None), ("solarize", 0.2, 4)],
  30. [("equalize", 0.4, None), ("rotate", 0.8, 8)],
  31. [("solarize", 0.6, 3), ("equalize", 0.6, None)],
  32. [("posterize", 0.8, 5), ("equalize", 1.0, None)],
  33. [("rotate", 0.2, 3), ("solarize", 0.6, 8)],
  34. [("equalize", 0.6, None), ("posterize", 0.4, 6)],
  35. [("rotate", 0.8, 8), ("color", 0.4, 0)],
  36. [("rotate", 0.4, 9), ("equalize", 0.6, None)],
  37. [("equalize", 0.0, None), ("equalize", 0.8, None)],
  38. [("invert", 0.6, None), ("equalize", 1.0, None)],
  39. [("color", 0.6, 4), ("contrast", 1.0, 8)],
  40. [("rotate", 0.8, 8), ("color", 1.0, 2)],
  41. [("color", 0.8, 8), ("solarize", 0.8, 7)],
  42. [("sharpness", 0.4, 7), ("invert", 0.6, None)],
  43. [("shear_x", 0.6, 5), ("equalize", 1.0, None)],
  44. [("color", 0.4, 0), ("equalize", 0.6, None)],
  45. [("equalize", 0.4, None), ("solarize", 0.2, 4)],
  46. [("solarize", 0.6, 5), ("auto_contrast", 0.6, None)],
  47. [("invert", 0.6, None), ("equalize", 1.0, None)],
  48. [("color", 0.6, 4), ("contrast", 1.0, 8)],
  49. [("equalize", 0.8, None), ("equalize", 0.6, None)],
  50. ]
  51. cifar10_policy: List[SUBPOLICY_CONFIG] = [
  52. [("invert", 0.1, None), ("contrast", 0.2, 6)],
  53. [("rotate", 0.7, 2), ("translate_x", 0.3, 9)],
  54. [("sharpness", 0.8, 1), ("sharpness", 0.9, 3)],
  55. [("shear_y", 0.5, 8), ("translate_y", 0.7, 9)],
  56. [("auto_contrast", 0.5, None), ("equalize", 0.9, None)],
  57. [("shear_y", 0.2, 7), ("posterize", 0.3, 7)],
  58. [("color", 0.4, 3), ("brightness", 0.6, 7)],
  59. [("sharpness", 0.3, 9), ("brightness", 0.7, 9)],
  60. [("equalize", 0.6, None), ("equalize", 0.5, None)],
  61. [("contrast", 0.6, 7), ("sharpness", 0.6, 5)],
  62. [("color", 0.7, 7), ("translate_x", 0.5, 8)],
  63. [("equalize", 0.3, None), ("auto_contrast", 0.4, None)],
  64. [("translate_y", 0.4, 3), ("sharpness", 0.2, 6)],
  65. [("brightness", 0.9, 6), ("color", 0.2, 8)],
  66. [("solarize", 0.5, 2), ("invert", 0.0, None)],
  67. [("equalize", 0.2, None), ("auto_contrast", 0.6, None)],
  68. [("equalize", 0.2, None), ("equalize", 0.6, None)],
  69. [("color", 0.9, 9), ("equalize", 0.6, None)],
  70. [("auto_contrast", 0.8, None), ("solarize", 0.2, 8)],
  71. [("brightness", 0.1, 3), ("color", 0.7, 0)],
  72. [("solarize", 0.4, 5), ("auto_contrast", 0.9, None)],
  73. [("translate_y", 0.9, 9), ("translate_y", 0.7, 9)],
  74. [("auto_contrast", 0.9, None), ("solarize", 0.8, 3)],
  75. [("equalize", 0.8, None), ("invert", 0.1, None)],
  76. [("translate_y", 0.7, 9), ("auto_contrast", 0.9, None)],
  77. ]
  78. svhn_policy: List[SUBPOLICY_CONFIG] = [
  79. [("shear_x", 0.9, 4), ("invert", 0.2, None)],
  80. [("shear_y", 0.9, 8), ("invert", 0.7, None)],
  81. [("equalize", 0.6, None), ("solarize", 0.6, 6)],
  82. [("invert", 0.9, None), ("equalize", 0.6, None)],
  83. [("equalize", 0.6, None), ("rotate", 0.9, 3)],
  84. [("shear_x", 0.9, 4), ("auto_contrast", 0.8, None)],
  85. [("shear_y", 0.9, 8), ("invert", 0.4, None)],
  86. [("shear_y", 0.9, 5), ("solarize", 0.2, 6)],
  87. [("invert", 0.9, None), ("auto_contrast", 0.8, None)],
  88. [("equalize", 0.6, None), ("rotate", 0.9, 3)],
  89. [("shear_x", 0.9, 4), ("solarize", 0.3, 3)],
  90. [("shear_y", 0.8, 8), ("invert", 0.7, None)],
  91. [("equalize", 0.9, None), ("translate_y", 0.6, 6)],
  92. [("invert", 0.9, None), ("equalize", 0.6, None)],
  93. [("contrast", 0.3, 3), ("rotate", 0.8, 4)],
  94. [("invert", 0.8, None), ("translate_y", 0.0, 2)],
  95. [("shear_y", 0.7, 6), ("solarize", 0.4, 8)],
  96. [("invert", 0.6, None), ("rotate", 0.8, 4)],
  97. [("shear_y", 0.3, 7), ("translate_x", 0.9, 3)],
  98. [("shear_x", 0.1, 6), ("invert", 0.6, None)],
  99. [("solarize", 0.7, 2), ("translate_y", 0.6, 7)],
  100. [("shear_y", 0.8, 4), ("invert", 0.8, None)],
  101. [("shear_x", 0.7, 9), ("translate_y", 0.8, 3)],
  102. [("shear_y", 0.8, 5), ("auto_contrast", 0.7, None)],
  103. [("shear_x", 0.7, 2), ("invert", 0.1, None)],
  104. ]
  105. class AutoAugment(PolicyAugmentBase):
  106. """Apply AutoAugment :cite:`cubuk2018autoaugment` searched strategies.
  107. Args:
  108. policy: a customized policy config or presets of "imagenet", "cifar10", and "svhn".
  109. transformation_matrix_mode: computation mode for the chained transformation matrix, via `.transform_matrix`
  110. attribute.
  111. If `silent`, transformation matrix will be computed silently and the non-rigid
  112. modules will be ignored as identity transformations.
  113. If `rigid`, transformation matrix will be computed silently and the non-rigid
  114. modules will trigger errors.
  115. If `skip`, transformation matrix will be totally ignored.
  116. Examples:
  117. >>> import torch
  118. >>> import kornia.augmentation as K
  119. >>> in_tensor = torch.rand(5, 3, 30, 30)
  120. >>> aug = K.AugmentationSequential(AutoAugment())
  121. >>> aug(in_tensor).shape
  122. torch.Size([5, 3, 30, 30])
  123. """
  124. def __init__(
  125. self, policy: Union[str, List[SUBPOLICY_CONFIG]] = "imagenet", transformation_matrix_mode: str = "silent"
  126. ) -> None:
  127. if policy == "imagenet":
  128. _policy = imagenet_policy
  129. elif policy == "cifar10":
  130. _policy = cifar10_policy
  131. elif policy == "svhn":
  132. _policy = svhn_policy
  133. elif isinstance(policy, (list, tuple)):
  134. _policy = policy
  135. else:
  136. raise NotImplementedError(f"Invalid policy `{policy}`.")
  137. super().__init__(_policy, transformation_matrix_mode=transformation_matrix_mode)
  138. selection_weights = tensor([1.0 / len(self)] * len(self))
  139. self.rand_selector = Categorical(selection_weights)
  140. def compose_subpolicy_sequential(self, subpolicy: SUBPOLICY_CONFIG) -> PolicySequential:
  141. return PolicySequential(*[getattr(ops, name)(prob, mag) for name, prob, mag in subpolicy])
  142. def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, Module]]:
  143. if params is None:
  144. idx = self.rand_selector.sample((1,))
  145. return self.get_children_by_indices(idx)
  146. return self.get_children_by_params(params)