grid_dropout.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. """Implementation of grid-based dropout augmentation.
  2. This module provides GridDropout, which creates a regular grid over the image and drops out
  3. rectangular regions according to the specified grid pattern. Unlike random dropout methods,
  4. grid dropout enforces a structured pattern of occlusions that can help models learn spatial
  5. relationships and context across the entire image space.
  6. """
  7. from __future__ import annotations
  8. from typing import Annotated, Any, Literal
  9. from pydantic import AfterValidator, Field
  10. import albumentations.augmentations.dropout.functional as fdropout
  11. from albumentations.augmentations.dropout.transforms import BaseDropout
  12. from albumentations.core.pydantic import check_range_bounds, nondecreasing
  13. __all__ = ["GridDropout"]
  14. class GridDropout(BaseDropout):
  15. """Apply GridDropout augmentation to images, masks, bounding boxes, and keypoints.
  16. GridDropout drops out rectangular regions of an image and the corresponding mask in a grid fashion.
  17. This technique can help improve model robustness by forcing the network to rely on a broader context
  18. rather than specific local features.
  19. Args:
  20. ratio (float): The ratio of the mask holes to the unit size (same for horizontal and vertical directions).
  21. Must be between 0 and 1. Default: 0.5.
  22. unit_size_range (tuple[int, int] | None): Range from which to sample grid size. Default: None.
  23. Must be between 2 and the image's shorter edge. If None, grid size is calculated based on image size.
  24. holes_number_xy (tuple[int, int] | None): The number of grid units in x and y directions.
  25. First value should be between 1 and image width//2,
  26. Second value should be between 1 and image height//2.
  27. Default: None. If provided, overrides unit_size_range.
  28. random_offset (bool): Whether to offset the grid randomly between 0 and (grid unit size - hole size).
  29. If True, entered shift_xy is ignored and set randomly. Default: True.
  30. fill (tuple[float, float] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"]):
  31. Value for the dropped pixels. Can be:
  32. - int or float: all channels are filled with this value
  33. - tuple: tuple of values for each channel
  34. - 'random': each pixel is filled with random values
  35. - 'random_uniform': each hole is filled with a single random color
  36. - 'inpaint_telea': uses OpenCV Telea inpainting method
  37. - 'inpaint_ns': uses OpenCV Navier-Stokes inpainting method
  38. Default: 0
  39. fill_mask (tuple[float, float] | float | None): Value for the dropped pixels in mask.
  40. If None, the mask is not modified. Default: None.
  41. shift_xy (tuple[int, int]): Offsets of the grid start in x and y directions from (0,0) coordinate.
  42. Only used when random_offset is False. Default: (0, 0).
  43. p (float): Probability of applying the transform. Default: 0.5.
  44. Targets:
  45. image, mask, bboxes, keypoints, volume, mask3d
  46. Image types:
  47. uint8, float32
  48. Note:
  49. - If both unit_size_range and holes_number_xy are None, the grid size is calculated based on the image size.
  50. - The actual number of dropped regions may differ slightly from holes_number_xy due to rounding.
  51. - Inpainting methods ('inpaint_telea', 'inpaint_ns') work only with grayscale or RGB images.
  52. - For 'random_uniform' fill, each grid cell gets a single random color, unlike 'random' where each pixel
  53. gets its own random value.
  54. Example:
  55. >>> import numpy as np
  56. >>> import albumentations as A
  57. >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  58. >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  59. >>> # Example with standard fill value
  60. >>> aug_basic = A.GridDropout(
  61. ... ratio=0.3,
  62. ... unit_size_range=(10, 20),
  63. ... random_offset=True,
  64. ... p=1.0
  65. ... )
  66. >>> # Example with random uniform fill
  67. >>> aug_random = A.GridDropout(
  68. ... ratio=0.3,
  69. ... unit_size_range=(10, 20),
  70. ... fill="random_uniform",
  71. ... p=1.0
  72. ... )
  73. >>> # Example with inpainting
  74. >>> aug_inpaint = A.GridDropout(
  75. ... ratio=0.3,
  76. ... unit_size_range=(10, 20),
  77. ... fill="inpaint_ns",
  78. ... p=1.0
  79. ... )
  80. >>> transformed = aug_random(image=image, mask=mask)
  81. >>> transformed_image, transformed_mask = transformed["image"], transformed["mask"]
  82. Reference:
  83. - Paper: https://arxiv.org/abs/2001.04086
  84. - OpenCV Inpainting methods: https://docs.opencv.org/master/df/d3d/tutorial_py_inpainting.html
  85. """
  86. class InitSchema(BaseDropout.InitSchema):
  87. ratio: float = Field(gt=0, le=1)
  88. random_offset: bool
  89. unit_size_range: (
  90. Annotated[tuple[int, int], AfterValidator(check_range_bounds(2, None)), AfterValidator(nondecreasing)]
  91. | None
  92. )
  93. shift_xy: Annotated[tuple[int, int], AfterValidator(check_range_bounds(0, None))]
  94. holes_number_xy: Annotated[tuple[int, int], AfterValidator(check_range_bounds(1, None))] | None
  95. def __init__(
  96. self,
  97. ratio: float = 0.5,
  98. random_offset: bool = True,
  99. unit_size_range: tuple[int, int] | None = None,
  100. holes_number_xy: tuple[int, int] | None = None,
  101. shift_xy: tuple[int, int] = (0, 0),
  102. fill: tuple[float, ...] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"] = 0,
  103. fill_mask: tuple[float, ...] | float | None = None,
  104. p: float = 0.5,
  105. ):
  106. super().__init__(fill=fill, fill_mask=fill_mask, p=p)
  107. self.ratio = ratio
  108. self.unit_size_range = unit_size_range
  109. self.holes_number_xy = holes_number_xy
  110. self.random_offset = random_offset
  111. self.shift_xy = shift_xy
  112. def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
  113. """Get parameters dependent on the data.
  114. Args:
  115. params (dict[str, Any]): Dictionary containing parameters.
  116. data (dict[str, Any]): Dictionary containing data.
  117. Returns:
  118. dict[str, Any]: Dictionary with parameters for transformation.
  119. """
  120. image_shape = params["shape"]
  121. if self.holes_number_xy:
  122. grid = self.holes_number_xy
  123. else:
  124. # Calculate grid based on unit_size_range or default
  125. unit_height, unit_width = fdropout.calculate_grid_dimensions(
  126. image_shape,
  127. self.unit_size_range,
  128. self.holes_number_xy,
  129. self.random_generator,
  130. )
  131. grid = (image_shape[0] // unit_height, image_shape[1] // unit_width)
  132. holes = fdropout.generate_grid_holes(
  133. image_shape,
  134. grid,
  135. self.ratio,
  136. self.random_offset,
  137. self.shift_xy,
  138. self.random_generator,
  139. )
  140. return {"holes": holes, "seed": self.random_generator.integers(0, 2**32 - 1)}