histogram_matching.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import torch
  18. def histogram_matching(source: torch.Tensor, template: torch.Tensor) -> torch.Tensor:
  19. """Adjust the pixel values of an image to match its histogram towards a target image.
  20. `Histogram matching <https://en.wikipedia.org/wiki/Histogram_matching>`_ is the transformation
  21. of an image so that its histogram matches a specified histogram. In this implementation, the
  22. histogram is computed over the flattened image array. Code referred to
  23. `here <https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x>`_.
  24. Args:
  25. source: Image to transform.
  26. template: Template image. It can have different dimensions to source.
  27. Returns:
  28. The transformed output image as the same shape as the source image.
  29. Note:
  30. This function does not matches histograms element-wisely if input a batched tensor.
  31. """
  32. oldshape = source.shape
  33. source = source.ravel()
  34. template = template.ravel()
  35. # get the set of unique pixel values and their corresponding indices and counts.
  36. _, bin_idx, s_counts = torch.unique(source, return_inverse=True, return_counts=True)
  37. t_values, t_counts = torch.unique(template, return_counts=True)
  38. # take the cumsum of the counts and normalize by the number of pixels to
  39. # get the empirical cumulative distribution functions for the source and
  40. # template images (maps pixel value --> quantile)
  41. s_quantiles = torch.cumsum(s_counts, dim=0, dtype=source.dtype)
  42. s_quantiles = s_quantiles / s_quantiles[-1]
  43. t_quantiles = torch.cumsum(t_counts, dim=0, dtype=source.dtype)
  44. t_quantiles = t_quantiles / t_quantiles[-1]
  45. # interpolate linearly to find the pixel values in the template image
  46. # that correspond most closely to the quantiles in the source image
  47. interp_t_values = interp(s_quantiles, t_quantiles, t_values)
  48. return interp_t_values[bin_idx].reshape(oldshape)
  49. def interp(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor:
  50. """One-dimensional linear interpolation for monotonically increasing sample points.
  51. Returns the one-dimensional piecewise linear interpolant to a function with
  52. given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.
  53. This is confirmed to be a correct implementation.
  54. See https://github.com/pytorch/pytorch/issues/1552#issuecomment-979998307
  55. Args:
  56. x: the :math:`x`-coordinates at which to evaluate the interpolated
  57. values.
  58. xp: the :math:`x`-coordinates of the data points, must be increasing.
  59. fp: the :math:`y`-coordinates of the data points, same length as `xp`.
  60. Returns:
  61. the interpolated values, same size as `x`.
  62. """
  63. i = torch.clip(torch.searchsorted(xp, x, right=True), 1, len(xp) - 1)
  64. return (fp[i - 1] * (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])