| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import List, Tuple, Union
- from torch import Tensor, nn
- from .augment import AugmentationSequential
- class ManyToManyAugmentationDispather(nn.Module):
- r"""Dispatches different augmentations to different inputs element-wisely.
- Args:
- augmentations: a list or a sequence of kornia AugmentationSequential modules.
- Examples:
- >>> import torch
- >>> input_1, input_2 = torch.randn(2, 3, 5, 6), torch.randn(2, 3, 5, 6)
- >>> mask_1, mask_2 = torch.ones(2, 3, 5, 6), torch.ones(2, 3, 5, 6)
- >>> aug_list = ManyToManyAugmentationDispather(
- ... AugmentationSequential(
- ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
- ... kornia.augmentation.RandomAffine(360, p=1.0),
- ... data_keys=["input", "mask",],
- ... ),
- ... AugmentationSequential(
- ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
- ... kornia.augmentation.RandomAffine(360, p=1.0),
- ... data_keys=["input", "mask",],
- ... )
- ... )
- >>> output = aug_list((input_1, mask_1), (input_2, mask_2))
- """
- def __init__(self, *augmentations: AugmentationSequential) -> None:
- super().__init__()
- self._check_consistency(*augmentations)
- self.augmentations = augmentations
- def _check_consistency(self, *augmentations: AugmentationSequential) -> bool:
- for i, aug in enumerate(augmentations):
- if not isinstance(aug, AugmentationSequential):
- raise ValueError(f"Please wrap your augmentations[`{i}`] with `AugmentationSequentials`.")
- return True
- def forward(self, *input: Union[List[Tensor], List[Tuple[Tensor]]]) -> Union[List[Tensor], List[Tuple[Tensor]]]:
- return [aug(*inp) for inp, aug in zip(input, self.augmentations)]
- class ManyToOneAugmentationDispather(nn.Module):
- r"""Dispatches different augmentations to a single input and returns a list.
- Same `datakeys` must be applied across different augmentations. By default, with input
- (image, mask), the augmentations must not mess it as (mask, image) to avoid unexpected
- errors. This check can be cancelled with `strict=False` if needed.
- Args:
- augmentations: a list or a sequence of kornia AugmentationSequential modules.
- Examples:
- >>> import torch
- >>> input = torch.randn(2, 3, 5, 6)
- >>> mask = torch.ones(2, 3, 5, 6)
- >>> aug_list = ManyToOneAugmentationDispather(
- ... AugmentationSequential(
- ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
- ... kornia.augmentation.RandomAffine(360, p=1.0),
- ... data_keys=["input", "mask",],
- ... ),
- ... AugmentationSequential(
- ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
- ... kornia.augmentation.RandomAffine(360, p=1.0),
- ... data_keys=["input", "mask",],
- ... )
- ... )
- >>> output = aug_list(input, mask)
- """
- def __init__(self, *augmentations: AugmentationSequential, strict: bool = True) -> None:
- super().__init__()
- self.strict = strict
- self._check_consistency(*augmentations)
- self.augmentations = augmentations
- def _check_consistency(self, *augmentations: AugmentationSequential) -> bool:
- for i, aug in enumerate(augmentations):
- if not isinstance(aug, AugmentationSequential):
- raise ValueError(f"Please wrap your augmentations[`{i}`] with `AugmentationSequentials`.")
- if self.strict and i != 0 and aug.data_keys != augmentations[i - 1].data_keys:
- raise RuntimeError(
- f"Different `data_keys` between {i - 1} and {i} elements, "
- f"got {aug.data_keys} and {augmentations[i - 1].data_keys}."
- )
- return True
- def forward(self, *input: Union[Tensor, Tuple[Tensor]]) -> Union[List[Tensor], List[Tuple[Tensor]]]:
- return [aug(*input) for aug in self.augmentations]
|