lambda_transform.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. """Lambda transform module for creating custom user-defined transformations.
  2. This module provides a flexible transform class that allows users to define their own
  3. custom transformation functions for different targets (image, mask, keypoints, bboxes).
  4. It's particularly useful for implementing custom logic that isn't available in the
  5. standard transforms.
  6. The Lambda transform accepts different callable functions for each target type and
  7. applies them when the transform is executed. This allows for maximum flexibility
  8. while maintaining compatibility with the Albumentations pipeline structure.
  9. Key features:
  10. - Apply different custom functions to different target types
  11. - Compatible with all Albumentations pipeline features
  12. - Support for all image types and formats
  13. - Ability to handle any number of channels
  14. - Warning system for lambda expressions and multiprocessing compatibility
  15. Note that using actual lambda expressions (rather than named functions) can cause
  16. issues with multiprocessing, as lambdas cannot be properly pickled.
  17. """
  18. from __future__ import annotations
  19. import warnings
  20. from types import LambdaType
  21. from typing import Any, Callable
  22. import numpy as np
  23. from albumentations.augmentations.pixel import functional as fpixel
  24. from albumentations.core.transforms_interface import NoOp
  25. from albumentations.core.utils import format_args
  26. __all__ = ["Lambda"]
  27. class Lambda(NoOp):
  28. """A flexible transformation class for using user-defined transformation functions per targets.
  29. Function signature must include **kwargs to accept optional arguments like interpolation method, image size, etc:
  30. Args:
  31. image (Callable[..., Any] | None): Image transformation function.
  32. mask (Callable[..., Any] | None): Mask transformation function.
  33. keypoints (Callable[..., Any] | None): Keypoints transformation function.
  34. bboxes (Callable[..., Any] | None): BBoxes transformation function.
  35. p (float): probability of applying the transform. Default: 1.0.
  36. Targets:
  37. image, mask, bboxes, keypoints, volume, mask3d
  38. Image types:
  39. uint8, float32
  40. Number of channels:
  41. Any
  42. """
  43. def __init__(
  44. self,
  45. image: Callable[..., Any] | None = None,
  46. mask: Callable[..., Any] | None = None,
  47. keypoints: Callable[..., Any] | None = None,
  48. bboxes: Callable[..., Any] | None = None,
  49. name: str | None = None,
  50. p: float = 1.0,
  51. ):
  52. super().__init__(p=p)
  53. self.name = name
  54. self.custom_apply_fns = dict.fromkeys(("image", "mask", "keypoints", "bboxes"), fpixel.noop)
  55. for target_name, custom_apply_fn in {
  56. "image": image,
  57. "mask": mask,
  58. "keypoints": keypoints,
  59. "bboxes": bboxes,
  60. }.items():
  61. if custom_apply_fn is not None:
  62. if isinstance(custom_apply_fn, LambdaType) and custom_apply_fn.__name__ == "<lambda>":
  63. warnings.warn(
  64. "Using lambda is incompatible with multiprocessing. "
  65. "Consider using regular functions or partial().",
  66. stacklevel=2,
  67. )
  68. self.custom_apply_fns[target_name] = custom_apply_fn
  69. def apply(self, img: np.ndarray, **params: Any) -> np.ndarray:
  70. """Apply the Lambda transform to the input image.
  71. Args:
  72. img (np.ndarray): The input image to apply the Lambda transform to.
  73. **params (Any): Additional parameters (not used in this transform).
  74. Returns:
  75. np.ndarray: The image with the applied Lambda transform.
  76. """
  77. fn = self.custom_apply_fns["image"]
  78. return fn(img, **params)
  79. def apply_to_mask(self, mask: np.ndarray, **params: Any) -> np.ndarray:
  80. """Apply the Lambda transform to the input mask.
  81. Args:
  82. mask (np.ndarray): The input mask to apply the Lambda transform to.
  83. **params (Any): Additional parameters (not used in this transform).
  84. Returns:
  85. np.ndarray: The mask with the applied Lambda transform.
  86. """
  87. fn = self.custom_apply_fns["mask"]
  88. return fn(mask, **params)
  89. def apply_to_bboxes(self, bboxes: np.ndarray, **params: Any) -> np.ndarray:
  90. """Apply the Lambda transform to the input bounding boxes.
  91. Args:
  92. bboxes (np.ndarray): The input bounding boxes to apply the Lambda transform to.
  93. **params (Any): Additional parameters (not used in this transform).
  94. Returns:
  95. np.ndarray: The bounding boxes with the applied Lambda transform.
  96. """
  97. fn = self.custom_apply_fns["bboxes"]
  98. return fn(bboxes, **params)
  99. def apply_to_keypoints(self, keypoints: np.ndarray, **params: Any) -> np.ndarray:
  100. """Apply the Lambda transform to the input keypoints.
  101. Args:
  102. keypoints (np.ndarray): The input keypoints to apply the Lambda transform to.
  103. **params (Any): Additional parameters (not used in this transform).
  104. Returns:
  105. np.ndarray: The keypoints with the applied Lambda transform.
  106. """
  107. fn = self.custom_apply_fns["keypoints"]
  108. return fn(keypoints, **params)
  109. @classmethod
  110. def is_serializable(cls) -> bool:
  111. """Check if the Lambda transform is serializable.
  112. Returns:
  113. bool: True if the transform is serializable, False otherwise.
  114. """
  115. return False
  116. def to_dict_private(self) -> dict[str, Any]:
  117. """Convert the Lambda transform to a dictionary.
  118. Returns:
  119. dict[str, Any]: The dictionary representation of the transform.
  120. """
  121. if self.name is None:
  122. msg = (
  123. "To make a Lambda transform serializable you should provide the `name` argument, "
  124. "e.g. `Lambda(name='my_transform', image=<some func>, ...)`."
  125. )
  126. raise ValueError(msg)
  127. return {"__class_fullname__": self.get_class_fullname(), "__name__": self.name}
  128. def __repr__(self) -> str:
  129. """Return the string representation of the Lambda transform.
  130. Returns:
  131. str: The string representation of the Lambda transform.
  132. """
  133. state = {"name": self.name}
  134. state.update(self.custom_apply_fns.items()) # type: ignore[arg-type]
  135. state.update(self.get_base_init_args())
  136. return f"{self.__class__.__name__}({format_args(state)})"