ops.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """AutoAugment operation wrapper."""
  18. from kornia.augmentation.auto.operations import (
  19. AutoContrast,
  20. Brightness,
  21. Contrast,
  22. Equalize,
  23. Invert,
  24. OperationBase,
  25. Posterize,
  26. Rotate,
  27. Saturate,
  28. Sharpness,
  29. ShearX,
  30. ShearY,
  31. Solarize,
  32. TranslateX,
  33. TranslateY,
  34. )
  35. from kornia.core import linspace
  36. def shear_x(probability: float, magnitude: int) -> OperationBase:
  37. """Return ShearX op."""
  38. magnitudes = linspace(-0.3, 0.3, 11) * 180.0
  39. return ShearX(
  40. None,
  41. probability,
  42. magnitude_range=(magnitudes[magnitude].item(), magnitudes[magnitude + 1].item()),
  43. symmetric_megnitude=False,
  44. )
  45. def shear_y(probability: float, magnitude: int) -> OperationBase:
  46. """Return ShearY op."""
  47. magnitudes = linspace(-0.3, 0.3, 11) * 180.0
  48. return ShearY(
  49. None,
  50. probability,
  51. magnitude_range=(magnitudes[magnitude].item(), magnitudes[magnitude + 1].item()),
  52. symmetric_megnitude=False,
  53. )
  54. def translate_x(probability: float, magnitude: int) -> OperationBase:
  55. """Return TranslateX op."""
  56. magnitudes = linspace(-0.5, 0.5, 11)
  57. return TranslateX(
  58. None,
  59. probability,
  60. magnitude_range=(magnitudes[magnitude].item(), magnitudes[magnitude + 1].item()),
  61. symmetric_megnitude=False,
  62. )
  63. def translate_y(probability: float, magnitude: int) -> OperationBase:
  64. """Return TranslateY op."""
  65. magnitudes = linspace(-0.5, 0.5, 11)
  66. return TranslateY(
  67. None,
  68. probability,
  69. magnitude_range=(magnitudes[magnitude].item(), magnitudes[magnitude + 1].item()),
  70. symmetric_megnitude=False,
  71. )
  72. def rotate(probability: float, magnitude: int) -> OperationBase:
  73. """Return rotate op."""
  74. magnitudes = linspace(-30, 30, 11)
  75. return Rotate(
  76. None,
  77. probability,
  78. magnitude_range=(magnitudes[magnitude].item(), magnitudes[magnitude + 1].item()),
  79. symmetric_megnitude=False,
  80. )
  81. def auto_contrast(probability: float, _: int) -> OperationBase:
  82. """Return AutoConstrast op."""
  83. return AutoContrast(probability)
  84. def invert(probability: float, _: int) -> OperationBase:
  85. """Return invert op."""
  86. return Invert(probability)
  87. def equalize(probability: float, _: int) -> OperationBase:
  88. """Return equalize op."""
  89. return Equalize(probability)
  90. def solarize(probability: float, magnitude: int) -> OperationBase:
  91. """Return solarize op."""
  92. magnitudes = linspace(0, 255, 11) / 255.0
  93. return Solarize(None, probability, magnitude_range=(magnitudes[magnitude].item(), magnitudes[magnitude + 1].item()))
  94. def posterize(probability: float, magnitude: int) -> OperationBase:
  95. """Return posterize op."""
  96. magnitudes = linspace(4, 8, 11)
  97. return Posterize(
  98. None, probability, magnitude_range=(magnitudes[magnitude].item(), magnitudes[magnitude + 1].item())
  99. )
  100. def contrast(probability: float, magnitude: int) -> OperationBase:
  101. """Return contrast op."""
  102. magnitudes = linspace(0.1, 1.9, 11)
  103. return Contrast(None, probability, magnitude_range=(magnitudes[magnitude].item(), magnitudes[magnitude + 1].item()))
  104. def brightness(probability: float, magnitude: int) -> OperationBase:
  105. """Return brightness op."""
  106. magnitudes = linspace(0.1, 1.9, 11)
  107. return Brightness(
  108. None, probability, magnitude_range=(magnitudes[magnitude].item(), magnitudes[magnitude + 1].item())
  109. )
  110. def sharpness(probability: float, magnitude: int) -> OperationBase:
  111. """Return sharpness op."""
  112. magnitudes = linspace(0.1, 1.9, 11)
  113. return Sharpness(
  114. None, probability, magnitude_range=(magnitudes[magnitude].item(), magnitudes[magnitude + 1].item())
  115. )
  116. def color(probability: float, magnitude: int) -> OperationBase:
  117. """Return color op."""
  118. magnitudes = linspace(0.1, 1.9, 11)
  119. return Saturate(None, probability, magnitude_range=(magnitudes[magnitude].item(), magnitudes[magnitude + 1].item()))