dispatcher.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import List, Tuple, Union
  18. from torch import Tensor, nn
  19. from .augment import AugmentationSequential
  20. class ManyToManyAugmentationDispather(nn.Module):
  21. r"""Dispatches different augmentations to different inputs element-wisely.
  22. Args:
  23. augmentations: a list or a sequence of kornia AugmentationSequential modules.
  24. Examples:
  25. >>> import torch
  26. >>> input_1, input_2 = torch.randn(2, 3, 5, 6), torch.randn(2, 3, 5, 6)
  27. >>> mask_1, mask_2 = torch.ones(2, 3, 5, 6), torch.ones(2, 3, 5, 6)
  28. >>> aug_list = ManyToManyAugmentationDispather(
  29. ... AugmentationSequential(
  30. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  31. ... kornia.augmentation.RandomAffine(360, p=1.0),
  32. ... data_keys=["input", "mask",],
  33. ... ),
  34. ... AugmentationSequential(
  35. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  36. ... kornia.augmentation.RandomAffine(360, p=1.0),
  37. ... data_keys=["input", "mask",],
  38. ... )
  39. ... )
  40. >>> output = aug_list((input_1, mask_1), (input_2, mask_2))
  41. """
  42. def __init__(self, *augmentations: AugmentationSequential) -> None:
  43. super().__init__()
  44. self._check_consistency(*augmentations)
  45. self.augmentations = augmentations
  46. def _check_consistency(self, *augmentations: AugmentationSequential) -> bool:
  47. for i, aug in enumerate(augmentations):
  48. if not isinstance(aug, AugmentationSequential):
  49. raise ValueError(f"Please wrap your augmentations[`{i}`] with `AugmentationSequentials`.")
  50. return True
  51. def forward(self, *input: Union[List[Tensor], List[Tuple[Tensor]]]) -> Union[List[Tensor], List[Tuple[Tensor]]]:
  52. return [aug(*inp) for inp, aug in zip(input, self.augmentations)]
  53. class ManyToOneAugmentationDispather(nn.Module):
  54. r"""Dispatches different augmentations to a single input and returns a list.
  55. Same `datakeys` must be applied across different augmentations. By default, with input
  56. (image, mask), the augmentations must not mess it as (mask, image) to avoid unexpected
  57. errors. This check can be cancelled with `strict=False` if needed.
  58. Args:
  59. augmentations: a list or a sequence of kornia AugmentationSequential modules.
  60. Examples:
  61. >>> import torch
  62. >>> input = torch.randn(2, 3, 5, 6)
  63. >>> mask = torch.ones(2, 3, 5, 6)
  64. >>> aug_list = ManyToOneAugmentationDispather(
  65. ... AugmentationSequential(
  66. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  67. ... kornia.augmentation.RandomAffine(360, p=1.0),
  68. ... data_keys=["input", "mask",],
  69. ... ),
  70. ... AugmentationSequential(
  71. ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
  72. ... kornia.augmentation.RandomAffine(360, p=1.0),
  73. ... data_keys=["input", "mask",],
  74. ... )
  75. ... )
  76. >>> output = aug_list(input, mask)
  77. """
  78. def __init__(self, *augmentations: AugmentationSequential, strict: bool = True) -> None:
  79. super().__init__()
  80. self.strict = strict
  81. self._check_consistency(*augmentations)
  82. self.augmentations = augmentations
  83. def _check_consistency(self, *augmentations: AugmentationSequential) -> bool:
  84. for i, aug in enumerate(augmentations):
  85. if not isinstance(aug, AugmentationSequential):
  86. raise ValueError(f"Please wrap your augmentations[`{i}`] with `AugmentationSequentials`.")
  87. if self.strict and i != 0 and aug.data_keys != augmentations[i - 1].data_keys:
  88. raise RuntimeError(
  89. f"Different `data_keys` between {i - 1} and {i} elements, "
  90. f"got {aug.data_keys} and {augmentations[i - 1].data_keys}."
  91. )
  92. return True
  93. def forward(self, *input: Union[Tensor, Tuple[Tensor]]) -> Union[List[Tensor], List[Tuple[Tensor]]]:
  94. return [aug(*input) for aug in self.augmentations]