functions.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. from __future__ import annotations
  2. from functools import wraps
  3. from typing import Any, Callable, Literal
  4. import cv2
  5. import numpy as np
  6. import simsimd as ss
  7. import stringzilla as sz
  8. from albucore.decorators import contiguous, preserve_channel_dim
  9. from albucore.utils import (
  10. MAX_OPENCV_WORKING_CHANNELS,
  11. MAX_VALUES_BY_DTYPE,
  12. MONO_CHANNEL_DIMENSIONS,
  13. NormalizationType,
  14. ValueType,
  15. clip,
  16. clipped,
  17. convert_value,
  18. get_max_value,
  19. get_num_channels,
  20. )
  21. np_operations = {"multiply": np.multiply, "add": np.add, "power": np.power}
  22. cv2_operations = {"multiply": cv2.multiply, "add": cv2.add, "power": cv2.pow}
  23. def add_weighted_simsimd(img1: np.ndarray, weight1: float, img2: np.ndarray, weight2: float) -> np.ndarray:
  24. original_shape = img1.shape
  25. original_dtype = img1.dtype
  26. if img2.dtype != original_dtype:
  27. img2 = clip(img2.astype(original_dtype, copy=False), original_dtype, inplace=True)
  28. return np.frombuffer(
  29. ss.wsum(img1.reshape(-1), img2.astype(original_dtype, copy=False).reshape(-1), alpha=weight1, beta=weight2),
  30. dtype=original_dtype,
  31. ).reshape(
  32. original_shape,
  33. )
  34. def add_array_simsimd(img: np.ndarray, value: np.ndarray) -> np.ndarray:
  35. return add_weighted_simsimd(img, 1, value, 1)
  36. def multiply_by_constant_simsimd(img: np.ndarray, value: float) -> np.ndarray:
  37. return add_weighted_simsimd(img, value, np.zeros_like(img), 0)
  38. def add_constant_simsimd(img: np.ndarray, value: float) -> np.ndarray:
  39. return add_weighted_simsimd(img, 1, (np.ones_like(img) * value).astype(img.dtype, copy=False), 1)
  40. def create_lut_array(
  41. dtype: type[np.number],
  42. value: float | np.ndarray,
  43. operation: Literal["add", "multiply", "power"],
  44. ) -> np.ndarray:
  45. max_value = MAX_VALUES_BY_DTYPE[dtype]
  46. if dtype == np.uint8 and operation == "add":
  47. value = np.trunc(value)
  48. value = np.array(value, dtype=np.float32).reshape(-1, 1)
  49. lut = np.arange(0, max_value + 1, dtype=np.float32)
  50. if operation in np_operations:
  51. return np_operations[operation](lut, value)
  52. raise ValueError(f"Unsupported operation: {operation}")
  53. @contiguous
  54. def sz_lut(img: np.ndarray, lut: np.ndarray, inplace: bool = True) -> np.ndarray:
  55. if not inplace:
  56. img = img.copy()
  57. sz.translate(memoryview(img), memoryview(lut), inplace=True)
  58. return img
  59. def apply_lut(
  60. img: np.ndarray,
  61. value: float | np.ndarray,
  62. operation: Literal["add", "multiply", "power"],
  63. inplace: bool,
  64. ) -> np.ndarray:
  65. dtype = img.dtype
  66. if isinstance(value, (int, float)):
  67. lut = create_lut_array(dtype, value, operation)
  68. return sz_lut(img, clip(lut, dtype, inplace=False), False)
  69. num_channels = img.shape[-1]
  70. luts = clip(create_lut_array(dtype, value, operation), dtype, inplace=False)
  71. return cv2.merge([sz_lut(img[:, :, i], luts[i], inplace) for i in range(num_channels)])
  72. def prepare_value_opencv(
  73. img: np.ndarray,
  74. value: np.ndarray | float,
  75. operation: Literal["add", "multiply"],
  76. ) -> np.ndarray:
  77. return (
  78. _prepare_scalar_value(img, value, operation)
  79. if isinstance(value, (int, float))
  80. else _prepare_array_value(img, value, operation)
  81. )
  82. def _prepare_scalar_value(
  83. img: np.ndarray,
  84. value: float,
  85. operation: Literal["add", "multiply"],
  86. ) -> np.ndarray | float:
  87. if operation == "add" and img.dtype == np.uint8:
  88. value = int(value)
  89. num_channels = get_num_channels(img)
  90. if num_channels > MAX_OPENCV_WORKING_CHANNELS:
  91. if operation == "add":
  92. # Cast to float32 if value is negative to handle potential underflow issues
  93. cast_type = np.float32 if value < 0 else img.dtype
  94. return np.full(img.shape, value, dtype=cast_type)
  95. if operation == "multiply":
  96. return np.full(img.shape, value, dtype=np.float32)
  97. return value
  98. def _prepare_array_value(
  99. img: np.ndarray,
  100. value: np.ndarray,
  101. operation: Literal["add", "multiply"],
  102. ) -> np.ndarray:
  103. if value.dtype == np.float64:
  104. value = value.astype(np.float32, copy=False)
  105. if value.ndim == 1:
  106. value = value.reshape(1, 1, -1)
  107. value = np.broadcast_to(value, img.shape)
  108. if operation == "add" and img.dtype == np.uint8:
  109. if np.all(value >= 0):
  110. return clip(value, np.uint8, inplace=False)
  111. return np.trunc(value).astype(np.float32, copy=False)
  112. return value
  113. def apply_numpy(
  114. img: np.ndarray,
  115. value: float | np.ndarray,
  116. operation: Literal["add", "multiply", "power"],
  117. ) -> np.ndarray:
  118. if operation == "add" and img.dtype == np.uint8:
  119. value = np.int16(value)
  120. return np_operations[operation](img.astype(np.float32, copy=False), value)
  121. def multiply_lut(img: np.ndarray, value: np.ndarray | float, inplace: bool) -> np.ndarray:
  122. return apply_lut(img, value, "multiply", inplace)
  123. @preserve_channel_dim
  124. def multiply_opencv(img: np.ndarray, value: np.ndarray | float) -> np.ndarray:
  125. value = prepare_value_opencv(img, value, "multiply")
  126. if img.dtype == np.uint8:
  127. return cv2.multiply(img.astype(np.float32, copy=False), value)
  128. return cv2.multiply(img, value)
  129. def multiply_numpy(img: np.ndarray, value: float | np.ndarray) -> np.ndarray:
  130. return apply_numpy(img, value, "multiply")
  131. @clipped
  132. def multiply_by_constant(img: np.ndarray, value: float, inplace: bool) -> np.ndarray:
  133. if img.dtype == np.uint8:
  134. return multiply_lut(img, value, inplace)
  135. if img.dtype == np.float32:
  136. return multiply_numpy(img, value)
  137. return multiply_opencv(img, value)
  138. @clipped
  139. def multiply_by_vector(img: np.ndarray, value: np.ndarray, num_channels: int, inplace: bool) -> np.ndarray:
  140. # Handle uint8 images separately to use 1a lookup table for performance
  141. if img.dtype == np.uint8:
  142. return multiply_lut(img, value, inplace)
  143. # Check if the number of channels exceeds the maximum that OpenCV can handle
  144. if num_channels > MAX_OPENCV_WORKING_CHANNELS:
  145. return multiply_numpy(img, value)
  146. return multiply_opencv(img, value)
  147. @clipped
  148. def multiply_by_array(img: np.ndarray, value: np.ndarray) -> np.ndarray:
  149. return multiply_opencv(img, value)
  150. def multiply(img: np.ndarray, value: ValueType, inplace: bool = False) -> np.ndarray:
  151. num_channels = get_num_channels(img)
  152. value = convert_value(value, num_channels)
  153. if isinstance(value, (float, int)):
  154. return multiply_by_constant(img, value, inplace)
  155. if isinstance(value, np.ndarray) and value.ndim == 1:
  156. return multiply_by_vector(img, value, num_channels, inplace)
  157. return multiply_by_array(img, value)
  158. @preserve_channel_dim
  159. def add_opencv(img: np.ndarray, value: np.ndarray | float, inplace: bool = False) -> np.ndarray:
  160. value = prepare_value_opencv(img, value, "add")
  161. # Convert to float32 if:
  162. # 1. uint8 image with negative scalar value
  163. # 2. uint8 image with non-uint8 array value
  164. needs_float = img.dtype == np.uint8 and (
  165. (isinstance(value, (int, float)) and value < 0) or (isinstance(value, np.ndarray) and value.dtype != np.uint8)
  166. )
  167. if needs_float:
  168. return cv2.add(
  169. img.astype(np.float32, copy=False),
  170. value if isinstance(value, (int, float)) else value.astype(np.float32, copy=False),
  171. )
  172. # Use img as the destination array if inplace=True
  173. dst = img if inplace else None
  174. return cv2.add(img, value, dst=dst)
  175. def add_numpy(img: np.ndarray, value: float | np.ndarray) -> np.ndarray:
  176. return apply_numpy(img, value, "add")
  177. def add_lut(img: np.ndarray, value: np.ndarray | float, inplace: bool) -> np.ndarray:
  178. return apply_lut(img, value, "add", inplace)
  179. @clipped
  180. def add_constant(img: np.ndarray, value: float, inplace: bool = False) -> np.ndarray:
  181. return add_opencv(img, value, inplace)
  182. @clipped
  183. def add_vector(img: np.ndarray, value: np.ndarray, inplace: bool) -> np.ndarray:
  184. if img.dtype == np.uint8:
  185. return add_lut(img, value, inplace)
  186. return add_opencv(img, value, inplace)
  187. @clipped
  188. def add_array(img: np.ndarray, value: np.ndarray, inplace: bool = False) -> np.ndarray:
  189. return add_opencv(img, value, inplace)
  190. def add(img: np.ndarray, value: ValueType, inplace: bool = False) -> np.ndarray:
  191. num_channels = get_num_channels(img)
  192. value = convert_value(value, num_channels)
  193. if isinstance(value, (float, int)):
  194. if value == 0:
  195. return img
  196. if img.dtype == np.uint8:
  197. value = int(value)
  198. return add_constant(img, value, inplace)
  199. return add_vector(img, value, inplace) if value.ndim == 1 else add_array(img, value, inplace)
  200. def normalize_numpy(img: np.ndarray, mean: float | np.ndarray, denominator: float | np.ndarray) -> np.ndarray:
  201. img = img.astype(np.float32, copy=False)
  202. img -= mean
  203. return img * denominator
  204. @preserve_channel_dim
  205. def normalize_opencv(img: np.ndarray, mean: float | np.ndarray, denominator: float | np.ndarray) -> np.ndarray:
  206. img = img.astype(np.float32, copy=False)
  207. mean_img = np.zeros_like(img, dtype=np.float32)
  208. denominator_img = np.zeros_like(img, dtype=np.float32)
  209. # If mean or denominator are scalar, convert them to arrays
  210. if isinstance(mean, (float, int)):
  211. mean = np.full(img.shape, mean, dtype=np.float32)
  212. if isinstance(denominator, (float, int)):
  213. denominator = np.full(img.shape, denominator, dtype=np.float32)
  214. # Ensure the shapes match for broadcasting
  215. mean_img = (mean_img + mean).astype(np.float32, copy=False)
  216. denominator_img = denominator_img + denominator
  217. result = cv2.subtract(img, mean_img)
  218. return cv2.multiply(result, denominator_img, dtype=cv2.CV_32F)
  219. @preserve_channel_dim
  220. def normalize_lut(img: np.ndarray, mean: float | np.ndarray, denominator: float | np.ndarray) -> np.ndarray:
  221. dtype = img.dtype
  222. max_value = MAX_VALUES_BY_DTYPE[dtype]
  223. num_channels = get_num_channels(img)
  224. if isinstance(denominator, (float, int)) and isinstance(mean, (float, int)):
  225. lut = (np.arange(0, max_value + 1, dtype=np.float32) - mean) * denominator
  226. return cv2.LUT(img, lut)
  227. if isinstance(denominator, np.ndarray) and denominator.shape != ():
  228. denominator = denominator.reshape(-1, 1)
  229. if isinstance(mean, np.ndarray):
  230. mean = mean.reshape(-1, 1)
  231. luts = (np.arange(0, max_value + 1, dtype=np.float32) - mean) * denominator
  232. return cv2.merge([cv2.LUT(img[:, :, i], luts[i]) for i in range(num_channels)])
  233. def normalize(img: np.ndarray, mean: ValueType, denominator: ValueType) -> np.ndarray:
  234. num_channels = get_num_channels(img)
  235. denominator = convert_value(denominator, num_channels)
  236. mean = convert_value(mean, num_channels)
  237. if img.dtype == np.uint8:
  238. return normalize_lut(img, mean, denominator)
  239. return normalize_opencv(img, mean, denominator)
  240. def power_numpy(img: np.ndarray, exponent: float | np.ndarray) -> np.ndarray:
  241. return apply_numpy(img, exponent, "power")
  242. @preserve_channel_dim
  243. def power_opencv(img: np.ndarray, value: float) -> np.ndarray:
  244. """Handle the 'power' operation for OpenCV."""
  245. if img.dtype == np.float32:
  246. # For float32 images, cv2.pow works directly
  247. return cv2.pow(img, value)
  248. if img.dtype == np.uint8 and int(value) == value:
  249. # For uint8 images, cv2.pow works directly if value is actual integer, even if it's type is float
  250. return cv2.pow(img, value)
  251. if img.dtype == np.uint8 and isinstance(value, float):
  252. # For uint8 images, convert to float32, apply power, then convert back to uint8
  253. img_float = img.astype(np.float32, copy=False)
  254. return cv2.pow(img_float, value)
  255. raise ValueError(f"Unsupported image type {img.dtype} for power operation with value {value}")
  256. # @preserve_channel_dim
  257. def power_lut(img: np.ndarray, exponent: float | np.ndarray, inplace: bool = False) -> np.ndarray:
  258. return apply_lut(img, exponent, "power", inplace)
  259. @clipped
  260. def power(img: np.ndarray, exponent: ValueType, inplace: bool = False) -> np.ndarray:
  261. num_channels = get_num_channels(img)
  262. exponent = convert_value(exponent, num_channels)
  263. if img.dtype == np.uint8:
  264. return power_lut(img, exponent, inplace)
  265. if isinstance(exponent, (float, int)):
  266. return power_opencv(img, exponent)
  267. return power_numpy(img, exponent)
  268. def add_weighted_numpy(img1: np.ndarray, weight1: float, img2: np.ndarray, weight2: float) -> np.ndarray:
  269. return img1.astype(np.float32, copy=False) * weight1 + img2.astype(np.float32, copy=False) * weight2
  270. @preserve_channel_dim
  271. def add_weighted_opencv(img1: np.ndarray, weight1: float, img2: np.ndarray, weight2: float) -> np.ndarray:
  272. return cv2.addWeighted(img1, weight1, img2, weight2, 0)
  273. @preserve_channel_dim
  274. def add_weighted_lut(
  275. img1: np.ndarray,
  276. weight1: float,
  277. img2: np.ndarray,
  278. weight2: float,
  279. inplace: bool = False,
  280. ) -> np.ndarray:
  281. dtype = img1.dtype
  282. max_value = MAX_VALUES_BY_DTYPE[dtype]
  283. if weight1 == 1 and weight2 == 0:
  284. return img1
  285. if weight1 == 0 and weight2 == 1:
  286. return img2
  287. if weight1 == 0 and weight2 == 0:
  288. return np.zeros_like(img1)
  289. if weight1 == 1 and weight2 == 1:
  290. return add_array(img1, img2, inplace)
  291. lut1 = np.arange(0, max_value + 1, dtype=np.float32) * weight1
  292. result1 = cv2.LUT(img1, lut1)
  293. lut2 = np.arange(0, max_value + 1, dtype=np.float32) * weight2
  294. result2 = cv2.LUT(img2, lut2)
  295. return add_opencv(result1, result2, inplace)
  296. @clipped
  297. def add_weighted(img1: np.ndarray, weight1: float, img2: np.ndarray, weight2: float) -> np.ndarray:
  298. if img1.shape != img2.shape:
  299. raise ValueError(f"The input images must have the same shape. Got {img1.shape} and {img2.shape}.")
  300. return add_weighted_simsimd(img1, weight1, img2, weight2)
  301. def multiply_add_numpy(img: np.ndarray, factor: ValueType, value: ValueType) -> np.ndarray:
  302. if isinstance(value, (int, float)) and value == 0 and isinstance(factor, (int, float)) and factor == 0:
  303. return np.zeros_like(img, dtype=img.dtype)
  304. result = np.multiply(img, factor) if factor != 0 else np.zeros_like(img)
  305. return result if value == 0 else np.add(result, value)
  306. @preserve_channel_dim
  307. def multiply_add_opencv(img: np.ndarray, factor: ValueType, value: ValueType) -> np.ndarray:
  308. if isinstance(value, (int, float)) and value == 0 and isinstance(factor, (int, float)) and factor == 0:
  309. return np.zeros_like(img)
  310. result = img.astype(np.float32, copy=False)
  311. result = (
  312. cv2.multiply(result, np.ones_like(result) * factor, dtype=cv2.CV_64F)
  313. if factor != 0
  314. else np.zeros_like(result, dtype=img.dtype)
  315. )
  316. return result if value == 0 else cv2.add(result, np.ones_like(result) * value, dtype=cv2.CV_64F)
  317. def multiply_add_lut(img: np.ndarray, factor: ValueType, value: ValueType, inplace: bool) -> np.ndarray:
  318. dtype = img.dtype
  319. max_value = MAX_VALUES_BY_DTYPE[dtype]
  320. num_channels = get_num_channels(img)
  321. if isinstance(factor, (float, int)) and isinstance(value, (float, int)):
  322. lut = clip(np.arange(0, max_value + 1, dtype=np.float32) * factor + value, dtype, inplace=False)
  323. return sz_lut(img, lut, inplace)
  324. if isinstance(factor, np.ndarray) and factor.shape != ():
  325. factor = factor.reshape(-1, 1)
  326. if isinstance(value, np.ndarray) and value.shape != ():
  327. value = value.reshape(-1, 1)
  328. luts = clip(np.arange(0, max_value + 1, dtype=np.float32) * factor + value, dtype, inplace=True)
  329. return cv2.merge([sz_lut(img[:, :, i], luts[i], inplace) for i in range(num_channels)])
  330. @clipped
  331. def multiply_add(img: np.ndarray, factor: ValueType, value: ValueType, inplace: bool = False) -> np.ndarray:
  332. num_channels = get_num_channels(img)
  333. factor = convert_value(factor, num_channels)
  334. value = convert_value(value, num_channels)
  335. if img.dtype == np.uint8:
  336. return multiply_add_lut(img, factor, value, inplace)
  337. return multiply_add_opencv(img, factor, value)
  338. @preserve_channel_dim
  339. def normalize_per_image_opencv(img: np.ndarray, normalization: NormalizationType) -> np.ndarray:
  340. img = img.astype(np.float32, copy=False)
  341. eps = 1e-4
  342. if img.ndim == MONO_CHANNEL_DIMENSIONS:
  343. img = np.expand_dims(img, axis=-1)
  344. if normalization == "image" or (img.shape[-1] == 1 and normalization == "image_per_channel"):
  345. mean = img.mean().item()
  346. std = img.std().item() + eps
  347. if img.shape[-1] > MAX_OPENCV_WORKING_CHANNELS:
  348. mean = np.full_like(img, mean)
  349. std = np.full_like(img, std)
  350. normalized_img = cv2.divide(cv2.subtract(img, mean), std)
  351. return np.clip(normalized_img, -20, 20, out=normalized_img)
  352. if normalization == "image_per_channel":
  353. mean, std = cv2.meanStdDev(img)
  354. mean = mean[:, 0]
  355. std = std[:, 0]
  356. if img.shape[-1] > MAX_OPENCV_WORKING_CHANNELS:
  357. mean = np.full_like(img, mean)
  358. std = np.full_like(img, std)
  359. normalized_img = cv2.divide(cv2.subtract(img, mean), std, dtype=cv2.CV_32F)
  360. return np.clip(normalized_img, -20, 20, out=normalized_img)
  361. if normalization == "min_max" or (img.shape[-1] == 1 and normalization == "min_max_per_channel"):
  362. img_min = img.min()
  363. img_max = img.max()
  364. return cv2.normalize(img, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
  365. if normalization == "min_max_per_channel":
  366. img_min = img.min(axis=(0, 1))
  367. img_max = img.max(axis=(0, 1))
  368. if img.shape[-1] > MAX_OPENCV_WORKING_CHANNELS:
  369. img_min = np.full_like(img, img_min)
  370. img_max = np.full_like(img, img_max)
  371. return np.clip(
  372. cv2.divide(cv2.subtract(img, img_min), (img_max - img_min + eps), dtype=cv2.CV_32F),
  373. -20,
  374. 20,
  375. out=img,
  376. )
  377. raise ValueError(f"Unknown normalization method: {normalization}")
  378. @preserve_channel_dim
  379. def normalize_per_image_numpy(img: np.ndarray, normalization: NormalizationType) -> np.ndarray:
  380. img = img.astype(np.float32, copy=False)
  381. eps = 1e-4
  382. if img.ndim == MONO_CHANNEL_DIMENSIONS:
  383. img = np.expand_dims(img, axis=-1)
  384. if normalization == "image":
  385. mean = img.mean()
  386. std = img.std() + eps
  387. normalized_img = (img - mean) / std
  388. return np.clip(normalized_img, -20, 20, out=normalized_img)
  389. if normalization == "image_per_channel":
  390. pixel_mean = img.mean(axis=(0, 1))
  391. pixel_std = img.std(axis=(0, 1)) + eps
  392. normalized_img = (img - pixel_mean) / pixel_std
  393. return np.clip(normalized_img, -20, 20, out=normalized_img)
  394. if normalization == "min_max":
  395. img_min = img.min()
  396. img_max = img.max()
  397. return np.clip((img - img_min) / (img_max - img_min + eps), -20, 20, out=img)
  398. if normalization == "min_max_per_channel":
  399. img_min = img.min(axis=(0, 1))
  400. img_max = img.max(axis=(0, 1))
  401. return np.clip((img - img_min) / (img_max - img_min + eps), -20, 20, out=img)
  402. raise ValueError(f"Unknown normalization method: {normalization}")
  403. @preserve_channel_dim
  404. def normalize_per_image_lut(img: np.ndarray, normalization: NormalizationType) -> np.ndarray:
  405. dtype = img.dtype
  406. max_value = MAX_VALUES_BY_DTYPE[dtype]
  407. eps = 1e-4
  408. num_channels = get_num_channels(img)
  409. if img.ndim == MONO_CHANNEL_DIMENSIONS:
  410. img = np.expand_dims(img, axis=-1)
  411. if normalization == "image" or (img.shape[-1] == 1 and normalization == "image_per_channel"):
  412. mean = img.mean()
  413. std = img.std() + eps
  414. lut = (np.arange(0, max_value + 1, dtype=np.float32) - mean) / std
  415. return cv2.LUT(img, lut).clip(-20, 20)
  416. if normalization == "image_per_channel":
  417. pixel_mean = img.mean(axis=(0, 1))
  418. pixel_std = img.std(axis=(0, 1)) + eps
  419. luts = [
  420. (np.arange(0, max_value + 1, dtype=np.float32) - pixel_mean[c]) / pixel_std[c] for c in range(num_channels)
  421. ]
  422. return cv2.merge([cv2.LUT(img[:, :, i], luts[i]).clip(-20, 20) for i in range(num_channels)])
  423. if normalization == "min_max" or (img.shape[-1] == 1 and normalization == "min_max_per_channel"):
  424. img_min = img.min()
  425. img_max = img.max()
  426. lut = (np.arange(0, max_value + 1, dtype=np.float32) - img_min) / (img_max - img_min + eps)
  427. return cv2.LUT(img, lut).clip(-20, 20)
  428. if normalization == "min_max_per_channel":
  429. img_min = img.min(axis=(0, 1))
  430. img_max = img.max(axis=(0, 1))
  431. luts = [
  432. (np.arange(0, max_value + 1, dtype=np.float32) - img_min[c]) / (img_max[c] - img_min[c] + eps)
  433. for c in range(num_channels)
  434. ]
  435. return cv2.merge([cv2.LUT(img[:, :, i], luts[i]) for i in range(num_channels)])
  436. raise ValueError(f"Unknown normalization method: {normalization}")
  437. def normalize_per_image(img: np.ndarray, normalization: NormalizationType) -> np.ndarray:
  438. if img.dtype == np.uint8 and normalization != "per_image_per_channel":
  439. return normalize_per_image_lut(img, normalization)
  440. return normalize_per_image_opencv(img, normalization)
  441. def to_float_numpy(img: np.ndarray, max_value: float | None = None) -> np.ndarray:
  442. if max_value is None:
  443. max_value = get_max_value(img.dtype)
  444. return (img / max_value).astype(np.float32, copy=False)
  445. @preserve_channel_dim
  446. def to_float_opencv(img: np.ndarray, max_value: float | None = None) -> np.ndarray:
  447. if max_value is None:
  448. max_value = get_max_value(img.dtype)
  449. img_float = img.astype(np.float32, copy=False)
  450. num_channels = get_num_channels(img)
  451. if num_channels > MAX_OPENCV_WORKING_CHANNELS:
  452. # For images with more than 4 channels, create a full-sized divisor
  453. max_value_array = np.full_like(img_float, max_value)
  454. return cv2.divide(img_float, max_value_array)
  455. # For images with 4 or fewer channels, use scalar division
  456. return cv2.divide(img_float, max_value)
  457. @preserve_channel_dim
  458. def to_float_lut(img: np.ndarray, max_value: float | None = None) -> np.ndarray:
  459. if img.dtype != np.uint8:
  460. raise ValueError("LUT method is only applicable for uint8 images")
  461. if max_value is None:
  462. max_value = MAX_VALUES_BY_DTYPE[img.dtype]
  463. lut = np.arange(256, dtype=np.float32) / max_value
  464. return cv2.LUT(img, lut)
  465. def to_float(img: np.ndarray, max_value: float | None = None) -> np.ndarray:
  466. if img.dtype == np.float64:
  467. return img.astype(np.float32, copy=False)
  468. if img.dtype == np.float32:
  469. return img
  470. if img.dtype == np.uint8:
  471. return to_float_lut(img, max_value)
  472. return to_float_numpy(img, max_value)
  473. def from_float_numpy(img: np.ndarray, target_dtype: np.dtype, max_value: float | None = None) -> np.ndarray:
  474. if max_value is None:
  475. max_value = get_max_value(target_dtype)
  476. return clip(np.rint(img * max_value), target_dtype, inplace=True)
  477. @preserve_channel_dim
  478. def from_float_opencv(img: np.ndarray, target_dtype: np.dtype, max_value: float | None = None) -> np.ndarray:
  479. if max_value is None:
  480. max_value = get_max_value(target_dtype)
  481. img_float = img.astype(np.float32, copy=False)
  482. num_channels = get_num_channels(img)
  483. if num_channels > MAX_OPENCV_WORKING_CHANNELS:
  484. # For images with more than 4 channels, create a full-sized multiplier
  485. max_value_array = np.full_like(img_float, max_value)
  486. return clip(np.rint(cv2.multiply(img_float, max_value_array)), target_dtype, inplace=False)
  487. # For images with 4 or fewer channels, use scalar multiplication
  488. return clip(np.rint(img * max_value), target_dtype, inplace=False)
  489. def from_float(img: np.ndarray, target_dtype: np.dtype, max_value: float | None = None) -> np.ndarray:
  490. """Convert a floating-point image to the specified target data type.
  491. This function converts an input floating-point image to the specified target data type,
  492. scaling the values appropriately based on the max_value parameter or the maximum value
  493. of the target data type.
  494. Args:
  495. img (np.ndarray): Input floating-point image array.
  496. target_dtype (np.dtype): Target numpy data type for the output image.
  497. max_value (float | None, optional): Maximum value to use for scaling. If None,
  498. the maximum value of the target data type will be used. Defaults to None.
  499. Returns:
  500. np.ndarray: Image converted to the target data type.
  501. Notes:
  502. - If the input image is of type float32, the function uses OpenCV for faster processing.
  503. - For other input types, it falls back to a numpy-based implementation.
  504. - The function clips values to ensure they fit within the range of the target data type.
  505. """
  506. if target_dtype == np.float32:
  507. return img
  508. if target_dtype == np.float64:
  509. return img.astype(np.float32, copy=False)
  510. if img.dtype == np.float32:
  511. return from_float_opencv(img, target_dtype, max_value)
  512. return from_float_numpy(img, target_dtype, max_value)
  513. @contiguous
  514. def hflip_numpy(img: np.ndarray) -> np.ndarray:
  515. return img[:, ::-1, ...]
  516. @preserve_channel_dim
  517. def hflip_cv2(img: np.ndarray) -> np.ndarray:
  518. # OpenCV's flip function has a limitation of 512 channels
  519. if img.ndim > 2 and img.shape[2] > 512:
  520. return _flip_multichannel(img, flip_code=1)
  521. return cv2.flip(img, 1)
  522. def hflip(img: np.ndarray) -> np.ndarray:
  523. return hflip_cv2(img)
  524. @preserve_channel_dim
  525. def vflip_cv2(img: np.ndarray) -> np.ndarray:
  526. # OpenCV's flip function has a limitation of 512 channels
  527. if img.ndim > 2 and img.shape[2] > 512:
  528. return _flip_multichannel(img, flip_code=0)
  529. return cv2.flip(img, 0)
  530. @contiguous
  531. def vflip_numpy(img: np.ndarray) -> np.ndarray:
  532. return img[::-1, ...]
  533. def vflip(img: np.ndarray) -> np.ndarray:
  534. return vflip_cv2(img)
  535. def _flip_multichannel(img: np.ndarray, flip_code: int) -> np.ndarray:
  536. """Process images with more than 512 channels by splitting into chunks.
  537. OpenCV's flip function has a limitation where it can only handle images with up to 512 channels.
  538. This function works around that limitation by splitting the image into chunks of 512 channels,
  539. flipping each chunk separately, and then concatenating the results.
  540. Args:
  541. img: Input image with many channels
  542. flip_code: OpenCV flip code (0 for vertical, 1 for horizontal, -1 for both)
  543. Returns:
  544. Flipped image with all channels preserved
  545. """
  546. # Get image dimensions
  547. height, width = img.shape[:2]
  548. num_channels = 1 if img.ndim == 2 else img.shape[2]
  549. # If the image has 2 dimensions or fewer than 512 channels, use cv2.flip directly
  550. if img.ndim == 2 or num_channels <= 512:
  551. return cv2.flip(img, flip_code)
  552. # Process in chunks of 512 channels
  553. chunk_size = 512
  554. result_chunks = []
  555. for i in range(0, num_channels, chunk_size):
  556. end_idx = min(i + chunk_size, num_channels)
  557. chunk = img[:, :, i:end_idx]
  558. flipped_chunk = cv2.flip(chunk, flip_code)
  559. # Ensure the chunk maintains its dimensionality
  560. # This is needed when the last chunk has only one channel and cv2.flip reduces the dimensions
  561. if flipped_chunk.ndim == 2 and img.ndim == 3:
  562. flipped_chunk = np.expand_dims(flipped_chunk, axis=2)
  563. result_chunks.append(flipped_chunk)
  564. # Concatenate the chunks along the channel dimension
  565. return np.concatenate(result_chunks, axis=2)
  566. def float32_io(func: Callable[..., np.ndarray]) -> Callable[..., np.ndarray]:
  567. """Decorator to ensure float32 input/output for image processing functions.
  568. This decorator converts the input image to float32 before passing it to the wrapped function,
  569. and then converts the result back to the original dtype if it wasn't float32.
  570. Args:
  571. func (Callable[..., np.ndarray]): The image processing function to be wrapped.
  572. Returns:
  573. Callable[..., np.ndarray]: A wrapped function that handles float32 conversion.
  574. Example:
  575. @float32_io
  576. def some_image_function(img: np.ndarray) -> np.ndarray:
  577. # Function implementation
  578. return processed_img
  579. """
  580. @wraps(func)
  581. def float32_wrapper(img: np.ndarray, *args: Any, **kwargs: Any) -> np.ndarray:
  582. input_dtype = img.dtype
  583. if input_dtype != np.float32:
  584. img = to_float(img)
  585. result = func(img, *args, **kwargs)
  586. return from_float(result, target_dtype=input_dtype) if input_dtype != np.float32 else result
  587. return float32_wrapper
  588. def uint8_io(func: Callable[..., np.ndarray]) -> Callable[..., np.ndarray]:
  589. """Decorator to ensure uint8 input/output for image processing functions.
  590. This decorator converts the input image to uint8 before passing it to the wrapped function,
  591. and then converts the result back to the original dtype if it wasn't uint8.
  592. Args:
  593. func (Callable[..., np.ndarray]): The image processing function to be wrapped.
  594. Returns:
  595. Callable[..., np.ndarray]: A wrapped function that handles uint8 conversion.
  596. Example:
  597. @uint8_io
  598. def some_image_function(img: np.ndarray) -> np.ndarray:
  599. # Function implementation
  600. return processed_img
  601. """
  602. @wraps(func)
  603. def uint8_wrapper(img: np.ndarray, *args: Any, **kwargs: Any) -> np.ndarray:
  604. input_dtype = img.dtype
  605. if input_dtype != np.uint8:
  606. img = from_float(img, target_dtype=np.uint8)
  607. result = func(img, *args, **kwargs)
  608. return to_float(result) if input_dtype != np.uint8 else result
  609. return uint8_wrapper