transforms.py 74 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962
  1. """Geometric transformation classes for image augmentation.
  2. This module provides a collection of transforms that modify the geometric properties
  3. of images and associated data (masks, bounding boxes, keypoints). Includes implementations
  4. for flipping, transposing, affine transformations, distortions, padding, and more complex
  5. transformations like grid shuffling and thin plate splines.
  6. """
  7. from __future__ import annotations
  8. import random
  9. from typing import Annotated, Any, Literal, cast
  10. from warnings import warn
  11. import cv2
  12. import numpy as np
  13. from albucore import batch_transform, is_grayscale_image, is_rgb_image
  14. from pydantic import (
  15. AfterValidator,
  16. Field,
  17. ValidationInfo,
  18. field_validator,
  19. model_validator,
  20. )
  21. from typing_extensions import Self
  22. from albumentations.augmentations.utils import check_range
  23. from albumentations.core.bbox_utils import (
  24. BboxProcessor,
  25. denormalize_bboxes,
  26. normalize_bboxes,
  27. )
  28. from albumentations.core.pydantic import (
  29. NonNegativeFloatRangeType,
  30. OnePlusIntRangeType,
  31. SymmetricRangeType,
  32. check_range_bounds,
  33. )
  34. from albumentations.core.transforms_interface import (
  35. BaseTransformInitSchema,
  36. DualTransform,
  37. )
  38. from albumentations.core.type_definitions import ALL_TARGETS
  39. from albumentations.core.utils import to_tuple
  40. from . import functional as fgeometric
  41. __all__ = [
  42. "Affine",
  43. "GridElasticDeform",
  44. "Morphological",
  45. "Perspective",
  46. "RandomGridShuffle",
  47. "ShiftScaleRotate",
  48. ]
  49. NUM_PADS_XY = 2
  50. NUM_PADS_ALL_SIDES = 4
  51. class Perspective(DualTransform):
  52. """Apply random four point perspective transformation to the input.
  53. Args:
  54. scale (float or tuple of float): Standard deviation of the normal distributions. These are used to sample
  55. the random distances of the subimage's corners from the full image's corners.
  56. If scale is a single float value, the range will be (0, scale).
  57. Default: (0.05, 0.1).
  58. keep_size (bool): Whether to resize image back to its original size after applying the perspective transform.
  59. If set to False, the resulting images may end up having different shapes.
  60. Default: True.
  61. border_mode (OpenCV flag): OpenCV border mode used for padding.
  62. Default: cv2.BORDER_CONSTANT.
  63. fill (tuple[float, ...] | float): Padding value if border_mode is cv2.BORDER_CONSTANT.
  64. Default: 0.
  65. fill_mask (tuple[float, ...] | float): Padding value for mask if border_mode is
  66. cv2.BORDER_CONSTANT. Default: 0.
  67. fit_output (bool): If True, the image plane size and position will be adjusted to still capture
  68. the whole image after perspective transformation. This is followed by image resizing if keep_size is set
  69. to True. If False, parts of the transformed image may be outside of the image plane.
  70. This setting should not be set to True when using large scale values as it could lead to very large images.
  71. Default: False.
  72. interpolation (int): Interpolation method to be used for image transformation. Should be one
  73. of the OpenCV interpolation types. Default: cv2.INTER_LINEAR
  74. mask_interpolation (int): Flag that is used to specify the interpolation algorithm for mask.
  75. Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
  76. Default: cv2.INTER_NEAREST.
  77. p (float): Probability of applying the transform. Default: 0.5.
  78. Targets:
  79. image, mask, keypoints, bboxes, volume, mask3d
  80. Image types:
  81. uint8, float32
  82. Note:
  83. This transformation creates a perspective effect by randomly moving the four corners of the image.
  84. The amount of movement is controlled by the 'scale' parameter.
  85. When 'keep_size' is True, the output image will have the same size as the input image,
  86. which may cause some parts of the transformed image to be cut off or padded.
  87. When 'fit_output' is True, the transformation ensures that the entire transformed image is visible,
  88. which may result in a larger output image if keep_size is False.
  89. Examples:
  90. >>> import numpy as np
  91. >>> import albumentations as A
  92. >>> import cv2
  93. >>>
  94. >>> # Prepare sample data
  95. >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  96. >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  97. >>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
  98. >>> bbox_labels = [1, 2]
  99. >>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
  100. >>> keypoint_labels = [0, 1]
  101. >>>
  102. >>> # Define transform with parameters as tuples when possible
  103. >>> transform = A.Compose([
  104. ... A.Perspective(
  105. ... scale=(0.05, 0.1),
  106. ... keep_size=True,
  107. ... fit_output=False,
  108. ... border_mode=cv2.BORDER_CONSTANT,
  109. ... p=1.0
  110. ... ),
  111. ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
  112. ... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
  113. >>>
  114. >>> # Apply the transform
  115. >>> transformed = transform(
  116. ... image=image,
  117. ... mask=mask,
  118. ... bboxes=bboxes,
  119. ... bbox_labels=bbox_labels,
  120. ... keypoints=keypoints,
  121. ... keypoint_labels=keypoint_labels
  122. ... )
  123. >>>
  124. >>> # Get the transformed data
  125. >>> transformed_image = transformed['image'] # Perspective-transformed image
  126. >>> transformed_mask = transformed['mask'] # Perspective-transformed mask
  127. >>> transformed_bboxes = transformed['bboxes'] # Perspective-transformed bounding boxes
  128. >>> transformed_bbox_labels = transformed['bbox_labels'] # Labels for transformed bboxes
  129. >>> transformed_keypoints = transformed['keypoints'] # Perspective-transformed keypoints
  130. >>> transformed_keypoint_labels = transformed['keypoint_labels'] # Labels for transformed keypoints
  131. """
  132. _targets = ALL_TARGETS
  133. class InitSchema(BaseTransformInitSchema):
  134. scale: NonNegativeFloatRangeType
  135. keep_size: bool
  136. fit_output: bool
  137. interpolation: Literal[
  138. cv2.INTER_NEAREST,
  139. cv2.INTER_LINEAR,
  140. cv2.INTER_CUBIC,
  141. cv2.INTER_AREA,
  142. cv2.INTER_LANCZOS4,
  143. ]
  144. mask_interpolation: Literal[
  145. cv2.INTER_NEAREST,
  146. cv2.INTER_LINEAR,
  147. cv2.INTER_CUBIC,
  148. cv2.INTER_AREA,
  149. cv2.INTER_LANCZOS4,
  150. ]
  151. fill: tuple[float, ...] | float
  152. fill_mask: tuple[float, ...] | float
  153. border_mode: Literal[
  154. cv2.BORDER_CONSTANT,
  155. cv2.BORDER_REPLICATE,
  156. cv2.BORDER_REFLECT,
  157. cv2.BORDER_WRAP,
  158. cv2.BORDER_REFLECT_101,
  159. ]
  160. def __init__(
  161. self,
  162. scale: tuple[float, float] | float = (0.05, 0.1),
  163. keep_size: bool = True,
  164. fit_output: bool = False,
  165. interpolation: Literal[
  166. cv2.INTER_NEAREST,
  167. cv2.INTER_LINEAR,
  168. cv2.INTER_CUBIC,
  169. cv2.INTER_AREA,
  170. cv2.INTER_LANCZOS4,
  171. ] = cv2.INTER_LINEAR,
  172. mask_interpolation: Literal[
  173. cv2.INTER_NEAREST,
  174. cv2.INTER_LINEAR,
  175. cv2.INTER_CUBIC,
  176. cv2.INTER_AREA,
  177. cv2.INTER_LANCZOS4,
  178. ] = cv2.INTER_NEAREST,
  179. border_mode: Literal[
  180. cv2.BORDER_CONSTANT,
  181. cv2.BORDER_REPLICATE,
  182. cv2.BORDER_REFLECT,
  183. cv2.BORDER_WRAP,
  184. cv2.BORDER_REFLECT_101,
  185. ] = cv2.BORDER_CONSTANT,
  186. fill: tuple[float, ...] | float = 0,
  187. fill_mask: tuple[float, ...] | float = 0,
  188. p: float = 0.5,
  189. ):
  190. super().__init__(p)
  191. self.scale = cast("tuple[float, float]", scale)
  192. self.keep_size = keep_size
  193. self.border_mode = border_mode
  194. self.fill = fill
  195. self.fill_mask = fill_mask
  196. self.fit_output = fit_output
  197. self.interpolation = interpolation
  198. self.mask_interpolation = mask_interpolation
  199. def apply(
  200. self,
  201. img: np.ndarray,
  202. matrix: np.ndarray,
  203. max_height: int,
  204. max_width: int,
  205. **params: Any,
  206. ) -> np.ndarray:
  207. """Apply the perspective transform to an image.
  208. Args:
  209. img (np.ndarray): Image to be distorted.
  210. matrix (np.ndarray): Transformation matrix.
  211. max_height (int): Maximum height of the image.
  212. max_width (int): Maximum width of the image.
  213. **params (Any): Additional parameters.
  214. Returns:
  215. np.ndarray: Distorted image.
  216. """
  217. return fgeometric.perspective(
  218. img,
  219. matrix,
  220. max_width,
  221. max_height,
  222. self.fill,
  223. self.border_mode,
  224. self.keep_size,
  225. self.interpolation,
  226. )
  227. @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
  228. def apply_to_images(self, images: np.ndarray, **params: Any) -> np.ndarray:
  229. """Apply the perspective transform to a batch of images.
  230. Args:
  231. images (np.ndarray): Batch of images to be distorted.
  232. **params (Any): Additional parameters.
  233. Returns:
  234. np.ndarray: Batch of distorted images.
  235. """
  236. return self.apply(images, **params)
  237. @batch_transform("spatial", has_batch_dim=False, has_depth_dim=True)
  238. def apply_to_volume(self, volume: np.ndarray, **params: Any) -> np.ndarray:
  239. """Apply the perspective transform to a volume.
  240. Args:
  241. volume (np.ndarray): Volume to be distorted.
  242. **params (Any): Additional parameters.
  243. Returns:
  244. np.ndarray: Distorted volume.
  245. """
  246. return self.apply(volume, **params)
  247. @batch_transform("spatial", has_batch_dim=True, has_depth_dim=True)
  248. def apply_to_volumes(self, volumes: np.ndarray, **params: Any) -> np.ndarray:
  249. """Apply the perspective transform to a batch of volumes.
  250. Args:
  251. volumes (np.ndarray): Batch of volumes to be distorted.
  252. **params (Any): Additional parameters.
  253. Returns:
  254. np.ndarray: Batch of distorted volumes.
  255. """
  256. return self.apply(volumes, **params)
  257. @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
  258. def apply_to_mask3d(self, mask3d: np.ndarray, **params: Any) -> np.ndarray:
  259. """Apply the perspective transform to a 3D mask.
  260. Args:
  261. mask3d (np.ndarray): 3D mask to be distorted.
  262. **params (Any): Additional parameters.
  263. Returns:
  264. np.ndarray: Distorted 3D mask.
  265. """
  266. return self.apply_to_mask(mask3d, **params)
  267. def apply_to_mask(
  268. self,
  269. mask: np.ndarray,
  270. matrix: np.ndarray,
  271. max_height: int,
  272. max_width: int,
  273. **params: Any,
  274. ) -> np.ndarray:
  275. """Apply the perspective transform to a mask.
  276. Args:
  277. mask (np.ndarray): Mask to be distorted.
  278. matrix (np.ndarray): Transformation matrix.
  279. max_height (int): Maximum height of the mask.
  280. max_width (int): Maximum width of the mask.
  281. **params (Any): Additional parameters.
  282. Returns:
  283. np.ndarray: Distorted mask.
  284. """
  285. return fgeometric.perspective(
  286. mask,
  287. matrix,
  288. max_width,
  289. max_height,
  290. self.fill_mask,
  291. self.border_mode,
  292. self.keep_size,
  293. self.mask_interpolation,
  294. )
  295. def apply_to_bboxes(
  296. self,
  297. bboxes: np.ndarray,
  298. matrix_bbox: np.ndarray,
  299. max_height: int,
  300. max_width: int,
  301. **params: Any,
  302. ) -> np.ndarray:
  303. """Apply the perspective transform to a batch of bounding boxes.
  304. Args:
  305. bboxes (np.ndarray): Batch of bounding boxes to be distorted.
  306. matrix_bbox (np.ndarray): Transformation matrix.
  307. max_height (int): Maximum height of the bounding boxes.
  308. max_width (int): Maximum width of the bounding boxes.
  309. **params (Any): Additional parameters.
  310. Returns:
  311. np.ndarray: Batch of distorted bounding boxes.
  312. """
  313. return fgeometric.perspective_bboxes(
  314. bboxes,
  315. params["shape"],
  316. matrix_bbox,
  317. max_width,
  318. max_height,
  319. self.keep_size,
  320. )
  321. def apply_to_keypoints(
  322. self,
  323. keypoints: np.ndarray,
  324. matrix: np.ndarray,
  325. max_height: int,
  326. max_width: int,
  327. **params: Any,
  328. ) -> np.ndarray:
  329. """Apply the perspective transform to a batch of keypoints.
  330. Args:
  331. keypoints (np.ndarray): Batch of keypoints to be distorted.
  332. matrix (np.ndarray): Transformation matrix.
  333. max_height (int): Maximum height of the keypoints.
  334. max_width (int): Maximum width of the keypoints.
  335. **params (Any): Additional parameters.
  336. Returns:
  337. np.ndarray: Batch of distorted keypoints.
  338. """
  339. return fgeometric.perspective_keypoints(
  340. keypoints,
  341. params["shape"],
  342. matrix,
  343. max_width,
  344. max_height,
  345. self.keep_size,
  346. )
  347. def get_params_dependent_on_data(
  348. self,
  349. params: dict[str, Any],
  350. data: dict[str, Any],
  351. ) -> dict[str, Any]:
  352. """Get the parameters dependent on the data.
  353. Args:
  354. params (dict[str, Any]): Parameters.
  355. data (dict[str, Any]): Data.
  356. Returns:
  357. dict[str, Any]: Parameters.
  358. """
  359. image_shape = params["shape"][:2]
  360. scale = self.py_random.uniform(*self.scale)
  361. points = fgeometric.generate_perspective_points(
  362. image_shape,
  363. scale,
  364. self.random_generator,
  365. )
  366. points = fgeometric.order_points(points)
  367. matrix, max_width, max_height = fgeometric.compute_perspective_params(
  368. points,
  369. image_shape,
  370. )
  371. if self.fit_output:
  372. matrix, max_width, max_height = fgeometric.expand_transform(
  373. matrix,
  374. image_shape,
  375. )
  376. return {
  377. "matrix": matrix,
  378. "max_height": max_height,
  379. "max_width": max_width,
  380. "matrix_bbox": matrix,
  381. }
  382. class Affine(DualTransform):
  383. """Augmentation to apply affine transformations to images.
  384. Affine transformations involve:
  385. - Translation ("move" image on the x-/y-axis)
  386. - Rotation
  387. - Scaling ("zoom" in/out)
  388. - Shear (move one side of the image, turning a square into a trapezoid)
  389. All such transformations can create "new" pixels in the image without a defined content, e.g.
  390. if the image is translated to the left, pixels are created on the right.
  391. A method has to be defined to deal with these pixel values.
  392. The parameters `fill` and `fill_mask` of this class deal with this.
  393. Some transformations involve interpolations between several pixels
  394. of the input image to generate output pixel values. The parameters `interpolation` and
  395. `mask_interpolation` deals with the method of interpolation used for this.
  396. Args:
  397. scale (number, tuple of number or dict): Scaling factor to use, where ``1.0`` denotes "no change" and
  398. ``0.5`` is zoomed out to ``50`` percent of the original size.
  399. * If a single number, then that value will be used for all images.
  400. * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``.
  401. That the same range will be used for both x- and y-axis. To keep the aspect ratio, set
  402. ``keep_ratio=True``, then the same value will be used for both x- and y-axis.
  403. * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
  404. Each of these keys can have the same values as described above.
  405. Using a dictionary allows to set different values for the two axis and sampling will then happen
  406. *independently* per axis, resulting in samples that differ between the axes. Note that when
  407. the ``keep_ratio=True``, the x- and y-axis ranges should be the same.
  408. translate_percent (None, number, tuple of number or dict): Translation as a fraction of the image height/width
  409. (x-translation, y-translation), where ``0`` denotes "no change"
  410. and ``0.5`` denotes "half of the axis size".
  411. * If ``None`` then equivalent to ``0.0`` unless `translate_px` has a value other than ``None``.
  412. * If a single number, then that value will be used for all images.
  413. * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``.
  414. That sampled fraction value will be used identically for both x- and y-axis.
  415. * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
  416. Each of these keys can have the same values as described above.
  417. Using a dictionary allows to set different values for the two axis and sampling will then happen
  418. *independently* per axis, resulting in samples that differ between the axes.
  419. translate_px (None, int, tuple of int or dict): Translation in pixels.
  420. * If ``None`` then equivalent to ``0`` unless `translate_percent` has a value other than ``None``.
  421. * If a single int, then that value will be used for all images.
  422. * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from
  423. the discrete interval ``[a..b]``. That number will be used identically for both x- and y-axis.
  424. * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
  425. Each of these keys can have the same values as described above.
  426. Using a dictionary allows to set different values for the two axis and sampling will then happen
  427. *independently* per axis, resulting in samples that differ between the axes.
  428. rotate (number or tuple of number): Rotation in degrees (**NOT** radians), i.e. expected value range is
  429. around ``[-360, 360]``. Rotation happens around the *center* of the image,
  430. not the top left corner as in some other frameworks.
  431. * If a number, then that value will be used for all images.
  432. * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``
  433. and used as the rotation value.
  434. shear (number, tuple of number or dict): Shear in degrees (**NOT** radians), i.e. expected value range is
  435. around ``[-360, 360]``, with reasonable values being in the range of ``[-45, 45]``.
  436. * If a number, then that value will be used for all images as
  437. the shear on the x-axis (no shear on the y-axis will be done).
  438. * If a tuple ``(a, b)``, then two value will be uniformly sampled per image
  439. from the interval ``[a, b]`` and be used as the x- and y-shear value.
  440. * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
  441. Each of these keys can have the same values as described above.
  442. Using a dictionary allows to set different values for the two axis and sampling will then happen
  443. *independently* per axis, resulting in samples that differ between the axes.
  444. interpolation (int): OpenCV interpolation flag.
  445. mask_interpolation (int): OpenCV interpolation flag.
  446. fill (tuple[float, ...] | float): The constant value to use when filling in newly created pixels.
  447. (E.g. translating by 1px to the right will create a new 1px-wide column of pixels
  448. on the left of the image).
  449. The value is only used when `mode=constant`. The expected value range is ``[0, 255]`` for ``uint8`` images.
  450. fill_mask (tuple[float, ...] | float): Same as fill but only for masks.
  451. border_mode (int): OpenCV border flag.
  452. fit_output (bool): If True, the image plane size and position will be adjusted to tightly capture
  453. the whole image after affine transformation (`translate_percent` and `translate_px` are ignored).
  454. Otherwise (``False``), parts of the transformed image may end up outside the image plane.
  455. Fitting the output shape can be useful to avoid corners of the image being outside the image plane
  456. after applying rotations. Default: False
  457. keep_ratio (bool): When True, the original aspect ratio will be kept when the random scale is applied.
  458. Default: False.
  459. rotate_method (Literal["largest_box", "ellipse"]): rotation method used for the bounding boxes.
  460. Should be one of "largest_box" or "ellipse"[1]. Default: "largest_box"
  461. balanced_scale (bool): When True, scaling factors are chosen to be either entirely below or above 1,
  462. ensuring balanced scaling. Default: False.
  463. This is important because without it, scaling tends to lean towards upscaling. For example, if we want
  464. the image to zoom in and out by 2x, we may pick an interval [0.5, 2]. Since the interval [0.5, 1] is
  465. three times smaller than [1, 2], values above 1 are picked three times more often if sampled directly
  466. from [0.5, 2]. With `balanced_scale`, the function ensures that half the time, the scaling
  467. factor is picked from below 1 (zooming out), and the other half from above 1 (zooming in).
  468. This makes the zooming in and out process more balanced.
  469. p (float): probability of applying the transform. Default: 0.5.
  470. Targets:
  471. image, mask, keypoints, bboxes, volume, mask3d
  472. Image types:
  473. uint8, float32
  474. References:
  475. Towards Rotation Invariance in Object Detection: https://arxiv.org/abs/2109.13488
  476. Examples:
  477. >>> import numpy as np
  478. >>> import albumentations as A
  479. >>> import cv2
  480. >>>
  481. >>> # Prepare sample data
  482. >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  483. >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  484. >>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
  485. >>> bbox_labels = [1, 2]
  486. >>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
  487. >>> keypoint_labels = [0, 1]
  488. >>>
  489. >>> # Define transform with different parameter types
  490. >>> transform = A.Compose([
  491. ... A.Affine(
  492. ... # Tuple for scale (will be used for both x and y)
  493. ... scale=(0.8, 1.2),
  494. ... # Dictionary with tuples for different x/y translations
  495. ... translate_percent={"x": (-0.2, 0.2), "y": (-0.1, 0.1)},
  496. ... # Tuple for rotation range
  497. ... rotate=(-30, 30),
  498. ... # Dictionary with tuples for different x/y shearing
  499. ... shear={"x": (-10, 10), "y": (-5, 5)},
  500. ... # Interpolation methods
  501. ... interpolation=cv2.INTER_LINEAR,
  502. ... mask_interpolation=cv2.INTER_NEAREST,
  503. ... # Other parameters
  504. ... fit_output=False,
  505. ... keep_ratio=True,
  506. ... rotate_method="largest_box",
  507. ... balanced_scale=True,
  508. ... border_mode=cv2.BORDER_CONSTANT,
  509. ... fill=0,
  510. ... fill_mask=0,
  511. ... p=1.0
  512. ... ),
  513. ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
  514. ... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
  515. >>>
  516. >>> # Apply the transform
  517. >>> transformed = transform(
  518. ... image=image,
  519. ... mask=mask,
  520. ... bboxes=bboxes,
  521. ... bbox_labels=bbox_labels,
  522. ... keypoints=keypoints,
  523. ... keypoint_labels=keypoint_labels
  524. ... )
  525. >>>
  526. >>> # Get the transformed data
  527. >>> transformed_image = transformed['image'] # Image with affine transforms applied
  528. >>> transformed_mask = transformed['mask'] # Mask with affine transforms applied
  529. >>> transformed_bboxes = transformed['bboxes'] # Bounding boxes with affine transforms applied
  530. >>> transformed_bbox_labels = transformed['bbox_labels'] # Labels for transformed bboxes
  531. >>> transformed_keypoints = transformed['keypoints'] # Keypoints with affine transforms applied
  532. >>> transformed_keypoint_labels = transformed['keypoint_labels'] # Labels for transformed keypoints
  533. >>>
  534. >>> # Simpler example with only essential parameters
  535. >>> simple_transform = A.Compose([
  536. ... A.Affine(
  537. ... scale=1.1, # Single scalar value for scale
  538. ... rotate=15, # Single scalar value for rotation (degrees)
  539. ... translate_px=30, # Single scalar value for translation (pixels)
  540. ... p=1.0
  541. ... ),
  542. ... ])
  543. >>> simple_result = simple_transform(image=image)
  544. >>> simple_transformed = simple_result['image']
  545. """
  546. _targets = ALL_TARGETS
  547. class InitSchema(BaseTransformInitSchema):
  548. scale: tuple[float, float] | float | dict[str, float | tuple[float, float]]
  549. translate_percent: tuple[float, float] | float | dict[str, float | tuple[float, float]] | None
  550. translate_px: tuple[float, float] | float | dict[str, float | tuple[float, float]] | None
  551. rotate: tuple[float, float] | float
  552. shear: tuple[float, float] | float | dict[str, float | tuple[float, float]]
  553. interpolation: Literal[
  554. cv2.INTER_NEAREST,
  555. cv2.INTER_LINEAR,
  556. cv2.INTER_CUBIC,
  557. cv2.INTER_AREA,
  558. cv2.INTER_LANCZOS4,
  559. ]
  560. mask_interpolation: Literal[
  561. cv2.INTER_NEAREST,
  562. cv2.INTER_LINEAR,
  563. cv2.INTER_CUBIC,
  564. cv2.INTER_AREA,
  565. cv2.INTER_LANCZOS4,
  566. ]
  567. fill: tuple[float, ...] | float
  568. fill_mask: tuple[float, ...] | float
  569. border_mode: Literal[
  570. cv2.BORDER_CONSTANT,
  571. cv2.BORDER_REPLICATE,
  572. cv2.BORDER_REFLECT,
  573. cv2.BORDER_WRAP,
  574. cv2.BORDER_REFLECT_101,
  575. ]
  576. fit_output: bool
  577. keep_ratio: bool
  578. rotate_method: Literal["largest_box", "ellipse"]
  579. balanced_scale: bool
  580. @field_validator("shear", "scale")
  581. @classmethod
  582. def _process_shear(
  583. cls,
  584. value: tuple[float, float] | float | dict[str, float | tuple[float, float]],
  585. info: ValidationInfo,
  586. ) -> dict[str, tuple[float, float]]:
  587. return cls._handle_dict_arg(value, info.field_name)
  588. @field_validator("rotate")
  589. @classmethod
  590. def _process_rotate(
  591. cls,
  592. value: tuple[float, float] | float,
  593. ) -> tuple[float, float]:
  594. return to_tuple(value, value)
  595. @model_validator(mode="after")
  596. def _handle_translate(self) -> Self:
  597. if self.translate_percent is None and self.translate_px is None:
  598. self.translate_px = 0
  599. if self.translate_percent is not None and self.translate_px is not None:
  600. msg = "Expected either translate_percent or translate_px to be provided, but both were provided."
  601. raise ValueError(msg)
  602. if self.translate_percent is not None:
  603. self.translate_percent = self._handle_dict_arg(
  604. self.translate_percent,
  605. "translate_percent",
  606. default=0.0,
  607. ) # type: ignore[assignment]
  608. if self.translate_px is not None:
  609. self.translate_px = self._handle_dict_arg(
  610. self.translate_px,
  611. "translate_px",
  612. default=0,
  613. ) # type: ignore[assignment]
  614. return self
  615. @staticmethod
  616. def _handle_dict_arg(
  617. val: tuple[float, float]
  618. | dict[str, float | tuple[float, float]]
  619. | float
  620. | tuple[int, int]
  621. | dict[str, int | tuple[int, int]],
  622. name: str | None,
  623. default: float = 1.0,
  624. ) -> dict[str, tuple[float, float]]:
  625. if isinstance(val, float):
  626. return {"x": (val, val), "y": (val, val)}
  627. if isinstance(val, dict):
  628. if "x" not in val and "y" not in val:
  629. raise ValueError(
  630. f'Expected {name} dictionary to contain at least key "x" or key "y". Found neither of them.',
  631. )
  632. x = val.get("x", default)
  633. y = val.get("y", default)
  634. return {"x": to_tuple(x, x), "y": to_tuple(y, y)}
  635. return {"x": to_tuple(val, val), "y": to_tuple(val, val)}
  636. def __init__(
  637. self,
  638. scale: tuple[float, float] | float | dict[str, float | tuple[float, float]] = (1.0, 1.0),
  639. translate_percent: tuple[float, float] | float | dict[str, float | tuple[float, float]] | None = None,
  640. translate_px: tuple[int, int] | int | dict[str, int | tuple[int, int]] | None = None,
  641. rotate: tuple[float, float] | float = 0.0,
  642. shear: tuple[float, float] | float | dict[str, float | tuple[float, float]] = (0.0, 0.0),
  643. interpolation: Literal[
  644. cv2.INTER_NEAREST,
  645. cv2.INTER_LINEAR,
  646. cv2.INTER_CUBIC,
  647. cv2.INTER_AREA,
  648. cv2.INTER_LANCZOS4,
  649. ] = cv2.INTER_LINEAR,
  650. mask_interpolation: Literal[
  651. cv2.INTER_NEAREST,
  652. cv2.INTER_LINEAR,
  653. cv2.INTER_CUBIC,
  654. cv2.INTER_AREA,
  655. cv2.INTER_LANCZOS4,
  656. ] = cv2.INTER_NEAREST,
  657. fit_output: bool = False,
  658. keep_ratio: bool = False,
  659. rotate_method: Literal["largest_box", "ellipse"] = "largest_box",
  660. balanced_scale: bool = False,
  661. border_mode: Literal[
  662. cv2.BORDER_CONSTANT,
  663. cv2.BORDER_REPLICATE,
  664. cv2.BORDER_REFLECT,
  665. cv2.BORDER_WRAP,
  666. cv2.BORDER_REFLECT_101,
  667. ] = cv2.BORDER_CONSTANT,
  668. fill: tuple[float, ...] | float = 0,
  669. fill_mask: tuple[float, ...] | float = 0,
  670. p: float = 0.5,
  671. ):
  672. super().__init__(p=p)
  673. self.interpolation = interpolation
  674. self.mask_interpolation = mask_interpolation
  675. self.fill = fill
  676. self.fill_mask = fill_mask
  677. self.border_mode = border_mode
  678. self.scale = cast("dict[str, tuple[float, float]]", scale)
  679. self.translate_percent = cast("dict[str, tuple[float, float]]", translate_percent)
  680. self.translate_px = cast("dict[str, tuple[int, int]]", translate_px)
  681. self.rotate = cast("tuple[float, float]", rotate)
  682. self.fit_output = fit_output
  683. self.shear = cast("dict[str, tuple[float, float]]", shear)
  684. self.keep_ratio = keep_ratio
  685. self.rotate_method = rotate_method
  686. self.balanced_scale = balanced_scale
  687. if self.keep_ratio and self.scale["x"] != self.scale["y"]:
  688. raise ValueError(
  689. f"When keep_ratio is True, the x and y scale range should be identical. got {self.scale}",
  690. )
  691. def apply(
  692. self,
  693. img: np.ndarray,
  694. matrix: np.ndarray,
  695. output_shape: tuple[int, int],
  696. **params: Any,
  697. ) -> np.ndarray:
  698. """Apply the affine transform to an image.
  699. Args:
  700. img (np.ndarray): Image to be distorted.
  701. matrix (np.ndarray): Transformation matrix.
  702. output_shape (tuple[int, int]): Output shape.
  703. **params (Any): Additional parameters.
  704. Returns:
  705. np.ndarray: Distorted image.
  706. """
  707. return fgeometric.warp_affine(
  708. img,
  709. matrix,
  710. interpolation=self.interpolation,
  711. fill=self.fill,
  712. border_mode=self.border_mode,
  713. output_shape=output_shape,
  714. )
  715. def apply_to_mask(
  716. self,
  717. mask: np.ndarray,
  718. matrix: np.ndarray,
  719. output_shape: tuple[int, int],
  720. **params: Any,
  721. ) -> np.ndarray:
  722. """Apply the affine transform to a mask.
  723. Args:
  724. mask (np.ndarray): Mask to be distorted.
  725. matrix (np.ndarray): Transformation matrix.
  726. output_shape (tuple[int, int]): Output shape.
  727. **params (Any): Additional parameters.
  728. Returns:
  729. np.ndarray: Distorted mask.
  730. """
  731. return fgeometric.warp_affine(
  732. mask,
  733. matrix,
  734. interpolation=self.mask_interpolation,
  735. fill=self.fill_mask,
  736. border_mode=self.border_mode,
  737. output_shape=output_shape,
  738. )
  739. def apply_to_bboxes(
  740. self,
  741. bboxes: np.ndarray,
  742. bbox_matrix: np.ndarray,
  743. output_shape: tuple[int, int],
  744. **params: Any,
  745. ) -> np.ndarray:
  746. """Apply the affine transform to bounding boxes.
  747. Args:
  748. bboxes (np.ndarray): Bounding boxes to be distorted.
  749. bbox_matrix (np.ndarray): Transformation matrix.
  750. output_shape (tuple[int, int]): Output shape.
  751. **params (Any): Additional parameters.
  752. Returns:
  753. np.ndarray: Distorted bounding boxes.
  754. """
  755. return fgeometric.bboxes_affine(
  756. bboxes,
  757. bbox_matrix,
  758. self.rotate_method,
  759. params["shape"][:2],
  760. self.border_mode,
  761. output_shape,
  762. )
  763. def apply_to_keypoints(
  764. self,
  765. keypoints: np.ndarray,
  766. matrix: np.ndarray,
  767. scale: dict[str, float],
  768. **params: Any,
  769. ) -> np.ndarray:
  770. """Apply the affine transform to keypoints.
  771. Args:
  772. keypoints (np.ndarray): Keypoints to be distorted.
  773. matrix (np.ndarray): Transformation matrix.
  774. scale (dict[str, float]): Scale.
  775. **params (Any): Additional parameters.
  776. Returns:
  777. np.ndarray: Distorted keypoints.
  778. """
  779. return fgeometric.keypoints_affine(
  780. keypoints,
  781. matrix,
  782. params["shape"],
  783. scale,
  784. self.border_mode,
  785. )
  786. @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
  787. def apply_to_images(self, images: np.ndarray, **params: Any) -> np.ndarray:
  788. """Apply the affine transform to a batch of images.
  789. Args:
  790. images (np.ndarray): Images to be distorted.
  791. **params (Any): Additional parameters.
  792. Returns:
  793. np.ndarray: Distorted images.
  794. """
  795. return self.apply(images, **params)
  796. @batch_transform("spatial", has_batch_dim=False, has_depth_dim=True)
  797. def apply_to_volume(self, volume: np.ndarray, **params: Any) -> np.ndarray:
  798. """Apply the affine transform to a volume.
  799. Args:
  800. volume (np.ndarray): Volume to be distorted.
  801. **params (Any): Additional parameters.
  802. Returns:
  803. np.ndarray: Distorted volume.
  804. """
  805. return self.apply(volume, **params)
  806. @batch_transform("spatial", has_batch_dim=True, has_depth_dim=True)
  807. def apply_to_volumes(self, volumes: np.ndarray, **params: Any) -> np.ndarray:
  808. """Apply the affine transform to a batch of volumes.
  809. Args:
  810. volumes (np.ndarray): Volumes to be distorted.
  811. **params (Any): Additional parameters.
  812. Returns:
  813. np.ndarray: Distorted volumes.
  814. """
  815. return self.apply(volumes, **params)
  816. @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
  817. def apply_to_mask3d(self, mask3d: np.ndarray, **params: Any) -> np.ndarray:
  818. """Apply the affine transform to a 3D mask.
  819. Args:
  820. mask3d (np.ndarray): 3D mask to be distorted.
  821. **params (Any): Additional parameters.
  822. Returns:
  823. np.ndarray: Distorted 3D mask.
  824. """
  825. return self.apply_to_mask(mask3d, **params)
  826. @staticmethod
  827. def _get_scale(
  828. scale: dict[str, tuple[float, float]],
  829. keep_ratio: bool,
  830. balanced_scale: bool,
  831. random_state: random.Random,
  832. ) -> dict[str, float]:
  833. result_scale = {}
  834. for key, value in scale.items():
  835. if isinstance(value, (int, float)):
  836. result_scale[key] = float(value)
  837. elif isinstance(value, tuple):
  838. if balanced_scale:
  839. lower_interval = (value[0], 1.0) if value[0] < 1 else None
  840. upper_interval = (1.0, value[1]) if value[1] > 1 else None
  841. if lower_interval is not None and upper_interval is not None:
  842. selected_interval = random_state.choice(
  843. [lower_interval, upper_interval],
  844. )
  845. elif lower_interval is not None:
  846. selected_interval = lower_interval
  847. elif upper_interval is not None:
  848. selected_interval = upper_interval
  849. else:
  850. result_scale[key] = 1.0
  851. continue
  852. result_scale[key] = random_state.uniform(*selected_interval)
  853. else:
  854. result_scale[key] = random_state.uniform(*value)
  855. else:
  856. raise TypeError(
  857. f"Invalid scale value for key {key}: {value}. Expected a float or a tuple of two floats.",
  858. )
  859. if keep_ratio:
  860. result_scale["y"] = result_scale["x"]
  861. return result_scale
  862. def get_params_dependent_on_data(
  863. self,
  864. params: dict[str, Any],
  865. data: dict[str, Any],
  866. ) -> dict[str, Any]:
  867. """Get the parameters dependent on the data.
  868. Args:
  869. params (dict[str, Any]): Parameters.
  870. data (dict[str, Any]): Data.
  871. Returns:
  872. dict[str, Any]: Parameters.
  873. """
  874. image_shape = params["shape"][:2]
  875. translate = self._get_translate_params(image_shape)
  876. shear = self._get_shear_params()
  877. scale = self._get_scale(
  878. self.scale,
  879. self.keep_ratio,
  880. self.balanced_scale,
  881. self.py_random,
  882. )
  883. rotate = self.py_random.uniform(*self.rotate)
  884. image_shift = fgeometric.center(image_shape)
  885. bbox_shift = fgeometric.center_bbox(image_shape)
  886. matrix = fgeometric.create_affine_transformation_matrix(
  887. translate,
  888. shear,
  889. scale,
  890. rotate,
  891. image_shift,
  892. )
  893. bbox_matrix = fgeometric.create_affine_transformation_matrix(
  894. translate,
  895. shear,
  896. scale,
  897. rotate,
  898. bbox_shift,
  899. )
  900. if self.fit_output:
  901. matrix, output_shape = fgeometric.compute_affine_warp_output_shape(
  902. matrix,
  903. image_shape,
  904. )
  905. bbox_matrix, _ = fgeometric.compute_affine_warp_output_shape(
  906. bbox_matrix,
  907. image_shape,
  908. )
  909. else:
  910. output_shape = image_shape
  911. return {
  912. "rotate": rotate,
  913. "scale": scale,
  914. "matrix": matrix,
  915. "bbox_matrix": bbox_matrix,
  916. "output_shape": output_shape,
  917. }
  918. def _get_translate_params(self, image_shape: tuple[int, int]) -> dict[str, int]:
  919. height, width = image_shape[:2]
  920. if self.translate_px is not None:
  921. return {
  922. "x": self.py_random.randint(int(self.translate_px["x"][0]), int(self.translate_px["x"][1])),
  923. "y": self.py_random.randint(int(self.translate_px["y"][0]), int(self.translate_px["y"][1])),
  924. }
  925. if self.translate_percent is not None:
  926. translate = {key: self.py_random.uniform(*value) for key, value in self.translate_percent.items()}
  927. return cast(
  928. "dict[str, int]",
  929. {"x": int(translate["x"] * width), "y": int(translate["y"] * height)},
  930. )
  931. return cast("dict[str, int]", {"x": 0, "y": 0})
  932. def _get_shear_params(self) -> dict[str, float]:
  933. return {
  934. "x": -self.py_random.uniform(*self.shear["x"]),
  935. "y": -self.py_random.uniform(*self.shear["y"]),
  936. }
  937. class ShiftScaleRotate(Affine):
  938. """Randomly apply affine transforms: translate, scale and rotate the input.
  939. Args:
  940. shift_limit ((float, float) or float): shift factor range for both height and width. If shift_limit
  941. is a single float value, the range will be (-shift_limit, shift_limit). Absolute values for lower and
  942. upper bounds should lie in range [-1, 1]. Default: (-0.0625, 0.0625).
  943. scale_limit ((float, float) or float): scaling factor range. If scale_limit is a single float value, the
  944. range will be (-scale_limit, scale_limit). Note that the scale_limit will be biased by 1.
  945. If scale_limit is a tuple, like (low, high), sampling will be done from the range (1 + low, 1 + high).
  946. Default: (-0.1, 0.1).
  947. rotate_limit ((int, int) or int): rotation range. If rotate_limit is a single int value, the
  948. range will be (-rotate_limit, rotate_limit). Default: (-45, 45).
  949. interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
  950. cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
  951. Default: cv2.INTER_LINEAR.
  952. border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
  953. cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
  954. Default: cv2.BORDER_CONSTANT
  955. fill (tuple[float, ...] | float): padding value if border_mode is cv2.BORDER_CONSTANT.
  956. fill_mask (tuple[float, ...] | float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
  957. shift_limit_x ((float, float) or float): shift factor range for width. If it is set then this value
  958. instead of shift_limit will be used for shifting width. If shift_limit_x is a single float value,
  959. the range will be (-shift_limit_x, shift_limit_x). Absolute values for lower and upper bounds should lie in
  960. the range [-1, 1]. Default: None.
  961. shift_limit_y ((float, float) or float): shift factor range for height. If it is set then this value
  962. instead of shift_limit will be used for shifting height. If shift_limit_y is a single float value,
  963. the range will be (-shift_limit_y, shift_limit_y). Absolute values for lower and upper bounds should lie
  964. in the range [-, 1]. Default: None.
  965. rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or "ellipse".
  966. Default: "largest_box"
  967. mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
  968. Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
  969. Default: cv2.INTER_NEAREST.
  970. p (float): probability of applying the transform. Default: 0.5.
  971. Targets:
  972. image, mask, keypoints, bboxes, volume, mask3d
  973. Image types:
  974. uint8, float32
  975. Examples:
  976. >>> import numpy as np
  977. >>> import albumentations as A
  978. >>> import cv2
  979. >>>
  980. >>> # Prepare sample data
  981. >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  982. >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  983. >>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
  984. >>> bbox_labels = [1, 2]
  985. >>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
  986. >>> keypoint_labels = [0, 1]
  987. >>>
  988. >>> # Define transform with parameters as tuples when possible
  989. >>> transform = A.Compose([
  990. ... A.ShiftScaleRotate(
  991. ... shift_limit=(-0.0625, 0.0625),
  992. ... scale_limit=(-0.1, 0.1),
  993. ... rotate_limit=(-45, 45),
  994. ... interpolation=cv2.INTER_LINEAR,
  995. ... border_mode=cv2.BORDER_CONSTANT,
  996. ... rotate_method="largest_box",
  997. ... p=1.0
  998. ... ),
  999. ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
  1000. ... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
  1001. >>>
  1002. >>> # Apply the transform
  1003. >>> transformed = transform(
  1004. ... image=image,
  1005. ... mask=mask,
  1006. ... bboxes=bboxes,
  1007. ... bbox_labels=bbox_labels,
  1008. ... keypoints=keypoints,
  1009. ... keypoint_labels=keypoint_labels
  1010. ... )
  1011. >>>
  1012. >>> # Get the transformed data
  1013. >>> transformed_image = transformed['image'] # Shifted, scaled and rotated image
  1014. >>> transformed_mask = transformed['mask'] # Shifted, scaled and rotated mask
  1015. >>> transformed_bboxes = transformed['bboxes'] # Shifted, scaled and rotated bounding boxes
  1016. >>> transformed_bbox_labels = transformed['bbox_labels'] # Labels for transformed bboxes
  1017. >>> transformed_keypoints = transformed['keypoints'] # Shifted, scaled and rotated keypoints
  1018. >>> transformed_keypoint_labels = transformed['keypoint_labels'] # Labels for transformed keypoints
  1019. """
  1020. _targets = ALL_TARGETS
  1021. class InitSchema(BaseTransformInitSchema):
  1022. shift_limit: SymmetricRangeType
  1023. scale_limit: SymmetricRangeType
  1024. rotate_limit: SymmetricRangeType
  1025. interpolation: Literal[
  1026. cv2.INTER_NEAREST,
  1027. cv2.INTER_LINEAR,
  1028. cv2.INTER_CUBIC,
  1029. cv2.INTER_AREA,
  1030. cv2.INTER_LANCZOS4,
  1031. ]
  1032. border_mode: Literal[
  1033. cv2.BORDER_CONSTANT,
  1034. cv2.BORDER_REPLICATE,
  1035. cv2.BORDER_REFLECT,
  1036. cv2.BORDER_WRAP,
  1037. cv2.BORDER_REFLECT_101,
  1038. ]
  1039. fill: tuple[float, ...] | float
  1040. fill_mask: tuple[float, ...] | float
  1041. shift_limit_x: tuple[float, float] | float | None
  1042. shift_limit_y: tuple[float, float] | float | None
  1043. rotate_method: Literal["largest_box", "ellipse"]
  1044. mask_interpolation: Literal[
  1045. cv2.INTER_NEAREST,
  1046. cv2.INTER_LINEAR,
  1047. cv2.INTER_CUBIC,
  1048. cv2.INTER_AREA,
  1049. cv2.INTER_LANCZOS4,
  1050. ]
  1051. @model_validator(mode="after")
  1052. def _check_shift_limit(self) -> Self:
  1053. bounds = -1, 1
  1054. self.shift_limit_x = to_tuple(
  1055. self.shift_limit_x if self.shift_limit_x is not None else self.shift_limit,
  1056. )
  1057. check_range(self.shift_limit_x, *bounds, "shift_limit_x")
  1058. self.shift_limit_y = to_tuple(
  1059. self.shift_limit_y if self.shift_limit_y is not None else self.shift_limit,
  1060. )
  1061. check_range(self.shift_limit_y, *bounds, "shift_limit_y")
  1062. return self
  1063. @field_validator("scale_limit")
  1064. @classmethod
  1065. def _check_scale_limit(
  1066. cls,
  1067. value: tuple[float, float] | float,
  1068. info: ValidationInfo,
  1069. ) -> tuple[float, float]:
  1070. bounds = 0, float("inf")
  1071. result = to_tuple(value, bias=1.0)
  1072. check_range(result, *bounds, str(info.field_name))
  1073. return result
  1074. def __init__(
  1075. self,
  1076. shift_limit: tuple[float, float] | float = (-0.0625, 0.0625),
  1077. scale_limit: tuple[float, float] | float = (-0.1, 0.1),
  1078. rotate_limit: tuple[float, float] | float = (-45, 45),
  1079. interpolation: Literal[
  1080. cv2.INTER_NEAREST,
  1081. cv2.INTER_LINEAR,
  1082. cv2.INTER_CUBIC,
  1083. cv2.INTER_AREA,
  1084. cv2.INTER_LANCZOS4,
  1085. ] = cv2.INTER_LINEAR,
  1086. border_mode: int = cv2.BORDER_CONSTANT,
  1087. shift_limit_x: tuple[float, float] | float | None = None,
  1088. shift_limit_y: tuple[float, float] | float | None = None,
  1089. rotate_method: Literal["largest_box", "ellipse"] = "largest_box",
  1090. mask_interpolation: Literal[
  1091. cv2.INTER_NEAREST,
  1092. cv2.INTER_LINEAR,
  1093. cv2.INTER_CUBIC,
  1094. cv2.INTER_AREA,
  1095. cv2.INTER_LANCZOS4,
  1096. ] = cv2.INTER_NEAREST,
  1097. fill: tuple[float, ...] | float = 0,
  1098. fill_mask: tuple[float, ...] | float = 0,
  1099. p: float = 0.5,
  1100. ):
  1101. shift_limit_x = cast("tuple[float, float]", shift_limit_x)
  1102. shift_limit_y = cast("tuple[float, float]", shift_limit_y)
  1103. super().__init__(
  1104. scale=scale_limit,
  1105. translate_percent={"x": shift_limit_x, "y": shift_limit_y},
  1106. rotate=rotate_limit,
  1107. shear=(0, 0),
  1108. interpolation=interpolation,
  1109. mask_interpolation=mask_interpolation,
  1110. fill=fill,
  1111. fill_mask=fill_mask,
  1112. border_mode=border_mode,
  1113. fit_output=False,
  1114. keep_ratio=False,
  1115. rotate_method=rotate_method,
  1116. p=p,
  1117. )
  1118. warn(
  1119. "ShiftScaleRotate is a special case of Affine transform. Please use Affine transform instead.",
  1120. UserWarning,
  1121. stacklevel=2,
  1122. )
  1123. self.shift_limit_x = shift_limit_x
  1124. self.shift_limit_y = shift_limit_y
  1125. self.scale_limit = cast("tuple[float, float]", scale_limit)
  1126. self.rotate_limit = cast("tuple[int, int]", rotate_limit)
  1127. self.border_mode = border_mode
  1128. self.fill = fill
  1129. self.fill_mask = fill_mask
  1130. def get_transform_init_args(self) -> dict[str, Any]:
  1131. """Get the transform initialization arguments.
  1132. Returns:
  1133. dict[str, Any]: Transform initialization arguments.
  1134. """
  1135. return {
  1136. "shift_limit_x": self.shift_limit_x,
  1137. "shift_limit_y": self.shift_limit_y,
  1138. "scale_limit": to_tuple(self.scale_limit, bias=-1.0),
  1139. "rotate_limit": self.rotate_limit,
  1140. "interpolation": self.interpolation,
  1141. "border_mode": self.border_mode,
  1142. "fill": self.fill,
  1143. "fill_mask": self.fill_mask,
  1144. "rotate_method": self.rotate_method,
  1145. "mask_interpolation": self.mask_interpolation,
  1146. }
  1147. class GridElasticDeform(DualTransform):
  1148. """Apply elastic deformations to images, masks, bounding boxes, and keypoints using a grid-based approach.
  1149. This transformation overlays a grid on the input and applies random displacements to the grid points,
  1150. resulting in local elastic distortions. The granularity and intensity of the distortions can be
  1151. controlled using the dimensions of the overlaying distortion grid and the magnitude parameter.
  1152. Args:
  1153. num_grid_xy (tuple[int, int]): Number of grid cells along the width and height.
  1154. Specified as (grid_width, grid_height). Each value must be greater than 1.
  1155. magnitude (int): Maximum pixel-wise displacement for distortion. Must be greater than 0.
  1156. interpolation (int): Interpolation method to be used for the image transformation.
  1157. Default: cv2.INTER_LINEAR
  1158. mask_interpolation (int): Interpolation method to be used for mask transformation.
  1159. Default: cv2.INTER_NEAREST
  1160. p (float): Probability of applying the transform. Default: 1.0.
  1161. Targets:
  1162. image, mask, bboxes, keypoints, volume, mask3d
  1163. Image types:
  1164. uint8, float32
  1165. Number of channels:
  1166. 1, 3
  1167. Examples:
  1168. >>> import numpy as np
  1169. >>> import albumentations as A
  1170. >>> import cv2
  1171. >>>
  1172. >>> # Prepare sample data
  1173. >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  1174. >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  1175. >>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
  1176. >>> bbox_labels = [1, 2]
  1177. >>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
  1178. >>> keypoint_labels = [0, 1]
  1179. >>>
  1180. >>> # Define transform with parameters as tuples when possible
  1181. >>> transform = A.Compose([
  1182. ... A.GridElasticDeform(
  1183. ... num_grid_xy=(4, 4),
  1184. ... magnitude=10,
  1185. ... interpolation=cv2.INTER_LINEAR,
  1186. ... mask_interpolation=cv2.INTER_NEAREST,
  1187. ... p=1.0
  1188. ... ),
  1189. ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
  1190. ... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
  1191. >>>
  1192. >>> # Apply the transform
  1193. >>> transformed = transform(
  1194. ... image=image,
  1195. ... mask=mask,
  1196. ... bboxes=bboxes,
  1197. ... bbox_labels=bbox_labels,
  1198. ... keypoints=keypoints,
  1199. ... keypoint_labels=keypoint_labels
  1200. ... )
  1201. >>>
  1202. >>> # Get the transformed data
  1203. >>> transformed_image = transformed['image'] # Elastically deformed image
  1204. >>> transformed_mask = transformed['mask'] # Elastically deformed mask
  1205. >>> transformed_bboxes = transformed['bboxes'] # Elastically deformed bounding boxes
  1206. >>> transformed_bbox_labels = transformed['bbox_labels'] # Labels for transformed bboxes
  1207. >>> transformed_keypoints = transformed['keypoints'] # Elastically deformed keypoints
  1208. >>> transformed_keypoint_labels = transformed['keypoint_labels'] # Labels for transformed keypoints
  1209. Note:
  1210. This transformation is particularly useful for data augmentation in medical imaging
  1211. and other domains where elastic deformations can simulate realistic variations.
  1212. """
  1213. _targets = ALL_TARGETS
  1214. class InitSchema(BaseTransformInitSchema):
  1215. num_grid_xy: Annotated[tuple[int, int], AfterValidator(check_range_bounds(1, None))]
  1216. magnitude: int = Field(gt=0)
  1217. interpolation: Literal[
  1218. cv2.INTER_NEAREST,
  1219. cv2.INTER_LINEAR,
  1220. cv2.INTER_CUBIC,
  1221. cv2.INTER_AREA,
  1222. cv2.INTER_LANCZOS4,
  1223. ]
  1224. mask_interpolation: Literal[
  1225. cv2.INTER_NEAREST,
  1226. cv2.INTER_LINEAR,
  1227. cv2.INTER_CUBIC,
  1228. cv2.INTER_AREA,
  1229. cv2.INTER_LANCZOS4,
  1230. ]
  1231. def __init__(
  1232. self,
  1233. num_grid_xy: tuple[int, int],
  1234. magnitude: int,
  1235. interpolation: Literal[
  1236. cv2.INTER_NEAREST,
  1237. cv2.INTER_LINEAR,
  1238. cv2.INTER_CUBIC,
  1239. cv2.INTER_AREA,
  1240. cv2.INTER_LANCZOS4,
  1241. ] = cv2.INTER_LINEAR,
  1242. mask_interpolation: Literal[
  1243. cv2.INTER_NEAREST,
  1244. cv2.INTER_LINEAR,
  1245. cv2.INTER_CUBIC,
  1246. cv2.INTER_AREA,
  1247. cv2.INTER_LANCZOS4,
  1248. ] = cv2.INTER_NEAREST,
  1249. p: float = 1.0,
  1250. ):
  1251. super().__init__(p=p)
  1252. self.num_grid_xy = num_grid_xy
  1253. self.magnitude = magnitude
  1254. self.interpolation = interpolation
  1255. self.mask_interpolation = mask_interpolation
  1256. @staticmethod
  1257. def _generate_mesh(polygons: np.ndarray, dimensions: np.ndarray) -> np.ndarray:
  1258. return np.hstack((dimensions.reshape(-1, 4), polygons))
  1259. def get_params_dependent_on_data(
  1260. self,
  1261. params: dict[str, Any],
  1262. data: dict[str, Any],
  1263. ) -> dict[str, Any]:
  1264. """Get the parameters dependent on the data.
  1265. Args:
  1266. params (dict[str, Any]): Parameters.
  1267. data (dict[str, Any]): Data.
  1268. Returns:
  1269. dict[str, Any]: Parameters.
  1270. """
  1271. image_shape = params["shape"][:2]
  1272. # Replace calculate_grid_dimensions with split_uniform_grid
  1273. tiles = fgeometric.split_uniform_grid(
  1274. image_shape,
  1275. self.num_grid_xy,
  1276. self.random_generator,
  1277. )
  1278. # Convert tiles to the format expected by generate_distorted_grid_polygons
  1279. dimensions = np.array(
  1280. [
  1281. [
  1282. tile[1],
  1283. tile[0],
  1284. tile[3],
  1285. tile[2],
  1286. ] # Reorder to [x_min, y_min, x_max, y_max]
  1287. for tile in tiles
  1288. ],
  1289. ).reshape(
  1290. self.num_grid_xy[::-1] + (4,),
  1291. ) # Reshape to (grid_height, grid_width, 4)
  1292. polygons = fgeometric.generate_distorted_grid_polygons(
  1293. dimensions,
  1294. self.magnitude,
  1295. self.random_generator,
  1296. )
  1297. generated_mesh = self._generate_mesh(polygons, dimensions)
  1298. return {"generated_mesh": generated_mesh}
  1299. def apply(
  1300. self,
  1301. img: np.ndarray,
  1302. generated_mesh: np.ndarray,
  1303. **params: Any,
  1304. ) -> np.ndarray:
  1305. """Apply the GridElasticDeform transform to an image.
  1306. Args:
  1307. img (np.ndarray): Image to be transformed.
  1308. generated_mesh (np.ndarray): Generated mesh.
  1309. **params (Any): Additional parameters.
  1310. """
  1311. if not is_rgb_image(img) and not is_grayscale_image(img):
  1312. raise ValueError("GridElasticDeform transform is only supported for RGB and grayscale images.")
  1313. return fgeometric.distort_image(img, generated_mesh, self.interpolation)
  1314. def apply_to_mask(
  1315. self,
  1316. mask: np.ndarray,
  1317. generated_mesh: np.ndarray,
  1318. **params: Any,
  1319. ) -> np.ndarray:
  1320. """Apply the GridElasticDeform transform to a mask.
  1321. Args:
  1322. mask (np.ndarray): Mask to be transformed.
  1323. generated_mesh (np.ndarray): Generated mesh.
  1324. **params (Any): Additional parameters.
  1325. """
  1326. return fgeometric.distort_image(mask, generated_mesh, self.mask_interpolation)
  1327. def apply_to_bboxes(
  1328. self,
  1329. bboxes: np.ndarray,
  1330. generated_mesh: np.ndarray,
  1331. **params: Any,
  1332. ) -> np.ndarray:
  1333. """Apply the GridElasticDeform transform to bounding boxes.
  1334. Args:
  1335. bboxes (np.ndarray): Bounding boxes to be transformed.
  1336. generated_mesh (np.ndarray): Generated mesh.
  1337. **params (Any): Additional parameters.
  1338. """
  1339. bboxes_denorm = denormalize_bboxes(bboxes, params["shape"][:2])
  1340. return normalize_bboxes(
  1341. fgeometric.bbox_distort_image(
  1342. bboxes_denorm,
  1343. generated_mesh,
  1344. params["shape"][:2],
  1345. ),
  1346. params["shape"][:2],
  1347. )
  1348. def apply_to_keypoints(
  1349. self,
  1350. keypoints: np.ndarray,
  1351. generated_mesh: np.ndarray,
  1352. **params: Any,
  1353. ) -> np.ndarray:
  1354. """Apply the GridElasticDeform transform to keypoints.
  1355. Args:
  1356. keypoints (np.ndarray): Keypoints to be transformed.
  1357. generated_mesh (np.ndarray): Generated mesh.
  1358. **params (Any): Additional parameters.
  1359. """
  1360. return fgeometric.distort_image_keypoints(
  1361. keypoints,
  1362. generated_mesh,
  1363. params["shape"][:2],
  1364. )
  1365. class RandomGridShuffle(DualTransform):
  1366. """Randomly shuffles the grid's cells on an image, mask, or keypoints,
  1367. effectively rearranging patches within the image.
  1368. This transformation divides the image into a grid and then permutes these grid cells based on a random mapping.
  1369. Args:
  1370. grid (tuple[int, int]): Size of the grid for splitting the image into cells. Each cell is shuffled randomly.
  1371. For example, (3, 3) will divide the image into a 3x3 grid, resulting in 9 cells to be shuffled.
  1372. Default: (3, 3)
  1373. p (float): Probability that the transform will be applied. Should be in the range [0, 1].
  1374. Default: 0.5
  1375. Targets:
  1376. image, mask, keypoints, bboxes, volume, mask3d
  1377. Image types:
  1378. uint8, float32
  1379. Note:
  1380. - This transform maintains consistency across all targets. If applied to an image and its corresponding
  1381. mask or keypoints, the same shuffling will be applied to all.
  1382. - The number of cells in the grid should be at least 2 (i.e., grid should be at least (1, 2), (2, 1), or (2, 2))
  1383. for the transform to have any effect.
  1384. - Keypoints are moved along with their corresponding grid cell.
  1385. - This transform could be useful when only micro features are important for the model, and memorizing
  1386. the global structure could be harmful. For example:
  1387. - Identifying the type of cell phone used to take a picture based on micro artifacts generated by
  1388. phone post-processing algorithms, rather than the semantic features of the photo.
  1389. See more at https://ieeexplore.ieee.org/abstract/document/8622031
  1390. - Identifying stress, glucose, hydration levels based on skin images.
  1391. Mathematical Formulation:
  1392. 1. The image is divided into a grid of size (m, n) as specified by the 'grid' parameter.
  1393. 2. A random permutation P of integers from 0 to (m*n - 1) is generated.
  1394. 3. Each cell in the grid is assigned a number from 0 to (m*n - 1) in row-major order.
  1395. 4. The cells are then rearranged according to the permutation P.
  1396. Examples:
  1397. >>> import numpy as np
  1398. >>> import albumentations as A
  1399. >>> # Prepare sample data
  1400. >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  1401. >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  1402. >>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
  1403. >>> bbox_labels = [1, 2]
  1404. >>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
  1405. >>> keypoint_labels = [0, 1]
  1406. >>>
  1407. >>> # Define transform with grid as a tuple
  1408. >>> transform = A.Compose([
  1409. ... A.RandomGridShuffle(grid=(3, 3), p=1.0),
  1410. ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
  1411. ... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
  1412. >>>
  1413. >>> # Apply the transform
  1414. >>> transformed = transform(
  1415. ... image=image,
  1416. ... mask=mask,
  1417. ... bboxes=bboxes,
  1418. ... bbox_labels=bbox_labels,
  1419. ... keypoints=keypoints,
  1420. ... keypoint_labels=keypoint_labels
  1421. ... )
  1422. >>>
  1423. >>> # Get the transformed data
  1424. >>> transformed_image = transformed['image'] # Grid-shuffled image
  1425. >>> transformed_mask = transformed['mask'] # Grid-shuffled mask
  1426. >>> transformed_bboxes = transformed['bboxes'] # Grid-shuffled bounding boxes
  1427. >>> transformed_keypoints = transformed['keypoints'] # Grid-shuffled keypoints
  1428. >>>
  1429. >>> # Visualization example with a simpler grid
  1430. >>> simple_image = np.array([
  1431. ... [1, 1, 1, 2, 2, 2],
  1432. ... [1, 1, 1, 2, 2, 2],
  1433. ... [1, 1, 1, 2, 2, 2],
  1434. ... [3, 3, 3, 4, 4, 4],
  1435. ... [3, 3, 3, 4, 4, 4],
  1436. ... [3, 3, 3, 4, 4, 4]
  1437. ... ])
  1438. >>> simple_transform = A.RandomGridShuffle(grid=(2, 2), p=1.0)
  1439. >>> simple_result = simple_transform(image=simple_image)
  1440. >>> simple_transformed = simple_result['image']
  1441. >>> # The result could look like:
  1442. >>> # array([[4, 4, 4, 2, 2, 2],
  1443. >>> # [4, 4, 4, 2, 2, 2],
  1444. >>> # [4, 4, 4, 2, 2, 2],
  1445. >>> # [3, 3, 3, 1, 1, 1],
  1446. >>> # [3, 3, 3, 1, 1, 1],
  1447. >>> # [3, 3, 3, 1, 1, 1]])
  1448. """
  1449. class InitSchema(BaseTransformInitSchema):
  1450. grid: Annotated[tuple[int, int], AfterValidator(check_range_bounds(1, None))]
  1451. _targets = ALL_TARGETS
  1452. def __init__(
  1453. self,
  1454. grid: tuple[int, int] = (3, 3),
  1455. p: float = 0.5,
  1456. ):
  1457. super().__init__(p=p)
  1458. self.grid = grid
  1459. def apply(
  1460. self,
  1461. img: np.ndarray,
  1462. tiles: np.ndarray,
  1463. mapping: list[int],
  1464. **params: Any,
  1465. ) -> np.ndarray:
  1466. """Apply the RandomGridShuffle transform to an image.
  1467. Args:
  1468. img (np.ndarray): Image to be transformed.
  1469. tiles (np.ndarray): Tiles to be transformed.
  1470. mapping (list[int]): Mapping of the tiles.
  1471. **params (Any): Additional parameters.
  1472. """
  1473. return fgeometric.swap_tiles_on_image(img, tiles, mapping)
  1474. def apply_to_bboxes(
  1475. self,
  1476. bboxes: np.ndarray,
  1477. tiles: np.ndarray,
  1478. mapping: np.ndarray,
  1479. **params: Any,
  1480. ) -> np.ndarray:
  1481. """Apply the RandomGridShuffle transform to bounding boxes.
  1482. Args:
  1483. bboxes (np.ndarray): Bounding boxes to be transformed.
  1484. tiles (np.ndarray): Tiles to be transformed.
  1485. mapping (np.ndarray): Mapping of the tiles.
  1486. **params (Any): Additional parameters.
  1487. """
  1488. image_shape = params["shape"][:2]
  1489. bboxes_denorm = denormalize_bboxes(bboxes, image_shape)
  1490. processor = cast("BboxProcessor", self.get_processor("bboxes"))
  1491. if processor is None:
  1492. return bboxes
  1493. bboxes_returned = fgeometric.bboxes_grid_shuffle(
  1494. bboxes_denorm,
  1495. tiles,
  1496. mapping,
  1497. image_shape,
  1498. min_area=processor.params.min_area,
  1499. min_visibility=processor.params.min_visibility,
  1500. )
  1501. return normalize_bboxes(bboxes_returned, image_shape)
  1502. def apply_to_keypoints(
  1503. self,
  1504. keypoints: np.ndarray,
  1505. tiles: np.ndarray,
  1506. mapping: np.ndarray,
  1507. **params: Any,
  1508. ) -> np.ndarray:
  1509. """Apply the RandomGridShuffle transform to keypoints.
  1510. Args:
  1511. keypoints (np.ndarray): Keypoints to be transformed.
  1512. tiles (np.ndarray): Tiles to be transformed.
  1513. mapping (np.ndarray): Mapping of the tiles.
  1514. **params (Any): Additional parameters.
  1515. """
  1516. return fgeometric.swap_tiles_on_keypoints(keypoints, tiles, mapping)
  1517. @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
  1518. def apply_to_images(self, images: np.ndarray, **params: Any) -> np.ndarray:
  1519. """Apply the RandomGridShuffle transform to a batch of images.
  1520. Args:
  1521. images (np.ndarray): Images to be transformed.
  1522. **params (Any): Additional parameters.
  1523. """
  1524. return self.apply(images, **params)
  1525. @batch_transform("spatial", has_batch_dim=False, has_depth_dim=True)
  1526. def apply_to_volume(self, volume: np.ndarray, **params: Any) -> np.ndarray:
  1527. """Apply the RandomGridShuffle transform to a volume.
  1528. Args:
  1529. volume (np.ndarray): Volume to be transformed.
  1530. **params (Any): Additional parameters.
  1531. """
  1532. return self.apply(volume, **params)
  1533. @batch_transform("spatial", has_batch_dim=True, has_depth_dim=True)
  1534. def apply_to_volumes(self, volumes: np.ndarray, **params: Any) -> np.ndarray:
  1535. """Apply the RandomGridShuffle transform to a batch of volumes.
  1536. Args:
  1537. volumes (np.ndarray): Volumes to be transformed.
  1538. **params (Any): Additional parameters.
  1539. """
  1540. return self.apply(volumes, **params)
  1541. @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
  1542. def apply_to_mask3d(self, mask3d: np.ndarray, **params: Any) -> np.ndarray:
  1543. """Apply the RandomGridShuffle transform to a 3D mask.
  1544. Args:
  1545. mask3d (np.ndarray): 3D mask to be transformed.
  1546. **params (Any): Additional parameters.
  1547. """
  1548. return self.apply(mask3d, **params)
  1549. def get_params_dependent_on_data(
  1550. self,
  1551. params: dict[str, Any],
  1552. data: dict[str, Any],
  1553. ) -> dict[str, np.ndarray]:
  1554. """Get the parameters dependent on the data.
  1555. Args:
  1556. params (dict[str, Any]): Parameters.
  1557. data (dict[str, Any]): Data.
  1558. Returns:
  1559. dict[str, np.ndarray]: Parameters.
  1560. """
  1561. image_shape = params["shape"][:2]
  1562. original_tiles = fgeometric.split_uniform_grid(
  1563. image_shape,
  1564. self.grid,
  1565. self.random_generator,
  1566. )
  1567. shape_groups = fgeometric.create_shape_groups(original_tiles)
  1568. mapping = fgeometric.shuffle_tiles_within_shape_groups(
  1569. shape_groups,
  1570. self.random_generator,
  1571. )
  1572. return {"tiles": original_tiles, "mapping": mapping}
  1573. class Morphological(DualTransform):
  1574. """Apply a morphological operation (dilation or erosion) to an image,
  1575. with particular value for enhancing document scans.
  1576. Morphological operations modify the structure of the image.
  1577. Dilation expands the white (foreground) regions in a binary or grayscale image, while erosion shrinks them.
  1578. These operations are beneficial in document processing, for example:
  1579. - Dilation helps in closing up gaps within text or making thin lines thicker,
  1580. enhancing legibility for OCR (Optical Character Recognition).
  1581. - Erosion can remove small white noise and detach connected objects,
  1582. making the structure of larger objects more pronounced.
  1583. Args:
  1584. scale (int or tuple/list of int): Specifies the size of the structuring element (kernel) used for the operation.
  1585. - If an integer is provided, a square kernel of that size will be used.
  1586. - If a tuple or list is provided, it should contain two integers representing the minimum
  1587. and maximum sizes for the dilation kernel.
  1588. operation (Literal["erosion", "dilation"]): The morphological operation to apply.
  1589. Default is 'dilation'.
  1590. p (float, optional): The probability of applying this transformation. Default is 0.5.
  1591. Targets:
  1592. image, mask, keypoints, bboxes, volume, mask3d
  1593. Image types:
  1594. uint8, float32
  1595. References:
  1596. Nougat: https://github.com/facebookresearch/nougat
  1597. Examples:
  1598. >>> import numpy as np
  1599. >>> import albumentations as A
  1600. >>> import cv2
  1601. >>>
  1602. >>> # Create a document-like binary image with text
  1603. >>> image = np.ones((200, 500), dtype=np.uint8) * 255 # White background
  1604. >>> # Add some "text" (black pixels)
  1605. >>> cv2.putText(image, "Document Text", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, 0, 2)
  1606. >>> # Add some "noise" (small black dots)
  1607. >>> for _ in range(50):
  1608. ... x, y = np.random.randint(0, image.shape[1]), np.random.randint(0, image.shape[0])
  1609. ... cv2.circle(image, (x, y), 1, 0, -1)
  1610. >>>
  1611. >>> # Create a mask representing text regions
  1612. >>> mask = np.zeros_like(image)
  1613. >>> mask[image < 128] = 1 # Binary mask where text exists
  1614. >>>
  1615. >>> # Example 1: Apply dilation to thicken text and fill gaps
  1616. >>> dilation_transform = A.Morphological(
  1617. ... scale=3, # Size of the structuring element
  1618. ... operation="dilation", # Expand white regions (or black if inverted)
  1619. ... p=1.0 # Always apply
  1620. ... )
  1621. >>> result = dilation_transform(image=image, mask=mask)
  1622. >>> dilated_image = result['image'] # Text is thicker, gaps are filled
  1623. >>> dilated_mask = result['mask'] # Mask is expanded around text regions
  1624. >>>
  1625. >>> # Example 2: Apply erosion to thin text or remove noise
  1626. >>> erosion_transform = A.Morphological(
  1627. ... scale=(2, 3), # Random kernel size between 2 and 3
  1628. ... operation="erosion", # Shrink white regions (or expand black if inverted)
  1629. ... p=1.0 # Always apply
  1630. ... )
  1631. >>> result = erosion_transform(image=image, mask=mask)
  1632. >>> eroded_image = result['image'] # Text is thinner, small noise may be removed
  1633. >>> eroded_mask = result['mask'] # Mask is contracted around text regions
  1634. >>>
  1635. >>> # Note: For document processing, dilation often helps enhance readability for OCR
  1636. >>> # while erosion can help remove noise or separate connected components
  1637. """
  1638. _targets = ALL_TARGETS
  1639. class InitSchema(BaseTransformInitSchema):
  1640. scale: OnePlusIntRangeType
  1641. operation: Literal["erosion", "dilation"]
  1642. def __init__(
  1643. self,
  1644. scale: tuple[int, int] | int = (2, 3),
  1645. operation: Literal["erosion", "dilation"] = "dilation",
  1646. p: float = 0.5,
  1647. ):
  1648. super().__init__(p=p)
  1649. self.scale = cast("tuple[int, int]", scale)
  1650. self.operation = operation
  1651. def apply(
  1652. self,
  1653. img: np.ndarray,
  1654. kernel: tuple[int, int],
  1655. **params: Any,
  1656. ) -> np.ndarray:
  1657. """Apply the Morphological transform to the input image.
  1658. Args:
  1659. img (np.ndarray): The input image to apply the Morphological transform to.
  1660. kernel (tuple[int, int]): The structuring element (kernel) used for the operation.
  1661. **params (Any): Additional parameters for the transform.
  1662. """
  1663. return fgeometric.morphology(img, kernel, self.operation)
  1664. def apply_to_bboxes(
  1665. self,
  1666. bboxes: np.ndarray,
  1667. kernel: tuple[int, int],
  1668. **params: Any,
  1669. ) -> np.ndarray:
  1670. """Apply the Morphological transform to the input bounding boxes.
  1671. Args:
  1672. bboxes (np.ndarray): The input bounding boxes to apply the Morphological transform to.
  1673. kernel (tuple[int, int]): The structuring element (kernel) used for the operation.
  1674. **params (Any): Additional parameters for the transform.
  1675. """
  1676. image_shape = params["shape"]
  1677. denormalized_boxes = denormalize_bboxes(bboxes, image_shape)
  1678. result = fgeometric.bboxes_morphology(
  1679. denormalized_boxes,
  1680. kernel,
  1681. self.operation,
  1682. image_shape,
  1683. )
  1684. return normalize_bboxes(result, image_shape)
  1685. def apply_to_keypoints(
  1686. self,
  1687. keypoints: np.ndarray,
  1688. **params: Any,
  1689. ) -> np.ndarray:
  1690. """Apply the Morphological transform to the input keypoints.
  1691. Args:
  1692. keypoints (np.ndarray): The input keypoints to apply the Morphological transform to.
  1693. **params (Any): Additional parameters for the transform.
  1694. """
  1695. return keypoints
  1696. def get_params(self) -> dict[str, float]:
  1697. """Generate parameters for the Morphological transform.
  1698. Returns:
  1699. dict[str, float]: The parameters of the transform.
  1700. """
  1701. return {
  1702. "kernel": cv2.getStructuringElement(cv2.MORPH_ELLIPSE, self.scale),
  1703. }