functional.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878
  1. """Functional implementations for image mixing operations.
  2. This module provides utility functions for blending and combining images,
  3. such as copy-and-paste operations with masking.
  4. """
  5. from __future__ import annotations
  6. import random
  7. from collections.abc import Sequence
  8. from typing import Any, Literal, TypedDict, cast
  9. from warnings import warn
  10. import cv2
  11. import numpy as np
  12. import albumentations.augmentations.geometric.functional as fgeometric
  13. from albumentations.augmentations.crops.transforms import Crop
  14. from albumentations.augmentations.geometric.resize import LongestMaxSize, SmallestMaxSize
  15. from albumentations.core.bbox_utils import BboxProcessor, denormalize_bboxes, normalize_bboxes
  16. from albumentations.core.composition import Compose
  17. from albumentations.core.keypoints_utils import KeypointsProcessor
  18. from albumentations.core.type_definitions import (
  19. NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS,
  20. NUM_KEYPOINTS_COLUMNS_IN_ALBUMENTATIONS,
  21. )
  22. # Type definition for a processed mosaic item
  23. class ProcessedMosaicItem(TypedDict):
  24. """Represents a single data item (primary or additional) after preprocessing.
  25. Includes the original image/mask and the *preprocessed* annotations.
  26. """
  27. image: np.ndarray # Image is mandatory
  28. mask: np.ndarray | None
  29. bboxes: np.ndarray | None
  30. keypoints: np.ndarray | None
  31. def copy_and_paste_blend(
  32. base_image: np.ndarray,
  33. overlay_image: np.ndarray,
  34. overlay_mask: np.ndarray,
  35. offset: tuple[int, int],
  36. ) -> np.ndarray:
  37. """Blend images by copying pixels from an overlay image to a base image using a mask.
  38. This function copies pixels from the overlay image to the base image only where
  39. the mask has non-zero values. The overlay is placed at the specified offset
  40. from the top-left corner of the base image.
  41. Args:
  42. base_image (np.ndarray): The destination image that will be modified.
  43. overlay_image (np.ndarray): The source image containing pixels to copy.
  44. overlay_mask (np.ndarray): Binary mask indicating which pixels to copy from the overlay.
  45. Pixels are copied where mask > 0.
  46. offset (tuple[int, int]): The (y, x) offset specifying where to place the
  47. top-left corner of the overlay relative to the base image.
  48. Returns:
  49. np.ndarray: The blended image with the overlay applied to the base image.
  50. """
  51. y_offset, x_offset = offset
  52. blended_image = base_image.copy()
  53. mask_indices = np.where(overlay_mask > 0)
  54. blended_image[mask_indices[0] + y_offset, mask_indices[1] + x_offset] = overlay_image[
  55. mask_indices[0],
  56. mask_indices[1],
  57. ]
  58. return blended_image
  59. def calculate_mosaic_center_point(
  60. grid_yx: tuple[int, int],
  61. cell_shape: tuple[int, int],
  62. target_size: tuple[int, int],
  63. center_range: tuple[float, float],
  64. py_random: random.Random,
  65. ) -> tuple[int, int]:
  66. """Calculates the center point for the mosaic crop using proportional sampling within the valid zone.
  67. Ensures the center point allows a crop of target_size to overlap
  68. all grid cells, applying randomness based on center_range proportionally
  69. within the valid region where the center can lie.
  70. Args:
  71. grid_yx (tuple[int, int]): The (rows, cols) of the mosaic grid.
  72. cell_shape (tuple[int, int]): Shape of each cell in the mosaic grid.
  73. target_size (tuple[int, int]): The final output (height, width).
  74. center_range (tuple[float, float]): Range [0.0-1.0] for sampling center proportionally
  75. within the valid zone.
  76. py_random (random.Random): Random state instance.
  77. Returns:
  78. tuple[int, int]: The calculated (x, y) center point relative to the
  79. top-left of the conceptual large grid.
  80. """
  81. rows, cols = grid_yx
  82. cell_h, cell_w = cell_shape
  83. target_h, target_w = target_size
  84. large_grid_h = rows * cell_h
  85. large_grid_w = cols * cell_w
  86. # Define valid center range bounds (inclusive)
  87. # The center must be far enough from edges so the crop window fits
  88. min_cx = target_w // 2
  89. max_cx = large_grid_w - (target_w + 1) // 2
  90. min_cy = target_h // 2
  91. max_cy = large_grid_h - (target_h + 1) // 2
  92. # Calculate valid range dimensions (size of the safe zone)
  93. valid_w = max_cx - min_cx + 1
  94. valid_h = max_cy - min_cy + 1
  95. # Sample relative position within the valid range using center_range
  96. rel_x = py_random.uniform(*center_range)
  97. rel_y = py_random.uniform(*center_range)
  98. # Calculate center coordinates by scaling relative position within valid range
  99. # Add the minimum bound to shift the range start
  100. center_x = min_cx + int(valid_w * rel_x)
  101. center_y = min_cy + int(valid_h * rel_y)
  102. # Ensure the result is strictly within the calculated bounds after int conversion
  103. # (This clip is mostly a safety measure, shouldn't be needed with correct int conversion)
  104. center_x = max(min_cx, min(center_x, max_cx))
  105. center_y = max(min_cy, min(center_y, max_cy))
  106. return center_x, center_y
  107. def calculate_cell_placements(
  108. grid_yx: tuple[int, int],
  109. cell_shape: tuple[int, int],
  110. target_size: tuple[int, int],
  111. center_xy: tuple[int, int],
  112. ) -> list[tuple[int, int, int, int]]:
  113. """Calculates placements by clipping arange-defined grid lines to the crop window.
  114. Args:
  115. grid_yx (tuple[int, int]): The (rows, cols) of the mosaic grid.
  116. cell_shape (tuple[int, int]): Shape of each cell in the mosaic grid.
  117. target_size (tuple[int, int]): The final output (height, width).
  118. center_xy (tuple[int, int]): The calculated (x, y) center of the final crop window,
  119. relative to the top-left of the conceptual large grid.
  120. Returns:
  121. list[tuple[int, int, int, int]]:
  122. A list containing placement coordinates `(x_min, y_min, x_max, y_max)`
  123. for each resulting cell part on the final output canvas.
  124. """
  125. rows, cols = grid_yx
  126. cell_h, cell_w = cell_shape
  127. target_h, target_w = target_size
  128. center_x, center_y = center_xy
  129. # 1. Generate grid line coordinates using arange for the large grid
  130. y_coords_large = np.arange(rows + 1) * cell_h
  131. x_coords_large = np.arange(cols + 1) * cell_w
  132. # 2. Calculate Crop Window boundaries
  133. crop_x_min = center_x - target_w // 2
  134. crop_y_min = center_y - target_h // 2
  135. crop_x_max = crop_x_min + target_w
  136. crop_y_max = crop_y_min + target_h
  137. def _clip_coords(coords: np.ndarray, min_val: int, max_val: int) -> np.ndarray:
  138. clipped_coords = np.clip(coords, min_val, max_val)
  139. # Subtract min_val to convert absolute clipped coordinates
  140. # into coordinates relative to the crop window's origin (min_val becomes 0).
  141. return np.unique(clipped_coords) - min_val
  142. y_coords_clipped = _clip_coords(y_coords_large, crop_y_min, crop_y_max)
  143. x_coords_clipped = _clip_coords(x_coords_large, crop_x_min, crop_x_max)
  144. # 4. Form all cell coordinates efficiently
  145. num_x_intervals = len(x_coords_clipped) - 1
  146. num_y_intervals = len(y_coords_clipped) - 1
  147. result = []
  148. for y_idx in range(num_y_intervals):
  149. y_min = y_coords_clipped[y_idx]
  150. y_max = y_coords_clipped[y_idx + 1]
  151. for x_idx in range(num_x_intervals):
  152. x_min = x_coords_clipped[x_idx]
  153. x_max = x_coords_clipped[x_idx + 1]
  154. result.append((int(x_min), int(y_min), int(x_max), int(y_max)))
  155. return result
  156. def _check_data_compatibility(
  157. primary_data: np.ndarray | None,
  158. item_data: np.ndarray | None,
  159. data_key: Literal["image", "mask"],
  160. ) -> tuple[bool, str | None]: # Returns (is_compatible, error_message)
  161. """Checks if the dimensions and channels of item_data match primary_data."""
  162. # 1. Check if item has the required data (image is always required)
  163. if item_data is None:
  164. if data_key == "image":
  165. return False, "Item is missing required key 'image'"
  166. # Mask is optional, missing is compatible
  167. return True, None
  168. # 2. If item data exists, check against primary data (if primary data exists)
  169. if primary_data is None: # No primary data to compare against
  170. return True, None
  171. # Both primary and item data exist, compare them
  172. primary_ndim = primary_data.ndim
  173. item_ndim = item_data.ndim
  174. if primary_ndim != item_ndim:
  175. return False, (
  176. f"Item '{data_key}' has {item_ndim} dimensions, but primary has {primary_ndim}. "
  177. f"Primary shape: {primary_data.shape}, Item shape: {item_data.shape}"
  178. )
  179. if primary_ndim == 3:
  180. primary_channels = primary_data.shape[-1]
  181. item_channels = item_data.shape[-1]
  182. if primary_channels != item_channels:
  183. return False, (
  184. f"Item '{data_key}' has {item_channels} channels, but primary has {primary_channels}. "
  185. f"Primary shape: {primary_data.shape}, Item shape: {item_data.shape}"
  186. )
  187. # Dimensions match (either both 2D or both 3D with same channels)
  188. return True, None
  189. def filter_valid_metadata(
  190. metadata_input: Sequence[dict[str, Any]] | None,
  191. metadata_key_name: str,
  192. data: dict[str, Any],
  193. ) -> list[dict[str, Any]]:
  194. """Filters a list of metadata dicts, keeping only valid ones based on data compatibility."""
  195. if not isinstance(metadata_input, Sequence):
  196. warn(
  197. f"Metadata under key '{metadata_key_name}' is not a Sequence (e.g., list or tuple). "
  198. f"Returning empty list for additional items.",
  199. UserWarning,
  200. stacklevel=3,
  201. )
  202. return []
  203. valid_items = []
  204. primary_image = data.get("image")
  205. primary_mask = data.get("mask")
  206. for i, item in enumerate(metadata_input):
  207. if not isinstance(item, dict):
  208. warn(
  209. f"Item at index {i} in '{metadata_key_name}' is not a dict and will be skipped.",
  210. UserWarning,
  211. stacklevel=4,
  212. )
  213. continue
  214. item_is_valid = True # Assume valid initially
  215. for target_key, primary_target_data in [
  216. ("image", primary_image),
  217. ("mask", primary_mask),
  218. ]:
  219. item_target_data = item.get(target_key)
  220. is_compatible, error_msg = _check_data_compatibility(
  221. primary_target_data,
  222. item_target_data,
  223. cast("Literal['image', 'mask']", target_key),
  224. )
  225. if not is_compatible:
  226. msg = (
  227. f"Item at index {i} in '{metadata_key_name}' skipped due "
  228. f"to incompatibility in '{target_key}': {error_msg}"
  229. )
  230. warn(msg, UserWarning, stacklevel=4)
  231. item_is_valid = False
  232. break # Stop checking other targets for this item
  233. if item_is_valid:
  234. valid_items.append(item)
  235. return valid_items
  236. def assign_items_to_grid_cells(
  237. num_items: int,
  238. cell_placements: list[tuple[int, int, int, int]],
  239. py_random: random.Random,
  240. ) -> dict[tuple[int, int, int, int], int]:
  241. """Assigns item indices to placement coordinate tuples.
  242. Assigns the primary item (index 0) to the placement with the largest area,
  243. and assigns the remaining items (indices 1 to num_items-1) randomly to the
  244. remaining placements.
  245. Args:
  246. num_items (int): The total number of items to assign (primary + additional + replicas).
  247. cell_placements (list[tuple[int, int, int, int]]): List of placement
  248. coords (x1, y1, x2, y2) for cells to be filled.
  249. py_random (random.Random): Random state instance.
  250. Returns:
  251. dict[tuple[int, int, int, int], int]: Dict mapping placement coords (x1, y1, x2, y2)
  252. to assigned item index.
  253. """
  254. if not cell_placements:
  255. return {}
  256. # Find the placement tuple with the largest area for primary assignment
  257. primary_placement = max(
  258. cell_placements,
  259. key=lambda coords: (coords[2] - coords[0]) * (coords[3] - coords[1]),
  260. )
  261. placement_to_item_index: dict[tuple[int, int, int, int], int] = {
  262. primary_placement: 0,
  263. }
  264. # Use list comprehension for potentially better performance
  265. remaining_placements = [coords for coords in cell_placements if coords != primary_placement]
  266. # Indices for additional/replicated items start from 1
  267. remaining_item_indices = list(range(1, num_items))
  268. py_random.shuffle(remaining_item_indices)
  269. num_to_assign = min(len(remaining_placements), len(remaining_item_indices))
  270. for i in range(num_to_assign):
  271. placement_to_item_index[remaining_placements[i]] = remaining_item_indices[i]
  272. return placement_to_item_index
  273. def _preprocess_item_annotations(
  274. item: dict[str, Any],
  275. processor: BboxProcessor | KeypointsProcessor | None,
  276. data_key: Literal["bboxes", "keypoints"],
  277. ) -> np.ndarray | None:
  278. """Helper to preprocess annotations (bboxes or keypoints) for a single item."""
  279. original_data = item.get(data_key)
  280. # Check if processor exists and the relevant data key is in the item
  281. if processor and data_key in item and item.get(data_key) is not None:
  282. # === Add validation for required label fields ===
  283. required_labels = processor.params.label_fields
  284. if required_labels and [field for field in required_labels if field not in item]:
  285. raise ValueError(
  286. f"Item contains '{data_key}' but is missing required label "
  287. "fields: {[field for field in required_labels if field not in item]}. "
  288. f"Ensure all label fields declared in {type(processor.params).__name__} "
  289. f"({required_labels}) are present in the item dictionary when '{data_key}' is present.",
  290. )
  291. # === End validation ===
  292. # Create a temporary minimal dict for the processor
  293. temp_data = {
  294. "image": item["image"],
  295. data_key: item[data_key],
  296. }
  297. # Add declared label fields if they exist in the item (already validated above)
  298. if required_labels:
  299. for field in required_labels:
  300. # Check again just in case validation logic changes, avoids KeyError
  301. if field in item:
  302. temp_data[field] = item[field]
  303. # Preprocess modifies temp_data in-place
  304. processor.preprocess(temp_data)
  305. # Return the potentially modified data from the temp dict
  306. return temp_data.get(data_key)
  307. # Return original data if no processor or data key wasn't in item
  308. return original_data
  309. def preprocess_selected_mosaic_items(
  310. selected_raw_items: list[dict[str, Any]],
  311. bbox_processor: BboxProcessor | None, # Allow None
  312. keypoint_processor: KeypointsProcessor | None, # Allow None
  313. ) -> list[ProcessedMosaicItem]:
  314. """Preprocesses bboxes/keypoints for selected raw additional items.
  315. Iterates through items, preprocesses annotations individually using processors
  316. (updating label encoders), and returns a list of dicts with original image/mask
  317. and the corresponding preprocessed bboxes/keypoints.
  318. """
  319. if not selected_raw_items:
  320. return []
  321. result_data_items: list[ProcessedMosaicItem] = []
  322. for item in selected_raw_items:
  323. processed_bboxes = _preprocess_item_annotations(item, bbox_processor, "bboxes")
  324. processed_keypoints = _preprocess_item_annotations(item, keypoint_processor, "keypoints")
  325. # Construct the final processed item dict
  326. processed_item_dict: ProcessedMosaicItem = {
  327. "image": item["image"],
  328. "mask": item.get("mask"),
  329. "bboxes": processed_bboxes, # Already np.ndarray or None
  330. "keypoints": processed_keypoints, # Already np.ndarray or None
  331. }
  332. result_data_items.append(processed_item_dict)
  333. return result_data_items
  334. def get_opposite_crop_coords(
  335. cell_size: tuple[int, int],
  336. crop_size: tuple[int, int],
  337. cell_position: Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"],
  338. ) -> tuple[int, int, int, int]:
  339. """Calculates crop coordinates positioned opposite to the specified cell_position.
  340. Given a cell of `cell_size`, this function determines the top-left (x_min, y_min)
  341. and bottom-right (x_max, y_max) coordinates for a crop of `crop_size`, such
  342. that the crop is located in the corner or center opposite to `cell_position`.
  343. For example, if `cell_position` is "top_left", the crop coordinates will
  344. correspond to the bottom-right region of the cell.
  345. Args:
  346. cell_size: The (height, width) of the cell from which to crop.
  347. crop_size: The (height, width) of the desired crop.
  348. cell_position: The reference position within the cell. The crop will be
  349. taken from the opposite position.
  350. Returns:
  351. tuple[int, int, int, int]: (x_min, y_min, x_max, y_max) representing the crop coordinates.
  352. Raises:
  353. ValueError: If crop_size is larger than cell_size in either dimension.
  354. """
  355. cell_h, cell_w = cell_size
  356. crop_h, crop_w = crop_size
  357. if crop_h > cell_h or crop_w > cell_w:
  358. raise ValueError(f"Crop size {crop_size} cannot be larger than cell size {cell_size}")
  359. # Determine top-left corner (x_min, y_min) based on the OPPOSITE position
  360. if cell_position == "top_left": # Crop from bottom_right
  361. x_min = cell_w - crop_w
  362. y_min = cell_h - crop_h
  363. elif cell_position == "top_right": # Crop from bottom_left
  364. x_min = 0
  365. y_min = cell_h - crop_h
  366. elif cell_position == "bottom_left": # Crop from top_right
  367. x_min = cell_w - crop_w
  368. y_min = 0
  369. elif cell_position == "bottom_right": # Crop from top_left
  370. x_min = 0
  371. y_min = 0
  372. elif cell_position == "center": # Crop from center
  373. x_min = (cell_w - crop_w) // 2
  374. y_min = (cell_h - crop_h) // 2
  375. else:
  376. # Should be unreachable due to Literal type hint, but good practice
  377. raise ValueError(f"Invalid cell_position: {cell_position}")
  378. # Calculate bottom-right corner
  379. x_max = x_min + crop_w
  380. y_max = y_min + crop_h
  381. return x_min, y_min, x_max, y_max
  382. def process_cell_geometry(
  383. cell_shape: tuple[int, int],
  384. item: ProcessedMosaicItem,
  385. target_shape: tuple[int, int],
  386. fill: float | tuple[float, ...],
  387. fill_mask: float | tuple[float, ...],
  388. fit_mode: Literal["cover", "contain"],
  389. interpolation: int,
  390. mask_interpolation: int,
  391. cell_position: Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"],
  392. ) -> ProcessedMosaicItem:
  393. """Applies geometric transformations (padding and/or cropping) to a single mosaic item.
  394. Uses a Compose pipeline with PadIfNeeded and Crop to ensure the output
  395. matches the target cell dimensions exactly, handling both padding and cropping cases.
  396. Args:
  397. cell_shape: (tuple[int, int]): Shape of the cell.
  398. item: (ProcessedMosaicItem): The preprocessed mosaic item dictionary.
  399. target_shape: (tuple[int, int]): Target shape of the cell.
  400. fill: (float | tuple[float, ...]): Fill value for image padding.
  401. fill_mask: (float | tuple[float, ...]): Fill value for mask padding.
  402. fit_mode: (Literal["cover", "contain"]): Fit mode for the mosaic.
  403. interpolation: (int): Interpolation method for image.
  404. mask_interpolation: (int): Interpolation method for mask.
  405. cell_position: (Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"]): Position
  406. of the cell.
  407. Returns: (ProcessedMosaicItem): Dictionary containing the geometrically processed image,
  408. mask, bboxes, and keypoints, fitting the target dimensions.
  409. """
  410. # Define the pipeline: PadIfNeeded first, then Crop
  411. compose_kwargs: dict[str, Any] = {"p": 1.0}
  412. if item.get("bboxes") is not None:
  413. compose_kwargs["bbox_params"] = {"format": "albumentations"}
  414. if item.get("keypoints") is not None:
  415. compose_kwargs["keypoint_params"] = {"format": "albumentations"}
  416. crop_coords = get_opposite_crop_coords(cell_shape, target_shape, cell_position)
  417. if fit_mode == "cover":
  418. geom_pipeline = Compose(
  419. [
  420. SmallestMaxSize(
  421. max_size_hw=cell_shape,
  422. interpolation=interpolation,
  423. mask_interpolation=mask_interpolation,
  424. p=1.0,
  425. ),
  426. Crop(
  427. x_min=crop_coords[0],
  428. y_min=crop_coords[1],
  429. x_max=crop_coords[2],
  430. y_max=crop_coords[3],
  431. ),
  432. ],
  433. **compose_kwargs,
  434. )
  435. elif fit_mode == "contain":
  436. geom_pipeline = Compose(
  437. [
  438. LongestMaxSize(
  439. max_size_hw=cell_shape,
  440. interpolation=interpolation,
  441. mask_interpolation=mask_interpolation,
  442. p=1.0,
  443. ),
  444. Crop(
  445. x_min=crop_coords[0],
  446. y_min=crop_coords[1],
  447. x_max=crop_coords[2],
  448. y_max=crop_coords[3],
  449. pad_if_needed=True,
  450. fill=fill,
  451. fill_mask=fill_mask,
  452. p=1.0,
  453. ),
  454. ],
  455. **compose_kwargs,
  456. )
  457. else:
  458. raise ValueError(f"Invalid fit_mode: {fit_mode}. Must be 'cover' or 'contain'.")
  459. # Prepare input data for the pipeline
  460. geom_input = {"image": item["image"]}
  461. if item.get("mask") is not None:
  462. geom_input["mask"] = item["mask"]
  463. if item.get("bboxes") is not None:
  464. # Compose expects bboxes in a specific format, ensure it's compatible
  465. # Assuming item['bboxes'] is already preprocessed correctly
  466. geom_input["bboxes"] = item["bboxes"]
  467. if item.get("keypoints") is not None:
  468. geom_input["keypoints"] = item["keypoints"]
  469. # Apply the pipeline
  470. processed_item = geom_pipeline(**geom_input)
  471. # Ensure output dict has the same structure as ProcessedMosaicItem
  472. # Compose might not return None for missing keys, handle explicitly
  473. return {
  474. "image": processed_item["image"],
  475. "mask": processed_item.get("mask"),
  476. "bboxes": processed_item.get("bboxes"),
  477. "keypoints": processed_item.get("keypoints"),
  478. }
  479. def shift_cell_coordinates(
  480. processed_item_geom: ProcessedMosaicItem,
  481. placement_coords: tuple[int, int, int, int],
  482. ) -> ProcessedMosaicItem:
  483. """Shifts the coordinates of geometrically processed bboxes and keypoints.
  484. Args:
  485. processed_item_geom: (ProcessedMosaicItem): The output from process_cell_geometry.
  486. placement_coords: (tuple[int, int, int, int]): The (x1, y1, x2, y2) placement on the final canvas.
  487. Returns: (ProcessedMosaicItem): A dictionary with keys 'bboxes' and 'keypoints', containing the shifted
  488. numpy arrays (potentially empty).
  489. """
  490. tgt_x1, tgt_y1, _, _ = placement_coords
  491. shifted_bboxes = None
  492. shifted_keypoints = None
  493. bboxes_geom = processed_item_geom.get("bboxes")
  494. if bboxes_geom is not None and np.asarray(bboxes_geom).size > 0:
  495. bboxes_geom_arr = np.asarray(bboxes_geom) # Ensure it's an array
  496. bbox_shift_vector = np.array([tgt_x1, tgt_y1, tgt_x1, tgt_y1], dtype=np.int32)
  497. shifted_bboxes = fgeometric.shift_bboxes(bboxes_geom_arr, bbox_shift_vector)
  498. keypoints_geom = processed_item_geom.get("keypoints")
  499. if keypoints_geom is not None and np.asarray(keypoints_geom).size > 0:
  500. keypoints_geom_arr = np.asarray(keypoints_geom) # Ensure it's an array
  501. kp_shift_vector = np.array([tgt_x1, tgt_y1, 0], dtype=keypoints_geom_arr.dtype)
  502. shifted_keypoints = fgeometric.shift_keypoints(keypoints_geom_arr, kp_shift_vector)
  503. return {
  504. "bboxes": shifted_bboxes,
  505. "keypoints": shifted_keypoints,
  506. "image": processed_item_geom["image"],
  507. "mask": processed_item_geom.get("mask"),
  508. }
  509. def assemble_mosaic_from_processed_cells(
  510. processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
  511. target_shape: tuple[int, ...], # Use full canvas shape (H, W) or (H, W, C)
  512. dtype: np.dtype,
  513. data_key: Literal["image", "mask"],
  514. fill: float | tuple[float, ...] | None, # Value for image fill or mask fill
  515. ) -> np.ndarray:
  516. """Assembles the final mosaic image or mask from processed cell data onto a canvas.
  517. Initializes the canvas with the fill value and overwrites with processed segments.
  518. Handles potentially multi-channel masks.
  519. Addresses potential broadcasting errors if mask segments have unexpected dimensions.
  520. Assumes input data is valid and correctly sized.
  521. Args:
  522. processed_cells (dict[tuple[int, int, int, int], dict[str, Any]]): Dictionary mapping
  523. placement coords to processed cell data.
  524. target_shape (tuple[int, ...]): The target shape of the output canvas (e.g., (H, W) or (H, W, C)).
  525. dtype (np.dtype): NumPy dtype for the canvas.
  526. data_key (Literal["image", "mask"]): Specifies whether to assemble 'image' or 'mask'.
  527. fill (float | tuple[float, ...] | None): Value used to initialize the canvas (image fill or mask fill).
  528. Should be a float/int or a tuple matching the number of channels.
  529. If None, defaults to 0.
  530. Returns:
  531. np.ndarray: The assembled mosaic canvas.
  532. """
  533. # Use 0 as default fill if None is provided
  534. actual_fill = fill if fill is not None else 0
  535. # Convert fill to numpy array to handle broadcasting in np.full
  536. fill_value = np.array(actual_fill, dtype=dtype)
  537. # Initialize canvas with the fill value.
  538. # If fill_value shape is incompatible with target_shape, np.full will raise ValueError.
  539. canvas = np.full(target_shape, fill_value=fill_value, dtype=dtype)
  540. # Iterate and paste segments onto the pre-filled canvas
  541. for placement_coords, cell_data in processed_cells.items():
  542. segment = cell_data.get(data_key)
  543. # If segment exists, paste it over the filled background
  544. if segment is not None:
  545. tgt_x1, tgt_y1, tgt_x2, tgt_y2 = placement_coords
  546. canvas[tgt_y1:tgt_y2, tgt_x1:tgt_x2] = segment
  547. return canvas
  548. def process_all_mosaic_geometries(
  549. canvas_shape: tuple[int, int],
  550. cell_shape: tuple[int, int],
  551. placement_to_item_index: dict[tuple[int, int, int, int], int],
  552. final_items_for_grid: list[ProcessedMosaicItem],
  553. fill: float | tuple[float, ...],
  554. fill_mask: float | tuple[float, ...],
  555. fit_mode: Literal["cover", "contain"],
  556. interpolation: Literal[
  557. cv2.INTER_NEAREST,
  558. cv2.INTER_NEAREST_EXACT,
  559. cv2.INTER_LINEAR,
  560. cv2.INTER_CUBIC,
  561. cv2.INTER_AREA,
  562. cv2.INTER_LANCZOS4,
  563. cv2.INTER_LINEAR_EXACT,
  564. ],
  565. mask_interpolation: Literal[
  566. cv2.INTER_NEAREST,
  567. cv2.INTER_NEAREST_EXACT,
  568. cv2.INTER_LINEAR,
  569. cv2.INTER_CUBIC,
  570. cv2.INTER_AREA,
  571. cv2.INTER_LANCZOS4,
  572. cv2.INTER_LINEAR_EXACT,
  573. ],
  574. ) -> dict[tuple[int, int, int, int], ProcessedMosaicItem]:
  575. """Processes the geometry (cropping/padding) for all assigned mosaic cells.
  576. Iterates through assigned placements, applies geometric transforms via process_cell_geometry,
  577. and returns a dictionary mapping final placement coordinates to the processed item data.
  578. The bbox/keypoint coordinates in the returned dict are *not* shifted yet.
  579. Args:
  580. canvas_shape (tuple[int, int]): The shape of the canvas.
  581. cell_shape (tuple[int, int]): Shape of each cell in the mosaic grid.
  582. placement_to_item_index (dict[tuple[int, int, int, int], int]): Mapping from placement
  583. coordinates (x1, y1, x2, y2) to assigned item index.
  584. final_items_for_grid (list[ProcessedMosaicItem]): List of all preprocessed items available.
  585. fill (float | tuple[float, ...]): Fill value for image padding.
  586. fill_mask (float | tuple[float, ...]): Fill value for mask padding.
  587. fit_mode (Literal["cover", "contain"]): Fit mode for the mosaic.
  588. interpolation (int): Interpolation method for image.
  589. mask_interpolation (int): Interpolation method for mask.
  590. Returns:
  591. dict[tuple[int, int, int, int], ProcessedMosaicItem]: Dictionary mapping final placement
  592. coordinates (x1, y1, x2, y2) to the geometrically processed item data (image, mask, un-shifted bboxes/kps).
  593. """
  594. processed_cells_geom: dict[tuple[int, int, int, int], ProcessedMosaicItem] = {}
  595. # Iterate directly over placements and their assigned item indices
  596. for placement_coords, item_idx in placement_to_item_index.items():
  597. item = final_items_for_grid[item_idx]
  598. tgt_x1, tgt_y1, tgt_x2, tgt_y2 = placement_coords
  599. target_h = tgt_y2 - tgt_y1
  600. target_w = tgt_x2 - tgt_x1
  601. cell_position = get_cell_relative_position(placement_coords, canvas_shape)
  602. # Apply geometric processing (crop/pad)
  603. processed_cells_geom[placement_coords] = process_cell_geometry(
  604. cell_shape=cell_shape,
  605. item=item,
  606. target_shape=(target_h, target_w),
  607. fill=fill,
  608. fill_mask=fill_mask,
  609. fit_mode=fit_mode,
  610. interpolation=interpolation,
  611. mask_interpolation=mask_interpolation,
  612. cell_position=cell_position,
  613. )
  614. return processed_cells_geom
  615. def get_cell_relative_position(
  616. placement_coords: tuple[int, int, int, int],
  617. target_shape: tuple[int, int],
  618. ) -> Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"]:
  619. """Determines the position of a cell relative to the center of the target canvas.
  620. Compares the cell center to the canvas center and returns its quadrant
  621. or "center" if it lies on or very close to a central axis.
  622. Args:
  623. placement_coords (tuple[int, int, int, int]): The (x_min, y_min, x_max, y_max) coordinates
  624. of the cell.
  625. target_shape (tuple[int, int]): The (height, width) of the overall target canvas.
  626. Returns:
  627. Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"]:
  628. The position of the cell relative to the center of the target canvas.
  629. """
  630. target_h, target_w = target_shape
  631. x1, y1, x2, y2 = placement_coords
  632. canvas_center_x = target_w / 2.0
  633. canvas_center_y = target_h / 2.0
  634. cell_center_x = (x1 + x2) / 2.0
  635. cell_center_y = (y1 + y2) / 2.0
  636. # Determine vertical position
  637. if cell_center_y < canvas_center_y:
  638. v_pos = "top"
  639. elif cell_center_y > canvas_center_y:
  640. v_pos = "bottom"
  641. else: # Exactly on the horizontal center line
  642. v_pos = "center"
  643. # Determine horizontal position
  644. if cell_center_x < canvas_center_x:
  645. h_pos = "left"
  646. elif cell_center_x > canvas_center_x:
  647. h_pos = "right"
  648. else: # Exactly on the vertical center line
  649. h_pos = "center"
  650. # Map positions to the final string
  651. position_map = {
  652. ("top", "left"): "top_left",
  653. ("top", "right"): "top_right",
  654. ("bottom", "left"): "bottom_left",
  655. ("bottom", "right"): "bottom_right",
  656. }
  657. # Default to "center" if the combination is not in the map
  658. # (which happens if either v_pos or h_pos is "center")
  659. return cast(
  660. "Literal['top_left', 'top_right', 'center', 'bottom_left', 'bottom_right']",
  661. position_map.get((v_pos, h_pos), "center"),
  662. )
  663. def shift_all_coordinates(
  664. processed_cells_geom: dict[tuple[int, int, int, int], ProcessedMosaicItem],
  665. canvas_shape: tuple[int, int],
  666. ) -> dict[tuple[int, int, int, int], ProcessedMosaicItem]: # Return type matches input, but values are updated
  667. """Shifts coordinates for all geometrically processed cells.
  668. Iterates through the processed cells (keyed by placement coords), applies coordinate
  669. shifting to bboxes/keypoints, and returns a new dictionary with the same keys
  670. but updated ProcessedMosaicItem values containing the *shifted* coordinates.
  671. Args:
  672. processed_cells_geom (dict[tuple[int, int, int, int], ProcessedMosaicItem]):
  673. Output from process_all_mosaic_geometries (keyed by placement coords).
  674. canvas_shape (tuple[int, int]): The shape of the canvas.
  675. Returns:
  676. dict[tuple[int, int, int, int], ProcessedMosaicItem]: Final dictionary mapping
  677. placement coords (x1, y1, x2, y2) to processed cell data with shifted coordinates.
  678. """
  679. final_processed_cells: dict[tuple[int, int, int, int], ProcessedMosaicItem] = {}
  680. canvas_h, canvas_w = canvas_shape
  681. for placement_coords, cell_data_geom in processed_cells_geom.items():
  682. tgt_x1, tgt_y1 = placement_coords[:2]
  683. cell_width = placement_coords[2] - placement_coords[0]
  684. cell_height = placement_coords[3] - placement_coords[1]
  685. # Extract geometrically processed bboxes/keypoints
  686. bboxes_geom = cell_data_geom.get("bboxes")
  687. keypoints_geom = cell_data_geom.get("keypoints")
  688. final_cell_data = {
  689. "image": cell_data_geom["image"],
  690. "mask": cell_data_geom.get("mask"),
  691. }
  692. # Perform shifting if data exists
  693. if bboxes_geom is not None and bboxes_geom.size > 0:
  694. bboxes_geom_arr = np.asarray(bboxes_geom)
  695. bbox_denoramlized = denormalize_bboxes(bboxes_geom_arr, {"height": cell_height, "width": cell_width})
  696. bbox_shift_vector = np.array([tgt_x1, tgt_y1, tgt_x1, tgt_y1], dtype=np.float32)
  697. shifted_bboxes_denormalized = fgeometric.shift_bboxes(bbox_denoramlized, bbox_shift_vector)
  698. shifted_bboxes = normalize_bboxes(shifted_bboxes_denormalized, {"height": canvas_h, "width": canvas_w})
  699. final_cell_data["bboxes"] = shifted_bboxes
  700. else:
  701. final_cell_data["bboxes"] = np.empty((0, NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS))
  702. if keypoints_geom is not None and keypoints_geom.size > 0:
  703. keypoints_geom_arr = np.asarray(keypoints_geom)
  704. # Ensure shift vector matches keypoint dtype (usually float)
  705. kp_shift_vector = np.array([tgt_x1, tgt_y1, 0], dtype=keypoints_geom_arr.dtype)
  706. shifted_keypoints = fgeometric.shift_keypoints(keypoints_geom_arr, kp_shift_vector)
  707. final_cell_data["keypoints"] = shifted_keypoints
  708. else:
  709. final_cell_data["keypoints"] = np.empty((0, NUM_KEYPOINTS_COLUMNS_IN_ALBUMENTATIONS))
  710. final_processed_cells[placement_coords] = cast("ProcessedMosaicItem", final_cell_data)
  711. return final_processed_cells