composition.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470
  1. """Module for composing multiple transforms into augmentation pipelines.
  2. This module provides classes for combining multiple transformations into cohesive
  3. augmentation pipelines. It includes various composition strategies such as sequential
  4. application, random selection, and conditional application of transforms. These
  5. composition classes handle the coordination between different transforms, ensuring
  6. proper data flow and maintaining consistent behavior across the augmentation pipeline.
  7. """
  8. from __future__ import annotations
  9. import random
  10. import warnings
  11. from collections import defaultdict
  12. from collections.abc import Iterator, Sequence
  13. from typing import Any, Union, cast
  14. import cv2
  15. import numpy as np
  16. from .bbox_utils import BboxParams, BboxProcessor
  17. from .hub_mixin import HubMixin
  18. from .keypoints_utils import KeypointParams, KeypointsProcessor
  19. from .serialization import (
  20. SERIALIZABLE_REGISTRY,
  21. Serializable,
  22. get_shortest_class_fullname,
  23. instantiate_nonserializable,
  24. )
  25. from .transforms_interface import BasicTransform
  26. from .utils import DataProcessor, format_args, get_shape
  27. __all__ = [
  28. "BaseCompose",
  29. "BboxParams",
  30. "Compose",
  31. "KeypointParams",
  32. "OneOf",
  33. "OneOrOther",
  34. "RandomOrder",
  35. "ReplayCompose",
  36. "SelectiveChannelTransform",
  37. "Sequential",
  38. "SomeOf",
  39. ]
  40. NUM_ONEOF_TRANSFORMS = 2
  41. REPR_INDENT_STEP = 2
  42. TransformType = Union[BasicTransform, "BaseCompose"]
  43. TransformsSeqType = list[TransformType]
  44. AVAILABLE_KEYS = ("image", "mask", "masks", "bboxes", "keypoints", "volume", "volumes", "mask3d", "masks3d")
  45. MASK_KEYS = (
  46. "mask", # 2D mask
  47. "masks", # Multiple 2D masks
  48. "mask3d", # 3D mask
  49. "masks3d", # Multiple 3D masks
  50. )
  51. # Keys related to image data
  52. IMAGE_KEYS = {"image", "images"}
  53. CHECKED_SINGLE = {"image", "mask"}
  54. CHECKED_MULTI = {"masks", "images", "volumes", "masks3d"}
  55. CHECK_BBOX_PARAM = {"bboxes"}
  56. CHECK_KEYPOINTS_PARAM = {"keypoints"}
  57. VOLUME_KEYS = {"volume", "volumes"}
  58. CHECKED_VOLUME = {"volume"}
  59. CHECKED_VOLUMES = {"volumes"}
  60. CHECKED_MASK3D = {"mask3d"}
  61. CHECKED_MASKS3D = {"masks3d"}
  62. class BaseCompose(Serializable):
  63. """Base class for composing multiple transforms together.
  64. This class serves as a foundation for creating compositions of transforms
  65. in the Albumentations library. It provides basic functionality for
  66. managing a sequence of transforms and applying them to data.
  67. Attributes:
  68. transforms (List[TransformType]): A list of transforms to be applied.
  69. p (float): Probability of applying the compose. Should be in the range [0, 1].
  70. replay_mode (bool): If True, the compose is in replay mode.
  71. _additional_targets (Dict[str, str]): Additional targets for transforms.
  72. _available_keys (Set[str]): Set of available keys for data.
  73. processors (Dict[str, Union[BboxProcessor, KeypointsProcessor]]): Processors for specific data types.
  74. Args:
  75. transforms (TransformsSeqType): A sequence of transforms to compose.
  76. p (float): Probability of applying the compose.
  77. Raises:
  78. ValueError: If an invalid additional target is specified.
  79. Note:
  80. - Subclasses should implement the __call__ method to define how
  81. the composition is applied to data.
  82. - The class supports serialization and deserialization of transforms.
  83. - It provides methods for adding targets, setting deterministic behavior,
  84. and checking data validity post-transform.
  85. """
  86. _transforms_dict: dict[int, BasicTransform] | None = None
  87. check_each_transform: tuple[DataProcessor, ...] | None = None
  88. main_compose: bool = True
  89. def __init__(
  90. self,
  91. transforms: TransformsSeqType,
  92. p: float,
  93. mask_interpolation: int | None = None,
  94. seed: int | None = None,
  95. save_applied_params: bool = False,
  96. **kwargs: Any,
  97. ):
  98. if isinstance(transforms, (BaseCompose, BasicTransform)):
  99. warnings.warn(
  100. "transforms is single transform, but a sequence is expected! Transform will be wrapped into list.",
  101. stacklevel=2,
  102. )
  103. transforms = [transforms]
  104. self.transforms = transforms
  105. self.p = p
  106. self.replay_mode = False
  107. self._additional_targets: dict[str, str] = {}
  108. self._available_keys: set[str] = set()
  109. self.processors: dict[str, BboxProcessor | KeypointsProcessor] = {}
  110. self._set_keys()
  111. self.set_mask_interpolation(mask_interpolation)
  112. self.set_random_seed(seed)
  113. self.save_applied_params = save_applied_params
  114. def _track_transform_params(self, transform: TransformType, data: dict[str, Any]) -> None:
  115. """Track transform parameters if tracking is enabled."""
  116. if "applied_transforms" in data and hasattr(transform, "params") and transform.params:
  117. data["applied_transforms"].append((transform.__class__.__name__, transform.params.copy()))
  118. def set_random_state(
  119. self,
  120. random_generator: np.random.Generator,
  121. py_random: random.Random,
  122. ) -> None:
  123. """Set random state directly from generators.
  124. Args:
  125. random_generator (np.random.Generator): numpy random generator to use
  126. py_random (random.Random): python random generator to use
  127. """
  128. self.random_generator = random_generator
  129. self.py_random = py_random
  130. # Propagate both random states to all transforms
  131. for transform in self.transforms:
  132. if isinstance(transform, (BasicTransform, BaseCompose)):
  133. transform.set_random_state(random_generator, py_random)
  134. def set_random_seed(self, seed: int | None) -> None:
  135. """Set random state from seed.
  136. Args:
  137. seed (int | None): Random seed to use
  138. """
  139. self.seed = seed
  140. self.random_generator = np.random.default_rng(seed)
  141. self.py_random = random.Random(seed)
  142. # Propagate seed to all transforms
  143. for transform in self.transforms:
  144. if isinstance(transform, (BasicTransform, BaseCompose)):
  145. transform.set_random_seed(seed)
  146. def set_mask_interpolation(self, mask_interpolation: int | None) -> None:
  147. """Set interpolation mode for mask resizing operations.
  148. Args:
  149. mask_interpolation (int | None): OpenCV interpolation flag to use for mask transforms.
  150. If None, default interpolation for masks will be used.
  151. """
  152. self.mask_interpolation = mask_interpolation
  153. self._set_mask_interpolation_recursive(self.transforms)
  154. def _set_mask_interpolation_recursive(self, transforms: TransformsSeqType) -> None:
  155. for transform in transforms:
  156. if isinstance(transform, BasicTransform):
  157. if hasattr(transform, "mask_interpolation") and self.mask_interpolation is not None:
  158. transform.mask_interpolation = self.mask_interpolation
  159. elif isinstance(transform, BaseCompose):
  160. transform.set_mask_interpolation(self.mask_interpolation)
  161. def __iter__(self) -> Iterator[TransformType]:
  162. return iter(self.transforms)
  163. def __len__(self) -> int:
  164. return len(self.transforms)
  165. def __call__(self, *args: Any, **data: Any) -> dict[str, Any]:
  166. """Apply transforms.
  167. Args:
  168. *args (Any): Positional arguments are not supported.
  169. **data (Any): Named parameters with data to transform.
  170. Returns:
  171. dict[str, Any]: Transformed data.
  172. Raises:
  173. NotImplementedError: This method must be implemented by subclasses.
  174. """
  175. raise NotImplementedError
  176. def __getitem__(self, item: int) -> TransformType:
  177. return self.transforms[item]
  178. def __repr__(self) -> str:
  179. return self.indented_repr()
  180. @property
  181. def additional_targets(self) -> dict[str, str]:
  182. """Get additional targets dictionary.
  183. Returns:
  184. dict[str, str]: Dictionary containing additional targets mapping.
  185. """
  186. return self._additional_targets
  187. @property
  188. def available_keys(self) -> set[str]:
  189. """Get set of available keys.
  190. Returns:
  191. set[str]: Set of string keys available for transforms.
  192. """
  193. return self._available_keys
  194. def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str:
  195. """Get an indented string representation of the composition.
  196. Args:
  197. indent (int): Indentation level. Default: REPR_INDENT_STEP.
  198. Returns:
  199. str: Formatted string representation with proper indentation.
  200. """
  201. args = {k: v for k, v in self.to_dict_private().items() if not (k.startswith("__") or k == "transforms")}
  202. repr_string = self.__class__.__name__ + "(["
  203. for t in self.transforms:
  204. repr_string += "\n"
  205. t_repr = t.indented_repr(indent + REPR_INDENT_STEP) if hasattr(t, "indented_repr") else repr(t)
  206. repr_string += " " * indent + t_repr + ","
  207. repr_string += "\n" + " " * (indent - REPR_INDENT_STEP) + f"], {format_args(args)})"
  208. return repr_string
  209. @classmethod
  210. def get_class_fullname(cls) -> str:
  211. """Get the full qualified name of the class.
  212. Returns:
  213. str: The shortest class fullname.
  214. """
  215. return get_shortest_class_fullname(cls)
  216. @classmethod
  217. def is_serializable(cls) -> bool:
  218. """Check if the class is serializable.
  219. Returns:
  220. bool: True if the class is serializable, False otherwise.
  221. """
  222. return True
  223. def to_dict_private(self) -> dict[str, Any]:
  224. """Convert the composition to a dictionary for serialization.
  225. Returns:
  226. dict[str, Any]: Dictionary representation of the composition.
  227. """
  228. return {
  229. "__class_fullname__": self.get_class_fullname(),
  230. "p": self.p,
  231. "transforms": [t.to_dict_private() for t in self.transforms],
  232. }
  233. def get_dict_with_id(self) -> dict[str, Any]:
  234. """Get a dictionary representation with object IDs for replay mode.
  235. Returns:
  236. dict[str, Any]: Dictionary with composition data and object IDs.
  237. """
  238. return {
  239. "__class_fullname__": self.get_class_fullname(),
  240. "id": id(self),
  241. "params": None,
  242. "transforms": [t.get_dict_with_id() for t in self.transforms],
  243. }
  244. def add_targets(self, additional_targets: dict[str, str] | None) -> None:
  245. """Add additional targets to all transforms.
  246. Args:
  247. additional_targets (dict[str, str] | None): Dict of name -> type mapping for additional targets.
  248. If None, no additional targets will be added.
  249. """
  250. if additional_targets:
  251. for k, v in additional_targets.items():
  252. if k in self._additional_targets and v != self._additional_targets[k]:
  253. raise ValueError(
  254. f"Trying to overwrite existed additional targets. "
  255. f"Key={k} Exists={self._additional_targets[k]} New value: {v}",
  256. )
  257. self._additional_targets.update(additional_targets)
  258. for t in self.transforms:
  259. t.add_targets(additional_targets)
  260. for proc in self.processors.values():
  261. proc.add_targets(additional_targets)
  262. self._set_keys()
  263. def _set_keys(self) -> None:
  264. """Set _available_keys"""
  265. self._available_keys.update(self._additional_targets.keys())
  266. for t in self.transforms:
  267. self._available_keys.update(t.available_keys)
  268. if hasattr(t, "targets_as_params"):
  269. self._available_keys.update(t.targets_as_params)
  270. if self.processors:
  271. self._available_keys.update(["labels"])
  272. for proc in self.processors.values():
  273. if proc.default_data_name not in self._available_keys: # if no transform to process this data
  274. warnings.warn(
  275. f"Got processor for {proc.default_data_name}, but no transform to process it.",
  276. stacklevel=2,
  277. )
  278. self._available_keys.update(proc.data_fields)
  279. if proc.params.label_fields:
  280. self._available_keys.update(proc.params.label_fields)
  281. def set_deterministic(self, flag: bool, save_key: str = "replay") -> None:
  282. """Set deterministic mode for all transforms.
  283. Args:
  284. flag (bool): Whether to enable deterministic mode.
  285. save_key (str): Key to save replay parameters. Default: "replay".
  286. """
  287. for t in self.transforms:
  288. t.set_deterministic(flag, save_key)
  289. def check_data_post_transform(self, data: dict[str, Any]) -> dict[str, Any]:
  290. """Check and filter data after transformation.
  291. Args:
  292. data (dict[str, Any]): Dictionary containing transformed data
  293. Returns:
  294. dict[str, Any]: Filtered data dictionary
  295. """
  296. if self.check_each_transform:
  297. shape = get_shape(data)
  298. for proc in self.check_each_transform:
  299. for data_name, data_value in data.items():
  300. if data_name in proc.data_fields or (
  301. data_name in self._additional_targets
  302. and self._additional_targets[data_name] in proc.data_fields
  303. ):
  304. data[data_name] = proc.filter(data_value, shape)
  305. return data
  306. class Compose(BaseCompose, HubMixin):
  307. """Compose multiple transforms together and apply them sequentially to input data.
  308. This class allows you to chain multiple image augmentation transforms and apply them
  309. in a specified order. It also handles bounding box and keypoint transformations if
  310. the appropriate parameters are provided.
  311. Args:
  312. transforms (list[BasicTransform | BaseCompose]): A list of transforms to apply.
  313. bbox_params (dict[str, Any] | BboxParams | None): Parameters for bounding box transforms.
  314. Can be a dict of params or a BboxParams object. Default is None.
  315. keypoint_params (dict[str, Any] | KeypointParams | None): Parameters for keypoint transforms.
  316. Can be a dict of params or a KeypointParams object. Default is None.
  317. additional_targets (dict[str, str] | None): A dictionary mapping additional target names
  318. to their types. For example, {'image2': 'image'}. Default is None.
  319. p (float): Probability of applying all transforms. Should be in range [0, 1]. Default is 1.0.
  320. is_check_shapes (bool): If True, checks consistency of shapes for image/mask/masks on each call.
  321. Disable only if you are sure about your data consistency. Default is True.
  322. strict (bool): If True, enables strict mode which:
  323. 1. Validates that all input keys are known/expected
  324. 2. Validates that no transforms have invalid arguments
  325. 3. Raises ValueError if any validation fails
  326. If False, these validations are skipped. Default is False.
  327. mask_interpolation (int | None): Interpolation method for mask transforms. When defined,
  328. it overrides the interpolation method specified in individual transforms. Default is None.
  329. seed (int | None): Controls reproducibility of random augmentations. Compose uses
  330. its own internal random state, completely independent from global random seeds.
  331. When seed is set (int):
  332. - Creates a fixed internal random state
  333. - Two Compose instances with the same seed and transforms will produce identical
  334. sequences of augmentations
  335. - Each call to the same Compose instance still produces random augmentations,
  336. but these sequences are reproducible between different Compose instances
  337. - Example: transform1 = A.Compose([...], seed=137) and
  338. transform2 = A.Compose([...], seed=137) will produce identical sequences
  339. When seed is None (default):
  340. - Generates a new internal random state on each Compose creation
  341. - Different Compose instances will produce different sequences of augmentations
  342. - Example: transform = A.Compose([...]) # random results
  343. Important: Setting random seeds outside of Compose (like np.random.seed() or
  344. random.seed()) has no effect on augmentations as Compose uses its own internal
  345. random state.
  346. save_applied_params (bool): If True, saves the applied parameters of each transform. Default is False.
  347. You will need to use the `applied_transforms` key in the output dictionary to access the parameters.
  348. Example:
  349. >>> import albumentations as A
  350. >>> transform = A.Compose([
  351. ... A.RandomCrop(width=256, height=256),
  352. ... A.HorizontalFlip(p=0.5),
  353. ... A.RandomBrightnessContrast(p=0.2),
  354. ... ], seed=137)
  355. >>> transformed = transform(image=image)
  356. Note:
  357. - The class checks the validity of input data and shapes if is_check_args and is_check_shapes are True.
  358. - When bbox_params or keypoint_params are provided, it sets up the corresponding processors.
  359. - The transform can handle additional targets specified in the additional_targets dictionary.
  360. - When strict mode is enabled, it performs additional validation to ensure data and transform
  361. configuration correctness.
  362. """
  363. def __init__(
  364. self,
  365. transforms: TransformsSeqType,
  366. bbox_params: dict[str, Any] | BboxParams | None = None,
  367. keypoint_params: dict[str, Any] | KeypointParams | None = None,
  368. additional_targets: dict[str, str] | None = None,
  369. p: float = 1.0,
  370. is_check_shapes: bool = True,
  371. strict: bool = False,
  372. mask_interpolation: int | None = None,
  373. seed: int | None = None,
  374. save_applied_params: bool = False,
  375. ):
  376. super().__init__(
  377. transforms=transforms,
  378. p=p,
  379. mask_interpolation=mask_interpolation,
  380. seed=seed,
  381. save_applied_params=save_applied_params,
  382. )
  383. if bbox_params:
  384. if isinstance(bbox_params, dict):
  385. b_params = BboxParams(**bbox_params)
  386. elif isinstance(bbox_params, BboxParams):
  387. b_params = bbox_params
  388. else:
  389. msg = "unknown format of bbox_params, please use `dict` or `BboxParams`"
  390. raise ValueError(msg)
  391. self.processors["bboxes"] = BboxProcessor(b_params)
  392. if keypoint_params:
  393. if isinstance(keypoint_params, dict):
  394. k_params = KeypointParams(**keypoint_params)
  395. elif isinstance(keypoint_params, KeypointParams):
  396. k_params = keypoint_params
  397. else:
  398. msg = "unknown format of keypoint_params, please use `dict` or `KeypointParams`"
  399. raise ValueError(msg)
  400. self.processors["keypoints"] = KeypointsProcessor(k_params)
  401. for proc in self.processors.values():
  402. proc.ensure_transforms_valid(self.transforms)
  403. self.add_targets(additional_targets)
  404. if not self.transforms: # if no transforms -> do nothing, all keys will be available
  405. self._available_keys.update(AVAILABLE_KEYS)
  406. self.is_check_args = True
  407. self.strict = strict
  408. self.is_check_shapes = is_check_shapes
  409. self.check_each_transform = tuple( # processors that checks after each transform
  410. proc for proc in self.processors.values() if getattr(proc.params, "check_each_transform", False)
  411. )
  412. self._set_check_args_for_transforms(self.transforms)
  413. self._set_processors_for_transforms(self.transforms)
  414. self.save_applied_params = save_applied_params
  415. self._images_was_list = False
  416. self._masks_was_list = False
  417. @property
  418. def strict(self) -> bool:
  419. """Get the current strict mode setting.
  420. Returns:
  421. bool: True if strict mode is enabled, False otherwise.
  422. """
  423. return self._strict
  424. @strict.setter
  425. def strict(self, value: bool) -> None:
  426. # if value and not self._strict:
  427. if value:
  428. # Only validate when enabling strict mode
  429. self._validate_strict()
  430. self._strict = value
  431. def _validate_strict(self) -> None:
  432. """Validate that no transforms have invalid arguments when strict mode is enabled."""
  433. def check_transform(transform: TransformType) -> None:
  434. if hasattr(transform, "invalid_args") and transform.invalid_args:
  435. message = (
  436. f"Argument(s) '{', '.join(transform.invalid_args)}' "
  437. f"are not valid for transform {transform.__class__.__name__}"
  438. )
  439. raise ValueError(message)
  440. if isinstance(transform, BaseCompose):
  441. for t in transform.transforms:
  442. check_transform(t)
  443. for transform in self.transforms:
  444. check_transform(transform)
  445. def _set_processors_for_transforms(self, transforms: TransformsSeqType) -> None:
  446. for transform in transforms:
  447. if isinstance(transform, BasicTransform):
  448. if hasattr(transform, "set_processors"):
  449. transform.set_processors(self.processors)
  450. elif isinstance(transform, BaseCompose):
  451. self._set_processors_for_transforms(transform.transforms)
  452. def _set_check_args_for_transforms(self, transforms: TransformsSeqType) -> None:
  453. for transform in transforms:
  454. if isinstance(transform, BaseCompose):
  455. self._set_check_args_for_transforms(transform.transforms)
  456. transform.check_each_transform = self.check_each_transform
  457. transform.processors = self.processors
  458. if isinstance(transform, Compose):
  459. transform.disable_check_args_private()
  460. def disable_check_args_private(self) -> None:
  461. """Disable argument checking for transforms.
  462. This method disables strict mode and argument checking for all transforms in the composition.
  463. """
  464. self.is_check_args = False
  465. self.strict = False
  466. self.main_compose = False
  467. def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
  468. """Apply transformations to data.
  469. Args:
  470. *args (Any): Positional arguments are not supported.
  471. force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
  472. **data (Any): Dict with data to transform.
  473. Returns:
  474. dict[str, Any]: Dictionary with transformed data.
  475. Raises:
  476. KeyError: If positional arguments are provided.
  477. """
  478. if args:
  479. msg = "You have to pass data to augmentations as named arguments, for example: aug(image=image)"
  480. raise KeyError(msg)
  481. if not isinstance(force_apply, (bool, int)):
  482. msg = "force_apply must have bool or int type"
  483. raise TypeError(msg)
  484. # Initialize applied_transforms only in top-level Compose if requested
  485. if self.save_applied_params and self.main_compose:
  486. data["applied_transforms"] = []
  487. need_to_run = force_apply or self.py_random.random() < self.p
  488. if not need_to_run:
  489. return data
  490. self.preprocess(data)
  491. for t in self.transforms:
  492. data = t(**data)
  493. self._track_transform_params(t, data)
  494. data = self.check_data_post_transform(data)
  495. return self.postprocess(data)
  496. def preprocess(self, data: Any) -> None:
  497. """Preprocess input data before applying transforms."""
  498. # Always validate shapes if is_check_shapes is True, regardless of strict mode
  499. if self.is_check_shapes:
  500. shapes = [] # For H,W checks
  501. volume_shapes = [] # For D,H,W checks
  502. for data_name, data_value in data.items():
  503. internal_name = self._additional_targets.get(data_name, data_name)
  504. # Skip empty data
  505. if data_value is None:
  506. continue
  507. shape = self._get_data_shape(data_name, internal_name, data_value)
  508. if shape is not None:
  509. if internal_name in CHECKED_VOLUME | CHECKED_MASK3D:
  510. shapes.append(shape[1:3]) # H,W from (D,H,W)
  511. volume_shapes.append(shape[:3]) # D,H,W
  512. elif internal_name in {"volumes", "masks3d"}:
  513. shapes.append(shape[2:4]) # H,W from (N,D,H,W)
  514. volume_shapes.append(shape[1:4]) # D,H,W from (N,D,H,W)
  515. else:
  516. shapes.append(shape[:2]) # H,W
  517. self._check_shape_consistency(shapes, volume_shapes)
  518. # Do strict validation only if enabled
  519. if self.strict:
  520. self._validate_data(data)
  521. self._preprocess_processors(data)
  522. self._preprocess_arrays(data)
  523. def _validate_data(self, data: dict[str, Any]) -> None:
  524. """Validate input data keys and arguments."""
  525. if not self.strict:
  526. return
  527. for data_name in data:
  528. if not self._is_valid_key(data_name):
  529. raise ValueError(f"Key {data_name} is not in available keys.")
  530. if self.is_check_args:
  531. self._check_args(**data)
  532. def _is_valid_key(self, key: str) -> bool:
  533. """Check if the key is valid for processing."""
  534. return key in self._available_keys or key in MASK_KEYS or key in IMAGE_KEYS or key == "applied_transforms"
  535. def _preprocess_processors(self, data: dict[str, Any]) -> None:
  536. """Run preprocessors if this is the main compose."""
  537. if not self.main_compose:
  538. return
  539. for processor in self.processors.values():
  540. processor.ensure_data_valid(data)
  541. for processor in self.processors.values():
  542. processor.preprocess(data)
  543. def _preprocess_arrays(self, data: dict[str, Any]) -> None:
  544. """Convert lists to numpy arrays for images and masks, and ensure contiguity."""
  545. self._preprocess_images(data)
  546. self._preprocess_masks(data)
  547. def _preprocess_images(self, data: dict[str, Any]) -> None:
  548. """Convert image lists to numpy arrays."""
  549. if "images" not in data:
  550. return
  551. if isinstance(data["images"], (list, tuple)):
  552. self._images_was_list = True
  553. # Skip stacking for empty lists
  554. if not data["images"]:
  555. return
  556. data["images"] = np.stack(data["images"])
  557. else:
  558. self._images_was_list = False
  559. def _preprocess_masks(self, data: dict[str, Any]) -> None:
  560. """Convert mask lists to numpy arrays."""
  561. if "masks" not in data:
  562. return
  563. if isinstance(data["masks"], (list, tuple)):
  564. self._masks_was_list = True
  565. # Skip stacking for empty lists
  566. if not data["masks"]:
  567. return
  568. data["masks"] = np.stack(data["masks"])
  569. else:
  570. self._masks_was_list = False
  571. def postprocess(self, data: dict[str, Any]) -> dict[str, Any]:
  572. """Apply post-processing to data after all transforms have been applied.
  573. Args:
  574. data (dict[str, Any]): Data after transformation.
  575. Returns:
  576. dict[str, Any]: Post-processed data.
  577. """
  578. if self.main_compose:
  579. for p in self.processors.values():
  580. p.postprocess(data)
  581. # Convert back to list if original input was a list
  582. if "images" in data and self._images_was_list:
  583. data["images"] = list(data["images"])
  584. if "masks" in data and self._masks_was_list:
  585. data["masks"] = list(data["masks"])
  586. return data
  587. def to_dict_private(self) -> dict[str, Any]:
  588. """Convert the composition to a dictionary for serialization.
  589. Returns:
  590. dict[str, Any]: Dictionary representation of the composition.
  591. """
  592. dictionary = super().to_dict_private()
  593. bbox_processor = self.processors.get("bboxes")
  594. keypoints_processor = self.processors.get("keypoints")
  595. dictionary.update(
  596. {
  597. "bbox_params": bbox_processor.params.to_dict_private() if bbox_processor else None,
  598. "keypoint_params": (keypoints_processor.params.to_dict_private() if keypoints_processor else None),
  599. "additional_targets": self.additional_targets,
  600. "is_check_shapes": self.is_check_shapes,
  601. },
  602. )
  603. return dictionary
  604. def get_dict_with_id(self) -> dict[str, Any]:
  605. """Get a dictionary representation with object IDs for replay mode.
  606. Returns:
  607. dict[str, Any]: Dictionary with composition data and object IDs.
  608. """
  609. dictionary = super().get_dict_with_id()
  610. bbox_processor = self.processors.get("bboxes")
  611. keypoints_processor = self.processors.get("keypoints")
  612. dictionary.update(
  613. {
  614. "bbox_params": bbox_processor.params.to_dict_private() if bbox_processor else None,
  615. "keypoint_params": (keypoints_processor.params.to_dict_private() if keypoints_processor else None),
  616. "additional_targets": self.additional_targets,
  617. "params": None,
  618. "is_check_shapes": self.is_check_shapes,
  619. },
  620. )
  621. return dictionary
  622. @staticmethod
  623. def _check_single_data(data_name: str, data: Any) -> tuple[int, int]:
  624. if not isinstance(data, np.ndarray):
  625. raise TypeError(f"{data_name} must be numpy array type")
  626. return data.shape[:2]
  627. @staticmethod
  628. def _check_masks_data(data_name: str, data: Any) -> tuple[int, int] | None:
  629. """Check masks data format and return shape.
  630. Args:
  631. data_name (str): Name of the data field being checked
  632. data (Any): Input data in one of these formats:
  633. - List of numpy arrays, each of shape (H, W) or (H, W, C)
  634. - Numpy array of shape (N, H, W) or (N, H, W, C)
  635. - Empty list for cases where no masks are present
  636. Returns:
  637. tuple[int, int] | None: (height, width) of the first mask, or None if masks list is empty
  638. Raises:
  639. TypeError: If data format is invalid
  640. """
  641. if isinstance(data, np.ndarray):
  642. if data.ndim not in [3, 4]: # (N,H,W) or (N,H,W,C)
  643. raise TypeError(f"{data_name} as numpy array must be 3D or 4D")
  644. return data.shape[1:3] # Return (H,W)
  645. if isinstance(data, (list, tuple)):
  646. if not data:
  647. # Allow empty list/tuple of masks
  648. return None
  649. if not all(isinstance(m, np.ndarray) for m in data):
  650. raise TypeError(f"All elements in {data_name} must be numpy arrays")
  651. if any(m.ndim not in {2, 3} for m in data):
  652. raise TypeError(f"All masks in {data_name} must be 2D or 3D numpy arrays")
  653. return data[0].shape[:2]
  654. raise TypeError(f"{data_name} must be either a numpy array or a sequence of numpy arrays")
  655. @staticmethod
  656. def _check_multi_data(data_name: str, data: Any) -> tuple[int, int]:
  657. """Check multi-image data format and return shape.
  658. Args:
  659. data_name (str): Name of the data field being checked
  660. data (Any): Input data in one of these formats:
  661. - List-like of numpy arrays
  662. - Numpy array of shape (N, H, W, C) or (N, H, W)
  663. Returns:
  664. tuple[int, int]: (height, width) of the first image
  665. Raises:
  666. TypeError: If data format is invalid
  667. """
  668. if isinstance(data, np.ndarray):
  669. if data.ndim not in {3, 4}: # (N,H,W) or (N,H,W,C)
  670. raise TypeError(f"{data_name} as numpy array must be 3D or 4D")
  671. return data.shape[1:3] # Return (H,W)
  672. if not isinstance(data, Sequence) or not isinstance(data[0], np.ndarray):
  673. raise TypeError(f"{data_name} must be either a numpy array or a list of numpy arrays")
  674. return data[0].shape[:2]
  675. @staticmethod
  676. def _check_bbox_keypoint_params(internal_data_name: str, processors: dict[str, Any]) -> None:
  677. if internal_data_name in CHECK_BBOX_PARAM and processors.get("bboxes") is None:
  678. raise ValueError("bbox_params must be specified for bbox transformations")
  679. if internal_data_name in CHECK_KEYPOINTS_PARAM and processors.get("keypoints") is None:
  680. raise ValueError("keypoints_params must be specified for keypoint transformations")
  681. @staticmethod
  682. def _check_shapes(shapes: list[tuple[int, ...]], is_check_shapes: bool) -> None:
  683. if is_check_shapes and shapes and shapes.count(shapes[0]) != len(shapes):
  684. raise ValueError(
  685. "Height and Width of image, mask or masks should be equal. You can disable shapes check "
  686. "by setting a parameter is_check_shapes=False of Compose class (do it only if you are sure "
  687. "about your data consistency).",
  688. )
  689. def _check_args(self, **kwargs: Any) -> None:
  690. shapes = [] # For H,W checks
  691. volume_shapes = [] # For D,H,W checks
  692. for data_name, data in kwargs.items():
  693. internal_name = self._additional_targets.get(data_name, data_name)
  694. # For CHECKED_SINGLE, we must validate even if None
  695. if internal_name in CHECKED_SINGLE:
  696. if not isinstance(data, np.ndarray):
  697. raise TypeError(f"{data_name} must be numpy array type")
  698. shapes.append(data.shape[:2])
  699. continue
  700. # Skip empty data or non-array/list inputs for other types
  701. if data is None:
  702. continue
  703. if not isinstance(data, (np.ndarray, list)):
  704. continue
  705. self._check_bbox_keypoint_params(internal_name, self.processors)
  706. shape = self._get_data_shape(data_name, internal_name, data)
  707. if shape is None:
  708. continue
  709. # Handle different shape types
  710. if internal_name in CHECKED_VOLUME | CHECKED_MASK3D:
  711. shapes.append(shape[1:3]) # H,W from (D,H,W)
  712. volume_shapes.append(shape[:3]) # D,H,W
  713. elif internal_name in {"volumes", "masks3d"}:
  714. shapes.append(shape[2:4]) # H,W from (N,D,H,W)
  715. volume_shapes.append(shape[1:4]) # D,H,W from (N,D,H,W)
  716. else:
  717. shapes.append(shape[:2]) # H,W
  718. self._check_shape_consistency(shapes, volume_shapes)
  719. def _get_data_shape(self, data_name: str, internal_name: str, data: Any) -> tuple[int, ...] | None:
  720. """Get shape of data based on its type."""
  721. # Handle single images and masks
  722. if internal_name in CHECKED_SINGLE:
  723. return self._get_single_data_shape(data_name, data)
  724. # Handle volumes
  725. if internal_name in CHECKED_VOLUME:
  726. return self._check_volume_data(data_name, data)
  727. # Handle 3D masks
  728. if internal_name in CHECKED_MASK3D:
  729. return self._check_mask3d_data(data_name, data)
  730. # Handle multi-item data (masks, images, volumes)
  731. if internal_name in CHECKED_MULTI:
  732. return self._get_multi_data_shape(data_name, internal_name, data)
  733. return None
  734. def _get_single_data_shape(self, data_name: str, data: np.ndarray) -> tuple[int, ...]:
  735. """Get shape of single image or mask."""
  736. if not isinstance(data, np.ndarray):
  737. raise TypeError(f"{data_name} must be numpy array type")
  738. return data.shape
  739. def _get_multi_data_shape(self, data_name: str, internal_name: str, data: Any) -> tuple[int, ...] | None:
  740. """Get shape of multi-item data (masks, images, volumes)."""
  741. if internal_name == "masks":
  742. shape = self._check_masks_data(data_name, data)
  743. # Skip empty masks lists when returning shape
  744. return None if shape is None else shape
  745. if internal_name in {"volumes", "masks3d"}: # Group these together
  746. if not isinstance(data, np.ndarray):
  747. raise TypeError(f"{data_name} must be numpy array type")
  748. if data.ndim not in {4, 5}: # (N,D,H,W) or (N,D,H,W,C)
  749. raise TypeError(f"{data_name} must be 4D or 5D array")
  750. return data.shape # Return full shape
  751. return self._check_multi_data(data_name, data)
  752. def _check_shape_consistency(self, shapes: list[tuple[int, ...]], volume_shapes: list[tuple[int, ...]]) -> None:
  753. """Check consistency of shapes."""
  754. # Check H,W consistency
  755. self._check_shapes(shapes, self.is_check_shapes)
  756. # Check D,H,W consistency for volumes and 3D masks
  757. if self.is_check_shapes and volume_shapes and volume_shapes.count(volume_shapes[0]) != len(volume_shapes):
  758. raise ValueError(
  759. "Depth, Height and Width of volume, mask3d, volumes and masks3d should be equal. "
  760. "You can disable shapes check by setting is_check_shapes=False.",
  761. )
  762. @staticmethod
  763. def _check_volume_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
  764. if data.ndim not in {3, 4}: # (D,H,W) or (D,H,W,C)
  765. raise TypeError(f"{data_name} must be 3D or 4D array")
  766. return data.shape[:3] # Return (D,H,W)
  767. @staticmethod
  768. def _check_volumes_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
  769. if data.ndim not in {4, 5}: # (N,D,H,W) or (N,D,H,W,C)
  770. raise TypeError(f"{data_name} must be 4D or 5D array")
  771. return data.shape[1:4] # Return (D,H,W)
  772. @staticmethod
  773. def _check_mask3d_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
  774. """Check single volumetric mask data format and return shape."""
  775. if data.ndim not in {3, 4}: # (D,H,W) or (D,H,W,C)
  776. raise TypeError(f"{data_name} must be 3D or 4D array")
  777. return data.shape[:3] # Return (D,H,W)
  778. @staticmethod
  779. def _check_masks3d_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
  780. """Check multiple volumetric masks data format and return shape."""
  781. if data.ndim not in [4, 5]: # (N,D,H,W) or (N,D,H,W,C)
  782. raise TypeError(f"{data_name} must be 4D or 5D array")
  783. return data.shape[1:4] # Return (D,H,W)
  784. class OneOf(BaseCompose):
  785. """Select one of transforms to apply. Selected transform will be called with `force_apply=True`.
  786. Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights.
  787. Args:
  788. transforms (list): list of transformations to compose.
  789. p (float): probability of applying selected transform. Default: 0.5.
  790. """
  791. def __init__(self, transforms: TransformsSeqType, p: float = 0.5):
  792. super().__init__(transforms=transforms, p=p)
  793. transforms_ps = [t.p for t in self.transforms]
  794. s = sum(transforms_ps)
  795. self.transforms_ps = [t / s for t in transforms_ps]
  796. def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
  797. """Apply the OneOf composition to the input data.
  798. Args:
  799. *args (Any): Positional arguments are not supported.
  800. force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
  801. **data (Any): Dict with data to transform.
  802. Returns:
  803. dict[str, Any]: Dictionary with transformed data.
  804. Raises:
  805. KeyError: If positional arguments are provided.
  806. """
  807. if self.replay_mode:
  808. for t in self.transforms:
  809. data = t(**data)
  810. return data
  811. if self.transforms_ps and (force_apply or self.py_random.random() < self.p):
  812. idx: int = self.random_generator.choice(len(self.transforms), p=self.transforms_ps)
  813. t = self.transforms[idx]
  814. data = t(force_apply=True, **data)
  815. self._track_transform_params(t, data)
  816. return data
  817. class SomeOf(BaseCompose):
  818. """Selects exactly `n` transforms from the given list and applies them.
  819. The selection of which `n` transforms to apply is done **uniformly at random**
  820. from the provided list. Each transform in the list has an equal chance of being selected.
  821. Once the `n` transforms are selected, each one is applied **based on its
  822. individual probability** `p`.
  823. Args:
  824. transforms (list[BasicTransform | BaseCompose]): A list of transforms to choose from.
  825. n (int): The exact number of transforms to select and potentially apply.
  826. If `replace=False` and `n` is greater than the number of available transforms,
  827. `n` will be capped at the number of transforms.
  828. replace (bool): Whether to sample transforms with replacement. If True, the same
  829. transform can be selected multiple times (up to `n` times).
  830. Default is False.
  831. p (float): The probability that this `SomeOf` composition will be applied.
  832. If applied, it will select `n` transforms and attempt to apply them.
  833. Default is 1.0.
  834. Note:
  835. - The overall probability `p` of the `SomeOf` block determines if *any* selection
  836. and application occurs.
  837. - The individual probability `p` of each transform inside the list determines if
  838. that specific transform runs *if it is selected*.
  839. - If `replace` is True, the same transform might be selected multiple times, and
  840. its individual probability `p` will be checked each time it's encountered.
  841. Example:
  842. >>> import albumentations as A
  843. >>> transform = A.SomeOf([
  844. ... A.HorizontalFlip(p=0.5), # 50% chance to apply if selected
  845. ... A.VerticalFlip(p=0.8), # 80% chance to apply if selected
  846. ... A.RandomRotate90(p=1.0), # 100% chance to apply if selected
  847. ... ], n=2, replace=False, p=1.0) # Always select 2 transforms uniformly
  848. # In each call, 2 transforms out of 3 are chosen uniformly.
  849. # For example, if HFlip and VFlip are chosen:
  850. # - HFlip runs if random() < 0.5
  851. # - VFlip runs if random() < 0.8
  852. # If VFlip and Rotate90 are chosen:
  853. # - VFlip runs if random() < 0.8
  854. # - Rotate90 runs if random() < 1.0 (always)
  855. """
  856. def __init__(self, transforms: TransformsSeqType, n: int = 1, replace: bool = False, p: float = 1):
  857. super().__init__(transforms, p)
  858. self.n = n
  859. if not replace and n > len(self.transforms):
  860. self.n = len(self.transforms)
  861. warnings.warn(
  862. f"`n` is greater than number of transforms. `n` will be set to {self.n}.",
  863. UserWarning,
  864. stacklevel=2,
  865. )
  866. self.replace = replace
  867. def __call__(self, *arg: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
  868. """Apply n randomly selected transforms from the list of transforms.
  869. Args:
  870. *arg (Any): Positional arguments are not supported.
  871. force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
  872. **data (Any): Dict with data to transform.
  873. Returns:
  874. dict[str, Any]: Dictionary with transformed data.
  875. """
  876. if self.replay_mode:
  877. for t in self.transforms:
  878. data = t(**data)
  879. data = self.check_data_post_transform(data)
  880. return data
  881. if self.py_random.random() < self.p: # Check overall SomeOf probability
  882. # Get indices uniformly
  883. indices_to_consider = self._get_idx()
  884. for i in indices_to_consider:
  885. t = self.transforms[i]
  886. # Apply the transform respecting its own probability `t.p`
  887. data = t(**data)
  888. self._track_transform_params(t, data)
  889. data = self.check_data_post_transform(data)
  890. return data
  891. def _get_idx(self) -> np.ndarray[np.int_]:
  892. # Use uniform probability for selection, ignore individual p values here
  893. idx = self.random_generator.choice(
  894. len(self.transforms),
  895. size=self.n,
  896. replace=self.replace,
  897. )
  898. idx.sort()
  899. return idx
  900. def to_dict_private(self) -> dict[str, Any]:
  901. """Convert the SomeOf composition to a dictionary for serialization.
  902. Returns:
  903. dict[str, Any]: Dictionary representation of the composition.
  904. """
  905. dictionary = super().to_dict_private()
  906. dictionary.update({"n": self.n, "replace": self.replace})
  907. return dictionary
  908. class RandomOrder(SomeOf):
  909. """Apply a random subset of transforms from the given list in a random order.
  910. Selects exactly `n` transforms uniformly at random from the list, and then applies
  911. the selected transforms in a random order. Each selected transform is applied
  912. based on its individual probability `p`.
  913. Attributes:
  914. transforms (TransformsSeqType): A list of transformations to choose from.
  915. n (int): The number of transforms to apply. If `n` is greater than the number of available transforms
  916. and `replace` is False, `n` will be set to the number of available transforms.
  917. replace (bool): Whether to sample transforms with replacement. If True, the same transform can be
  918. selected multiple times. Default is False.
  919. p (float): Probability of applying the selected transforms. Should be in the range [0, 1]. Default is 1.0.
  920. Example:
  921. >>> import albumentations as A
  922. >>> transform = A.RandomOrder([
  923. ... A.HorizontalFlip(p=0.5),
  924. ... A.VerticalFlip(p=1.0),
  925. ... A.RandomBrightnessContrast(p=0.8),
  926. ... ], n=2, replace=False, p=1.0)
  927. >>> # This will uniformly select 2 transforms and apply them in a random order,
  928. >>> # respecting their individual probabilities (0.5, 1.0, 0.8).
  929. Note:
  930. - Inherits from SomeOf, but overrides `_get_idx` to ensure random order without sorting.
  931. - Selection is uniform; application depends on individual transform probabilities.
  932. """
  933. def __init__(self, transforms: TransformsSeqType, n: int = 1, replace: bool = False, p: float = 1):
  934. # Initialize using SomeOf's logic (which now does uniform selection setup)
  935. super().__init__(transforms=transforms, n=n, replace=replace, p=p)
  936. def _get_idx(self) -> np.ndarray[np.int_]:
  937. # Perform uniform random selection without replacement, like SomeOf
  938. # Crucially, DO NOT sort the indices here to maintain random order.
  939. return self.random_generator.choice(
  940. len(self.transforms),
  941. size=self.n,
  942. replace=self.replace,
  943. )
  944. class OneOrOther(BaseCompose):
  945. """Select one or another transform to apply. Selected transform will be called with `force_apply=True`."""
  946. def __init__(
  947. self,
  948. first: TransformType | None = None,
  949. second: TransformType | None = None,
  950. transforms: TransformsSeqType | None = None,
  951. p: float = 0.5,
  952. ):
  953. if transforms is None:
  954. if first is None or second is None:
  955. msg = "You must set both first and second or set transforms argument."
  956. raise ValueError(msg)
  957. transforms = [first, second]
  958. super().__init__(transforms, p)
  959. if len(self.transforms) != NUM_ONEOF_TRANSFORMS:
  960. warnings.warn("Length of transforms is not equal to 2.", stacklevel=2)
  961. def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
  962. """Apply one or another transform to the input data.
  963. Args:
  964. *args (Any): Positional arguments are not supported.
  965. force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
  966. **data (Any): Dict with data to transform.
  967. Returns:
  968. dict[str, Any]: Dictionary with transformed data.
  969. """
  970. if self.replay_mode:
  971. for t in self.transforms:
  972. data = t(**data)
  973. self._track_transform_params(t, data)
  974. return data
  975. if self.py_random.random() < self.p:
  976. return self.transforms[0](force_apply=True, **data)
  977. return self.transforms[-1](force_apply=True, **data)
  978. class SelectiveChannelTransform(BaseCompose):
  979. """A transformation class to apply specified transforms to selected channels of an image.
  980. This class extends BaseCompose to allow selective application of transformations to
  981. specified image channels. It extracts the selected channels, applies the transformations,
  982. and then reinserts the transformed channels back into their original positions in the image.
  983. Args:
  984. transforms (TransformsSeqType):
  985. A sequence of transformations (from Albumentations) to be applied to the specified channels.
  986. channels (Sequence[int]):
  987. A sequence of integers specifying the indices of the channels to which the transforms should be applied.
  988. p (float): Probability that the transform will be applied; the default is 1.0 (always apply).
  989. Returns:
  990. dict[str, Any]: The transformed data dictionary, which includes the transformed 'image' key.
  991. """
  992. def __init__(
  993. self,
  994. transforms: TransformsSeqType,
  995. channels: Sequence[int] = (0, 1, 2),
  996. p: float = 1.0,
  997. ) -> None:
  998. super().__init__(transforms, p)
  999. self.channels = channels
  1000. def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
  1001. """Apply transforms to specific channels of the image.
  1002. Args:
  1003. *args (Any): Positional arguments are not supported.
  1004. force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
  1005. **data (Any): Dict with data to transform.
  1006. Returns:
  1007. dict[str, Any]: Dictionary with transformed data.
  1008. """
  1009. if force_apply or self.py_random.random() < self.p:
  1010. image = data["image"]
  1011. selected_channels = image[:, :, self.channels]
  1012. sub_image = np.ascontiguousarray(selected_channels)
  1013. for t in self.transforms:
  1014. sub_image = t(image=sub_image)["image"]
  1015. self._track_transform_params(t, sub_image)
  1016. transformed_channels = cv2.split(sub_image)
  1017. output_img = image.copy()
  1018. for idx, channel in zip(self.channels, transformed_channels):
  1019. output_img[:, :, idx] = channel
  1020. data["image"] = np.ascontiguousarray(output_img)
  1021. return data
  1022. class ReplayCompose(Compose):
  1023. """Composition class that enables transform replay functionality.
  1024. This class extends the Compose class with the ability to record and replay
  1025. transformations. This is useful for applying the same sequence of random
  1026. transformations to different data.
  1027. Args:
  1028. transforms (TransformsSeqType): List of transformations to compose.
  1029. bbox_params (dict[str, Any] | BboxParams | None): Parameters for bounding box transforms.
  1030. keypoint_params (dict[str, Any] | KeypointParams | None): Parameters for keypoint transforms.
  1031. additional_targets (dict[str, str] | None): Dictionary of additional targets.
  1032. p (float): Probability of applying the compose.
  1033. is_check_shapes (bool): Whether to check shapes of different targets.
  1034. save_key (str): Key for storing the applied transformations.
  1035. """
  1036. def __init__(
  1037. self,
  1038. transforms: TransformsSeqType,
  1039. bbox_params: dict[str, Any] | BboxParams | None = None,
  1040. keypoint_params: dict[str, Any] | KeypointParams | None = None,
  1041. additional_targets: dict[str, str] | None = None,
  1042. p: float = 1.0,
  1043. is_check_shapes: bool = True,
  1044. save_key: str = "replay",
  1045. ):
  1046. super().__init__(transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes)
  1047. self.set_deterministic(True, save_key=save_key)
  1048. self.save_key = save_key
  1049. self._available_keys.add(save_key)
  1050. def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> dict[str, Any]:
  1051. """Apply transforms and record parameters for future replay.
  1052. Args:
  1053. *args (Any): Positional arguments are not supported.
  1054. force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
  1055. **kwargs (Any): Dict with data to transform.
  1056. Returns:
  1057. dict[str, Any]: Dictionary with transformed data and replay information.
  1058. """
  1059. kwargs[self.save_key] = defaultdict(dict)
  1060. result = super().__call__(force_apply=force_apply, **kwargs)
  1061. serialized = self.get_dict_with_id()
  1062. self.fill_with_params(serialized, result[self.save_key])
  1063. self.fill_applied(serialized)
  1064. result[self.save_key] = serialized
  1065. return result
  1066. @staticmethod
  1067. def replay(saved_augmentations: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
  1068. """Replay previously saved augmentations.
  1069. Args:
  1070. saved_augmentations (dict[str, Any]): Previously saved augmentation parameters.
  1071. **kwargs (Any): Dict with data to transform.
  1072. Returns:
  1073. dict[str, Any]: Dictionary with transformed data using saved parameters.
  1074. """
  1075. augs = ReplayCompose._restore_for_replay(saved_augmentations)
  1076. return augs(force_apply=True, **kwargs)
  1077. @staticmethod
  1078. def _restore_for_replay(
  1079. transform_dict: dict[str, Any],
  1080. lambda_transforms: dict[str, Any] | None = None,
  1081. ) -> TransformType:
  1082. """Args:
  1083. transform_dict (dict[str, Any]): A dictionary that contains transform data.
  1084. lambda_transforms (dict): A dictionary that contains lambda transforms, that
  1085. is instances of the Lambda class.
  1086. This dictionary is required when you are restoring a pipeline that contains lambda transforms.
  1087. Keys in that dictionary should be named same as `name` arguments in respective lambda transforms
  1088. from a serialized pipeline.
  1089. """
  1090. applied = transform_dict["applied"]
  1091. params = transform_dict["params"]
  1092. lmbd = instantiate_nonserializable(transform_dict, lambda_transforms)
  1093. if lmbd:
  1094. transform = lmbd
  1095. else:
  1096. name = transform_dict["__class_fullname__"]
  1097. args = {k: v for k, v in transform_dict.items() if k not in ["__class_fullname__", "applied", "params"]}
  1098. cls = SERIALIZABLE_REGISTRY[name]
  1099. if "transforms" in args:
  1100. args["transforms"] = [
  1101. ReplayCompose._restore_for_replay(t, lambda_transforms=lambda_transforms)
  1102. for t in args["transforms"]
  1103. ]
  1104. transform = cls(**args)
  1105. transform = cast("BasicTransform", transform)
  1106. if isinstance(transform, BasicTransform):
  1107. transform.params = params
  1108. transform.replay_mode = True
  1109. transform.applied_in_replay = applied
  1110. return transform
  1111. def fill_with_params(self, serialized: dict[str, Any], all_params: Any) -> None:
  1112. """Fill serialized transform data with parameters for replay.
  1113. Args:
  1114. serialized (dict[str, Any]): Serialized transform data.
  1115. all_params (Any): Parameters to fill in.
  1116. """
  1117. params = all_params.get(serialized.get("id"))
  1118. serialized["params"] = params
  1119. del serialized["id"]
  1120. for transform in serialized.get("transforms", []):
  1121. self.fill_with_params(transform, all_params)
  1122. def fill_applied(self, serialized: dict[str, Any]) -> bool:
  1123. """Set 'applied' flag for transforms based on parameters.
  1124. Args:
  1125. serialized (dict[str, Any]): Serialized transform data.
  1126. Returns:
  1127. bool: True if any transform was applied, False otherwise.
  1128. """
  1129. if "transforms" in serialized:
  1130. applied = [self.fill_applied(t) for t in serialized["transforms"]]
  1131. serialized["applied"] = any(applied)
  1132. else:
  1133. serialized["applied"] = serialized.get("params") is not None
  1134. return serialized["applied"]
  1135. def to_dict_private(self) -> dict[str, Any]:
  1136. """Convert the ReplayCompose to a dictionary for serialization.
  1137. Returns:
  1138. dict[str, Any]: Dictionary representation of the composition.
  1139. """
  1140. dictionary = super().to_dict_private()
  1141. dictionary.update({"save_key": self.save_key})
  1142. return dictionary
  1143. class Sequential(BaseCompose):
  1144. """Sequentially applies all transforms to targets.
  1145. Note:
  1146. This transform is not intended to be a replacement for `Compose`. Instead, it should be used inside `Compose`
  1147. the same way `OneOf` or `OneOrOther` are used. For instance, you can combine `OneOf` with `Sequential` to
  1148. create an augmentation pipeline that contains multiple sequences of augmentations and applies one randomly
  1149. chose sequence to input data (see the `Example` section for an example definition of such pipeline).
  1150. Example:
  1151. >>> import albumentations as A
  1152. >>> transform = A.Compose([
  1153. >>> A.OneOf([
  1154. >>> A.Sequential([
  1155. >>> A.HorizontalFlip(p=0.5),
  1156. >>> A.ShiftScaleRotate(p=0.5),
  1157. >>> ]),
  1158. >>> A.Sequential([
  1159. >>> A.VerticalFlip(p=0.5),
  1160. >>> A.RandomBrightnessContrast(p=0.5),
  1161. >>> ]),
  1162. >>> ], p=1)
  1163. >>> ])
  1164. """
  1165. def __init__(self, transforms: TransformsSeqType, p: float = 0.5):
  1166. super().__init__(transforms=transforms, p=p)
  1167. def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
  1168. """Apply all transforms in sequential order.
  1169. Args:
  1170. *args (Any): Positional arguments are not supported.
  1171. force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
  1172. **data (Any): Dict with data to transform.
  1173. Returns:
  1174. dict[str, Any]: Dictionary with transformed data.
  1175. """
  1176. if self.replay_mode or force_apply or self.py_random.random() < self.p:
  1177. for t in self.transforms:
  1178. data = t(**data)
  1179. self._track_transform_params(t, data)
  1180. data = self.check_data_post_transform(data)
  1181. return data