equalization.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """In this module several equalization methods are exposed: he, ahe, clahe."""
  18. import math
  19. from typing import Tuple
  20. import torch
  21. import torch.nn.functional as F
  22. from kornia.utils.helpers import _torch_histc_cast
  23. from kornia.utils.image import perform_keep_shape_image
  24. from .histogram import histogram
  25. def _compute_tiles(
  26. imgs: torch.Tensor, grid_size: Tuple[int, int], even_tile_size: bool = False
  27. ) -> Tuple[torch.Tensor, torch.Tensor]:
  28. r"""Compute tiles on an image according to a grid size.
  29. Note that padding can be added to the image in order to crop properly the image.
  30. So, the grid_size (GH, GW) x tile_size (TH, TW) >= image_size (H, W)
  31. Args:
  32. imgs: batch of 2D images with shape (B, C, H, W) or (C, H, W).
  33. grid_size: number of tiles to be cropped in each direction (GH, GW)
  34. even_tile_size: Determine if the width and height of the tiles must be even.
  35. Returns:
  36. tensor with tiles (B, GH, GW, C, TH, TW). B = 1 in case of a single image is provided.
  37. tensor with the padded batch of 2D imageswith shape (B, C, H', W').
  38. """
  39. batch: torch.Tensor = imgs # B x C x H x W
  40. # compute stride and kernel size
  41. h, w = batch.shape[-2:]
  42. kernel_vert: int = math.ceil(h / grid_size[0])
  43. kernel_horz: int = math.ceil(w / grid_size[1])
  44. if even_tile_size:
  45. kernel_vert += 1 if kernel_vert % 2 else 0
  46. kernel_horz += 1 if kernel_horz % 2 else 0
  47. # add padding (with that kernel size we could need some extra cols and rows...)
  48. pad_vert = kernel_vert * grid_size[0] - h
  49. pad_horz = kernel_horz * grid_size[1] - w
  50. # add the padding in the last coluns and rows
  51. if pad_vert > batch.shape[-2] or pad_horz > batch.shape[-1]:
  52. raise ValueError("Cannot compute tiles on the image according to the given grid size")
  53. if pad_vert > 0 or pad_horz > 0:
  54. batch = F.pad(batch, [0, pad_horz, 0, pad_vert], mode="reflect") # B x C x H' x W'
  55. # compute tiles
  56. c: int = batch.shape[-3]
  57. tiles: torch.Tensor = (
  58. batch.unfold(1, c, c) # unfold(dimension, size, step)
  59. .unfold(2, kernel_vert, kernel_vert)
  60. .unfold(3, kernel_horz, kernel_horz)
  61. .squeeze(1)
  62. ).contiguous() # GH x GW x C x TH x TW
  63. if tiles.shape[-5] != grid_size[0]:
  64. raise AssertionError
  65. if tiles.shape[-4] != grid_size[1]:
  66. raise AssertionError
  67. return tiles, batch
  68. def _compute_interpolation_tiles(padded_imgs: torch.Tensor, tile_size: Tuple[int, int]) -> torch.Tensor:
  69. r"""Compute interpolation tiles on a properly padded set of images.
  70. Note that images must be padded. So, the tile_size (TH, TW) * grid_size (GH, GW) = image_size (H, W)
  71. Args:
  72. padded_imgs: batch of 2D images with shape (B, C, H, W) already padded to extract tiles
  73. of size (TH, TW).
  74. tile_size: shape of the current tiles (TH, TW).
  75. Returns:
  76. tensor with the interpolation tiles (B, 2GH, 2GW, C, TH/2, TW/2).
  77. """
  78. if padded_imgs.dim() != 4:
  79. raise AssertionError("Images Tensor must be 4D.")
  80. if padded_imgs.shape[-2] % tile_size[0] != 0:
  81. raise AssertionError("Images are not correctly padded.")
  82. if padded_imgs.shape[-1] % tile_size[1] != 0:
  83. raise AssertionError("Images are not correctly padded.")
  84. # tiles to be interpolated are built by dividing in 4 each already existing
  85. interp_kernel_vert: int = tile_size[0] // 2
  86. interp_kernel_horz: int = tile_size[1] // 2
  87. c: int = padded_imgs.shape[-3]
  88. interp_tiles: torch.Tensor = (
  89. padded_imgs.unfold(1, c, c)
  90. .unfold(2, interp_kernel_vert, interp_kernel_vert)
  91. .unfold(3, interp_kernel_horz, interp_kernel_horz)
  92. .squeeze(1)
  93. ).contiguous() # 2GH x 2GW x C x TH/2 x TW/2
  94. if interp_tiles.shape[-3] != c:
  95. raise AssertionError
  96. if interp_tiles.shape[-2] != tile_size[0] / 2:
  97. raise AssertionError
  98. if interp_tiles.shape[-1] != tile_size[1] / 2:
  99. raise AssertionError
  100. return interp_tiles
  101. def _my_histc(tiles: torch.Tensor, bins: int) -> torch.Tensor:
  102. return _torch_histc_cast(tiles, bins=bins, min=0, max=1)
  103. def _compute_luts(
  104. tiles_x_im: torch.Tensor, num_bins: int = 256, clip: float = 40.0, diff: bool = False
  105. ) -> torch.Tensor:
  106. r"""Compute luts for a batched set of tiles.
  107. Same approach as in OpenCV (https://github.com/opencv/opencv/blob/master/modules/imgproc/src/clahe.cpp)
  108. Args:
  109. tiles_x_im: set of tiles per image to apply the lut. (B, GH, GW, C, TH, TW)
  110. num_bins: number of bins. default: 256
  111. clip: threshold value for contrast limiting. If it is 0 then the clipping is disabled.
  112. diff: denote if the differentiable histagram will be used. Default: False
  113. Returns:
  114. Lut for each tile (B, GH, GW, C, 256).
  115. """
  116. if tiles_x_im.dim() != 6:
  117. raise AssertionError("Tensor must be 6D.")
  118. b, gh, gw, c, th, tw = tiles_x_im.shape
  119. pixels: int = th * tw
  120. tiles: torch.Tensor = tiles_x_im.view(-1, pixels) # test with view # T x (THxTW)
  121. if not diff:
  122. if torch.jit.is_scripting():
  123. histos = torch.stack([_torch_histc_cast(tile, bins=num_bins, min=0, max=1) for tile in tiles])
  124. else:
  125. histos = torch.stack(list(map(_my_histc, tiles, [num_bins] * len(tiles))))
  126. else:
  127. bins: torch.Tensor = torch.linspace(0, 1, num_bins, device=tiles.device)
  128. histos = histogram(tiles, bins, torch.tensor(0.001)).squeeze()
  129. histos *= pixels
  130. if clip > 0.0:
  131. max_val: float = max(clip * pixels // num_bins, 1)
  132. histos.clamp_(max=max_val)
  133. clipped: torch.Tensor = pixels - histos.sum(1)
  134. residual: torch.Tensor = torch.remainder(clipped, num_bins)
  135. redist: torch.Tensor = (clipped - residual).div(num_bins)
  136. histos += redist[None].transpose(0, 1)
  137. # trick to avoid using a loop to assign the residual
  138. v_range: torch.Tensor = torch.arange(num_bins, device=histos.device)
  139. mat_range: torch.Tensor = v_range.repeat(histos.shape[0], 1)
  140. histos += mat_range < residual[None].transpose(0, 1)
  141. lut_scale: float = (num_bins - 1) / pixels
  142. luts: torch.Tensor = torch.cumsum(histos, 1) * lut_scale
  143. luts = luts.clamp(0, num_bins - 1)
  144. if not diff:
  145. luts = luts.floor() # to get the same values as converting to int maintaining the type
  146. luts = luts.view((b, gh, gw, c, num_bins))
  147. return luts
  148. def _map_luts(interp_tiles: torch.Tensor, luts: torch.Tensor) -> torch.Tensor:
  149. r"""Assign the required luts to each tile.
  150. Args:
  151. interp_tiles: set of interpolation tiles. (B, 2GH, 2GW, C, TH/2, TW/2)
  152. luts: luts for each one of the original tiles. (B, GH, GW, C, 256)
  153. Returns:
  154. mapped luts (B, 2GH, 2GW, 4, C, 256)
  155. """
  156. if interp_tiles.dim() != 6:
  157. raise AssertionError("interp_tiles tensor must be 6D.")
  158. if luts.dim() != 5:
  159. raise AssertionError("luts tensor must be 5D.")
  160. # gh, gw -> 2x the number of tiles used to compute the histograms
  161. # th, tw -> /2 the sizes of the tiles used to compute the histograms
  162. num_imgs, gh, gw, c, _, _ = interp_tiles.shape
  163. # precompute idxs for non corner regions (doing it in cpu seems slightly faster)
  164. j_idxs = torch.empty(0, 4, dtype=torch.long)
  165. if gh > 2:
  166. j_floor = torch.arange(1, gh - 1).view(gh - 2, 1).div(2, rounding_mode="trunc")
  167. j_idxs = torch.tensor([[0, 0, 1, 1], [-1, -1, 0, 0]] * ((gh - 2) // 2)) # reminder + j_idxs[:, 0:2] -= 1
  168. j_idxs += j_floor
  169. i_idxs = torch.empty(0, 4, dtype=torch.long)
  170. if gw > 2:
  171. i_floor = torch.arange(1, gw - 1).view(gw - 2, 1).div(2, rounding_mode="trunc")
  172. i_idxs = torch.tensor([[0, 1, 0, 1], [-1, 0, -1, 0]] * ((gw - 2) // 2)) # reminder + i_idxs[:, [0, 2]] -= 1
  173. i_idxs += i_floor
  174. # selection of luts to interpolate each patch
  175. # create a tensor with dims: interp_patches height and width x 4 x num channels x bins in the histograms
  176. # the tensor is init to -1 to denote non init hists
  177. luts_x_interp_tiles: torch.Tensor = torch.full( # B x GH x GW x 4 x C x 256
  178. (num_imgs, gh, gw, 4, c, luts.shape[-1]), -1, dtype=interp_tiles.dtype, device=interp_tiles.device
  179. )
  180. # corner regions
  181. luts_x_interp_tiles[:, 0 :: gh - 1, 0 :: gw - 1, 0] = luts[:, 0 :: max(gh // 2 - 1, 1), 0 :: max(gw // 2 - 1, 1)]
  182. # border region (h)
  183. luts_x_interp_tiles[:, 1:-1, 0 :: gw - 1, 0] = luts[:, j_idxs[:, 0], 0 :: max(gw // 2 - 1, 1)]
  184. luts_x_interp_tiles[:, 1:-1, 0 :: gw - 1, 1] = luts[:, j_idxs[:, 2], 0 :: max(gw // 2 - 1, 1)]
  185. # border region (w)
  186. luts_x_interp_tiles[:, 0 :: gh - 1, 1:-1, 0] = luts[:, 0 :: max(gh // 2 - 1, 1), i_idxs[:, 0]]
  187. luts_x_interp_tiles[:, 0 :: gh - 1, 1:-1, 1] = luts[:, 0 :: max(gh // 2 - 1, 1), i_idxs[:, 1]]
  188. # internal region
  189. luts_x_interp_tiles[:, 1:-1, 1:-1, :] = luts[
  190. :, j_idxs.repeat(max(gh - 2, 1), 1, 1).permute(1, 0, 2), i_idxs.repeat(max(gw - 2, 1), 1, 1)
  191. ]
  192. return luts_x_interp_tiles
  193. def _compute_equalized_tiles(interp_tiles: torch.Tensor, luts: torch.Tensor) -> torch.Tensor:
  194. r"""Equalize the tiles.
  195. Args:
  196. interp_tiles: set of interpolation tiles, values must be in the range [0, 1].
  197. (B, 2GH, 2GW, C, TH/2, TW/2)
  198. luts: luts for each one of the original tiles. (B, GH, GW, C, 256)
  199. Returns:
  200. equalized tiles (B, 2GH, 2GW, C, TH/2, TW/2)
  201. """
  202. if interp_tiles.dim() != 6:
  203. raise AssertionError("interp_tiles tensor must be 6D.")
  204. if luts.dim() != 5:
  205. raise AssertionError("luts tensor must be 5D.")
  206. mapped_luts: torch.Tensor = _map_luts(interp_tiles, luts) # Bx2GHx2GWx4xCx256
  207. # gh, gw -> 2x the number of tiles used to compute the histograms
  208. # th, tw -> /2 the sizes of the tiles used to compute the histograms
  209. num_imgs, gh, gw, c, th, tw = interp_tiles.shape
  210. # equalize tiles
  211. flatten_interp_tiles: torch.Tensor = (interp_tiles * 255).long().flatten(-2, -1) # B x GH x GW x 4 x C x (THxTW)
  212. flatten_interp_tiles = flatten_interp_tiles.unsqueeze(-3).expand(num_imgs, gh, gw, 4, c, th * tw)
  213. preinterp_tiles_equalized = (
  214. torch.gather(mapped_luts, 5, flatten_interp_tiles) # B x GH x GW x 4 x C x TH x TW
  215. .to(interp_tiles)
  216. .reshape(num_imgs, gh, gw, 4, c, th, tw)
  217. )
  218. # interp tiles
  219. tiles_equalized: torch.Tensor = torch.zeros_like(interp_tiles)
  220. # compute the interpolation weights (shapes are 2 x TH x TW because they must be applied to 2 interp tiles)
  221. ih = (
  222. torch.arange(2 * th - 1, -1, -1, dtype=interp_tiles.dtype, device=interp_tiles.device)
  223. .div(2.0 * th - 1)[None]
  224. .transpose(-2, -1)
  225. .expand(2 * th, tw)
  226. )
  227. ih = ih.unfold(0, th, th).unfold(1, tw, tw) # 2 x 1 x TH x TW
  228. iw = (
  229. torch.arange(2 * tw - 1, -1, -1, dtype=interp_tiles.dtype, device=interp_tiles.device)
  230. .div(2.0 * tw - 1)
  231. .expand(th, 2 * tw)
  232. )
  233. iw = iw.unfold(0, th, th).unfold(1, tw, tw) # 1 x 2 x TH x TW
  234. # compute row and column interpolation weights
  235. tiw = iw.expand((gw - 2) // 2, 2, th, tw).reshape(gw - 2, 1, th, tw).unsqueeze(0) # 1 x GW-2 x 1 x TH x TW
  236. tih = ih.repeat((gh - 2) // 2, 1, 1, 1).unsqueeze(1) # GH-2 x 1 x 1 x TH x TW
  237. # internal regions
  238. tl, tr, bl, br = preinterp_tiles_equalized[:, 1:-1, 1:-1].unbind(3)
  239. t = torch.addcmul(tr, tiw, torch.sub(tl, tr))
  240. b = torch.addcmul(br, tiw, torch.sub(bl, br))
  241. tiles_equalized[:, 1:-1, 1:-1] = torch.addcmul(b, tih, torch.sub(t, b))
  242. # corner regions
  243. tiles_equalized[:, 0 :: gh - 1, 0 :: gw - 1] = preinterp_tiles_equalized[:, 0 :: gh - 1, 0 :: gw - 1, 0]
  244. # border region (h)
  245. t, b, _, _ = preinterp_tiles_equalized[:, 1:-1, 0].unbind(2)
  246. tiles_equalized[:, 1:-1, 0] = torch.addcmul(b, tih.squeeze(1), torch.sub(t, b))
  247. t, b, _, _ = preinterp_tiles_equalized[:, 1:-1, gh - 1].unbind(2)
  248. tiles_equalized[:, 1:-1, gh - 1] = torch.addcmul(b, tih.squeeze(1), torch.sub(t, b))
  249. # border region (w)
  250. left, right, _, _ = preinterp_tiles_equalized[:, 0, 1:-1].unbind(2)
  251. tiles_equalized[:, 0, 1:-1] = torch.addcmul(right, tiw, torch.sub(left, right))
  252. left, right, _, _ = preinterp_tiles_equalized[:, gw - 1, 1:-1].unbind(2)
  253. tiles_equalized[:, gw - 1, 1:-1] = torch.addcmul(right, tiw, torch.sub(left, right))
  254. # same type as the input
  255. return tiles_equalized.div(255.0)
  256. @perform_keep_shape_image
  257. def equalize_clahe(
  258. input: torch.Tensor,
  259. clip_limit: float = 40.0,
  260. grid_size: Tuple[int, int] = (8, 8),
  261. slow_and_differentiable: bool = False,
  262. ) -> torch.Tensor:
  263. r"""Apply clahe equalization on the input tensor.
  264. .. image:: _static/img/equalize_clahe.png
  265. NOTE: Lut computation uses the same approach as in OpenCV, in next versions this can change.
  266. Args:
  267. input: images tensor to equalize with values in the range [0, 1] and shape :math:`(*, C, H, W)`.
  268. clip_limit: threshold value for contrast limiting. If 0 clipping is disabled.
  269. grid_size: number of tiles to be cropped in each direction (GH, GW).
  270. slow_and_differentiable: flag to select implementation
  271. Returns:
  272. Equalized image or images with shape as the input.
  273. Examples:
  274. >>> img = torch.rand(1, 10, 20)
  275. >>> res = equalize_clahe(img)
  276. >>> res.shape
  277. torch.Size([1, 10, 20])
  278. >>> img = torch.rand(2, 3, 10, 20)
  279. >>> res = equalize_clahe(img)
  280. >>> res.shape
  281. torch.Size([2, 3, 10, 20])
  282. """
  283. if not isinstance(clip_limit, float):
  284. raise TypeError(f"Input clip_limit type is not float. Got {type(clip_limit)}")
  285. if not isinstance(grid_size, tuple):
  286. raise TypeError(f"Input grid_size type is not Tuple. Got {type(grid_size)}")
  287. if len(grid_size) != 2:
  288. raise TypeError(f"Input grid_size is not a Tuple with 2 elements. Got {len(grid_size)}")
  289. if isinstance(grid_size[0], float) or isinstance(grid_size[1], float):
  290. raise TypeError("Input grid_size type is not valid, must be a Tuple[int, int].")
  291. if grid_size[0] <= 0 or grid_size[1] <= 0:
  292. raise ValueError(f"Input grid_size elements must be positive. Got {grid_size}")
  293. imgs: torch.Tensor = input # B x C x H x W
  294. # hist_tiles: torch.Tensor # B x GH x GW x C x TH x TW # not supported by JIT
  295. # img_padded: torch.Tensor # B x C x H' x W' # not supported by JIT
  296. # the size of the tiles must be even in order to divide them into 4 tiles for the interpolation
  297. hist_tiles, img_padded = _compute_tiles(imgs, grid_size, True)
  298. tile_size: Tuple[int, int] = (hist_tiles.shape[-2], hist_tiles.shape[-1])
  299. interp_tiles: torch.Tensor = _compute_interpolation_tiles(img_padded, tile_size) # B x 2GH x 2GW x C x TH/2 x TW/2
  300. luts: torch.Tensor = _compute_luts(
  301. hist_tiles, clip=clip_limit, diff=slow_and_differentiable
  302. ) # B x GH x GW x C x 256
  303. equalized_tiles: torch.Tensor = _compute_equalized_tiles(interp_tiles, luts) # B x 2GH x 2GW x C x TH/2 x TW/2
  304. # reconstruct the images form the tiles
  305. # try permute + contiguous + view
  306. eq_imgs: torch.Tensor = equalized_tiles.permute(0, 3, 1, 4, 2, 5).reshape_as(img_padded)
  307. h, w = imgs.shape[-2:]
  308. eq_imgs = eq_imgs[..., :h, :w] # crop imgs if they were padded
  309. # remove batch if the input was not in batch form
  310. if input.dim() != eq_imgs.dim():
  311. eq_imgs = eq_imgs.squeeze(0)
  312. return eq_imgs