channel_dropout.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. """Implementation of the Channel Dropout transform for multi-channel images.
  2. This module provides the ChannelDropout transform, which randomly drops (sets to a fill value)
  3. one or more channels in multi-channel images. This augmentation can help models become more
  4. robust to missing or corrupted channel information and encourage learning from all available
  5. channels rather than relying on a subset.
  6. """
  7. from __future__ import annotations
  8. from typing import Annotated, Any
  9. import numpy as np
  10. from albucore import get_num_channels
  11. from pydantic import AfterValidator
  12. from albumentations.core.pydantic import check_range_bounds
  13. from albumentations.core.transforms_interface import BaseTransformInitSchema, ImageOnlyTransform
  14. from .functional import channel_dropout
  15. __all__ = ["ChannelDropout"]
  16. MIN_DROPOUT_CHANNEL_LIST_LENGTH = 2
  17. class ChannelDropout(ImageOnlyTransform):
  18. """Randomly drop channels in the input image.
  19. This transform randomly selects a number of channels to drop from the input image
  20. and replaces them with a specified fill value. This can improve model robustness
  21. to missing or corrupted channels.
  22. The technique is conceptually similar to:
  23. - Dropout layers in neural networks, which randomly set input units to 0 during training.
  24. - CoarseDropout augmentation, which drops out regions in the spatial dimensions of the image.
  25. However, ChannelDropout operates on the channel dimension, effectively "dropping out"
  26. entire color channels or feature maps.
  27. Args:
  28. channel_drop_range (tuple[int, int]): Range from which to choose the number
  29. of channels to drop. The actual number will be randomly selected from
  30. the inclusive range [min, max]. Default: (1, 1).
  31. fill (float): Pixel value used to fill the dropped channels.
  32. Default: 0.
  33. p (float): Probability of applying the transform. Must be in the range
  34. [0, 1]. Default: 0.5.
  35. Raises:
  36. NotImplementedError: If the input image has only one channel.
  37. ValueError: If the upper bound of channel_drop_range is greater than or
  38. equal to the number of channels in the input image.
  39. Targets:
  40. image, volume
  41. Image types:
  42. uint8, float32
  43. Examples:
  44. >>> import numpy as np
  45. >>> import albumentations as A
  46. >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  47. >>> transform = A.ChannelDropout(channel_drop_range=(1, 2), fill=128, p=1.0)
  48. >>> result = transform(image=image)
  49. >>> dropped_image = result['image']
  50. >>> assert dropped_image.shape == image.shape
  51. >>> assert np.any(dropped_image != image) # Some channels should be different
  52. Note:
  53. - The number of channels to drop is randomly chosen within the specified range.
  54. - Channels are randomly selected for dropping.
  55. - This transform is not applicable to single-channel (grayscale) images.
  56. - The transform will raise an error if it's not possible to drop the specified
  57. number of channels (e.g., trying to drop 3 channels from an RGB image).
  58. - This augmentation can be particularly useful for training models to be robust
  59. against missing or corrupted channel data in multi-spectral or hyperspectral imagery.
  60. """
  61. class InitSchema(BaseTransformInitSchema):
  62. channel_drop_range: Annotated[tuple[int, int], AfterValidator(check_range_bounds(1, None))]
  63. fill: float
  64. def __init__(
  65. self,
  66. channel_drop_range: tuple[int, int] = (1, 1),
  67. fill: float = 0,
  68. p: float = 0.5,
  69. ):
  70. super().__init__(p=p)
  71. self.channel_drop_range = channel_drop_range
  72. self.fill = fill
  73. def apply(self, img: np.ndarray, channels_to_drop: list[int], **params: Any) -> np.ndarray:
  74. """Apply channel dropout to the image.
  75. Args:
  76. img (np.ndarray): Image to apply channel dropout to.
  77. channels_to_drop (list[int]): List of channel indices to drop.
  78. **params (Any): Additional parameters.
  79. Returns:
  80. np.ndarray: Image with dropped channels.
  81. """
  82. return channel_dropout(img, channels_to_drop, self.fill)
  83. def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, list[int]]:
  84. """Get parameters that depend on input data.
  85. Args:
  86. params (dict[str, Any]): Parameters.
  87. data (dict[str, Any]): Input data.
  88. Returns:
  89. dict[str, list[int]]: Dictionary with channels to drop.
  90. """
  91. image = data["image"] if "image" in data else data["images"][0]
  92. num_channels = get_num_channels(image)
  93. if num_channels == 1:
  94. msg = "Images has one channel. ChannelDropout is not defined."
  95. raise NotImplementedError(msg)
  96. if self.channel_drop_range[1] >= num_channels:
  97. msg = "Can not drop all channels in ChannelDropout."
  98. raise ValueError(msg)
  99. num_drop_channels = self.py_random.randint(*self.channel_drop_range)
  100. channels_to_drop = self.py_random.sample(range(num_channels), k=num_drop_channels)
  101. return {"channels_to_drop": channels_to_drop}