xy_masking.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. """Implementation of XY masking for time-frequency domain transformations.
  2. This module provides the XYMasking transform, which applies masking strips along the X and Y axes
  3. of an image. This is particularly useful for audio spectrograms, time-series data visualizations,
  4. and other grid-like data representations where masking in specific directions (time or frequency)
  5. can improve model robustness and generalization.
  6. """
  7. from __future__ import annotations
  8. from typing import Any, Literal, cast
  9. import numpy as np
  10. from pydantic import model_validator
  11. from typing_extensions import Self
  12. from albumentations.augmentations.dropout.transforms import BaseDropout
  13. from albumentations.core.pydantic import NonNegativeIntRangeType
  14. from albumentations.core.transforms_interface import BaseTransformInitSchema
  15. __all__ = ["XYMasking"]
  16. class XYMasking(BaseDropout):
  17. """Applies masking strips to an image, either horizontally (X axis) or vertically (Y axis),
  18. simulating occlusions. This transform is useful for training models to recognize images
  19. with varied visibility conditions. It's particularly effective for spectrogram images,
  20. allowing spectral and frequency masking to improve model robustness.
  21. At least one of `max_x_length` or `max_y_length` must be specified, dictating the mask's
  22. maximum size along each axis.
  23. Args:
  24. num_masks_x (int | tuple[int, int]): Number or range of horizontal regions to mask. Defaults to 0.
  25. num_masks_y (int | tuple[int, int]): Number or range of vertical regions to mask. Defaults to 0.
  26. mask_x_length (int | tuple[int, int]): Specifies the length of the masks along
  27. the X (horizontal) axis. If an integer is provided, it sets a fixed mask length.
  28. If a tuple of two integers (min, max) is provided,
  29. the mask length is randomly chosen within this range for each mask.
  30. This allows for variable-length masks in the horizontal direction.
  31. mask_y_length (int | tuple[int, int]): Specifies the height of the masks along
  32. the Y (vertical) axis. Similar to `mask_x_length`, an integer sets a fixed mask height,
  33. while a tuple (min, max) allows for variable-height masks, chosen randomly
  34. within the specified range for each mask. This flexibility facilitates creating masks of various
  35. sizes in the vertical direction.
  36. fill (tuple[float, float] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"]):
  37. Value for the dropped pixels. Can be:
  38. - int or float: all channels are filled with this value
  39. - tuple: tuple of values for each channel
  40. - 'random': each pixel is filled with random values
  41. - 'random_uniform': each hole is filled with a single random color
  42. - 'inpaint_telea': uses OpenCV Telea inpainting method
  43. - 'inpaint_ns': uses OpenCV Navier-Stokes inpainting method
  44. Default: 0
  45. fill_mask (tuple[float, float] | float | None): Fill value for dropout regions in the mask.
  46. If None, mask regions corresponding to image dropouts are unchanged. Default: None
  47. p (float): Probability of applying the transform. Defaults to 0.5.
  48. Targets:
  49. image, mask, bboxes, keypoints, volume, mask3d
  50. Image types:
  51. uint8, float32
  52. Note: Either `max_x_length` or `max_y_length` or both must be defined.
  53. """
  54. class InitSchema(BaseTransformInitSchema):
  55. num_masks_x: NonNegativeIntRangeType
  56. num_masks_y: NonNegativeIntRangeType
  57. mask_x_length: NonNegativeIntRangeType
  58. mask_y_length: NonNegativeIntRangeType
  59. fill: tuple[float, ...] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"]
  60. fill_mask: tuple[float, ...] | float | None
  61. @model_validator(mode="after")
  62. def _check_mask_length(self) -> Self:
  63. if (
  64. isinstance(self.mask_x_length, int)
  65. and self.mask_x_length <= 0
  66. and isinstance(self.mask_y_length, int)
  67. and self.mask_y_length <= 0
  68. ):
  69. msg = "At least one of `mask_x_length` or `mask_y_length` Should be a positive number."
  70. raise ValueError(msg)
  71. return self
  72. def __init__(
  73. self,
  74. num_masks_x: tuple[int, int] | int = 0,
  75. num_masks_y: tuple[int, int] | int = 0,
  76. mask_x_length: tuple[int, int] | int = 0,
  77. mask_y_length: tuple[int, int] | int = 0,
  78. fill: tuple[float, ...] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"] = 0,
  79. fill_mask: tuple[float, ...] | float | None = None,
  80. p: float = 0.5,
  81. ):
  82. super().__init__(p=p, fill=fill, fill_mask=fill_mask)
  83. self.num_masks_x = cast("tuple[int, int]", num_masks_x)
  84. self.num_masks_y = cast("tuple[int, int]", num_masks_y)
  85. self.mask_x_length = cast("tuple[int, int]", mask_x_length)
  86. self.mask_y_length = cast("tuple[int, int]", mask_y_length)
  87. def _validate_mask_length(
  88. self,
  89. mask_length: tuple[int, int] | None,
  90. dimension_size: int,
  91. dimension_name: str,
  92. ) -> None:
  93. """Validate the mask length against the corresponding image dimension size."""
  94. if mask_length is not None:
  95. if isinstance(mask_length, (tuple, list)):
  96. if mask_length[0] < 0 or mask_length[1] > dimension_size:
  97. raise ValueError(
  98. f"{dimension_name} range {mask_length} is out of valid range [0, {dimension_size}]",
  99. )
  100. elif mask_length < 0 or mask_length > dimension_size:
  101. raise ValueError(f"{dimension_name} {mask_length} exceeds image {dimension_name} {dimension_size}")
  102. def get_params_dependent_on_data(
  103. self,
  104. params: dict[str, Any],
  105. data: dict[str, Any],
  106. ) -> dict[str, np.ndarray]:
  107. """Get parameters dependent on the data.
  108. Args:
  109. params (dict[str, Any]): Dictionary containing parameters.
  110. data (dict[str, Any]): Dictionary containing data.
  111. Returns:
  112. dict[str, np.ndarray]: Dictionary with parameters for transformation.
  113. """
  114. image_shape = params["shape"][:2]
  115. height, width = image_shape
  116. self._validate_mask_length(self.mask_x_length, width, "mask_x_length")
  117. self._validate_mask_length(self.mask_y_length, height, "mask_y_length")
  118. masks_x = self._generate_masks(self.num_masks_x, image_shape, self.mask_x_length, axis="x")
  119. masks_y = self._generate_masks(self.num_masks_y, image_shape, self.mask_y_length, axis="y")
  120. holes = np.array(masks_x + masks_y)
  121. return {"holes": holes, "seed": self.random_generator.integers(0, 2**32 - 1)}
  122. def _generate_mask_size(self, mask_length: tuple[int, int]) -> int:
  123. return self.py_random.randint(*mask_length)
  124. def _generate_masks(
  125. self,
  126. num_masks: tuple[int, int],
  127. image_shape: tuple[int, int],
  128. max_length: tuple[int, int] | None,
  129. axis: str,
  130. ) -> list[tuple[int, int, int, int]]:
  131. if max_length is None or max_length == 0 or (isinstance(num_masks, (int, float)) and num_masks == 0):
  132. return []
  133. masks = []
  134. num_masks_integer = (
  135. num_masks if isinstance(num_masks, int) else self.py_random.randint(num_masks[0], num_masks[1])
  136. )
  137. height, width = image_shape
  138. for _ in range(num_masks_integer):
  139. length = self._generate_mask_size(max_length)
  140. if axis == "x":
  141. x_min = self.py_random.randint(0, width - length)
  142. y_min = 0
  143. x_max, y_max = x_min + length, height
  144. else: # axis == 'y'
  145. y_min = self.py_random.randint(0, height - length)
  146. x_min = 0
  147. x_max, y_max = width, y_min + length
  148. masks.append((x_min, y_min, x_max, y_max))
  149. return masks