| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393 |
- """Module containing functional implementations of 3D transformations.
- This module provides a collection of utility functions for manipulating and transforming
- 3D volumetric data (such as medical imaging volumes). The functions here implement the core
- algorithms for operations like padding, cropping, rotation, and other spatial manipulations
- specifically designed for 3D data.
- """
- from __future__ import annotations
- import random
- from typing import Literal
- import numpy as np
- from albumentations.augmentations.utils import handle_empty_array
- from albumentations.core.type_definitions import NUM_VOLUME_DIMENSIONS
- def adjust_padding_by_position3d(
- paddings: list[tuple[int, int]], # [(front, back), (top, bottom), (left, right)]
- position: Literal["center", "random"],
- py_random: random.Random,
- ) -> tuple[int, int, int, int, int, int]:
- """Adjust padding values based on desired position for 3D data.
- Args:
- paddings (list[tuple[int, int]]): List of tuples containing padding pairs
- for each dimension [(d_pad), (h_pad), (w_pad)]
- position (Literal["center", "random"]): Position of the image after padding.
- py_random (random.Random): Random number generator
- Returns:
- tuple[int, int, int, int, int, int]: Final padding values (d_front, d_back, h_top, h_bottom, w_left, w_right)
- """
- if position == "center":
- return (
- paddings[0][0], # d_front
- paddings[0][1], # d_back
- paddings[1][0], # h_top
- paddings[1][1], # h_bottom
- paddings[2][0], # w_left
- paddings[2][1], # w_right
- )
- # For random position, redistribute padding for each dimension
- d_pad = sum(paddings[0])
- h_pad = sum(paddings[1])
- w_pad = sum(paddings[2])
- return (
- py_random.randint(0, d_pad), # d_front
- d_pad - py_random.randint(0, d_pad), # d_back
- py_random.randint(0, h_pad), # h_top
- h_pad - py_random.randint(0, h_pad), # h_bottom
- py_random.randint(0, w_pad), # w_left
- w_pad - py_random.randint(0, w_pad), # w_right
- )
- def pad_3d_with_params(
- volume: np.ndarray,
- padding: tuple[int, int, int, int, int, int],
- value: tuple[float, ...] | float,
- ) -> np.ndarray:
- """Pad 3D volume with given parameters.
- Args:
- volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
- padding (tuple[int, int, int, int, int, int]): Padding values in format:
- (depth_front, depth_back, height_top, height_bottom, width_left, width_right)
- where:
- - depth_front/back: padding at start/end of depth axis (z)
- - height_top/bottom: padding at start/end of height axis (y)
- - width_left/right: padding at start/end of width axis (x)
- value (tuple[float, ...] | float): Value to fill the padding
- Returns:
- np.ndarray: Padded volume with same number of dimensions as input
- Note:
- The padding order matches the volume dimensions (depth, height, width).
- For each dimension, the first value is padding at the start (smaller indices),
- and the second value is padding at the end (larger indices).
- """
- depth_front, depth_back, height_top, height_bottom, width_left, width_right = padding
- # Skip if no padding is needed
- if all(p == 0 for p in padding):
- return volume
- # Handle both 3D and 4D arrays
- pad_width = [
- (depth_front, depth_back), # depth (z) padding
- (height_top, height_bottom), # height (y) padding
- (width_left, width_right), # width (x) padding
- ]
- # Add channel padding if 4D array
- if volume.ndim == NUM_VOLUME_DIMENSIONS:
- pad_width.append((0, 0)) # no padding for channels
- return np.pad(
- volume,
- pad_width=pad_width,
- mode="constant",
- constant_values=value,
- )
- def crop3d(
- volume: np.ndarray,
- crop_coords: tuple[int, int, int, int, int, int],
- ) -> np.ndarray:
- """Crop 3D volume using coordinates.
- Args:
- volume (np.ndarray): Input volume with shape (z, y, x) or (z, y, x, channels)
- crop_coords (tuple[int, int, int, int, int, int]):
- (z_min, z_max, y_min, y_max, x_min, x_max) coordinates for cropping
- Returns:
- np.ndarray: Cropped volume with same number of dimensions as input
- """
- z_min, z_max, y_min, y_max, x_min, x_max = crop_coords
- return volume[z_min:z_max, y_min:y_max, x_min:x_max]
- def cutout3d(volume: np.ndarray, holes: np.ndarray, fill: tuple[float, ...] | float) -> np.ndarray:
- """Cut out holes in 3D volume and fill them with a given value.
- Args:
- volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
- holes (np.ndarray): Array of holes with shape (num_holes, 6).
- Each hole is represented as [z1, y1, x1, z2, y2, x2]
- fill (tuple[float, ...] | float): Value to fill the holes
- Returns:
- np.ndarray: Volume with holes filled with the given value
- """
- volume = volume.copy()
- for z1, y1, x1, z2, y2, x2 in holes:
- volume[z1:z2, y1:y2, x1:x2] = fill
- return volume
- def transform_cube(cube: np.ndarray, index: int) -> np.ndarray:
- """Transform cube by index (0-47)
- Args:
- cube (np.ndarray): Input array with shape (D, H, W) or (D, H, W, C)
- index (int): Integer from 0 to 47 specifying which transformation to apply
- Returns:
- np.ndarray: Transformed cube with same shape as input
- """
- if not (0 <= index < 48):
- raise ValueError("Index must be between 0 and 47")
- transformations = {
- # First 4: rotate around axis 0 (indices 0-3)
- 0: lambda x: x,
- 1: lambda x: np.rot90(x, k=1, axes=(1, 2)),
- 2: lambda x: np.rot90(x, k=2, axes=(1, 2)),
- 3: lambda x: np.rot90(x, k=3, axes=(1, 2)),
- # Next 4: flip 180° about axis 1, then rotate around axis 0 (indices 4-7)
- 4: lambda x: x[::-1, :, ::-1], # was: np.flip(x, axis=(0, 2))
- 5: lambda x: np.rot90(np.rot90(x, k=2, axes=(0, 2)), k=1, axes=(1, 2)),
- 6: lambda x: x[::-1, ::-1, :], # was: np.flip(x, axis=(0, 1))
- 7: lambda x: np.rot90(np.rot90(x, k=2, axes=(0, 2)), k=3, axes=(1, 2)),
- # Next 8: split between 90° and 270° about axis 1, then rotate around axis 2 (indices 8-15)
- 8: lambda x: np.rot90(x, k=1, axes=(0, 2)),
- 9: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 2)), k=1, axes=(0, 1)),
- 10: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 2)), k=2, axes=(0, 1)),
- 11: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim)),
- 12: lambda x: np.rot90(x, k=-1, axes=(0, 2)),
- 13: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=1, axes=(0, 1)),
- 14: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=2, axes=(0, 1)),
- 15: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=3, axes=(0, 1)),
- # Final 8: split between rotations about axis 2, then rotate around axis 1 (indices 16-23)
- 16: lambda x: np.rot90(x, k=1, axes=(0, 1)),
- 17: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 1)), k=1, axes=(0, 2)),
- 18: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 1)), k=2, axes=(0, 2)),
- 19: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim)),
- 20: lambda x: np.rot90(x, k=-1, axes=(0, 1)),
- 21: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=1, axes=(0, 2)),
- 22: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=2, axes=(0, 2)),
- 23: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=3, axes=(0, 2)),
- # Reflected versions (24-47) - same as above but with initial reflection
- 24: lambda x: x[:, :, ::-1], # was: np.flip(x, axis=2)
- 25: lambda x: x.transpose(0, 2, 1, *range(3, x.ndim)),
- 26: lambda x: x[:, ::-1, :], # was: np.flip(x, axis=1)
- 27: lambda x: np.rot90(x[:, :, ::-1], k=3, axes=(1, 2)),
- 28: lambda x: x[::-1, :, :], # was: np.flip(x, axis=0)
- 29: lambda x: np.rot90(x[::-1, :, :], k=1, axes=(1, 2)),
- 30: lambda x: x[::-1, ::-1, ::-1], # was: np.flip(x, axis=(0, 1, 2))
- 31: lambda x: np.rot90(x[::-1, :, :], k=-1, axes=(1, 2)),
- 32: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim)),
- 33: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[::-1, :, :],
- 34: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim))[::-1, ::-1, :],
- 35: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[:, ::-1, :],
- 36: lambda x: np.rot90(x[:, :, ::-1], k=-1, axes=(0, 2)),
- 37: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[::-1, ::-1, ::-1],
- 38: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim))[:, ::-1, ::-1],
- 39: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[:, :, ::-1],
- 40: lambda x: np.rot90(x[:, :, ::-1], k=1, axes=(0, 1)),
- 41: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[:, :, ::-1],
- 42: lambda x: x.transpose(1, 0, 2, *range(3, x.ndim)),
- 43: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[::-1, :, :],
- 44: lambda x: np.rot90(x[:, :, ::-1], k=-1, axes=(0, 1)),
- 45: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[:, ::-1, :],
- 46: lambda x: x.transpose(1, 0, 2, *range(3, x.ndim))[::-1, ::-1, :],
- 47: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[::-1, ::-1, ::-1],
- }
- return transformations[index](cube.copy())
- @handle_empty_array("keypoints")
- def filter_keypoints_in_holes3d(keypoints: np.ndarray, holes: np.ndarray) -> np.ndarray:
- """Filter out keypoints that are inside any of the 3D holes.
- Args:
- keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
- The first three columns are x, y, z coordinates.
- holes (np.ndarray): Array of holes with shape (num_holes, 6).
- Each hole is represented as [z1, y1, x1, z2, y2, x2].
- Returns:
- np.ndarray: Array of keypoints that are not inside any hole.
- """
- if holes.size == 0:
- return keypoints
- # Broadcast keypoints and holes for vectorized comparison
- # Convert keypoints from XYZ to ZYX for comparison with holes
- kp_z = keypoints[:, 2][:, np.newaxis] # Shape: (num_keypoints, 1)
- kp_y = keypoints[:, 1][:, np.newaxis] # Shape: (num_keypoints, 1)
- kp_x = keypoints[:, 0][:, np.newaxis] # Shape: (num_keypoints, 1)
- # Extract hole coordinates (in ZYX order)
- hole_z1 = holes[:, 0] # Shape: (num_holes,)
- hole_y1 = holes[:, 1]
- hole_x1 = holes[:, 2]
- hole_z2 = holes[:, 3]
- hole_y2 = holes[:, 4]
- hole_x2 = holes[:, 5]
- # Check if each keypoint is inside each hole
- inside_hole = (
- (kp_z >= hole_z1)
- & (kp_z < hole_z2)
- & (kp_y >= hole_y1)
- & (kp_y < hole_y2)
- & (kp_x >= hole_x1)
- & (kp_x < hole_x2)
- )
- # A keypoint is valid if it's not inside any hole
- valid_keypoints = ~np.any(inside_hole, axis=1)
- # Return filtered keypoints with same dtype as input
- result = keypoints[valid_keypoints]
- if len(result) == 0:
- # Ensure empty result has correct shape and dtype
- return np.array([], dtype=keypoints.dtype).reshape(0, keypoints.shape[1])
- return result
- def keypoints_rot90(
- keypoints: np.ndarray,
- k: int,
- axes: tuple[int, int],
- volume_shape: tuple[int, int, int],
- ) -> np.ndarray:
- """Rotate keypoints 90 degrees k times around the specified axes.
- Args:
- keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
- The first three columns are x, y, z coordinates.
- k (int): Number of times to rotate by 90 degrees.
- axes (tuple[int, int]): Axes to rotate around.
- volume_shape (tuple[int, int, int]): Shape of the volume (depth, height, width).
- Returns:
- np.ndarray: Rotated keypoints with same shape as input.
- """
- if k == 0 or len(keypoints) == 0:
- return keypoints
- # Normalize factor to range [0, 3]
- k = ((k % 4) + 4) % 4
- result = keypoints.copy()
- # Get dimensions for the rotation axes
- dims = [volume_shape[ax] for ax in axes]
- # Get coordinates to rotate
- coords1 = result[:, axes[0]].copy()
- coords2 = result[:, axes[1]].copy()
- # Apply rotation based on factor (counterclockwise)
- if k == 1: # 90 degrees CCW
- result[:, axes[0]] = (dims[1] - 1) - coords2
- result[:, axes[1]] = coords1
- elif k == 2: # 180 degrees
- result[:, axes[0]] = (dims[0] - 1) - coords1
- result[:, axes[1]] = (dims[1] - 1) - coords2
- elif k == 3: # 270 degrees CCW
- result[:, axes[0]] = coords2
- result[:, axes[1]] = (dims[0] - 1) - coords1
- return result
- @handle_empty_array("keypoints")
- def transform_cube_keypoints(
- keypoints: np.ndarray,
- index: int,
- volume_shape: tuple[int, int, int],
- ) -> np.ndarray:
- """Transform keypoints according to the cube transformation specified by index.
- Args:
- keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
- The first three columns are x, y, z coordinates.
- index (int): Integer from 0 to 47 specifying which transformation to apply.
- volume_shape (tuple[int, int, int]): Shape of the volume (depth, height, width).
- Returns:
- np.ndarray: Transformed keypoints with same shape as input.
- """
- if not (0 <= index < 48):
- raise ValueError("Index must be between 0 and 47")
- # Create working copy preserving all columns
- working_points = keypoints.copy()
- # Convert only XYZ coordinates to HWD, keeping other columns unchanged
- xyz = working_points[:, :3] # Get first 3 columns (XYZ)
- xyz = xyz[:, [2, 1, 0]] # XYZ -> HWD
- working_points[:, :3] = xyz # Put back transformed coordinates
- current_shape = volume_shape
- # Handle reflection first (indices 24-47)
- if index >= 24:
- working_points[:, 2] = current_shape[2] - 1 - working_points[:, 2] # Reflect W axis
- rotation_index = index % 24
- # Apply the same rotation logic as transform_cube
- if rotation_index < 4:
- # First 4: rotate around axis 0
- result = keypoints_rot90(working_points, k=rotation_index, axes=(1, 2), volume_shape=current_shape)
- elif rotation_index < 8:
- # Next 4: flip 180° about axis 1, then rotate around axis 0
- temp = keypoints_rot90(working_points, k=2, axes=(0, 2), volume_shape=current_shape)
- result = keypoints_rot90(temp, k=rotation_index - 4, axes=(1, 2), volume_shape=volume_shape)
- elif rotation_index < 16:
- if rotation_index < 12:
- temp = keypoints_rot90(working_points, k=1, axes=(0, 2), volume_shape=current_shape)
- temp_shape = (current_shape[2], current_shape[1], current_shape[0])
- result = keypoints_rot90(temp, k=rotation_index - 8, axes=(0, 1), volume_shape=temp_shape)
- else:
- temp = keypoints_rot90(working_points, k=3, axes=(0, 2), volume_shape=current_shape)
- temp_shape = (current_shape[2], current_shape[1], current_shape[0])
- result = keypoints_rot90(temp, k=rotation_index - 12, axes=(0, 1), volume_shape=temp_shape)
- elif rotation_index < 20:
- temp = keypoints_rot90(working_points, k=1, axes=(0, 1), volume_shape=current_shape)
- temp_shape = (current_shape[1], current_shape[0], current_shape[2])
- result = keypoints_rot90(temp, k=rotation_index - 16, axes=(0, 2), volume_shape=temp_shape)
- else:
- temp = keypoints_rot90(working_points, k=3, axes=(0, 1), volume_shape=current_shape)
- temp_shape = (current_shape[1], current_shape[0], current_shape[2])
- result = keypoints_rot90(temp, k=rotation_index - 20, axes=(0, 2), volume_shape=temp_shape)
- # Convert back from HWD to XYZ coordinates for first 3 columns only
- xyz = result[:, :3]
- xyz = xyz[:, [2, 1, 0]] # HWD -> XYZ
- result[:, :3] = xyz
- return result
|