transforms.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832
  1. """Transforms that combine multiple images and their associated annotations.
  2. This module contains transformations that take multiple input sources (e.g., a primary image
  3. and additional images provided via metadata) and combine them into a single output.
  4. Examples include overlaying elements (`OverlayElements`) or creating complex compositions
  5. like `Mosaic`.
  6. """
  7. from __future__ import annotations
  8. import random
  9. from copy import deepcopy
  10. from typing import Annotated, Any, Literal, cast
  11. import cv2
  12. import numpy as np
  13. from pydantic import AfterValidator, model_validator
  14. from typing_extensions import Self
  15. from albumentations.augmentations.mixing import functional as fmixing
  16. from albumentations.core.bbox_utils import BboxProcessor, check_bboxes, denormalize_bboxes, filter_bboxes
  17. from albumentations.core.keypoints_utils import KeypointsProcessor
  18. from albumentations.core.pydantic import check_range_bounds, nondecreasing
  19. from albumentations.core.transforms_interface import BaseTransformInitSchema, DualTransform
  20. from albumentations.core.type_definitions import LENGTH_RAW_BBOX, Targets
  21. __all__ = ["Mosaic", "OverlayElements"]
  22. class OverlayElements(DualTransform):
  23. """Apply overlay elements such as images and masks onto an input image. This transformation can be used to add
  24. various objects (e.g., stickers, logos) to images with optional masks and bounding boxes for better placement
  25. control.
  26. Args:
  27. metadata_key (str): Additional target key for metadata. Default `overlay_metadata`.
  28. p (float): Probability of applying the transformation. Default: 0.5.
  29. Possible Metadata Fields:
  30. - image (np.ndarray): The overlay image to be applied. This is a required field.
  31. - bbox (list[int]): The bounding box specifying the region where the overlay should be applied. It should
  32. contain four floats: [y_min, x_min, y_max, x_max]. If `label_id` is provided, it should
  33. be appended as the fifth element in the bbox. BBox should be in Albumentations format,
  34. that is the same as normalized Pascal VOC format
  35. [x_min / width, y_min / height, x_max / width, y_max / height]
  36. - mask (np.ndarray): An optional mask that defines the non-rectangular region of the overlay image. If not
  37. provided, the entire overlay image is used.
  38. - mask_id (int): An optional identifier for the mask. If provided, the regions specified by the mask will
  39. be labeled with this identifier in the output mask.
  40. Targets:
  41. image, mask
  42. Image types:
  43. uint8, float32
  44. References:
  45. doc-augmentation: https://github.com/danaaubakirova/doc-augmentation
  46. Examples:
  47. >>> import numpy as np
  48. >>> import albumentations as A
  49. >>> import cv2
  50. >>>
  51. >>> # Prepare primary data (base image and mask)
  52. >>> image = np.zeros((300, 300, 3), dtype=np.uint8)
  53. >>> mask = np.zeros((300, 300), dtype=np.uint8)
  54. >>>
  55. >>> # 1. Create a simple overlay image (a red square)
  56. >>> overlay_image1 = np.zeros((50, 50, 3), dtype=np.uint8)
  57. >>> overlay_image1[:, :, 0] = 255 # Red color
  58. >>>
  59. >>> # 2. Create another overlay with a mask (a blue circle with transparency)
  60. >>> overlay_image2 = np.zeros((80, 80, 3), dtype=np.uint8)
  61. >>> overlay_image2[:, :, 2] = 255 # Blue color
  62. >>> overlay_mask2 = np.zeros((80, 80), dtype=np.uint8)
  63. >>> # Create a circular mask
  64. >>> center = (40, 40)
  65. >>> radius = 30
  66. >>> for i in range(80):
  67. ... for j in range(80):
  68. ... if (i - center[0])**2 + (j - center[1])**2 < radius**2:
  69. ... overlay_mask2[i, j] = 255
  70. >>>
  71. >>> # 3. Create an overlay with both bbox and mask_id
  72. >>> overlay_image3 = np.zeros((60, 120, 3), dtype=np.uint8)
  73. >>> overlay_image3[:, :, 1] = 255 # Green color
  74. >>> # Create a rectangular mask with rounded corners
  75. >>> overlay_mask3 = np.zeros((60, 120), dtype=np.uint8)
  76. >>> cv2.rectangle(overlay_mask3, (10, 10), (110, 50), 255, -1)
  77. >>>
  78. >>> # Create the metadata list - each item is a dictionary with overlay information
  79. >>> overlay_metadata = [
  80. ... {
  81. ... 'image': overlay_image1,
  82. ... # No bbox provided - will be placed randomly
  83. ... },
  84. ... {
  85. ... 'image': overlay_image2,
  86. ... 'bbox': [0.6, 0.1, 0.9, 0.4], # Normalized coordinates [x_min, y_min, x_max, y_max]
  87. ... 'mask': overlay_mask2,
  88. ... 'mask_id': 1 # This overlay will update the mask with id 1
  89. ... },
  90. ... {
  91. ... 'image': overlay_image3,
  92. ... 'bbox': [0.1, 0.7, 0.5, 0.9], # Bottom left placement
  93. ... 'mask': overlay_mask3,
  94. ... 'mask_id': 2 # This overlay will update the mask with id 2
  95. ... }
  96. ... ]
  97. >>>
  98. >>> # Create the transform
  99. >>> transform = A.Compose([
  100. ... A.OverlayElements(p=1.0),
  101. ... ])
  102. >>>
  103. >>> # Apply the transform
  104. >>> result = transform(
  105. ... image=image,
  106. ... mask=mask,
  107. ... overlay_metadata=overlay_metadata # Pass metadata using the default key
  108. ... )
  109. >>>
  110. >>> # Get results with overlays applied
  111. >>> result_image = result['image'] # Image with the three overlays applied
  112. >>> result_mask = result['mask'] # Mask with regions labeled using the mask_id values
  113. >>>
  114. >>> # Let's verify the mask contains the specified mask_id values
  115. >>> has_mask_id_1 = np.any(result_mask == 1) # Should be True
  116. >>> has_mask_id_2 = np.any(result_mask == 2) # Should be True
  117. """
  118. _targets = (Targets.IMAGE, Targets.MASK)
  119. class InitSchema(BaseTransformInitSchema):
  120. metadata_key: str
  121. def __init__(
  122. self,
  123. metadata_key: str = "overlay_metadata",
  124. p: float = 0.5,
  125. ):
  126. super().__init__(p=p)
  127. self.metadata_key = metadata_key
  128. @property
  129. def targets_as_params(self) -> list[str]:
  130. """Get list of targets that should be passed as parameters to transforms.
  131. Returns:
  132. list[str]: List containing the metadata key name
  133. """
  134. return [self.metadata_key]
  135. @staticmethod
  136. def preprocess_metadata(
  137. metadata: dict[str, Any],
  138. img_shape: tuple[int, int],
  139. random_state: random.Random,
  140. ) -> dict[str, Any]:
  141. """Process overlay metadata to prepare for application.
  142. Args:
  143. metadata (dict[str, Any]): Dictionary containing overlay data such as image, mask, bbox
  144. img_shape (tuple[int, int]): Shape of the target image as (height, width)
  145. random_state (random.Random): Random state object for reproducible randomness
  146. Returns:
  147. dict[str, Any]: Processed overlay data including resized overlay image, mask,
  148. offset coordinates, and bounding box information
  149. """
  150. overlay_image = metadata["image"]
  151. overlay_height, overlay_width = overlay_image.shape[:2]
  152. image_height, image_width = img_shape[:2]
  153. if "bbox" in metadata:
  154. bbox = metadata["bbox"]
  155. bbox_np = np.array([bbox])
  156. check_bboxes(bbox_np)
  157. denormalized_bbox = denormalize_bboxes(bbox_np, img_shape[:2])[0]
  158. x_min, y_min, x_max, y_max = (int(x) for x in denormalized_bbox[:4])
  159. if "mask" in metadata:
  160. mask = metadata["mask"]
  161. mask = cv2.resize(mask, (x_max - x_min, y_max - y_min), interpolation=cv2.INTER_NEAREST)
  162. else:
  163. mask = np.ones((y_max - y_min, x_max - x_min), dtype=np.uint8)
  164. overlay_image = cv2.resize(overlay_image, (x_max - x_min, y_max - y_min), interpolation=cv2.INTER_AREA)
  165. offset = (y_min, x_min)
  166. if len(bbox) == LENGTH_RAW_BBOX and "bbox_id" in metadata:
  167. bbox = [x_min, y_min, x_max, y_max, metadata["bbox_id"]]
  168. else:
  169. bbox = (x_min, y_min, x_max, y_max, *bbox[4:])
  170. else:
  171. if image_height < overlay_height or image_width < overlay_width:
  172. overlay_image = cv2.resize(overlay_image, (image_width, image_height), interpolation=cv2.INTER_AREA)
  173. overlay_height, overlay_width = overlay_image.shape[:2]
  174. mask = metadata["mask"] if "mask" in metadata else np.ones_like(overlay_image, dtype=np.uint8)
  175. max_x_offset = image_width - overlay_width
  176. max_y_offset = image_height - overlay_height
  177. offset_x = random_state.randint(0, max_x_offset)
  178. offset_y = random_state.randint(0, max_y_offset)
  179. offset = (offset_y, offset_x)
  180. bbox = [
  181. offset_x,
  182. offset_y,
  183. offset_x + overlay_width,
  184. offset_y + overlay_height,
  185. ]
  186. if "bbox_id" in metadata:
  187. bbox = [*bbox, metadata["bbox_id"]]
  188. result = {
  189. "overlay_image": overlay_image,
  190. "overlay_mask": mask,
  191. "offset": offset,
  192. "bbox": bbox,
  193. }
  194. if "mask_id" in metadata:
  195. result["mask_id"] = metadata["mask_id"]
  196. return result
  197. def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
  198. """Generate parameters for overlay transform based on input data.
  199. Args:
  200. params (dict[str, Any]): Dictionary of existing parameters
  201. data (dict[str, Any]): Dictionary containing input data with image and metadata
  202. Returns:
  203. dict[str, Any]: Dictionary containing processed overlay data ready for application
  204. """
  205. metadata = data[self.metadata_key]
  206. img_shape = params["shape"]
  207. if isinstance(metadata, list):
  208. overlay_data = [self.preprocess_metadata(md, img_shape, self.py_random) for md in metadata]
  209. else:
  210. overlay_data = [self.preprocess_metadata(metadata, img_shape, self.py_random)]
  211. return {
  212. "overlay_data": overlay_data,
  213. }
  214. def apply(
  215. self,
  216. img: np.ndarray,
  217. overlay_data: list[dict[str, Any]],
  218. **params: Any,
  219. ) -> np.ndarray:
  220. """Apply overlay elements to the input image.
  221. Args:
  222. img (np.ndarray): Input image
  223. overlay_data (list[dict[str, Any]]): List of dictionaries containing overlay information
  224. **params (Any): Additional parameters
  225. Returns:
  226. np.ndarray: Image with overlays applied
  227. """
  228. for data in overlay_data:
  229. overlay_image = data["overlay_image"]
  230. overlay_mask = data["overlay_mask"]
  231. offset = data["offset"]
  232. img = fmixing.copy_and_paste_blend(img, overlay_image, overlay_mask, offset=offset)
  233. return img
  234. def apply_to_mask(
  235. self,
  236. mask: np.ndarray,
  237. overlay_data: list[dict[str, Any]],
  238. **params: Any,
  239. ) -> np.ndarray:
  240. """Apply overlay masks to the input mask.
  241. Args:
  242. mask (np.ndarray): Input mask
  243. overlay_data (list[dict[str, Any]]): List of dictionaries containing overlay information
  244. **params (Any): Additional parameters
  245. Returns:
  246. np.ndarray: Mask with overlay masks applied using the specified mask_id values
  247. """
  248. for data in overlay_data:
  249. if "mask_id" in data and data["mask_id"] is not None:
  250. overlay_mask = data["overlay_mask"]
  251. offset = data["offset"]
  252. mask_id = data["mask_id"]
  253. y_min, x_min = offset
  254. y_max = y_min + overlay_mask.shape[0]
  255. x_max = x_min + overlay_mask.shape[1]
  256. mask_section = mask[y_min:y_max, x_min:x_max]
  257. mask_section[overlay_mask > 0] = mask_id
  258. return mask
  259. class Mosaic(DualTransform):
  260. """Combine multiple images and their annotations into a single image using a mosaic grid layout.
  261. This transform takes a primary input image (and its annotations) and combines it with
  262. additional images/annotations provided via metadata. It calculates the geometry for
  263. a mosaic grid, selects additional items, preprocesses annotations consistently
  264. (handling label encoding updates), applies geometric transformations, and assembles
  265. the final output.
  266. Args:
  267. grid_yx (tuple[int, int]): The number of rows (y) and columns (x) in the mosaic grid.
  268. Determines the maximum number of images involved (grid_yx[0] * grid_yx[1]).
  269. Default: (2, 2).
  270. target_size (tuple[int, int]): The desired output (height, width) for the final mosaic image.
  271. after cropping the mosaic grid.
  272. cell_shape (tuple[int, int]): cell shape of each cell in the mosaic grid.
  273. metadata_key (str): Key in the input dictionary specifying the list of additional data dictionaries
  274. for the mosaic. Each dictionary in the list should represent one potential additional item.
  275. Expected keys: 'image' (required, np.ndarray), and optionally 'mask' (np.ndarray),
  276. 'bboxes' (np.ndarray), 'keypoints' (np.ndarray), and any relevant label fields
  277. (e.g., 'class_labels') corresponding to those specified in `Compose`'s `bbox_params` or
  278. `keypoint_params`. Default: "mosaic_metadata".
  279. center_range (tuple[float, float]): Range [0.0-1.0] to sample the center point of the mosaic view
  280. relative to the valid central region of the conceptual large grid. This affects which parts
  281. of the assembled grid are visible in the final crop. Default: (0.3, 0.7).
  282. interpolation (int): OpenCV interpolation flag used for resizing images during geometric processing.
  283. Default: cv2.INTER_LINEAR.
  284. mask_interpolation (int): OpenCV interpolation flag used for resizing masks during geometric processing.
  285. Default: cv2.INTER_NEAREST.
  286. fill (tuple[float, ...] | float): Value used for padding images if needed during geometric processing.
  287. Default: 0.
  288. fill_mask (tuple[float, ...] | float): Value used for padding masks if needed during geometric processing.
  289. Default: 0.
  290. p (float): Probability of applying the transform. Default: 0.5.
  291. Workflow (`get_params_dependent_on_data`):
  292. 1. Calculate Geometry & Visible Cells: Determine which grid cells are visible in the final
  293. `target_size` crop and their placement coordinates on the output canvas.
  294. 2. Validate Raw Additional Metadata: Filter the list provided via `metadata_key`,
  295. keeping only valid items (dicts with an 'image' key).
  296. 3. Select Subset of Raw Additional Metadata: Choose a subset of the valid raw items based
  297. on the number of visible cells requiring additional data.
  298. 4. Preprocess Selected Raw Additional Items: Preprocess bboxes/keypoints for the *selected*
  299. additional items *only*. This uses shared processors from `Compose`, updating their
  300. internal state (e.g., `LabelEncoder`) based on labels in these selected items.
  301. 5. Prepare Primary Data: Extract preprocessed primary data fields from the input `data` dictionary
  302. into a `primary` dictionary.
  303. 6. Determine & Perform Replication: If fewer additional items were selected than needed,
  304. replicate the preprocessed primary data as required.
  305. 7. Combine Final Items: Create the list of all preprocessed items (primary, selected additional,
  306. replicated primary) that will be used.
  307. 8. Assign Items to VISIBLE Grid Cells
  308. 9. Process Geometry & Shift Coordinates: For each assigned item:
  309. a. Apply geometric transforms (Crop, Resize, Pad) to image/mask.
  310. b. Apply geometric shift to the *preprocessed* bboxes/keypoints based on cell placement.
  311. 10. Return Parameters: Return the processed cell data (image, mask, shifted bboxes, shifted kps)
  312. keyed by placement coordinates.
  313. Label Handling:
  314. - The transform relies on `bbox_processor` and `keypoint_processor` provided by `Compose`.
  315. - `Compose.preprocess` initially fits the processors' `LabelEncoder` on the primary data.
  316. - This transform (`Mosaic`) preprocesses the *selected* additional raw items using the same
  317. processors. If new labels are found, the shared `LabelEncoder` state is updated via its
  318. `update` method.
  319. - `Compose.postprocess` uses the final updated encoder state to decode all labels present
  320. in the mosaic output for the current `Compose` call.
  321. - The encoder state is transient per `Compose` call.
  322. Targets:
  323. image, mask, bboxes, keypoints
  324. Image types:
  325. uint8, float32
  326. Reference:
  327. YOLOv4: Optimal Speed and Accuracy of Object Detection: https://arxiv.org/pdf/2004.10934
  328. Examples:
  329. >>> import numpy as np
  330. >>> import albumentations as A
  331. >>> import cv2
  332. >>>
  333. >>> # Prepare primary data
  334. >>> primary_image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  335. >>> primary_mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  336. >>> primary_bboxes = np.array([[10, 10, 40, 40], [50, 50, 90, 90]], dtype=np.float32)
  337. >>> primary_labels = [1, 2]
  338. >>>
  339. >>> # Prepare additional images for mosaic
  340. >>> additional_image1 = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  341. >>> additional_mask1 = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  342. >>> additional_bboxes1 = np.array([[20, 20, 60, 60]], dtype=np.float32)
  343. >>> additional_labels1 = [3]
  344. >>>
  345. >>> additional_image2 = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  346. >>> additional_mask2 = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  347. >>> additional_bboxes2 = np.array([[30, 30, 70, 70]], dtype=np.float32)
  348. >>> additional_labels2 = [4]
  349. >>>
  350. >>> additional_image3 = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
  351. >>> additional_mask3 = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
  352. >>> additional_bboxes3 = np.array([[5, 5, 45, 45]], dtype=np.float32)
  353. >>> additional_labels3 = [5]
  354. >>>
  355. >>> # Create metadata for additional images - structured as a list of dicts
  356. >>> mosaic_metadata = [
  357. ... {
  358. ... 'image': additional_image1,
  359. ... 'mask': additional_mask1,
  360. ... 'bboxes': additional_bboxes1,
  361. ... 'labels': additional_labels1
  362. ... },
  363. ... {
  364. ... 'image': additional_image2,
  365. ... 'mask': additional_mask2,
  366. ... 'bboxes': additional_bboxes2,
  367. ... 'labels': additional_labels2
  368. ... },
  369. ... {
  370. ... 'image': additional_image3,
  371. ... 'mask': additional_mask3,
  372. ... 'bboxes': additional_bboxes3,
  373. ... 'labels': additional_labels3
  374. ... }
  375. ... ]
  376. >>>
  377. >>> # Create the transform with Mosaic
  378. >>> transform = A.Compose([
  379. ... A.Mosaic(
  380. ... grid_yx=(2, 2),
  381. ... target_size=(200, 200),
  382. ... cell_shape=(120, 120),
  383. ... center_range=(0.4, 0.6),
  384. ... fit_mode="cover",
  385. ... p=1.0
  386. ... ),
  387. ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))
  388. >>>
  389. >>> # Apply the transform
  390. >>> transformed = transform(
  391. ... image=primary_image,
  392. ... mask=primary_mask,
  393. ... bboxes=primary_bboxes,
  394. ... labels=primary_labels,
  395. ... mosaic_metadata=mosaic_metadata # Pass the metadata using the default key
  396. ... )
  397. >>>
  398. >>> # Access the transformed data
  399. >>> mosaic_image = transformed['image'] # Combined mosaic image
  400. >>> mosaic_mask = transformed['mask'] # Combined mosaic mask
  401. >>> mosaic_bboxes = transformed['bboxes'] # Combined and repositioned bboxes
  402. >>> mosaic_labels = transformed['labels'] # Combined labels from all images
  403. """
  404. _targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)
  405. class InitSchema(BaseTransformInitSchema):
  406. grid_yx: tuple[int, int]
  407. target_size: Annotated[
  408. tuple[int, int],
  409. AfterValidator(check_range_bounds(1, None)),
  410. ]
  411. cell_shape: Annotated[
  412. tuple[int, int],
  413. AfterValidator(check_range_bounds(1, None)),
  414. ]
  415. metadata_key: str
  416. center_range: Annotated[
  417. tuple[float, float],
  418. AfterValidator(check_range_bounds(0, 1)),
  419. AfterValidator(nondecreasing),
  420. ]
  421. interpolation: Literal[
  422. cv2.INTER_NEAREST,
  423. cv2.INTER_NEAREST_EXACT,
  424. cv2.INTER_LINEAR,
  425. cv2.INTER_CUBIC,
  426. cv2.INTER_AREA,
  427. cv2.INTER_LANCZOS4,
  428. cv2.INTER_LINEAR_EXACT,
  429. ]
  430. mask_interpolation: Literal[
  431. cv2.INTER_NEAREST,
  432. cv2.INTER_NEAREST_EXACT,
  433. cv2.INTER_LINEAR,
  434. cv2.INTER_CUBIC,
  435. cv2.INTER_AREA,
  436. cv2.INTER_LANCZOS4,
  437. cv2.INTER_LINEAR_EXACT,
  438. ]
  439. fill: tuple[float, ...] | float
  440. fill_mask: tuple[float, ...] | float
  441. fit_mode: Literal["cover", "contain"]
  442. @model_validator(mode="after")
  443. def _check_cell_shape(self) -> Self:
  444. if (
  445. self.cell_shape[0] * self.grid_yx[0] < self.target_size[0]
  446. or self.cell_shape[1] * self.grid_yx[1] < self.target_size[1]
  447. ):
  448. raise ValueError("Target size should be smaller than cell cell_size * grid_yx")
  449. return self
  450. def __init__(
  451. self,
  452. grid_yx: tuple[int, int] = (2, 2),
  453. target_size: tuple[int, int] = (512, 512),
  454. cell_shape: tuple[int, int] = (512, 512),
  455. center_range: tuple[float, float] = (0.3, 0.7),
  456. fit_mode: Literal["cover", "contain"] = "cover",
  457. interpolation: Literal[
  458. cv2.INTER_NEAREST,
  459. cv2.INTER_NEAREST_EXACT,
  460. cv2.INTER_LINEAR,
  461. cv2.INTER_CUBIC,
  462. cv2.INTER_AREA,
  463. cv2.INTER_LANCZOS4,
  464. cv2.INTER_LINEAR_EXACT,
  465. ] = cv2.INTER_LINEAR,
  466. mask_interpolation: Literal[
  467. cv2.INTER_NEAREST,
  468. cv2.INTER_NEAREST_EXACT,
  469. cv2.INTER_LINEAR,
  470. cv2.INTER_CUBIC,
  471. cv2.INTER_AREA,
  472. cv2.INTER_LANCZOS4,
  473. cv2.INTER_LINEAR_EXACT,
  474. ] = cv2.INTER_NEAREST,
  475. fill: tuple[float, ...] | float = 0,
  476. fill_mask: tuple[float, ...] | float = 0,
  477. metadata_key: str = "mosaic_metadata",
  478. p: float = 0.5,
  479. ) -> None:
  480. super().__init__(p=p)
  481. self.grid_yx = grid_yx
  482. self.target_size = target_size
  483. self.metadata_key = metadata_key
  484. self.center_range = center_range
  485. self.interpolation = interpolation
  486. self.mask_interpolation = mask_interpolation
  487. self.fill = fill
  488. self.fill_mask = fill_mask
  489. self.fit_mode = fit_mode
  490. self.cell_shape = cell_shape
  491. @property
  492. def targets_as_params(self) -> list[str]:
  493. """Get list of targets that should be passed as parameters to transforms.
  494. Returns:
  495. list[str]: List containing the metadata key name
  496. """
  497. return [self.metadata_key]
  498. def _calculate_geometry(self, data: dict[str, Any]) -> list[tuple[int, int, int, int]]:
  499. # Step 1: Calculate Geometry & Cell Placements
  500. center_xy = fmixing.calculate_mosaic_center_point(
  501. grid_yx=self.grid_yx,
  502. cell_shape=self.cell_shape,
  503. target_size=self.target_size,
  504. center_range=self.center_range,
  505. py_random=self.py_random,
  506. )
  507. return fmixing.calculate_cell_placements(
  508. grid_yx=self.grid_yx,
  509. cell_shape=self.cell_shape,
  510. target_size=self.target_size,
  511. center_xy=center_xy,
  512. )
  513. def _select_additional_items(self, data: dict[str, Any], num_additional_needed: int) -> list[dict[str, Any]]:
  514. valid_items = fmixing.filter_valid_metadata(data.get(self.metadata_key), self.metadata_key, data)
  515. if len(valid_items) > num_additional_needed:
  516. return self.py_random.sample(valid_items, num_additional_needed)
  517. return valid_items
  518. def _preprocess_additional_items(
  519. self,
  520. additional_items: list[dict[str, Any]],
  521. data: dict[str, Any],
  522. ) -> list[fmixing.ProcessedMosaicItem]:
  523. if "bboxes" in data or "keypoints" in data:
  524. bbox_processor = cast("BboxProcessor", self.get_processor("bboxes"))
  525. keypoint_processor = cast("KeypointsProcessor", self.get_processor("keypoints"))
  526. return fmixing.preprocess_selected_mosaic_items(additional_items, bbox_processor, keypoint_processor)
  527. return cast("list[fmixing.ProcessedMosaicItem]", list(additional_items))
  528. def _prepare_final_items(
  529. self,
  530. primary: fmixing.ProcessedMosaicItem,
  531. additional_items: list[fmixing.ProcessedMosaicItem],
  532. num_needed: int,
  533. ) -> list[fmixing.ProcessedMosaicItem]:
  534. num_replications = max(0, num_needed - len(additional_items))
  535. replicated = [deepcopy(primary) for _ in range(num_replications)]
  536. return [primary, *additional_items, *replicated]
  537. def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
  538. """Orchestrates the steps to calculate mosaic parameters by calling helper methods."""
  539. cell_placements = self._calculate_geometry(data)
  540. num_cells = len(cell_placements)
  541. num_additional_needed = max(0, num_cells - 1)
  542. additional_items = self._select_additional_items(data, num_additional_needed)
  543. preprocessed_additional = self._preprocess_additional_items(additional_items, data)
  544. primary = self.get_primary_data(data)
  545. final_items = self._prepare_final_items(primary, preprocessed_additional, num_additional_needed)
  546. placement_to_item_index = fmixing.assign_items_to_grid_cells(
  547. num_items=len(final_items),
  548. cell_placements=cell_placements,
  549. py_random=self.py_random,
  550. )
  551. processed_cells = fmixing.process_all_mosaic_geometries(
  552. canvas_shape=self.target_size,
  553. cell_shape=self.cell_shape,
  554. placement_to_item_index=placement_to_item_index,
  555. final_items_for_grid=final_items,
  556. fill=self.fill,
  557. fill_mask=self.fill_mask if self.fill_mask is not None else self.fill,
  558. fit_mode=self.fit_mode,
  559. interpolation=self.interpolation,
  560. mask_interpolation=self.mask_interpolation,
  561. )
  562. if "bboxes" in data or "keypoints" in data:
  563. processed_cells = fmixing.shift_all_coordinates(processed_cells, canvas_shape=self.target_size)
  564. result = {"processed_cells": processed_cells, "target_shape": self._get_target_shape(data["image"].shape)}
  565. if "mask" in data:
  566. result["target_mask_shape"] = self._get_target_shape(data["mask"].shape)
  567. return result
  568. @staticmethod
  569. def get_primary_data(data: dict[str, Any]) -> fmixing.ProcessedMosaicItem:
  570. """Get a copy of the primary data (data passed in `data` parameter) to avoid modifying the original data.
  571. Args:
  572. data (dict[str, Any]): Dictionary containing the primary data.
  573. Returns:
  574. fmixing.ProcessedMosaicItem: A copy of the primary data.
  575. """
  576. mask = data.get("mask")
  577. if mask is not None:
  578. mask = mask.copy()
  579. bboxes = data.get("bboxes")
  580. if bboxes is not None:
  581. bboxes = bboxes.copy()
  582. keypoints = data.get("keypoints")
  583. if keypoints is not None:
  584. keypoints = keypoints.copy()
  585. return {
  586. "image": data["image"],
  587. "mask": mask,
  588. "bboxes": bboxes,
  589. "keypoints": keypoints,
  590. }
  591. def _get_target_shape(self, np_shape: tuple[int, ...]) -> list[int]:
  592. target_shape = list(np_shape)
  593. target_shape[0] = self.target_size[0]
  594. target_shape[1] = self.target_size[1]
  595. return target_shape
  596. def apply(
  597. self,
  598. img: np.ndarray,
  599. processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
  600. target_shape: tuple[int, int],
  601. **params: Any,
  602. ) -> np.ndarray:
  603. """Apply mosaic transformation to the input image.
  604. Args:
  605. img (np.ndarray): Input image
  606. processed_cells (dict[tuple[int, int, int, int], dict[str, Any]]): Dictionary of processed cell data
  607. target_shape (tuple[int, int]): Shape of the target image.
  608. **params (Any): Additional parameters
  609. Returns:
  610. np.ndarray: Mosaic transformed image
  611. """
  612. return fmixing.assemble_mosaic_from_processed_cells(
  613. processed_cells=processed_cells,
  614. target_shape=target_shape,
  615. dtype=img.dtype,
  616. data_key="image",
  617. fill=self.fill,
  618. )
  619. def apply_to_mask(
  620. self,
  621. mask: np.ndarray,
  622. processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
  623. target_mask_shape: tuple[int, int],
  624. **params: Any,
  625. ) -> np.ndarray:
  626. """Apply mosaic transformation to the input mask.
  627. Args:
  628. mask (np.ndarray): Input mask.
  629. processed_cells (dict): Dictionary of processed cell data containing cropped/padded mask segments.
  630. target_mask_shape (tuple[int, int]): Shape of the target mask.
  631. **params (Any): Additional parameters (unused).
  632. Returns:
  633. np.ndarray: Mosaic transformed mask.
  634. """
  635. return fmixing.assemble_mosaic_from_processed_cells(
  636. processed_cells=processed_cells,
  637. target_shape=target_mask_shape,
  638. dtype=mask.dtype,
  639. data_key="mask",
  640. fill=self.fill_mask,
  641. )
  642. def apply_to_bboxes(
  643. self,
  644. bboxes: np.ndarray, # Original bboxes - ignored
  645. processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
  646. **params: Any,
  647. ) -> np.ndarray:
  648. """Applies mosaic transformation to bounding boxes.
  649. Args:
  650. bboxes (np.ndarray): Original bounding boxes (ignored).
  651. processed_cells (dict): Dictionary mapping placement coords to processed cell data
  652. (containing shifted bboxes in absolute pixel coords).
  653. **params (Any): Additional parameters (unused).
  654. Returns:
  655. np.ndarray: Final combined, filtered, bounding boxes.
  656. """
  657. all_shifted_bboxes = []
  658. for cell_data in processed_cells.values():
  659. shifted_bboxes = cell_data["bboxes"]
  660. if shifted_bboxes.size > 0:
  661. all_shifted_bboxes.append(shifted_bboxes)
  662. if not all_shifted_bboxes:
  663. return np.empty((0, bboxes.shape[1]), dtype=bboxes.dtype)
  664. # Concatenate (these are absolute pixel coordinates)
  665. combined_bboxes = np.concatenate(all_shifted_bboxes, axis=0)
  666. # Apply filtering using processor parameters
  667. bbox_processor = cast("BboxProcessor", self.get_processor("bboxes"))
  668. # Assume processor exists if bboxes are being processed
  669. shape_dict: dict[Literal["depth", "height", "width"], int] = {
  670. "height": self.target_size[0],
  671. "width": self.target_size[1],
  672. }
  673. return filter_bboxes(
  674. combined_bboxes,
  675. shape_dict,
  676. min_area=bbox_processor.params.min_area,
  677. min_visibility=bbox_processor.params.min_visibility,
  678. min_width=bbox_processor.params.min_width,
  679. min_height=bbox_processor.params.min_height,
  680. max_accept_ratio=bbox_processor.params.max_accept_ratio,
  681. )
  682. def apply_to_keypoints(
  683. self,
  684. keypoints: np.ndarray, # Original keypoints - ignored
  685. processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
  686. **params: Any,
  687. ) -> np.ndarray:
  688. """Applies mosaic transformation to keypoints.
  689. Args:
  690. keypoints (np.ndarray): Original keypoints (ignored).
  691. processed_cells (dict): Dictionary mapping placement coords to processed cell data
  692. (containing shifted keypoints).
  693. **params (Any): Additional parameters (unused).
  694. Returns:
  695. np.ndarray: Final combined, filtered keypoints.
  696. """
  697. all_shifted_keypoints = []
  698. for cell_data in processed_cells.values():
  699. shifted_keypoints = cell_data["keypoints"]
  700. if shifted_keypoints.size > 0:
  701. all_shifted_keypoints.append(shifted_keypoints)
  702. if not all_shifted_keypoints:
  703. return np.empty((0, keypoints.shape[1]), dtype=keypoints.dtype)
  704. combined_keypoints = np.concatenate(all_shifted_keypoints, axis=0)
  705. # Filter out keypoints outside the target canvas boundaries
  706. target_h, target_w = self.target_size
  707. valid_indices = (
  708. (combined_keypoints[:, 0] >= 0)
  709. & (combined_keypoints[:, 0] < target_w)
  710. & (combined_keypoints[:, 1] >= 0)
  711. & (combined_keypoints[:, 1] < target_h)
  712. )
  713. return combined_keypoints[valid_indices]