transforms.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. """Transform classes for dropout-based augmentations.
  2. This module contains transform classes for various dropout techniques used in image
  3. augmentation. It provides the base dropout class and specialized implementations like
  4. PixelDropout. These transforms randomly remove or modify pixels, channels, or regions
  5. in images, which can help models become more robust to occlusions and missing information.
  6. """
  7. from __future__ import annotations
  8. from typing import Any, Literal, cast
  9. import numpy as np
  10. from albucore import get_num_channels
  11. from pydantic import Field
  12. from albumentations.augmentations.dropout import functional as fdropout
  13. from albumentations.augmentations.dropout.functional import (
  14. cutout,
  15. cutout_on_volume,
  16. cutout_on_volumes,
  17. filter_bboxes_by_holes,
  18. filter_keypoints_in_holes,
  19. )
  20. from albumentations.augmentations.pixel import functional as fpixel
  21. from albumentations.core.bbox_utils import BboxProcessor, denormalize_bboxes, normalize_bboxes
  22. from albumentations.core.keypoints_utils import KeypointsProcessor
  23. from albumentations.core.transforms_interface import BaseTransformInitSchema, DualTransform
  24. from albumentations.core.type_definitions import ALL_TARGETS, Targets
  25. __all__ = ["PixelDropout"]
  26. class BaseDropout(DualTransform):
  27. """Base class for dropout-style transformations.
  28. This class provides common functionality for various dropout techniques,
  29. including applying cutouts to images and masks.
  30. Args:
  31. fill (tuple[float, ...] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"]):
  32. Value to fill dropped regions.
  33. fill_mask (tuple[float, ...] | float | None): Value to fill
  34. dropped regions in the mask. If None, the mask is not modified.
  35. p (float): Probability of applying the transform.
  36. Targets:
  37. image, mask, bboxes, keypoints, volume, mask3d
  38. Image types:
  39. uint8, float32
  40. Examples:
  41. >>> import numpy as np
  42. >>> import albumentations as A
  43. >>>
  44. >>> # Example of a custom dropout transform inheriting from BaseDropout
  45. >>> class CustomDropout(A.BaseDropout):
  46. ... def __init__(self, num_holes_range=(4, 8), hole_size_range=(10, 20), *args, **kwargs):
  47. ... super().__init__(*args, **kwargs)
  48. ... self.num_holes_range = num_holes_range
  49. ... self.hole_size_range = hole_size_range
  50. ...
  51. ... def get_params_dependent_on_data(self, params, data):
  52. ... img = data["image"]
  53. ... height, width = img.shape[:2]
  54. ...
  55. ... # Generate random holes
  56. ... num_holes = self.py_random.randint(*self.num_holes_range)
  57. ... hole_sizes = self.py_random.randint(*self.hole_size_range, size=num_holes)
  58. ...
  59. ... holes = []
  60. ... for i in range(num_holes):
  61. ... # Random position for each hole
  62. ... x1 = self.py_random.randint(0, max(1, width - hole_sizes[i]))
  63. ... y1 = self.py_random.randint(0, max(1, height - hole_sizes[i]))
  64. ... x2 = min(width, x1 + hole_sizes[i])
  65. ... y2 = min(height, y1 + hole_sizes[i])
  66. ... holes.append([x1, y1, x2, y2])
  67. ...
  68. ... # Return holes and random seed
  69. ... return {
  70. ... "holes": np.array(holes) if holes else np.empty((0, 4), dtype=np.int32),
  71. ... "seed": self.py_random.integers(0, 100000)
  72. ... }
  73. >>>
  74. >>> # Prepare sample data
  75. >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  76. >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  77. >>> bboxes = np.array([[0.1, 0.1, 0.4, 0.4], [0.6, 0.6, 0.9, 0.9]])
  78. >>>
  79. >>> # Create a transform with custom dropout
  80. >>> transform = A.Compose([
  81. ... CustomDropout(
  82. ... num_holes_range=(3, 6), # Generate 3-6 random holes
  83. ... hole_size_range=(5, 15), # Holes of size 5-15 pixels
  84. ... fill=0, # Fill holes with black
  85. ... fill_mask=1, # Fill mask holes with 1
  86. ... p=1.0 # Always apply for this example
  87. ... )
  88. ... ], bbox_params=A.BboxParams(format='yolo', min_visibility=0.3))
  89. >>>
  90. >>> # Apply the transform
  91. >>> transformed = transform(image=image, mask=mask, bboxes=bboxes)
  92. >>>
  93. >>> # Get the transformed data
  94. >>> dropout_image = transformed["image"] # Image with random holes filled with 0
  95. >>> dropout_mask = transformed["mask"] # Mask with same holes filled with 1
  96. >>> dropout_bboxes = transformed["bboxes"] # Bboxes filtered by visibility threshold
  97. """
  98. _targets: tuple[Targets, ...] | Targets = ALL_TARGETS
  99. class InitSchema(BaseTransformInitSchema):
  100. fill: tuple[float, ...] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"]
  101. fill_mask: tuple[float, ...] | float | None
  102. def __init__(
  103. self,
  104. fill: tuple[float, ...] | float | Literal["random", "random_uniform", "inpaint_telea", "inpaint_ns"],
  105. fill_mask: tuple[float, ...] | float | None,
  106. p: float,
  107. ):
  108. super().__init__(p=p)
  109. self.fill = fill # type: ignore[assignment]
  110. self.fill_mask = fill_mask
  111. def apply(self, img: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> np.ndarray:
  112. if holes.size == 0:
  113. return img
  114. if self.fill in {"inpaint_telea", "inpaint_ns"}:
  115. num_channels = get_num_channels(img)
  116. if num_channels not in {1, 3}:
  117. raise ValueError("Inpainting works only for 1 or 3 channel images")
  118. return cutout(img, holes, self.fill, np.random.default_rng(seed))
  119. def apply_to_images(self, images: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> np.ndarray:
  120. if holes.size == 0:
  121. return images
  122. if self.fill in {"inpaint_telea", "inpaint_ns"}:
  123. num_channels = images.shape[3] if images.ndim == 4 else 1
  124. if num_channels not in {1, 3}:
  125. raise ValueError("Inpainting works only for 1 or 3 channel images")
  126. # Images (N, H, W, C) have the same structure as volumes (D, H, W, C)
  127. return cutout_on_volume(images, holes, self.fill, np.random.default_rng(seed))
  128. def apply_to_volume(self, volume: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> np.ndarray:
  129. # Volume (D, H, W, C) has the same structure as images (N, H, W, C)
  130. # We can reuse the same logic
  131. return self.apply_to_images(volume, holes, seed, **params)
  132. def apply_to_volumes(self, volumes: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> np.ndarray:
  133. if holes.size == 0:
  134. return volumes
  135. if self.fill in {"inpaint_telea", "inpaint_ns"}:
  136. num_channels = volumes.shape[4] if volumes.ndim == 5 else 1
  137. if num_channels not in {1, 3}:
  138. raise ValueError("Inpainting works only for 1 or 3 channel images")
  139. return cutout_on_volumes(volumes, holes, self.fill, np.random.default_rng(seed))
  140. def apply_to_mask3d(self, mask: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> np.ndarray:
  141. if self.fill_mask is None or holes.size == 0:
  142. return mask
  143. return cutout_on_volume(mask, holes, self.fill_mask, np.random.default_rng(seed))
  144. def apply_to_masks3d(self, mask: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> np.ndarray:
  145. if self.fill_mask is None or holes.size == 0:
  146. return mask
  147. return cutout_on_volumes(mask, holes, self.fill_mask, np.random.default_rng(seed))
  148. def apply_to_mask(self, mask: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> np.ndarray:
  149. if self.fill_mask is None or holes.size == 0:
  150. return mask
  151. return cutout(mask, holes, self.fill_mask, np.random.default_rng(seed))
  152. def apply_to_bboxes(
  153. self,
  154. bboxes: np.ndarray,
  155. holes: np.ndarray,
  156. **params: Any,
  157. ) -> np.ndarray:
  158. if holes.size == 0:
  159. return bboxes
  160. processor = cast("BboxProcessor", self.get_processor("bboxes"))
  161. if processor is None:
  162. return bboxes
  163. image_shape = params["shape"][:2]
  164. denormalized_bboxes = denormalize_bboxes(bboxes, image_shape)
  165. return normalize_bboxes(
  166. filter_bboxes_by_holes(
  167. denormalized_bboxes,
  168. holes,
  169. image_shape,
  170. min_area=processor.params.min_area,
  171. min_visibility=processor.params.min_visibility,
  172. ),
  173. image_shape,
  174. )
  175. def apply_to_keypoints(
  176. self,
  177. keypoints: np.ndarray,
  178. holes: np.ndarray,
  179. **params: Any,
  180. ) -> np.ndarray:
  181. if holes.size == 0:
  182. return keypoints
  183. processor = cast("KeypointsProcessor", self.get_processor("keypoints"))
  184. if processor is None or not processor.params.remove_invisible:
  185. return keypoints
  186. return filter_keypoints_in_holes(keypoints, holes)
  187. def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
  188. raise NotImplementedError("Subclasses must implement this method.")
  189. class PixelDropout(DualTransform):
  190. """Drops random pixels from the image.
  191. This transform randomly sets pixels in the image to a specified value, effectively "dropping out" those pixels.
  192. It can be applied to both the image and its corresponding mask.
  193. Args:
  194. dropout_prob (float): Probability of dropping out each pixel. Should be in the range [0, 1].
  195. Default: 0.01
  196. per_channel (bool): If True, the dropout mask will be generated independently for each channel.
  197. If False, the same dropout mask will be applied to all channels.
  198. Default: False
  199. drop_value (float | tuple[float, ...] | None): Value to assign to the dropped pixels.
  200. If None, the value will be randomly sampled for each application:
  201. - For uint8 images: Random integer in [0, 255]
  202. - For float32 images: Random float in [0, 1]
  203. If a single number, that value will be used for all dropped pixels.
  204. If a sequence, it should contain one value per channel.
  205. Default: 0
  206. mask_drop_value (float | tuple[float, ...] | None): Value to assign to dropped pixels in the mask.
  207. If None, the mask will remain unchanged.
  208. If a single number, that value will be used for all dropped pixels in the mask.
  209. If a sequence, it should contain one value per channel.
  210. Default: None
  211. p (float): Probability of applying the transform. Should be in the range [0, 1].
  212. Default: 0.5
  213. Targets:
  214. image, mask, bboxes, keypoints, volume, mask3d
  215. Image types:
  216. uint8, float32
  217. Note:
  218. - When applied to bounding boxes, this transform may cause some boxes to have zero area
  219. if all pixels within the box are dropped. Such boxes will be removed.
  220. - When applied to keypoints, keypoints that fall on dropped pixels will be removed if
  221. the keypoint processor is configured to remove invisible keypoints.
  222. Examples:
  223. >>> import numpy as np
  224. >>> import albumentations as A
  225. >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  226. >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  227. >>> transform = A.PixelDropout(dropout_prob=0.1, per_channel=True, p=1.0)
  228. >>> result = transform(image=image, mask=mask)
  229. >>> dropped_image, dropped_mask = result['image'], result['mask']
  230. """
  231. class InitSchema(BaseTransformInitSchema):
  232. dropout_prob: float = Field(ge=0, le=1)
  233. per_channel: bool
  234. drop_value: tuple[float, ...] | float | None
  235. mask_drop_value: tuple[float, ...] | float | None
  236. _targets = ALL_TARGETS
  237. def __init__(
  238. self,
  239. dropout_prob: float = 0.01,
  240. per_channel: bool = False,
  241. drop_value: tuple[float, ...] | float | None = 0,
  242. mask_drop_value: tuple[float, ...] | float | None = None,
  243. p: float = 0.5,
  244. ):
  245. super().__init__(p=p)
  246. self.dropout_prob = dropout_prob
  247. self.per_channel = per_channel
  248. self.drop_value = drop_value
  249. self.mask_drop_value = mask_drop_value
  250. def apply(
  251. self,
  252. img: np.ndarray,
  253. drop_mask: np.ndarray,
  254. drop_values: np.ndarray,
  255. **params: Any,
  256. ) -> np.ndarray:
  257. """Apply pixel dropout to the image.
  258. Args:
  259. img (np.ndarray): The image to apply the transform to.
  260. drop_mask (np.ndarray): The dropout mask.
  261. drop_values (np.ndarray): The values to assign to the dropped pixels.
  262. **params (Any): Additional parameters for the transform.
  263. Returns:
  264. np.ndarray: The transformed image.
  265. """
  266. return fpixel.pixel_dropout(img, drop_mask, drop_values)
  267. def apply_to_mask(
  268. self,
  269. mask: np.ndarray,
  270. mask_drop_mask: np.ndarray,
  271. mask_drop_values: float | np.ndarray,
  272. **params: Any,
  273. ) -> np.ndarray:
  274. """Apply pixel dropout to the mask.
  275. Args:
  276. mask (np.ndarray): The mask to apply the transform to.
  277. mask_drop_mask (np.ndarray): The dropout mask for the mask.
  278. mask_drop_values (float | np.ndarray): The values to assign to the dropped pixels in the mask.
  279. **params (Any): Additional parameters for the transform.
  280. Returns:
  281. np.ndarray: The transformed mask.
  282. """
  283. if self.mask_drop_value is None:
  284. return mask
  285. return fpixel.pixel_dropout(mask, mask_drop_mask, mask_drop_values)
  286. def apply_to_bboxes(
  287. self,
  288. bboxes: np.ndarray,
  289. drop_mask: np.ndarray | None,
  290. **params: Any,
  291. ) -> np.ndarray:
  292. """Apply pixel dropout to the bounding boxes.
  293. Args:
  294. bboxes (np.ndarray): The bounding boxes to apply the transform to.
  295. drop_mask (np.ndarray | None): The dropout mask for the bounding boxes.
  296. **params (Any): Additional parameters for the transform.
  297. Returns:
  298. np.ndarray: The transformed bounding boxes.
  299. """
  300. if drop_mask is None or self.per_channel:
  301. return bboxes
  302. processor = cast("BboxProcessor", self.get_processor("bboxes"))
  303. if processor is None:
  304. return bboxes
  305. image_shape = params["shape"][:2]
  306. denormalized_bboxes = denormalize_bboxes(bboxes, image_shape)
  307. # If per_channel is True, we need to create a single channel mask
  308. # by combining the multi-channel mask (considering a pixel dropped if it's dropped in any channel)
  309. if self.per_channel and len(drop_mask.shape) > 2:
  310. # Create a single channel mask where a pixel is considered dropped if it's dropped in any channel
  311. combined_mask = np.any(drop_mask, axis=-1 if drop_mask.shape[-1] <= 4 else 0)
  312. # Ensure the mask has the right shape for the bboxes function
  313. if combined_mask.ndim == 3 and combined_mask.shape[0] == 1:
  314. combined_mask = combined_mask[0]
  315. else:
  316. combined_mask = drop_mask
  317. result = fdropout.mask_dropout_bboxes(
  318. denormalized_bboxes,
  319. combined_mask,
  320. image_shape,
  321. processor.params.min_area,
  322. processor.params.min_visibility,
  323. )
  324. return normalize_bboxes(result, image_shape)
  325. def apply_to_keypoints(
  326. self,
  327. keypoints: np.ndarray,
  328. **params: Any,
  329. ) -> np.ndarray:
  330. """Apply pixel dropout to the keypoints.
  331. Args:
  332. keypoints (np.ndarray): The keypoints to apply the transform to.
  333. **params (Any): Additional parameters for the transform.
  334. Returns:
  335. np.ndarray: The transformed keypoints.
  336. """
  337. return keypoints
  338. def get_params_dependent_on_data(
  339. self,
  340. params: dict[str, Any],
  341. data: dict[str, Any],
  342. ) -> dict[str, Any]:
  343. """Generate parameters for pixel dropout based on input data.
  344. Args:
  345. params (dict[str, Any]): Transform parameters
  346. data (dict[str, Any]): Input data dictionary
  347. Returns:
  348. dict[str, Any]: Dictionary of parameters for applying the transform
  349. """
  350. reference_array = data["image"] if "image" in data else data["images"][0]
  351. # Generate drop mask and values for all targets
  352. drop_mask = fpixel.get_drop_mask(
  353. reference_array.shape,
  354. self.per_channel,
  355. self.dropout_prob,
  356. self.random_generator,
  357. )
  358. drop_values = fpixel.prepare_drop_values(
  359. reference_array,
  360. self.drop_value,
  361. self.random_generator,
  362. )
  363. # Handle mask drop values if specified
  364. mask_drop_mask = None
  365. mask_drop_values = None
  366. mask = fpixel.get_mask_array(data)
  367. if self.mask_drop_value is not None and mask is not None:
  368. mask_drop_mask = fpixel.get_drop_mask(
  369. mask.shape,
  370. self.per_channel,
  371. self.dropout_prob,
  372. self.random_generator,
  373. )
  374. mask_drop_values = fpixel.prepare_drop_values(
  375. mask,
  376. self.mask_drop_value,
  377. self.random_generator,
  378. )
  379. return {
  380. "drop_mask": drop_mask,
  381. "drop_values": drop_values,
  382. "mask_drop_mask": mask_drop_mask if mask_drop_mask is not None else None,
  383. "mask_drop_values": mask_drop_values if mask_drop_values is not None else None,
  384. }