transforms.py 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422
  1. """Module containing 3D transformation classes for volumetric data augmentation.
  2. This module provides a collection of transformation classes designed specifically for
  3. 3D volumetric data (such as medical CT/MRI scans). These transforms can manipulate properties
  4. such as spatial dimensions, apply dropout effects, and perform symmetry operations on
  5. 3D volumes, masks, and keypoints. Each transformation inherits from a base transform
  6. interface and implements specific 3D augmentation logic.
  7. """
  8. from __future__ import annotations
  9. from typing import Annotated, Any, Literal, Union, cast
  10. import numpy as np
  11. from pydantic import AfterValidator, field_validator, model_validator
  12. from typing_extensions import Self
  13. from albumentations.augmentations.geometric import functional as fgeometric
  14. from albumentations.augmentations.transforms3d import functional as f3d
  15. from albumentations.core.keypoints_utils import KeypointsProcessor
  16. from albumentations.core.pydantic import check_range_bounds, nondecreasing
  17. from albumentations.core.transforms_interface import BaseTransformInitSchema, Transform3D
  18. from albumentations.core.type_definitions import Targets
  19. __all__ = ["CenterCrop3D", "CoarseDropout3D", "CubicSymmetry", "Pad3D", "PadIfNeeded3D", "RandomCrop3D"]
  20. NUM_DIMENSIONS = 3
  21. class BasePad3D(Transform3D):
  22. """Base class for 3D padding transforms.
  23. This class serves as a foundation for all 3D transforms that perform padding operations
  24. on volumetric data. It provides common functionality for padding 3D volumes, masks,
  25. and processing 3D keypoints during padding operations.
  26. The class handles different types of padding values (scalar or per-channel) and
  27. provides separate fill values for volumes and masks.
  28. Args:
  29. fill (tuple[float, ...] | float): Value to fill the padded voxels for volumes.
  30. Can be a single value for all channels or a tuple of values per channel.
  31. fill_mask (tuple[float, ...] | float): Value to fill the padded voxels for 3D masks.
  32. Can be a single value for all channels or a tuple of values per channel.
  33. p (float): Probability of applying the transform. Default: 1.0.
  34. Targets:
  35. volume, mask3d, keypoints
  36. Note:
  37. This is a base class and not intended to be used directly. Use its derivatives
  38. like Pad3D or PadIfNeeded3D instead, or create a custom padding transform
  39. by inheriting from this class.
  40. Examples:
  41. >>> import numpy as np
  42. >>> import albumentations as A
  43. >>>
  44. >>> # Example of a custom padding transform inheriting from BasePad3D
  45. >>> class CustomPad3D(A.BasePad3D):
  46. ... def __init__(self, padding_size: tuple[int, int, int] = (5, 5, 5), *args, **kwargs):
  47. ... super().__init__(*args, **kwargs)
  48. ... self.padding_size = padding_size
  49. ...
  50. ... def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
  51. ... # Create symmetric padding: same amount on all sides of each dimension
  52. ... pad_d, pad_h, pad_w = self.padding_size
  53. ... padding = (pad_d, pad_d, pad_h, pad_h, pad_w, pad_w)
  54. ... return {"padding": padding}
  55. >>>
  56. >>> # Prepare sample data
  57. >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  58. >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  59. >>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
  60. >>> keypoint_labels = [1, 2] # Labels for each keypoint
  61. >>>
  62. >>> # Use the custom transform in a pipeline
  63. >>> transform = A.Compose([
  64. ... CustomPad3D(
  65. ... padding_size=(2, 10, 10),
  66. ... fill=0,
  67. ... fill_mask=1,
  68. ... p=1.0
  69. ... )
  70. ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
  71. >>>
  72. >>> # Apply the transform
  73. >>> transformed = transform(
  74. ... volume=volume,
  75. ... mask3d=mask3d,
  76. ... keypoints=keypoints,
  77. ... keypoint_labels=keypoint_labels
  78. ... )
  79. >>>
  80. >>> # Get the transformed data
  81. >>> transformed_volume = transformed["volume"] # Shape: (14, 120, 120)
  82. >>> transformed_mask3d = transformed["mask3d"] # Shape: (14, 120, 120)
  83. >>> transformed_keypoints = transformed["keypoints"] # Keypoints shifted by padding offsets
  84. >>> transformed_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
  85. """
  86. _targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
  87. class InitSchema(Transform3D.InitSchema):
  88. fill: tuple[float, ...] | float
  89. fill_mask: tuple[float, ...] | float
  90. def __init__(
  91. self,
  92. fill: tuple[float, ...] | float = 0,
  93. fill_mask: tuple[float, ...] | float = 0,
  94. p: float = 1.0,
  95. ):
  96. super().__init__(p=p)
  97. self.fill = fill
  98. self.fill_mask = fill_mask
  99. def apply_to_volume(
  100. self,
  101. volume: np.ndarray,
  102. padding: tuple[int, int, int, int, int, int],
  103. **params: Any,
  104. ) -> np.ndarray:
  105. """Apply padding to a 3D volume.
  106. Args:
  107. volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
  108. padding (tuple[int, int, int, int, int, int]): Padding values in format:
  109. (depth_front, depth_back, height_top, height_bottom, width_left, width_right)
  110. **params (Any): Additional parameters
  111. Returns:
  112. np.ndarray: Padded volume with same number of dimensions as input
  113. """
  114. if padding == (0, 0, 0, 0, 0, 0):
  115. return volume
  116. return f3d.pad_3d_with_params(
  117. volume=volume,
  118. padding=padding,
  119. value=self.fill,
  120. )
  121. def apply_to_mask3d(
  122. self,
  123. mask3d: np.ndarray,
  124. padding: tuple[int, int, int, int, int, int],
  125. **params: Any,
  126. ) -> np.ndarray:
  127. """Apply padding to a 3D mask.
  128. Args:
  129. mask3d (np.ndarray): Input mask with shape (depth, height, width) or (depth, height, width, channels)
  130. padding (tuple[int, int, int, int, int, int]): Padding values in format:
  131. (depth_front, depth_back, height_top, height_bottom, width_left, width_right)
  132. **params (Any): Additional parameters
  133. Returns:
  134. np.ndarray: Padded mask with same number of dimensions as input
  135. """
  136. if padding == (0, 0, 0, 0, 0, 0):
  137. return mask3d
  138. return f3d.pad_3d_with_params(
  139. volume=mask3d,
  140. padding=padding,
  141. value=cast("Union[tuple[float, ...], float]", self.fill_mask),
  142. )
  143. def apply_to_keypoints(self, keypoints: np.ndarray, **params: Any) -> np.ndarray:
  144. """Apply padding to keypoints.
  145. Args:
  146. keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
  147. The first three columns are x, y, z coordinates.
  148. **params (Any): Additional parameters containing padding values
  149. Returns:
  150. np.ndarray: Shifted keypoints with same shape as input
  151. """
  152. padding = params["padding"]
  153. shift_vector = np.array([padding[4], padding[2], padding[0]])
  154. return fgeometric.shift_keypoints(keypoints, shift_vector)
  155. class Pad3D(BasePad3D):
  156. """Pad the sides of a 3D volume by specified number of voxels.
  157. Args:
  158. padding (int, tuple[int, int, int] or tuple[int, int, int, int, int, int]): Padding values. Can be:
  159. * int - pad all sides by this value
  160. * tuple[int, int, int] - symmetric padding (depth, height, width) where each value
  161. is applied to both sides of the corresponding dimension
  162. * tuple[int, int, int, int, int, int] - explicit padding per side in order:
  163. (depth_front, depth_back, height_top, height_bottom, width_left, width_right)
  164. fill (tuple[float, ...] | float): Padding value for image
  165. fill_mask (tuple[float, ...] | float): Padding value for mask
  166. p (float): probability of applying the transform. Default: 1.0.
  167. Targets:
  168. volume, mask3d, keypoints
  169. Image types:
  170. uint8, float32
  171. Note:
  172. Input volume should be a numpy array with dimensions ordered as (z, y, x) or (depth, height, width),
  173. with optional channel dimension as the last axis.
  174. Examples:
  175. >>> import numpy as np
  176. >>> import albumentations as A
  177. >>>
  178. >>> # Prepare sample data
  179. >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  180. >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  181. >>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
  182. >>> keypoint_labels = [1, 2] # Labels for each keypoint
  183. >>>
  184. >>> # Create the transform with symmetric padding
  185. >>> transform = A.Compose([
  186. ... A.Pad3D(
  187. ... padding=(2, 5, 10), # (depth, height, width) applied symmetrically
  188. ... fill=0,
  189. ... fill_mask=1,
  190. ... p=1.0
  191. ... )
  192. ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
  193. >>>
  194. >>> # Apply the transform
  195. >>> transformed = transform(
  196. ... volume=volume,
  197. ... mask3d=mask3d,
  198. ... keypoints=keypoints,
  199. ... keypoint_labels=keypoint_labels
  200. ... )
  201. >>>
  202. >>> # Get the transformed data
  203. >>> padded_volume = transformed["volume"] # Shape: (14, 110, 120)
  204. >>> padded_mask3d = transformed["mask3d"] # Shape: (14, 110, 120)
  205. >>> padded_keypoints = transformed["keypoints"] # Keypoints shifted by padding
  206. >>> padded_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
  207. """
  208. class InitSchema(BasePad3D.InitSchema):
  209. padding: int | tuple[int, int, int] | tuple[int, int, int, int, int, int]
  210. @field_validator("padding")
  211. @classmethod
  212. def validate_padding(
  213. cls,
  214. v: int | tuple[int, int, int] | tuple[int, int, int, int, int, int],
  215. ) -> int | tuple[int, int, int] | tuple[int, int, int, int, int, int]:
  216. """Validate the padding parameter.
  217. Args:
  218. cls (type): The class object
  219. v (int | tuple[int, int, int] | tuple[int, int, int, int, int, int]): The padding value to validate,
  220. can be an integer or tuple of integers
  221. Returns:
  222. int | tuple[int, int, int] | tuple[int, int, int, int, int, int]: The validated padding value
  223. Raises:
  224. ValueError: If padding is negative or contains negative values
  225. """
  226. if isinstance(v, int) and v < 0:
  227. raise ValueError("Padding value must be non-negative")
  228. if isinstance(v, tuple) and not all(isinstance(i, int) and i >= 0 for i in v):
  229. raise ValueError("Padding tuple must contain non-negative integers")
  230. return v
  231. def __init__(
  232. self,
  233. padding: int | tuple[int, int, int] | tuple[int, int, int, int, int, int],
  234. fill: tuple[float, ...] | float = 0,
  235. fill_mask: tuple[float, ...] | float = 0,
  236. p: float = 1.0,
  237. ):
  238. super().__init__(fill=fill, fill_mask=fill_mask, p=p)
  239. self.padding = padding
  240. self.fill = fill
  241. self.fill_mask = fill_mask
  242. def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
  243. """Get parameters dependent on input data.
  244. Args:
  245. params (dict[str, Any]): Dictionary of existing parameters
  246. data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
  247. Returns:
  248. dict[str, Any]: Dictionary containing the padding parameter tuple in format:
  249. (depth_front, depth_back, height_top, height_bottom, width_left, width_right)
  250. """
  251. if isinstance(self.padding, int):
  252. pad_d = pad_h = pad_w = self.padding
  253. padding = (pad_d, pad_d, pad_h, pad_h, pad_w, pad_w)
  254. elif len(self.padding) == NUM_DIMENSIONS:
  255. pad_d, pad_h, pad_w = self.padding # type: ignore[misc]
  256. padding = (pad_d, pad_d, pad_h, pad_h, pad_w, pad_w)
  257. else:
  258. padding = self.padding # type: ignore[assignment]
  259. return {"padding": padding}
  260. class PadIfNeeded3D(BasePad3D):
  261. """Pads the sides of a 3D volume if its dimensions are less than specified minimum dimensions.
  262. If the pad_divisor_zyx is specified, the function additionally ensures that the volume
  263. dimensions are divisible by these values.
  264. Args:
  265. min_zyx (tuple[int, int, int] | None): Minimum desired size as (depth, height, width).
  266. Ensures volume dimensions are at least these values.
  267. If not specified, pad_divisor_zyx must be provided.
  268. pad_divisor_zyx (tuple[int, int, int] | None): If set, pads each dimension to make it
  269. divisible by corresponding value in format (depth_div, height_div, width_div).
  270. If not specified, min_zyx must be provided.
  271. position (Literal["center", "random"]): Position where the volume is to be placed after padding.
  272. Default is 'center'.
  273. fill (tuple[float, ...] | float): Value to fill the border voxels for volume. Default: 0
  274. fill_mask (tuple[float, ...] | float): Value to fill the border voxels for masks. Default: 0
  275. p (float): Probability of applying the transform. Default: 1.0
  276. Targets:
  277. volume, mask3d, keypoints
  278. Image types:
  279. uint8, float32
  280. Note:
  281. Input volume should be a numpy array with dimensions ordered as (z, y, x) or (depth, height, width),
  282. with optional channel dimension as the last axis.
  283. Examples:
  284. >>> import numpy as np
  285. >>> import albumentations as A
  286. >>>
  287. >>> # Prepare sample data
  288. >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  289. >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  290. >>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
  291. >>> keypoint_labels = [1, 2] # Labels for each keypoint
  292. >>>
  293. >>> # Create a transform with both min_zyx and pad_divisor_zyx
  294. >>> transform = A.Compose([
  295. ... A.PadIfNeeded3D(
  296. ... min_zyx=(16, 128, 128), # Minimum size (depth, height, width)
  297. ... pad_divisor_zyx=(8, 16, 16), # Make dimensions divisible by these values
  298. ... position="center", # Center the volume in the padded space
  299. ... fill=0, # Fill value for volume
  300. ... fill_mask=1, # Fill value for mask
  301. ... p=1.0
  302. ... )
  303. ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
  304. >>>
  305. >>> # Apply the transform
  306. >>> transformed = transform(
  307. ... volume=volume,
  308. ... mask3d=mask3d,
  309. ... keypoints=keypoints,
  310. ... keypoint_labels=keypoint_labels
  311. ... )
  312. >>>
  313. >>> # Get the transformed data
  314. >>> padded_volume = transformed["volume"] # Shape: (16, 128, 128)
  315. >>> padded_mask3d = transformed["mask3d"] # Shape: (16, 128, 128)
  316. >>> padded_keypoints = transformed["keypoints"] # Keypoints shifted by padding
  317. >>> padded_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
  318. """
  319. class InitSchema(BasePad3D.InitSchema):
  320. min_zyx: Annotated[tuple[int, int, int] | None, AfterValidator(check_range_bounds(0, None))]
  321. pad_divisor_zyx: Annotated[tuple[int, int, int] | None, AfterValidator(check_range_bounds(1, None))]
  322. position: Literal["center", "random"]
  323. @model_validator(mode="after")
  324. def validate_params(self) -> Self:
  325. """Validate that either min_zyx or pad_divisor_zyx is provided.
  326. Returns:
  327. Self: Self reference for method chaining
  328. Raises:
  329. ValueError: If both min_zyx and pad_divisor_zyx are None
  330. """
  331. if self.min_zyx is None and self.pad_divisor_zyx is None:
  332. msg = "At least one of min_zyx or pad_divisor_zyx must be set"
  333. raise ValueError(msg)
  334. return self
  335. def __init__(
  336. self,
  337. min_zyx: tuple[int, int, int] | None = None,
  338. pad_divisor_zyx: tuple[int, int, int] | None = None,
  339. position: Literal["center", "random"] = "center",
  340. fill: tuple[float, ...] | float = 0,
  341. fill_mask: tuple[float, ...] | float = 0,
  342. p: float = 1.0,
  343. ):
  344. super().__init__(fill=fill, fill_mask=fill_mask, p=p)
  345. self.min_zyx = min_zyx
  346. self.pad_divisor_zyx = pad_divisor_zyx
  347. self.position = position
  348. def get_params_dependent_on_data(
  349. self,
  350. params: dict[str, Any],
  351. data: dict[str, Any],
  352. ) -> dict[str, Any]:
  353. """Calculate padding parameters based on input data dimensions.
  354. Args:
  355. params (dict[str, Any]): Dictionary of existing parameters
  356. data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
  357. Returns:
  358. dict[str, Any]: Dictionary containing calculated padding parameters
  359. """
  360. depth, height, width = data["volume"].shape[:3]
  361. sizes = (depth, height, width)
  362. paddings = [
  363. fgeometric.get_dimension_padding(
  364. current_size=size,
  365. min_size=self.min_zyx[i] if self.min_zyx else None,
  366. divisor=self.pad_divisor_zyx[i] if self.pad_divisor_zyx else None,
  367. )
  368. for i, size in enumerate(sizes)
  369. ]
  370. padding = f3d.adjust_padding_by_position3d(
  371. paddings=paddings,
  372. position=self.position,
  373. py_random=self.py_random,
  374. )
  375. return {"padding": padding}
  376. class BaseCropAndPad3D(Transform3D):
  377. """Base class for 3D transforms that need both cropping and padding.
  378. This class serves as a foundation for transforms that combine cropping and padding operations
  379. on 3D volumetric data. It provides functionality for calculating padding parameters,
  380. applying crop and pad operations to volumes, masks, and handling keypoint coordinate shifts.
  381. Args:
  382. pad_if_needed (bool): Whether to pad if the volume is smaller than target dimensions
  383. fill (tuple[float, ...] | float): Value to fill the padded voxels for volume
  384. fill_mask (tuple[float, ...] | float): Value to fill the padded voxels for mask
  385. pad_position (Literal["center", "random"]): How to distribute padding when needed
  386. "center" - equal amount on both sides, "random" - random distribution
  387. p (float): Probability of applying the transform. Default: 1.0
  388. Targets:
  389. volume, mask3d, keypoints
  390. Note:
  391. This is a base class and not intended to be used directly. Use its derivatives
  392. like CenterCrop3D or RandomCrop3D instead, or create a custom transform
  393. by inheriting from this class.
  394. Examples:
  395. >>> import numpy as np
  396. >>> import albumentations as A
  397. >>>
  398. >>> # Example of a custom crop transform inheriting from BaseCropAndPad3D
  399. >>> class CustomFixedCrop3D(A.BaseCropAndPad3D):
  400. ... def __init__(self, crop_size: tuple[int, int, int] = (8, 64, 64), *args, **kwargs):
  401. ... super().__init__(
  402. ... pad_if_needed=True,
  403. ... fill=0,
  404. ... fill_mask=0,
  405. ... pad_position="center",
  406. ... *args,
  407. ... **kwargs
  408. ... )
  409. ... self.crop_size = crop_size
  410. ...
  411. ... def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
  412. ... # Get the volume shape
  413. ... volume = data["volume"]
  414. ... z, h, w = volume.shape[:3]
  415. ... target_z, target_h, target_w = self.crop_size
  416. ...
  417. ... # Check if padding is needed and calculate parameters
  418. ... pad_params = self._get_pad_params(
  419. ... image_shape=(z, h, w),
  420. ... target_shape=self.crop_size,
  421. ... )
  422. ...
  423. ... # Update dimensions if padding is applied
  424. ... if pad_params is not None:
  425. ... z = z + pad_params["pad_front"] + pad_params["pad_back"]
  426. ... h = h + pad_params["pad_top"] + pad_params["pad_bottom"]
  427. ... w = w + pad_params["pad_left"] + pad_params["pad_right"]
  428. ...
  429. ... # Calculate fixed crop coordinates - always start at position (0,0,0)
  430. ... crop_coords = (0, target_z, 0, target_h, 0, target_w)
  431. ...
  432. ... return {
  433. ... "crop_coords": crop_coords,
  434. ... "pad_params": pad_params,
  435. ... }
  436. >>>
  437. >>> # Prepare sample data
  438. >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  439. >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  440. >>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
  441. >>> keypoint_labels = [1, 2] # Labels for each keypoint
  442. >>>
  443. >>> # Use the custom transform in a pipeline
  444. >>> transform = A.Compose([
  445. ... CustomFixedCrop3D(
  446. ... crop_size=(8, 64, 64), # Crop first 8x64x64 voxels (with padding if needed)
  447. ... p=1.0
  448. ... )
  449. ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
  450. >>>
  451. >>> # Apply the transform
  452. >>> transformed = transform(
  453. ... volume=volume,
  454. ... mask3d=mask3d,
  455. ... keypoints=keypoints,
  456. ... keypoint_labels=keypoint_labels
  457. ... )
  458. >>>
  459. >>> # Get the transformed data
  460. >>> cropped_volume = transformed["volume"] # Shape: (8, 64, 64)
  461. >>> cropped_mask3d = transformed["mask3d"] # Shape: (8, 64, 64)
  462. >>> cropped_keypoints = transformed["keypoints"] # Keypoints shifted relative to crop
  463. >>> cropped_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
  464. """
  465. _targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
  466. class InitSchema(Transform3D.InitSchema):
  467. pad_if_needed: bool
  468. fill: tuple[float, ...] | float
  469. fill_mask: tuple[float, ...] | float
  470. pad_position: Literal["center", "random"]
  471. def __init__(
  472. self,
  473. pad_if_needed: bool,
  474. fill: tuple[float, ...] | float,
  475. fill_mask: tuple[float, ...] | float,
  476. pad_position: Literal["center", "random"],
  477. p: float = 1.0,
  478. ):
  479. super().__init__(p=p)
  480. self.pad_if_needed = pad_if_needed
  481. self.fill = fill
  482. self.fill_mask = fill_mask
  483. self.pad_position = pad_position
  484. def _random_pad(self, pad: int) -> tuple[int, int]:
  485. """Generate random padding values.
  486. Args:
  487. pad (int): Total padding value to distribute
  488. Returns:
  489. tuple[int, int]: Random padding values (front, back)
  490. """
  491. if pad > 0:
  492. pad_start = self.py_random.randint(0, pad)
  493. pad_end = pad - pad_start
  494. else:
  495. pad_start = pad_end = 0
  496. return pad_start, pad_end
  497. def _center_pad(self, pad: int) -> tuple[int, int]:
  498. """Generate centered padding values.
  499. Args:
  500. pad (int): Total padding value to distribute
  501. Returns:
  502. tuple[int, int]: Centered padding values (front, back)
  503. """
  504. pad_start = pad // 2
  505. pad_end = pad - pad_start
  506. return pad_start, pad_end
  507. def _get_pad_params(
  508. self,
  509. image_shape: tuple[int, int, int],
  510. target_shape: tuple[int, int, int],
  511. ) -> dict[str, int] | None:
  512. """Calculate padding parameters to reach target shape.
  513. Args:
  514. image_shape (tuple[int, int, int]): Current shape (depth, height, width)
  515. target_shape (tuple[int, int, int]): Target shape (depth, height, width)
  516. Returns:
  517. dict[str, int] | None: Padding parameters or None if no padding needed
  518. """
  519. if not self.pad_if_needed:
  520. return None
  521. z, h, w = image_shape
  522. target_z, target_h, target_w = target_shape
  523. # Calculate total padding needed for each dimension
  524. z_pad = max(0, target_z - z)
  525. h_pad = max(0, target_h - h)
  526. w_pad = max(0, target_w - w)
  527. if z_pad == 0 and h_pad == 0 and w_pad == 0:
  528. return None
  529. # For center padding, split equally
  530. if self.pad_position == "center":
  531. z_front, z_back = self._center_pad(z_pad)
  532. h_top, h_bottom = self._center_pad(h_pad)
  533. w_left, w_right = self._center_pad(w_pad)
  534. # For random padding, randomly distribute the padding
  535. else: # random
  536. z_front, z_back = self._random_pad(z_pad)
  537. h_top, h_bottom = self._random_pad(h_pad)
  538. w_left, w_right = self._random_pad(w_pad)
  539. return {
  540. "pad_front": z_front,
  541. "pad_back": z_back,
  542. "pad_top": h_top,
  543. "pad_bottom": h_bottom,
  544. "pad_left": w_left,
  545. "pad_right": w_right,
  546. }
  547. def apply_to_volume(
  548. self,
  549. volume: np.ndarray,
  550. crop_coords: tuple[int, int, int, int, int, int],
  551. pad_params: dict[str, int] | None,
  552. **params: Any,
  553. ) -> np.ndarray:
  554. """Apply cropping and padding to a 3D volume.
  555. Args:
  556. volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
  557. crop_coords (tuple[int, int, int, int, int, int]): Crop coordinates (z1, z2, y1, y2, x1, x2)
  558. pad_params (dict[str, int] | None): Padding parameters or None if no padding needed
  559. **params (Any): Additional parameters
  560. Returns:
  561. np.ndarray: Cropped and padded volume with same number of dimensions as input
  562. """
  563. # First crop
  564. cropped = f3d.crop3d(volume, crop_coords)
  565. # Then pad if needed
  566. if pad_params is not None:
  567. padding = (
  568. pad_params["pad_front"],
  569. pad_params["pad_back"],
  570. pad_params["pad_top"],
  571. pad_params["pad_bottom"],
  572. pad_params["pad_left"],
  573. pad_params["pad_right"],
  574. )
  575. return f3d.pad_3d_with_params(
  576. cropped,
  577. padding=padding,
  578. value=self.fill,
  579. )
  580. return cropped
  581. def apply_to_mask3d(
  582. self,
  583. mask3d: np.ndarray,
  584. crop_coords: tuple[int, int, int, int, int, int],
  585. pad_params: dict[str, int] | None,
  586. **params: Any,
  587. ) -> np.ndarray:
  588. """Apply cropping and padding to a 3D mask.
  589. Args:
  590. mask3d (np.ndarray): Input mask with shape (depth, height, width) or (depth, height, width, channels)
  591. crop_coords (tuple[int, int, int, int, int, int]): Crop coordinates (z1, z2, y1, y2, x1, x2)
  592. pad_params (dict[str, int] | None): Padding parameters or None if no padding needed
  593. **params (Any): Additional parameters
  594. Returns:
  595. np.ndarray: Cropped and padded mask with same number of dimensions as input
  596. """
  597. # First crop
  598. cropped = f3d.crop3d(mask3d, crop_coords)
  599. # Then pad if needed
  600. if pad_params is not None:
  601. padding = (
  602. pad_params["pad_front"],
  603. pad_params["pad_back"],
  604. pad_params["pad_top"],
  605. pad_params["pad_bottom"],
  606. pad_params["pad_left"],
  607. pad_params["pad_right"],
  608. )
  609. return f3d.pad_3d_with_params(
  610. cropped,
  611. padding=padding,
  612. value=cast("Union[tuple[float, ...], float]", self.fill_mask),
  613. )
  614. return cropped
  615. def apply_to_keypoints(
  616. self,
  617. keypoints: np.ndarray,
  618. crop_coords: tuple[int, int, int, int, int, int],
  619. pad_params: dict[str, int] | None,
  620. **params: Any,
  621. ) -> np.ndarray:
  622. """Apply cropping and padding to keypoints.
  623. Args:
  624. keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
  625. The first three columns are x, y, z coordinates.
  626. crop_coords (tuple[int, int, int, int, int, int]): Crop coordinates (z1, z2, y1, y2, x1, x2)
  627. pad_params (dict[str, int] | None): Padding parameters or None if no padding needed
  628. **params (Any): Additional parameters
  629. Returns:
  630. np.ndarray: Shifted keypoints with same shape as input
  631. """
  632. # Extract crop start coordinates (z1,y1,x1)
  633. crop_z1, _, crop_y1, _, crop_x1, _ = crop_coords
  634. # Initialize shift vector with negative crop coordinates
  635. shift = np.array(
  636. [
  637. -crop_x1, # X shift
  638. -crop_y1, # Y shift
  639. -crop_z1, # Z shift
  640. ],
  641. )
  642. # Add padding shift if needed
  643. if pad_params is not None:
  644. shift += np.array(
  645. [
  646. pad_params["pad_left"], # X shift
  647. pad_params["pad_top"], # Y shift
  648. pad_params["pad_front"], # Z shift
  649. ],
  650. )
  651. # Apply combined shift
  652. return fgeometric.shift_keypoints(keypoints, shift)
  653. class CenterCrop3D(BaseCropAndPad3D):
  654. """Crop the center of 3D volume.
  655. Args:
  656. size (tuple[int, int, int]): Desired output size of the crop in format (depth, height, width)
  657. pad_if_needed (bool): Whether to pad if the volume is smaller than desired crop size. Default: False
  658. fill (tuple[float, float] | float): Padding value for image if pad_if_needed is True. Default: 0
  659. fill_mask (tuple[float, float] | float): Padding value for mask if pad_if_needed is True. Default: 0
  660. p (float): probability of applying the transform. Default: 1.0
  661. Targets:
  662. volume, mask3d, keypoints
  663. Image types:
  664. uint8, float32
  665. Note:
  666. If you want to perform cropping only in the XY plane while preserving all slices along
  667. the Z axis, consider using CenterCrop instead. CenterCrop will apply the same XY crop
  668. to each slice independently, maintaining the full depth of the volume.
  669. Examples:
  670. >>> import numpy as np
  671. >>> import albumentations as A
  672. >>>
  673. >>> # Prepare sample data
  674. >>> volume = np.random.randint(0, 256, (20, 200, 200), dtype=np.uint8) # (D, H, W)
  675. >>> mask3d = np.random.randint(0, 2, (20, 200, 200), dtype=np.uint8) # (D, H, W)
  676. >>> keypoints = np.array([[100, 100, 10], [150, 150, 15]], dtype=np.float32) # (x, y, z)
  677. >>> keypoint_labels = [1, 2] # Labels for each keypoint
  678. >>>
  679. >>> # Create the transform - crop to 16x128x128 from center
  680. >>> transform = A.Compose([
  681. ... A.CenterCrop3D(
  682. ... size=(16, 128, 128), # Output size (depth, height, width)
  683. ... pad_if_needed=True, # Pad if input is smaller than crop size
  684. ... fill=0, # Fill value for volume padding
  685. ... fill_mask=1, # Fill value for mask padding
  686. ... p=1.0
  687. ... )
  688. ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
  689. >>>
  690. >>> # Apply the transform
  691. >>> transformed = transform(
  692. ... volume=volume,
  693. ... mask3d=mask3d,
  694. ... keypoints=keypoints,
  695. ... keypoint_labels=keypoint_labels
  696. ... )
  697. >>>
  698. >>> # Get the transformed data
  699. >>> cropped_volume = transformed["volume"] # Shape: (16, 128, 128)
  700. >>> cropped_mask3d = transformed["mask3d"] # Shape: (16, 128, 128)
  701. >>> cropped_keypoints = transformed["keypoints"] # Keypoints shifted relative to center crop
  702. >>> cropped_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
  703. >>>
  704. >>> # Example with a small volume that requires padding
  705. >>> small_volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8)
  706. >>> small_transform = A.Compose([
  707. ... A.CenterCrop3D(
  708. ... size=(16, 128, 128),
  709. ... pad_if_needed=True, # Will pad since the input is smaller
  710. ... fill=0,
  711. ... p=1.0
  712. ... )
  713. ... ])
  714. >>> small_result = small_transform(volume=small_volume)
  715. >>> padded_and_cropped = small_result["volume"] # Shape: (16, 128, 128), padded to size
  716. """
  717. class InitSchema(BaseTransformInitSchema):
  718. size: Annotated[tuple[int, int, int], AfterValidator(check_range_bounds(1, None))]
  719. pad_if_needed: bool
  720. fill: tuple[float, ...] | float
  721. fill_mask: tuple[float, ...] | float
  722. def __init__(
  723. self,
  724. size: tuple[int, int, int],
  725. pad_if_needed: bool = False,
  726. fill: tuple[float, ...] | float = 0,
  727. fill_mask: tuple[float, ...] | float = 0,
  728. p: float = 1.0,
  729. ):
  730. super().__init__(
  731. pad_if_needed=pad_if_needed,
  732. fill=fill,
  733. fill_mask=fill_mask,
  734. pad_position="center", # Center crop always uses center padding
  735. p=p,
  736. )
  737. self.size = size
  738. def get_params_dependent_on_data(
  739. self,
  740. params: dict[str, Any],
  741. data: dict[str, Any],
  742. ) -> dict[str, Any]:
  743. """Calculate crop coordinates for center cropping.
  744. Args:
  745. params (dict[str, Any]): Dictionary of existing parameters
  746. data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
  747. Returns:
  748. dict[str, Any]: Dictionary containing crop coordinates and optional padding parameters
  749. """
  750. volume = data["volume"]
  751. z, h, w = volume.shape[:3]
  752. target_z, target_h, target_w = self.size
  753. # Get padding params if needed
  754. pad_params = self._get_pad_params(
  755. image_shape=(z, h, w),
  756. target_shape=self.size,
  757. )
  758. # Update dimensions if padding is applied
  759. if pad_params is not None:
  760. z = z + pad_params["pad_front"] + pad_params["pad_back"]
  761. h = h + pad_params["pad_top"] + pad_params["pad_bottom"]
  762. w = w + pad_params["pad_left"] + pad_params["pad_right"]
  763. # Validate dimensions after padding
  764. if z < target_z or h < target_h or w < target_w:
  765. msg = (
  766. f"Crop size {self.size} is larger than padded image size ({z}, {h}, {w}). "
  767. f"This should not happen - please report this as a bug."
  768. )
  769. raise ValueError(msg)
  770. # For CenterCrop3D:
  771. z_start = (z - target_z) // 2
  772. h_start = (h - target_h) // 2
  773. w_start = (w - target_w) // 2
  774. crop_coords = (
  775. z_start,
  776. z_start + target_z,
  777. h_start,
  778. h_start + target_h,
  779. w_start,
  780. w_start + target_w,
  781. )
  782. return {
  783. "crop_coords": crop_coords,
  784. "pad_params": pad_params,
  785. }
  786. class RandomCrop3D(BaseCropAndPad3D):
  787. """Crop random part of 3D volume.
  788. Args:
  789. size (tuple[int, int, int]): Desired output size of the crop in format (depth, height, width)
  790. pad_if_needed (bool): Whether to pad if the volume is smaller than desired crop size. Default: False
  791. fill (tuple[float, float] | float): Padding value for image if pad_if_needed is True. Default: 0
  792. fill_mask (tuple[float, float] | float): Padding value for mask if pad_if_needed is True. Default: 0
  793. p (float): probability of applying the transform. Default: 1.0
  794. Targets:
  795. volume, mask3d, keypoints
  796. Image types:
  797. uint8, float32
  798. Note:
  799. If you want to perform random cropping only in the XY plane while preserving all slices along
  800. the Z axis, consider using RandomCrop instead. RandomCrop will apply the same XY crop
  801. to each slice independently, maintaining the full depth of the volume.
  802. Examples:
  803. >>> import numpy as np
  804. >>> import albumentations as A
  805. >>>
  806. >>> # Prepare sample data
  807. >>> volume = np.random.randint(0, 256, (20, 200, 200), dtype=np.uint8) # (D, H, W)
  808. >>> mask3d = np.random.randint(0, 2, (20, 200, 200), dtype=np.uint8) # (D, H, W)
  809. >>> keypoints = np.array([[100, 100, 10], [150, 150, 15]], dtype=np.float32) # (x, y, z)
  810. >>> keypoint_labels = [1, 2] # Labels for each keypoint
  811. >>>
  812. >>> # Create the transform with random crop and padding if needed
  813. >>> transform = A.Compose([
  814. ... A.RandomCrop3D(
  815. ... size=(16, 128, 128), # Output size (depth, height, width)
  816. ... pad_if_needed=True, # Pad if input is smaller than crop size
  817. ... fill=0, # Fill value for volume padding
  818. ... fill_mask=1, # Fill value for mask padding
  819. ... p=1.0
  820. ... )
  821. ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
  822. >>>
  823. >>> # Apply the transform
  824. >>> transformed = transform(
  825. ... volume=volume,
  826. ... mask3d=mask3d,
  827. ... keypoints=keypoints,
  828. ... keypoint_labels=keypoint_labels
  829. ... )
  830. >>>
  831. >>> # Get the transformed data
  832. >>> cropped_volume = transformed["volume"] # Shape: (16, 128, 128)
  833. >>> cropped_mask3d = transformed["mask3d"] # Shape: (16, 128, 128)
  834. >>> cropped_keypoints = transformed["keypoints"] # Keypoints shifted relative to random crop
  835. >>> cropped_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
  836. """
  837. class InitSchema(BaseTransformInitSchema):
  838. size: Annotated[tuple[int, int, int], AfterValidator(check_range_bounds(1, None))]
  839. pad_if_needed: bool
  840. fill: tuple[float, ...] | float
  841. fill_mask: tuple[float, ...] | float
  842. def __init__(
  843. self,
  844. size: tuple[int, int, int],
  845. pad_if_needed: bool = False,
  846. fill: tuple[float, ...] | float = 0,
  847. fill_mask: tuple[float, ...] | float = 0,
  848. p: float = 1.0,
  849. ):
  850. super().__init__(
  851. pad_if_needed=pad_if_needed,
  852. fill=fill,
  853. fill_mask=fill_mask,
  854. pad_position="random", # Random crop uses random padding position
  855. p=p,
  856. )
  857. self.size = size
  858. def get_params_dependent_on_data(
  859. self,
  860. params: dict[str, Any],
  861. data: dict[str, Any],
  862. ) -> dict[str, Any]:
  863. """Calculate random crop coordinates.
  864. Args:
  865. params (dict[str, Any]): Dictionary of existing parameters
  866. data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
  867. Returns:
  868. dict[str, Any]: Dictionary containing randomly generated crop coordinates and optional padding parameters
  869. """
  870. volume = data["volume"]
  871. z, h, w = volume.shape[:3]
  872. target_z, target_h, target_w = self.size
  873. # Get padding params if needed
  874. pad_params = self._get_pad_params(
  875. image_shape=(z, h, w),
  876. target_shape=self.size,
  877. )
  878. # Update dimensions if padding is applied
  879. if pad_params is not None:
  880. z = z + pad_params["pad_front"] + pad_params["pad_back"]
  881. h = h + pad_params["pad_top"] + pad_params["pad_bottom"]
  882. w = w + pad_params["pad_left"] + pad_params["pad_right"]
  883. # Calculate random crop coordinates
  884. z_start = self.py_random.randint(0, max(0, z - target_z))
  885. h_start = self.py_random.randint(0, max(0, h - target_h))
  886. w_start = self.py_random.randint(0, max(0, w - target_w))
  887. crop_coords = (
  888. z_start,
  889. z_start + target_z,
  890. h_start,
  891. h_start + target_h,
  892. w_start,
  893. w_start + target_w,
  894. )
  895. return {
  896. "crop_coords": crop_coords,
  897. "pad_params": pad_params,
  898. }
  899. class CoarseDropout3D(Transform3D):
  900. """CoarseDropout3D randomly drops out cuboid regions from a 3D volume and optionally,
  901. the corresponding regions in an associated 3D mask, to simulate occlusion and
  902. varied object sizes found in real-world volumetric data.
  903. Args:
  904. num_holes_range (tuple[int, int]): Range (min, max) for the number of cuboid
  905. regions to drop out. Default: (1, 1)
  906. hole_depth_range (tuple[float, float]): Range (min, max) for the depth
  907. of dropout regions as a fraction of the volume depth (between 0 and 1). Default: (0.1, 0.2)
  908. hole_height_range (tuple[float, float]): Range (min, max) for the height
  909. of dropout regions as a fraction of the volume height (between 0 and 1). Default: (0.1, 0.2)
  910. hole_width_range (tuple[float, float]): Range (min, max) for the width
  911. of dropout regions as a fraction of the volume width (between 0 and 1). Default: (0.1, 0.2)
  912. fill (tuple[float, float] | float): Value for the dropped voxels. Can be:
  913. - int or float: all channels are filled with this value
  914. - tuple: tuple of values for each channel
  915. Default: 0
  916. fill_mask (tuple[float, float] | float | None): Fill value for dropout regions in the 3D mask.
  917. If None, mask regions corresponding to volume dropouts are unchanged. Default: None
  918. p (float): Probability of applying the transform. Default: 0.5
  919. Targets:
  920. volume, mask3d, keypoints
  921. Image types:
  922. uint8, float32
  923. Note:
  924. - The actual number and size of dropout regions are randomly chosen within the specified ranges.
  925. - All values in hole_depth_range, hole_height_range and hole_width_range must be between 0 and 1.
  926. - If you want to apply dropout only in the XY plane while preserving the full depth dimension,
  927. consider using CoarseDropout instead. CoarseDropout will apply the same rectangular dropout
  928. to each slice independently, effectively creating cylindrical dropout regions that extend
  929. through the entire depth of the volume.
  930. Examples:
  931. >>> import numpy as np
  932. >>> import albumentations as A
  933. >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  934. >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  935. >>> aug = A.CoarseDropout3D(
  936. ... num_holes_range=(3, 6),
  937. ... hole_depth_range=(0.1, 0.2),
  938. ... hole_height_range=(0.1, 0.2),
  939. ... hole_width_range=(0.1, 0.2),
  940. ... fill=0,
  941. ... p=1.0
  942. ... )
  943. >>> transformed = aug(volume=volume, mask3d=mask3d)
  944. >>> transformed_volume, transformed_mask3d = transformed["volume"], transformed["mask3d"]
  945. """
  946. _targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
  947. class InitSchema(Transform3D.InitSchema):
  948. num_holes_range: Annotated[
  949. tuple[int, int],
  950. AfterValidator(check_range_bounds(0, None)),
  951. AfterValidator(nondecreasing),
  952. ]
  953. hole_depth_range: Annotated[
  954. tuple[float, float],
  955. AfterValidator(check_range_bounds(0, 1)),
  956. AfterValidator(nondecreasing),
  957. ]
  958. hole_height_range: Annotated[
  959. tuple[float, float],
  960. AfterValidator(check_range_bounds(0, 1)),
  961. AfterValidator(nondecreasing),
  962. ]
  963. hole_width_range: Annotated[
  964. tuple[float, float],
  965. AfterValidator(check_range_bounds(0, 1)),
  966. AfterValidator(nondecreasing),
  967. ]
  968. fill: tuple[float, ...] | float
  969. fill_mask: tuple[float, ...] | float | None
  970. @staticmethod
  971. def validate_range(range_value: tuple[float, float], range_name: str) -> None:
  972. """Validate that range values are between 0 and 1 and in non-decreasing order.
  973. Args:
  974. range_value (tuple[float, float]): Tuple of (min, max) values to check
  975. range_name (str): Name of the range for error reporting
  976. Raises:
  977. ValueError: If range values are invalid
  978. """
  979. if not 0 <= range_value[0] <= range_value[1] <= 1:
  980. raise ValueError(
  981. f"All values in {range_name} should be in [0, 1] range and first value "
  982. f"should be less or equal than the second value. Got: {range_value}",
  983. )
  984. @model_validator(mode="after")
  985. def _check_ranges(self) -> Self:
  986. self.validate_range(self.hole_depth_range, "hole_depth_range")
  987. self.validate_range(self.hole_height_range, "hole_height_range")
  988. self.validate_range(self.hole_width_range, "hole_width_range")
  989. return self
  990. def __init__(
  991. self,
  992. num_holes_range: tuple[int, int] = (1, 1),
  993. hole_depth_range: tuple[float, float] = (0.1, 0.2),
  994. hole_height_range: tuple[float, float] = (0.1, 0.2),
  995. hole_width_range: tuple[float, float] = (0.1, 0.2),
  996. fill: tuple[float, ...] | float = 0,
  997. fill_mask: tuple[float, ...] | float | None = None,
  998. p: float = 0.5,
  999. ):
  1000. super().__init__(p=p)
  1001. self.num_holes_range = num_holes_range
  1002. self.hole_depth_range = hole_depth_range
  1003. self.hole_height_range = hole_height_range
  1004. self.hole_width_range = hole_width_range
  1005. self.fill = fill
  1006. self.fill_mask = fill_mask
  1007. def calculate_hole_dimensions(
  1008. self,
  1009. volume_shape: tuple[int, int, int],
  1010. depth_range: tuple[float, float],
  1011. height_range: tuple[float, float],
  1012. width_range: tuple[float, float],
  1013. size: int,
  1014. ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
  1015. """Calculate dimensions for dropout holes.
  1016. Args:
  1017. volume_shape (tuple[int, int, int]): Shape of the volume (depth, height, width)
  1018. depth_range (tuple[float, float]): Range for hole depth as fraction of volume depth
  1019. height_range (tuple[float, float]): Range for hole height as fraction of volume height
  1020. width_range (tuple[float, float]): Range for hole width as fraction of volume width
  1021. size (int): Number of holes to generate
  1022. Returns:
  1023. tuple[np.ndarray, np.ndarray, np.ndarray]: Arrays of hole dimensions (depths, heights, widths)
  1024. """
  1025. depth, height, width = volume_shape[:3]
  1026. hole_depths = np.maximum(1, np.ceil(depth * self.random_generator.uniform(*depth_range, size=size))).astype(int)
  1027. hole_heights = np.maximum(1, np.ceil(height * self.random_generator.uniform(*height_range, size=size))).astype(
  1028. int,
  1029. )
  1030. hole_widths = np.maximum(1, np.ceil(width * self.random_generator.uniform(*width_range, size=size))).astype(int)
  1031. return hole_depths, hole_heights, hole_widths
  1032. def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
  1033. """Generate parameters for coarse dropout based on input data.
  1034. Args:
  1035. params (dict[str, Any]): Dictionary of existing parameters
  1036. data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
  1037. Returns:
  1038. dict[str, Any]: Dictionary containing generated hole parameters for dropout
  1039. """
  1040. volume_shape = data["volume"].shape[:3]
  1041. num_holes = self.py_random.randint(*self.num_holes_range)
  1042. hole_depths, hole_heights, hole_widths = self.calculate_hole_dimensions(
  1043. volume_shape,
  1044. self.hole_depth_range,
  1045. self.hole_height_range,
  1046. self.hole_width_range,
  1047. size=num_holes,
  1048. )
  1049. depth, height, width = volume_shape[:3]
  1050. z_min = self.random_generator.integers(0, depth - hole_depths + 1, size=num_holes)
  1051. y_min = self.random_generator.integers(0, height - hole_heights + 1, size=num_holes)
  1052. x_min = self.random_generator.integers(0, width - hole_widths + 1, size=num_holes)
  1053. z_max = z_min + hole_depths
  1054. y_max = y_min + hole_heights
  1055. x_max = x_min + hole_widths
  1056. holes = np.stack([z_min, y_min, x_min, z_max, y_max, x_max], axis=-1)
  1057. return {"holes": holes}
  1058. def apply_to_volume(self, volume: np.ndarray, holes: np.ndarray, **params: Any) -> np.ndarray:
  1059. """Apply dropout to a 3D volume.
  1060. Args:
  1061. volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
  1062. holes (np.ndarray): Array of holes with shape (num_holes, 6).
  1063. Each hole is represented as [z1, y1, x1, z2, y2, x2]
  1064. **params (Any): Additional parameters
  1065. Returns:
  1066. np.ndarray: Volume with holes filled with the given value
  1067. """
  1068. if holes.size == 0:
  1069. return volume
  1070. return f3d.cutout3d(volume, holes, self.fill)
  1071. def apply_to_mask(self, mask: np.ndarray, holes: np.ndarray, **params: Any) -> np.ndarray:
  1072. """Apply dropout to a 3D mask.
  1073. Args:
  1074. mask (np.ndarray): Input mask with shape (depth, height, width) or (depth, height, width, channels)
  1075. holes (np.ndarray): Array of holes with shape (num_holes, 6).
  1076. Each hole is represented as [z1, y1, x1, z2, y2, x2]
  1077. **params (Any): Additional parameters
  1078. Returns:
  1079. np.ndarray: Mask with holes filled with the given value
  1080. """
  1081. if self.fill_mask is None or holes.size == 0:
  1082. return mask
  1083. return f3d.cutout3d(mask, holes, self.fill_mask)
  1084. def apply_to_keypoints(
  1085. self,
  1086. keypoints: np.ndarray,
  1087. holes: np.ndarray,
  1088. **params: Any,
  1089. ) -> np.ndarray:
  1090. """Apply dropout to keypoints.
  1091. Args:
  1092. keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
  1093. The first three columns are x, y, z coordinates.
  1094. holes (np.ndarray): Array of holes with shape (num_holes, 6).
  1095. Each hole is represented as [z1, y1, x1, z2, y2, x2]
  1096. **params (Any): Additional parameters
  1097. Returns:
  1098. np.ndarray: Filtered keypoints with same shape as input
  1099. """
  1100. if holes.size == 0:
  1101. return keypoints
  1102. processor = cast("KeypointsProcessor", self.get_processor("keypoints"))
  1103. if processor is None or not processor.params.remove_invisible:
  1104. return keypoints
  1105. return f3d.filter_keypoints_in_holes3d(keypoints, holes)
  1106. class CubicSymmetry(Transform3D):
  1107. """Applies a random cubic symmetry transformation to a 3D volume.
  1108. This transform is a 3D extension of D4. While D4 handles the 8 symmetries
  1109. of a square (4 rotations x 2 reflections), CubicSymmetry handles all 48 symmetries of a cube.
  1110. Like D4, this transform does not create any interpolation artifacts as it only remaps voxels
  1111. from one position to another without any interpolation.
  1112. The 48 transformations consist of:
  1113. - 24 rotations (orientation-preserving):
  1114. * 4 rotations around each face diagonal (6 face diagonals x 4 rotations = 24)
  1115. - 24 rotoreflections (orientation-reversing):
  1116. * Reflection through a plane followed by any of the 24 rotations
  1117. For a cube, these transformations preserve:
  1118. - All face centers (6)
  1119. - All vertex positions (8)
  1120. - All edge centers (12)
  1121. works with 3D volumes and masks of the shape (D, H, W) or (D, H, W, C)
  1122. Args:
  1123. p (float): Probability of applying the transform. Default: 1.0
  1124. Targets:
  1125. volume, mask3d, keypoints
  1126. Image types:
  1127. uint8, float32
  1128. Note:
  1129. - This transform is particularly useful for data augmentation in 3D medical imaging,
  1130. crystallography, and voxel-based 3D modeling where the object's orientation
  1131. is arbitrary.
  1132. - All transformations preserve the object's chirality (handedness) when using
  1133. pure rotations (indices 0-23) and invert it when using rotoreflections
  1134. (indices 24-47).
  1135. Examples:
  1136. >>> import numpy as np
  1137. >>> import albumentations as A
  1138. >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  1139. >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
  1140. >>> transform = A.CubicSymmetry(p=1.0)
  1141. >>> transformed = transform(volume=volume, mask3d=mask3d)
  1142. >>> transformed_volume = transformed["volume"]
  1143. >>> transformed_mask3d = transformed["mask3d"]
  1144. See Also:
  1145. - D4: The 2D version that handles the 8 symmetries of a square
  1146. """
  1147. _targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
  1148. def __init__(
  1149. self,
  1150. p: float = 1.0,
  1151. ):
  1152. super().__init__(p=p)
  1153. def get_params_dependent_on_data(
  1154. self,
  1155. params: dict[str, Any],
  1156. data: dict[str, Any],
  1157. ) -> dict[str, Any]:
  1158. """Generate parameters for cubic symmetry transformation.
  1159. Args:
  1160. params (dict[str, Any]): Dictionary of existing parameters
  1161. data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
  1162. Returns:
  1163. dict[str, Any]: Dictionary containing the randomly selected transformation index
  1164. """
  1165. # Randomly select one of 48 possible transformations
  1166. volume_shape = data["volume"].shape
  1167. return {"index": self.py_random.randint(0, 47), "volume_shape": volume_shape}
  1168. def apply_to_volume(self, volume: np.ndarray, index: int, **params: Any) -> np.ndarray:
  1169. """Apply cubic symmetry transformation to a 3D volume.
  1170. Args:
  1171. volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
  1172. index (int): Index of the transformation to apply (0-47)
  1173. **params (Any): Additional parameters
  1174. Returns:
  1175. np.ndarray: Transformed volume with same shape as input
  1176. """
  1177. return f3d.transform_cube(volume, index)
  1178. def apply_to_keypoints(self, keypoints: np.ndarray, index: int, **params: Any) -> np.ndarray:
  1179. """Apply cubic symmetry transformation to keypoints.
  1180. Args:
  1181. keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
  1182. The first three columns are x, y, z coordinates.
  1183. index (int): Index of the transformation to apply (0-47)
  1184. **params (Any): Additional parameters
  1185. Returns:
  1186. np.ndarray: Transformed keypoints with same shape as input
  1187. """
  1188. return f3d.transform_cube_keypoints(keypoints, index, volume_shape=params["volume_shape"])