utils.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from __future__ import annotations
  2. import sys
  3. from functools import wraps
  4. from typing import Any, Callable, Literal, Union
  5. if sys.version_info >= (3, 10):
  6. from typing import Concatenate, ParamSpec
  7. else:
  8. from typing_extensions import Concatenate, ParamSpec
  9. import cv2
  10. import numpy as np
  11. NUM_RGB_CHANNELS = 3
  12. MONO_CHANNEL_DIMENSIONS = 2
  13. NUM_MULTI_CHANNEL_DIMENSIONS = 3
  14. FOUR = 4
  15. TWO = 2
  16. MAX_OPENCV_WORKING_CHANNELS = 4
  17. NormalizationType = Literal["image", "image_per_channel", "min_max", "min_max_per_channel"]
  18. P = ParamSpec("P")
  19. MAX_VALUES_BY_DTYPE = {
  20. np.dtype("uint8"): 255,
  21. np.dtype("uint16"): 65535,
  22. np.dtype("uint32"): 4294967295,
  23. np.dtype("float16"): 1.0,
  24. np.dtype("float32"): 1.0,
  25. np.dtype("float64"): 1.0,
  26. np.uint8: 255,
  27. np.uint16: 65535,
  28. np.uint32: 4294967295,
  29. np.float16: 1.0,
  30. np.float32: 1.0,
  31. np.float64: 1.0,
  32. np.int32: 2147483647,
  33. }
  34. NPDTYPE_TO_OPENCV_DTYPE = {
  35. np.uint8: cv2.CV_8U,
  36. np.uint16: cv2.CV_16U,
  37. np.float32: cv2.CV_32F,
  38. np.float64: cv2.CV_64F,
  39. np.int32: cv2.CV_32S,
  40. np.dtype("uint8"): cv2.CV_8U,
  41. np.dtype("uint16"): cv2.CV_16U,
  42. np.dtype("float32"): cv2.CV_32F,
  43. np.dtype("float64"): cv2.CV_64F,
  44. np.dtype("int32"): cv2.CV_32S,
  45. }
  46. def maybe_process_in_chunks(
  47. process_fn: Callable[Concatenate[np.ndarray, P], np.ndarray],
  48. *args: P.args,
  49. **kwargs: P.kwargs,
  50. ) -> Callable[[np.ndarray], np.ndarray]:
  51. """Wrap OpenCV function to enable processing images with more than 4 channels.
  52. Limitations:
  53. This wrapper requires image to be the first argument and rest must be sent via named arguments.
  54. Args:
  55. process_fn: Transform function (e.g cv2.resize).
  56. args: Additional positional arguments.
  57. kwargs: Additional keyword arguments.
  58. Returns:
  59. np.ndarray: Transformed image.
  60. """
  61. @wraps(process_fn)
  62. def __process_fn(img: np.ndarray, *process_args: P.args, **process_kwargs: P.kwargs) -> np.ndarray:
  63. # Merge args and kwargs
  64. all_args = (*args, *process_args)
  65. all_kwargs: dict[str, Any] = kwargs | process_kwargs
  66. num_channels = get_num_channels(img)
  67. if num_channels > MAX_OPENCV_WORKING_CHANNELS:
  68. chunks = []
  69. for index in range(0, num_channels, 4):
  70. if num_channels - index == TWO:
  71. # Many OpenCV functions cannot work with 2-channel images
  72. for i in range(2):
  73. chunk = img[:, :, index + i : index + i + 1]
  74. chunk = process_fn(chunk, *all_args, **all_kwargs)
  75. chunk = np.expand_dims(chunk, -1)
  76. chunks.append(chunk)
  77. else:
  78. chunk = img[:, :, index : index + 4]
  79. chunk = process_fn(chunk, *all_args, **all_kwargs)
  80. chunks.append(chunk)
  81. return np.dstack(chunks)
  82. return process_fn(img, *all_args, **all_kwargs)
  83. return __process_fn
  84. def clip(img: np.ndarray, dtype: Any, inplace: bool = False) -> np.ndarray:
  85. max_value = MAX_VALUES_BY_DTYPE[dtype]
  86. if inplace:
  87. return np.clip(img, 0, max_value, out=img)
  88. return np.clip(img, 0, max_value).astype(dtype, copy=False)
  89. def clipped(func: Callable[Concatenate[np.ndarray, P], np.ndarray]) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
  90. @wraps(func)
  91. def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
  92. dtype = img.dtype
  93. result = func(img, *args, **kwargs)
  94. if result.dtype == np.uint8:
  95. return result
  96. return clip(result, dtype)
  97. return wrapped_function
  98. def get_num_channels(image: np.ndarray) -> int:
  99. return image.shape[2] if image.ndim == NUM_MULTI_CHANNEL_DIMENSIONS else 1
  100. def is_grayscale_image(image: np.ndarray) -> bool:
  101. return get_num_channels(image) == 1
  102. def get_opencv_dtype_from_numpy(value: np.ndarray | int | np.dtype | object) -> int:
  103. if isinstance(value, np.ndarray):
  104. value = value.dtype
  105. return NPDTYPE_TO_OPENCV_DTYPE[value]
  106. def is_rgb_image(image: np.ndarray) -> bool:
  107. return get_num_channels(image) == NUM_RGB_CHANNELS
  108. def is_multispectral_image(image: np.ndarray) -> bool:
  109. num_channels = get_num_channels(image)
  110. return num_channels not in {1, 3}
  111. def convert_value(value: np.ndarray | float, num_channels: int) -> float | np.ndarray:
  112. """Convert a value to a float or numpy array based on its shape and number of channels.
  113. Args:
  114. value: Input value to convert (numpy array, float, or int)
  115. num_channels: Number of channels in the target image
  116. Returns:
  117. float: If value is a scalar or 1D array that should be converted to scalar
  118. np.ndarray: If value is a multi-dimensional array or channel vector
  119. Raises:
  120. TypeError: If value is of unsupported type
  121. """
  122. # Handle scalar types
  123. if isinstance(value, (float, int, np.float32, np.float64)):
  124. return float(value) if isinstance(value, (float, int)) else value.item()
  125. # Handle numpy arrays
  126. if isinstance(value, np.ndarray):
  127. # Return scalars and 0-dim arrays as float
  128. if value.ndim == 0:
  129. return value.item()
  130. # Return multi-dimensional arrays as-is
  131. if value.ndim > 1:
  132. return value
  133. # Handle 1D arrays
  134. if len(value) == 1 or num_channels == 1 or len(value) < num_channels:
  135. return float(value[0])
  136. return value[:num_channels]
  137. raise TypeError(f"Unsupported value type: {type(value)}")
  138. ValueType = Union[np.ndarray, float, int]
  139. def get_max_value(dtype: np.dtype) -> float:
  140. if dtype not in MAX_VALUES_BY_DTYPE:
  141. msg = (
  142. f"Can't infer the maximum value for dtype {dtype}. "
  143. "You need to specify the maximum value manually by passing the max_value argument."
  144. )
  145. raise RuntimeError(msg)
  146. return MAX_VALUES_BY_DTYPE[dtype]