transforms.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. """Module containing PyTorch-specific transforms for Albumentations.
  2. This module provides transforms that convert NumPy arrays to PyTorch tensors in
  3. the appropriate format. It handles both 2D image data and 3D volumetric data,
  4. ensuring that the tensor dimensions are correctly arranged according to PyTorch's
  5. expected format (channels first). These transforms are typically used as the final
  6. step in an augmentation pipeline before feeding data to a PyTorch model.
  7. """
  8. from __future__ import annotations
  9. from typing import Any, overload
  10. import numpy as np
  11. import torch
  12. from albumentations.core.transforms_interface import BasicTransform
  13. from albumentations.core.type_definitions import (
  14. MONO_CHANNEL_DIMENSIONS,
  15. NUM_MULTI_CHANNEL_DIMENSIONS,
  16. NUM_VOLUME_DIMENSIONS,
  17. Targets,
  18. )
  19. __all__ = ["ToTensor3D", "ToTensorV2"]
  20. class ToTensorV2(BasicTransform):
  21. """Converts images/masks to PyTorch Tensors, inheriting from BasicTransform.
  22. For images:
  23. - If input is in `HWC` format, converts to PyTorch `CHW` format
  24. - If input is in `HW` format, converts to PyTorch `1HW` format (adds channel dimension)
  25. Attributes:
  26. transpose_mask (bool): If True, transposes 3D input mask dimensions from `[height, width, num_channels]` to
  27. `[num_channels, height, width]`.
  28. p (float): Probability of applying the transform. Default: 1.0.
  29. """
  30. _targets = (Targets.IMAGE, Targets.MASK)
  31. def __init__(self, transpose_mask: bool = False, p: float = 1.0):
  32. super().__init__(p=p)
  33. self.transpose_mask = transpose_mask
  34. @property
  35. def targets(self) -> dict[str, Any]:
  36. """Define mapping of target name to target function.
  37. Returns:
  38. dict[str, Any]: Dictionary mapping target names to corresponding transform functions.
  39. """
  40. return {
  41. "image": self.apply,
  42. "images": self.apply_to_images,
  43. "mask": self.apply_to_mask,
  44. "masks": self.apply_to_masks,
  45. }
  46. def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor:
  47. """Convert a 2D image array to a PyTorch tensor.
  48. Converts image from HWC or HW format to CHW format, handling both
  49. single-channel and multi-channel images.
  50. Args:
  51. img (np.ndarray): Image as a numpy array of shape (H,W) or (H,W,C)
  52. **params (Any): Additional parameters
  53. Returns:
  54. torch.Tensor: PyTorch tensor in CHW format
  55. Raises:
  56. ValueError: If image dimensions are neither HW nor HWC
  57. """
  58. if img.ndim not in {MONO_CHANNEL_DIMENSIONS, NUM_MULTI_CHANNEL_DIMENSIONS}:
  59. msg = "Albumentations only supports images in HW or HWC format"
  60. raise ValueError(msg)
  61. if img.ndim == MONO_CHANNEL_DIMENSIONS:
  62. img = np.expand_dims(img, 2)
  63. return torch.from_numpy(img.transpose(2, 0, 1))
  64. def apply_to_mask(self, mask: np.ndarray, **params: Any) -> torch.Tensor:
  65. """Convert a mask array to a PyTorch tensor.
  66. If transpose_mask is True and mask has 3 dimensions (H,W,C),
  67. converts mask to channels-first format (C,H,W).
  68. Args:
  69. mask (np.ndarray): Mask as a numpy array
  70. **params (Any): Additional parameters
  71. Returns:
  72. torch.Tensor: PyTorch tensor of mask
  73. """
  74. if self.transpose_mask and mask.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
  75. mask = mask.transpose(2, 0, 1)
  76. return torch.from_numpy(mask)
  77. @overload
  78. def apply_to_masks(self, masks: list[np.ndarray], **params: Any) -> list[torch.Tensor]: ...
  79. @overload
  80. def apply_to_masks(self, masks: np.ndarray, **params: Any) -> torch.Tensor: ...
  81. def apply_to_masks(self, masks: np.ndarray | list[np.ndarray], **params: Any) -> torch.Tensor | list[torch.Tensor]:
  82. """Convert numpy array or list of numpy array masks to torch tensor(s).
  83. Args:
  84. masks (np.ndarray | list[np.ndarray]): Numpy array of shape (N, H, W) or (N, H, W, C),
  85. or a list of numpy arrays with shape (H, W) or (H, W, C).
  86. **params (Any): Additional parameters.
  87. Returns:
  88. torch.Tensor | list[torch.Tensor]: If transpose_mask is True and input is (N, H, W, C),
  89. returns tensor of shape (N, C, H, W). If transpose_mask is True and input is (H, W, C), r
  90. eturns a list of tensors with shape (C, H, W). Otherwise, returns tensors with the same shape as input.
  91. """
  92. if isinstance(masks, list):
  93. return [self.apply_to_mask(mask, **params) for mask in masks]
  94. if self.transpose_mask and masks.ndim == NUM_VOLUME_DIMENSIONS: # (N, H, W, C)
  95. masks = np.transpose(masks, (0, 3, 1, 2)) # -> (N, C, H, W)
  96. return torch.from_numpy(masks)
  97. def apply_to_images(self, images: np.ndarray, **params: Any) -> torch.Tensor:
  98. """Convert batch of images from (N, H, W, C) to (N, C, H, W)."""
  99. if images.ndim != NUM_VOLUME_DIMENSIONS: # N,H,W,C
  100. raise ValueError(f"Expected 4D array (N,H,W,C), got {images.ndim}D array")
  101. return torch.from_numpy(images.transpose(0, 3, 1, 2)) # -> (N,C,H,W)
  102. class ToTensor3D(BasicTransform):
  103. """Convert 3D volumes and masks to PyTorch tensors.
  104. This transform is designed for 3D medical imaging data. It converts numpy arrays
  105. to PyTorch tensors and ensures consistent channel positioning.
  106. For all inputs (volumes and masks):
  107. - Input: (D, H, W, C) or (D, H, W) - depth, height, width, [channels]
  108. - Output: (C, D, H, W) - channels first format for PyTorch
  109. For single-channel input, adds C=1 dimension
  110. Note:
  111. This transform always moves channels to first position as this is
  112. the standard PyTorch format. For masks that need to stay in DHWC format,
  113. use a different transform or handle the transposition after this transform.
  114. Args:
  115. p (float): Probability of applying the transform. Default: 1.0
  116. """
  117. _targets = (Targets.IMAGE, Targets.MASK)
  118. def __init__(self, p: float = 1.0):
  119. super().__init__(p=p)
  120. @property
  121. def targets(self) -> dict[str, Any]:
  122. """Define mapping of target name to target function.
  123. Returns:
  124. dict[str, Any]: Dictionary mapping target names to corresponding transform functions
  125. """
  126. return {
  127. "volume": self.apply_to_volume,
  128. "mask3d": self.apply_to_mask3d,
  129. }
  130. def apply_to_volume(self, volume: np.ndarray, **params: Any) -> torch.Tensor:
  131. """Convert 3D volume to channels-first tensor."""
  132. if volume.ndim == NUM_VOLUME_DIMENSIONS: # D,H,W,C
  133. return torch.from_numpy(volume.transpose(3, 0, 1, 2))
  134. if volume.ndim == NUM_VOLUME_DIMENSIONS - 1: # D,H,W
  135. return torch.from_numpy(volume[np.newaxis, ...])
  136. raise ValueError(f"Expected 3D or 4D array (D,H,W) or (D,H,W,C), got {volume.ndim}D array")
  137. def apply_to_mask3d(self, mask3d: np.ndarray, **params: Any) -> torch.Tensor:
  138. """Convert 3D mask to channels-first tensor."""
  139. return self.apply_to_volume(mask3d, **params)