functional.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. """Module containing functional implementations of 3D transformations.
  2. This module provides a collection of utility functions for manipulating and transforming
  3. 3D volumetric data (such as medical imaging volumes). The functions here implement the core
  4. algorithms for operations like padding, cropping, rotation, and other spatial manipulations
  5. specifically designed for 3D data.
  6. """
  7. from __future__ import annotations
  8. import random
  9. from typing import Literal
  10. import numpy as np
  11. from albumentations.augmentations.utils import handle_empty_array
  12. from albumentations.core.type_definitions import NUM_VOLUME_DIMENSIONS
  13. def adjust_padding_by_position3d(
  14. paddings: list[tuple[int, int]], # [(front, back), (top, bottom), (left, right)]
  15. position: Literal["center", "random"],
  16. py_random: random.Random,
  17. ) -> tuple[int, int, int, int, int, int]:
  18. """Adjust padding values based on desired position for 3D data.
  19. Args:
  20. paddings (list[tuple[int, int]]): List of tuples containing padding pairs
  21. for each dimension [(d_pad), (h_pad), (w_pad)]
  22. position (Literal["center", "random"]): Position of the image after padding.
  23. py_random (random.Random): Random number generator
  24. Returns:
  25. tuple[int, int, int, int, int, int]: Final padding values (d_front, d_back, h_top, h_bottom, w_left, w_right)
  26. """
  27. if position == "center":
  28. return (
  29. paddings[0][0], # d_front
  30. paddings[0][1], # d_back
  31. paddings[1][0], # h_top
  32. paddings[1][1], # h_bottom
  33. paddings[2][0], # w_left
  34. paddings[2][1], # w_right
  35. )
  36. # For random position, redistribute padding for each dimension
  37. d_pad = sum(paddings[0])
  38. h_pad = sum(paddings[1])
  39. w_pad = sum(paddings[2])
  40. return (
  41. py_random.randint(0, d_pad), # d_front
  42. d_pad - py_random.randint(0, d_pad), # d_back
  43. py_random.randint(0, h_pad), # h_top
  44. h_pad - py_random.randint(0, h_pad), # h_bottom
  45. py_random.randint(0, w_pad), # w_left
  46. w_pad - py_random.randint(0, w_pad), # w_right
  47. )
  48. def pad_3d_with_params(
  49. volume: np.ndarray,
  50. padding: tuple[int, int, int, int, int, int],
  51. value: tuple[float, ...] | float,
  52. ) -> np.ndarray:
  53. """Pad 3D volume with given parameters.
  54. Args:
  55. volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
  56. padding (tuple[int, int, int, int, int, int]): Padding values in format:
  57. (depth_front, depth_back, height_top, height_bottom, width_left, width_right)
  58. where:
  59. - depth_front/back: padding at start/end of depth axis (z)
  60. - height_top/bottom: padding at start/end of height axis (y)
  61. - width_left/right: padding at start/end of width axis (x)
  62. value (tuple[float, ...] | float): Value to fill the padding
  63. Returns:
  64. np.ndarray: Padded volume with same number of dimensions as input
  65. Note:
  66. The padding order matches the volume dimensions (depth, height, width).
  67. For each dimension, the first value is padding at the start (smaller indices),
  68. and the second value is padding at the end (larger indices).
  69. """
  70. depth_front, depth_back, height_top, height_bottom, width_left, width_right = padding
  71. # Skip if no padding is needed
  72. if all(p == 0 for p in padding):
  73. return volume
  74. # Handle both 3D and 4D arrays
  75. pad_width = [
  76. (depth_front, depth_back), # depth (z) padding
  77. (height_top, height_bottom), # height (y) padding
  78. (width_left, width_right), # width (x) padding
  79. ]
  80. # Add channel padding if 4D array
  81. if volume.ndim == NUM_VOLUME_DIMENSIONS:
  82. pad_width.append((0, 0)) # no padding for channels
  83. return np.pad(
  84. volume,
  85. pad_width=pad_width,
  86. mode="constant",
  87. constant_values=value,
  88. )
  89. def crop3d(
  90. volume: np.ndarray,
  91. crop_coords: tuple[int, int, int, int, int, int],
  92. ) -> np.ndarray:
  93. """Crop 3D volume using coordinates.
  94. Args:
  95. volume (np.ndarray): Input volume with shape (z, y, x) or (z, y, x, channels)
  96. crop_coords (tuple[int, int, int, int, int, int]):
  97. (z_min, z_max, y_min, y_max, x_min, x_max) coordinates for cropping
  98. Returns:
  99. np.ndarray: Cropped volume with same number of dimensions as input
  100. """
  101. z_min, z_max, y_min, y_max, x_min, x_max = crop_coords
  102. return volume[z_min:z_max, y_min:y_max, x_min:x_max]
  103. def cutout3d(volume: np.ndarray, holes: np.ndarray, fill: tuple[float, ...] | float) -> np.ndarray:
  104. """Cut out holes in 3D volume and fill them with a given value.
  105. Args:
  106. volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
  107. holes (np.ndarray): Array of holes with shape (num_holes, 6).
  108. Each hole is represented as [z1, y1, x1, z2, y2, x2]
  109. fill (tuple[float, ...] | float): Value to fill the holes
  110. Returns:
  111. np.ndarray: Volume with holes filled with the given value
  112. """
  113. volume = volume.copy()
  114. for z1, y1, x1, z2, y2, x2 in holes:
  115. volume[z1:z2, y1:y2, x1:x2] = fill
  116. return volume
  117. def transform_cube(cube: np.ndarray, index: int) -> np.ndarray:
  118. """Transform cube by index (0-47)
  119. Args:
  120. cube (np.ndarray): Input array with shape (D, H, W) or (D, H, W, C)
  121. index (int): Integer from 0 to 47 specifying which transformation to apply
  122. Returns:
  123. np.ndarray: Transformed cube with same shape as input
  124. """
  125. if not (0 <= index < 48):
  126. raise ValueError("Index must be between 0 and 47")
  127. transformations = {
  128. # First 4: rotate around axis 0 (indices 0-3)
  129. 0: lambda x: x,
  130. 1: lambda x: np.rot90(x, k=1, axes=(1, 2)),
  131. 2: lambda x: np.rot90(x, k=2, axes=(1, 2)),
  132. 3: lambda x: np.rot90(x, k=3, axes=(1, 2)),
  133. # Next 4: flip 180° about axis 1, then rotate around axis 0 (indices 4-7)
  134. 4: lambda x: x[::-1, :, ::-1], # was: np.flip(x, axis=(0, 2))
  135. 5: lambda x: np.rot90(np.rot90(x, k=2, axes=(0, 2)), k=1, axes=(1, 2)),
  136. 6: lambda x: x[::-1, ::-1, :], # was: np.flip(x, axis=(0, 1))
  137. 7: lambda x: np.rot90(np.rot90(x, k=2, axes=(0, 2)), k=3, axes=(1, 2)),
  138. # Next 8: split between 90° and 270° about axis 1, then rotate around axis 2 (indices 8-15)
  139. 8: lambda x: np.rot90(x, k=1, axes=(0, 2)),
  140. 9: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 2)), k=1, axes=(0, 1)),
  141. 10: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 2)), k=2, axes=(0, 1)),
  142. 11: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim)),
  143. 12: lambda x: np.rot90(x, k=-1, axes=(0, 2)),
  144. 13: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=1, axes=(0, 1)),
  145. 14: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=2, axes=(0, 1)),
  146. 15: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=3, axes=(0, 1)),
  147. # Final 8: split between rotations about axis 2, then rotate around axis 1 (indices 16-23)
  148. 16: lambda x: np.rot90(x, k=1, axes=(0, 1)),
  149. 17: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 1)), k=1, axes=(0, 2)),
  150. 18: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 1)), k=2, axes=(0, 2)),
  151. 19: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim)),
  152. 20: lambda x: np.rot90(x, k=-1, axes=(0, 1)),
  153. 21: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=1, axes=(0, 2)),
  154. 22: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=2, axes=(0, 2)),
  155. 23: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=3, axes=(0, 2)),
  156. # Reflected versions (24-47) - same as above but with initial reflection
  157. 24: lambda x: x[:, :, ::-1], # was: np.flip(x, axis=2)
  158. 25: lambda x: x.transpose(0, 2, 1, *range(3, x.ndim)),
  159. 26: lambda x: x[:, ::-1, :], # was: np.flip(x, axis=1)
  160. 27: lambda x: np.rot90(x[:, :, ::-1], k=3, axes=(1, 2)),
  161. 28: lambda x: x[::-1, :, :], # was: np.flip(x, axis=0)
  162. 29: lambda x: np.rot90(x[::-1, :, :], k=1, axes=(1, 2)),
  163. 30: lambda x: x[::-1, ::-1, ::-1], # was: np.flip(x, axis=(0, 1, 2))
  164. 31: lambda x: np.rot90(x[::-1, :, :], k=-1, axes=(1, 2)),
  165. 32: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim)),
  166. 33: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[::-1, :, :],
  167. 34: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim))[::-1, ::-1, :],
  168. 35: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[:, ::-1, :],
  169. 36: lambda x: np.rot90(x[:, :, ::-1], k=-1, axes=(0, 2)),
  170. 37: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[::-1, ::-1, ::-1],
  171. 38: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim))[:, ::-1, ::-1],
  172. 39: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[:, :, ::-1],
  173. 40: lambda x: np.rot90(x[:, :, ::-1], k=1, axes=(0, 1)),
  174. 41: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[:, :, ::-1],
  175. 42: lambda x: x.transpose(1, 0, 2, *range(3, x.ndim)),
  176. 43: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[::-1, :, :],
  177. 44: lambda x: np.rot90(x[:, :, ::-1], k=-1, axes=(0, 1)),
  178. 45: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[:, ::-1, :],
  179. 46: lambda x: x.transpose(1, 0, 2, *range(3, x.ndim))[::-1, ::-1, :],
  180. 47: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[::-1, ::-1, ::-1],
  181. }
  182. return transformations[index](cube.copy())
  183. @handle_empty_array("keypoints")
  184. def filter_keypoints_in_holes3d(keypoints: np.ndarray, holes: np.ndarray) -> np.ndarray:
  185. """Filter out keypoints that are inside any of the 3D holes.
  186. Args:
  187. keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
  188. The first three columns are x, y, z coordinates.
  189. holes (np.ndarray): Array of holes with shape (num_holes, 6).
  190. Each hole is represented as [z1, y1, x1, z2, y2, x2].
  191. Returns:
  192. np.ndarray: Array of keypoints that are not inside any hole.
  193. """
  194. if holes.size == 0:
  195. return keypoints
  196. # Broadcast keypoints and holes for vectorized comparison
  197. # Convert keypoints from XYZ to ZYX for comparison with holes
  198. kp_z = keypoints[:, 2][:, np.newaxis] # Shape: (num_keypoints, 1)
  199. kp_y = keypoints[:, 1][:, np.newaxis] # Shape: (num_keypoints, 1)
  200. kp_x = keypoints[:, 0][:, np.newaxis] # Shape: (num_keypoints, 1)
  201. # Extract hole coordinates (in ZYX order)
  202. hole_z1 = holes[:, 0] # Shape: (num_holes,)
  203. hole_y1 = holes[:, 1]
  204. hole_x1 = holes[:, 2]
  205. hole_z2 = holes[:, 3]
  206. hole_y2 = holes[:, 4]
  207. hole_x2 = holes[:, 5]
  208. # Check if each keypoint is inside each hole
  209. inside_hole = (
  210. (kp_z >= hole_z1)
  211. & (kp_z < hole_z2)
  212. & (kp_y >= hole_y1)
  213. & (kp_y < hole_y2)
  214. & (kp_x >= hole_x1)
  215. & (kp_x < hole_x2)
  216. )
  217. # A keypoint is valid if it's not inside any hole
  218. valid_keypoints = ~np.any(inside_hole, axis=1)
  219. # Return filtered keypoints with same dtype as input
  220. result = keypoints[valid_keypoints]
  221. if len(result) == 0:
  222. # Ensure empty result has correct shape and dtype
  223. return np.array([], dtype=keypoints.dtype).reshape(0, keypoints.shape[1])
  224. return result
  225. def keypoints_rot90(
  226. keypoints: np.ndarray,
  227. k: int,
  228. axes: tuple[int, int],
  229. volume_shape: tuple[int, int, int],
  230. ) -> np.ndarray:
  231. """Rotate keypoints 90 degrees k times around the specified axes.
  232. Args:
  233. keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
  234. The first three columns are x, y, z coordinates.
  235. k (int): Number of times to rotate by 90 degrees.
  236. axes (tuple[int, int]): Axes to rotate around.
  237. volume_shape (tuple[int, int, int]): Shape of the volume (depth, height, width).
  238. Returns:
  239. np.ndarray: Rotated keypoints with same shape as input.
  240. """
  241. if k == 0 or len(keypoints) == 0:
  242. return keypoints
  243. # Normalize factor to range [0, 3]
  244. k = ((k % 4) + 4) % 4
  245. result = keypoints.copy()
  246. # Get dimensions for the rotation axes
  247. dims = [volume_shape[ax] for ax in axes]
  248. # Get coordinates to rotate
  249. coords1 = result[:, axes[0]].copy()
  250. coords2 = result[:, axes[1]].copy()
  251. # Apply rotation based on factor (counterclockwise)
  252. if k == 1: # 90 degrees CCW
  253. result[:, axes[0]] = (dims[1] - 1) - coords2
  254. result[:, axes[1]] = coords1
  255. elif k == 2: # 180 degrees
  256. result[:, axes[0]] = (dims[0] - 1) - coords1
  257. result[:, axes[1]] = (dims[1] - 1) - coords2
  258. elif k == 3: # 270 degrees CCW
  259. result[:, axes[0]] = coords2
  260. result[:, axes[1]] = (dims[0] - 1) - coords1
  261. return result
  262. @handle_empty_array("keypoints")
  263. def transform_cube_keypoints(
  264. keypoints: np.ndarray,
  265. index: int,
  266. volume_shape: tuple[int, int, int],
  267. ) -> np.ndarray:
  268. """Transform keypoints according to the cube transformation specified by index.
  269. Args:
  270. keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
  271. The first three columns are x, y, z coordinates.
  272. index (int): Integer from 0 to 47 specifying which transformation to apply.
  273. volume_shape (tuple[int, int, int]): Shape of the volume (depth, height, width).
  274. Returns:
  275. np.ndarray: Transformed keypoints with same shape as input.
  276. """
  277. if not (0 <= index < 48):
  278. raise ValueError("Index must be between 0 and 47")
  279. # Create working copy preserving all columns
  280. working_points = keypoints.copy()
  281. # Convert only XYZ coordinates to HWD, keeping other columns unchanged
  282. xyz = working_points[:, :3] # Get first 3 columns (XYZ)
  283. xyz = xyz[:, [2, 1, 0]] # XYZ -> HWD
  284. working_points[:, :3] = xyz # Put back transformed coordinates
  285. current_shape = volume_shape
  286. # Handle reflection first (indices 24-47)
  287. if index >= 24:
  288. working_points[:, 2] = current_shape[2] - 1 - working_points[:, 2] # Reflect W axis
  289. rotation_index = index % 24
  290. # Apply the same rotation logic as transform_cube
  291. if rotation_index < 4:
  292. # First 4: rotate around axis 0
  293. result = keypoints_rot90(working_points, k=rotation_index, axes=(1, 2), volume_shape=current_shape)
  294. elif rotation_index < 8:
  295. # Next 4: flip 180° about axis 1, then rotate around axis 0
  296. temp = keypoints_rot90(working_points, k=2, axes=(0, 2), volume_shape=current_shape)
  297. result = keypoints_rot90(temp, k=rotation_index - 4, axes=(1, 2), volume_shape=volume_shape)
  298. elif rotation_index < 16:
  299. if rotation_index < 12:
  300. temp = keypoints_rot90(working_points, k=1, axes=(0, 2), volume_shape=current_shape)
  301. temp_shape = (current_shape[2], current_shape[1], current_shape[0])
  302. result = keypoints_rot90(temp, k=rotation_index - 8, axes=(0, 1), volume_shape=temp_shape)
  303. else:
  304. temp = keypoints_rot90(working_points, k=3, axes=(0, 2), volume_shape=current_shape)
  305. temp_shape = (current_shape[2], current_shape[1], current_shape[0])
  306. result = keypoints_rot90(temp, k=rotation_index - 12, axes=(0, 1), volume_shape=temp_shape)
  307. elif rotation_index < 20:
  308. temp = keypoints_rot90(working_points, k=1, axes=(0, 1), volume_shape=current_shape)
  309. temp_shape = (current_shape[1], current_shape[0], current_shape[2])
  310. result = keypoints_rot90(temp, k=rotation_index - 16, axes=(0, 2), volume_shape=temp_shape)
  311. else:
  312. temp = keypoints_rot90(working_points, k=3, axes=(0, 1), volume_shape=current_shape)
  313. temp_shape = (current_shape[1], current_shape[0], current_shape[2])
  314. result = keypoints_rot90(temp, k=rotation_index - 20, axes=(0, 2), volume_shape=temp_shape)
  315. # Convert back from HWD to XYZ coordinates for first 3 columns only
  316. xyz = result[:, :3]
  317. xyz = xyz[:, [2, 1, 0]] # HWD -> XYZ
  318. result[:, :3] = xyz
  319. return result