mask_dropout.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. """Implementation of mask-based dropout augmentation.
  2. This module provides the MaskDropout transform, which identifies objects in a segmentation mask
  3. and drops out random objects completely. This augmentation is particularly useful for instance
  4. segmentation and object detection tasks, as it simulates occlusions or missing objects in a
  5. semantically meaningful way, rather than dropping out random pixels or regions.
  6. """
  7. from __future__ import annotations
  8. from typing import Any, Literal, cast
  9. import cv2
  10. import numpy as np
  11. import albumentations.augmentations.dropout.functional as fdropout
  12. from albumentations.core.bbox_utils import BboxProcessor, denormalize_bboxes, normalize_bboxes
  13. from albumentations.core.keypoints_utils import KeypointsProcessor
  14. from albumentations.core.pydantic import OnePlusIntRangeType
  15. from albumentations.core.transforms_interface import BaseTransformInitSchema, DualTransform
  16. from albumentations.core.type_definitions import ALL_TARGETS
  17. __all__ = ["MaskDropout"]
  18. class MaskDropout(DualTransform):
  19. """Apply dropout to random objects in a mask, zeroing out the corresponding regions in both the image and mask.
  20. This transform identifies objects in the mask (where each unique non-zero value represents a distinct object),
  21. randomly selects a number of these objects, and sets their corresponding regions to zero in both the image and mask.
  22. It can also handle bounding boxes and keypoints, removing or adjusting them based on the dropout regions.
  23. Args:
  24. max_objects (int | tuple[int, int]): Maximum number of objects to dropout. If a single int is provided,
  25. it's treated as the upper bound. If a tuple of two ints is provided, it's treated as a range [min, max].
  26. fill (float | Literal["inpaint_telea", "inpaint_ns"]): Value to fill dropped out regions in the image.
  27. Can be one of:
  28. - float: Constant value to fill the regions (e.g., 0 for black, 255 for white)
  29. - "inpaint_telea": Use Telea inpainting algorithm (for 3-channel images only)
  30. - "inpaint_ns": Use Navier-Stokes inpainting algorithm (for 3-channel images only)
  31. fill_mask (float): Value to fill the dropped out regions in the mask.
  32. min_area (float): Minimum area (in pixels) of a bounding box that must remain visible after dropout to be kept.
  33. Only applicable if bounding box augmentation is enabled. Default: 0.0
  34. min_visibility (float): Minimum visibility ratio (visible area / total area) of a bounding box after dropout
  35. to be kept. Only applicable if bounding box augmentation is enabled. Default: 0.0
  36. p (float): Probability of applying the transform. Default: 0.5.
  37. Targets:
  38. image, mask, bboxes, keypoints, volume, mask3d
  39. Image types:
  40. uint8, float32
  41. Note:
  42. - The mask should be a single-channel image where 0 represents the background and non-zero values represent
  43. different object instances.
  44. - For bounding box and keypoint augmentation, make sure to set up the corresponding processors in the pipeline.
  45. Examples:
  46. >>> import numpy as np
  47. >>> import albumentations as A
  48. >>>
  49. >>> # Prepare sample data
  50. >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  51. >>> mask = np.zeros((100, 100), dtype=np.uint8)
  52. >>> mask[20:40, 20:40] = 1 # Object 1
  53. >>> mask[60:80, 60:80] = 2 # Object 2
  54. >>> bboxes = np.array([[20, 20, 40, 40], [60, 60, 80, 80]], dtype=np.float32)
  55. >>> bbox_labels = [1, 2]
  56. >>> keypoints = np.array([[30, 30], [70, 70]], dtype=np.float32)
  57. >>> keypoint_labels = [0, 1]
  58. >>>
  59. >>> # Define the transform with tuple for max_objects
  60. >>> transform = A.Compose(
  61. ... transforms=[
  62. ... A.MaskDropout(
  63. ... max_objects=(1, 2), # Using tuple to specify min and max objects to drop
  64. ... fill=0, # Fill value for dropped regions in image
  65. ... fill_mask=0, # Fill value for dropped regions in mask
  66. ... p=1.0
  67. ... ),
  68. ... ],
  69. ... bbox_params=A.BboxParams(
  70. ... format='pascal_voc',
  71. ... label_fields=['bbox_labels'],
  72. ... min_area=1,
  73. ... min_visibility=0.1
  74. ... ),
  75. ... keypoint_params=A.KeypointParams(
  76. ... format='xy',
  77. ... label_fields=['keypoint_labels'],
  78. ... remove_invisible=True
  79. ... )
  80. ... )
  81. >>>
  82. >>> # Apply the transform
  83. >>> transformed = transform(
  84. ... image=image,
  85. ... mask=mask,
  86. ... bboxes=bboxes,
  87. ... bbox_labels=bbox_labels,
  88. ... keypoints=keypoints,
  89. ... keypoint_labels=keypoint_labels
  90. ... )
  91. >>>
  92. >>> # Get the transformed data
  93. >>> transformed_image = transformed['image'] # Image with dropped out regions
  94. >>> transformed_mask = transformed['mask'] # Mask with dropped out regions
  95. >>> transformed_bboxes = transformed['bboxes'] # Remaining bboxes after dropout
  96. >>> transformed_bbox_labels = transformed['bbox_labels'] # Labels for remaining bboxes
  97. >>> transformed_keypoints = transformed['keypoints'] # Remaining keypoints after dropout
  98. >>> transformed_keypoint_labels = transformed['keypoint_labels'] # Labels for remaining keypoints
  99. """
  100. _targets = ALL_TARGETS
  101. class InitSchema(BaseTransformInitSchema):
  102. max_objects: OnePlusIntRangeType
  103. fill: float | Literal["inpaint_telea", "inpaint_ns"]
  104. fill_mask: float
  105. def __init__(
  106. self,
  107. max_objects: tuple[int, int] | int = (1, 1),
  108. fill: float | Literal["inpaint_telea", "inpaint_ns"] = 0,
  109. fill_mask: float = 0,
  110. p: float = 0.5,
  111. ):
  112. super().__init__(p=p)
  113. self.max_objects = cast("tuple[int, int]", max_objects)
  114. self.fill = fill # type: ignore[assignment]
  115. self.fill_mask = fill_mask
  116. @property
  117. def targets_as_params(self) -> list[str]:
  118. """Get targets as parameters.
  119. Returns:
  120. list[str]: List of targets as parameters.
  121. """
  122. return ["mask"]
  123. def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
  124. """Get parameters dependent on the data.
  125. Args:
  126. params (dict[str, Any]): Dictionary containing parameters.
  127. data (dict[str, Any]): Dictionary containing data.
  128. Returns:
  129. dict[str, Any]: Dictionary with parameters for transformation.
  130. """
  131. mask = data["mask"]
  132. label_image, num_labels = fdropout.label(mask, return_num=True)
  133. if num_labels == 0:
  134. dropout_mask = None
  135. else:
  136. objects_to_drop = self.py_random.randint(*self.max_objects)
  137. objects_to_drop = min(num_labels, objects_to_drop)
  138. if objects_to_drop == num_labels:
  139. dropout_mask = mask > 0
  140. else:
  141. labels_index = self.py_random.sample(range(1, num_labels + 1), objects_to_drop)
  142. dropout_mask = np.zeros(mask.shape[:2], dtype=bool)
  143. for label_index in labels_index:
  144. dropout_mask |= label_image == label_index
  145. return {"dropout_mask": dropout_mask}
  146. def apply(self, img: np.ndarray, dropout_mask: np.ndarray | None, **params: Any) -> np.ndarray:
  147. """Apply dropout to the image.
  148. Args:
  149. img (np.ndarray): The image to apply the transform to.
  150. dropout_mask (np.ndarray | None): The dropout mask for the image.
  151. **params (Any): Additional parameters for the transform.
  152. Returns:
  153. np.ndarray: The transformed image.
  154. """
  155. if dropout_mask is None:
  156. return img
  157. if self.fill in {"inpaint_telea", "inpaint_ns"}:
  158. dropout_mask = dropout_mask.astype(np.uint8)
  159. _, _, width, height = cv2.boundingRect(dropout_mask)
  160. radius = min(3, max(width, height) // 2)
  161. return cv2.inpaint(img, dropout_mask, radius, cast("Literal['inpaint_telea', 'inpaint_ns']", self.fill))
  162. img = img.copy()
  163. img[dropout_mask] = self.fill
  164. return img
  165. def apply_to_mask(self, mask: np.ndarray, dropout_mask: np.ndarray | None, **params: Any) -> np.ndarray:
  166. """Apply dropout to the mask.
  167. Args:
  168. mask (np.ndarray): The mask to apply the transform to.
  169. dropout_mask (np.ndarray | None): The dropout mask for the mask.
  170. **params (Any): Additional parameters for the transform.
  171. Returns:
  172. np.ndarray: The transformed mask.
  173. """
  174. if dropout_mask is None or self.fill_mask is None:
  175. return mask
  176. mask = mask.copy()
  177. mask[dropout_mask] = self.fill_mask
  178. return mask
  179. def apply_to_bboxes(self, bboxes: np.ndarray, dropout_mask: np.ndarray | None, **params: Any) -> np.ndarray:
  180. """Apply dropout to bounding boxes.
  181. Args:
  182. bboxes (np.ndarray): The bounding boxes to apply the transform to.
  183. dropout_mask (np.ndarray | None): The dropout mask for the bounding boxes.
  184. **params (Any): Additional parameters for the transform.
  185. Returns:
  186. np.ndarray: The transformed bounding boxes.
  187. """
  188. if dropout_mask is None:
  189. return bboxes
  190. processor = cast("BboxProcessor", self.get_processor("bboxes"))
  191. if processor is None:
  192. return bboxes
  193. image_shape = params["shape"][:2]
  194. denormalized_bboxes = denormalize_bboxes(bboxes, image_shape)
  195. result = fdropout.mask_dropout_bboxes(
  196. denormalized_bboxes,
  197. dropout_mask,
  198. image_shape,
  199. processor.params.min_area,
  200. processor.params.min_visibility,
  201. )
  202. return normalize_bboxes(result, image_shape)
  203. def apply_to_keypoints(self, keypoints: np.ndarray, dropout_mask: np.ndarray | None, **params: Any) -> np.ndarray:
  204. """Apply dropout to keypoints.
  205. Args:
  206. keypoints (np.ndarray): The keypoints to apply the transform to.
  207. dropout_mask (np.ndarray | None): The dropout mask for the keypoints.
  208. **params (Any): Additional parameters for the transform.
  209. Returns:
  210. np.ndarray: The transformed keypoints.
  211. """
  212. if dropout_mask is None:
  213. return keypoints
  214. processor = cast("KeypointsProcessor", self.get_processor("keypoints"))
  215. if processor is None or not processor.params.remove_invisible:
  216. return keypoints
  217. return fdropout.mask_dropout_keypoints(keypoints, dropout_mask)