# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import torch def histogram_matching(source: torch.Tensor, template: torch.Tensor) -> torch.Tensor: """Adjust the pixel values of an image to match its histogram towards a target image. `Histogram matching `_ is the transformation of an image so that its histogram matches a specified histogram. In this implementation, the histogram is computed over the flattened image array. Code referred to `here `_. Args: source: Image to transform. template: Template image. It can have different dimensions to source. Returns: The transformed output image as the same shape as the source image. Note: This function does not matches histograms element-wisely if input a batched tensor. """ oldshape = source.shape source = source.ravel() template = template.ravel() # get the set of unique pixel values and their corresponding indices and counts. _, bin_idx, s_counts = torch.unique(source, return_inverse=True, return_counts=True) t_values, t_counts = torch.unique(template, return_counts=True) # take the cumsum of the counts and normalize by the number of pixels to # get the empirical cumulative distribution functions for the source and # template images (maps pixel value --> quantile) s_quantiles = torch.cumsum(s_counts, dim=0, dtype=source.dtype) s_quantiles = s_quantiles / s_quantiles[-1] t_quantiles = torch.cumsum(t_counts, dim=0, dtype=source.dtype) t_quantiles = t_quantiles / t_quantiles[-1] # interpolate linearly to find the pixel values in the template image # that correspond most closely to the quantiles in the source image interp_t_values = interp(s_quantiles, t_quantiles, t_values) return interp_t_values[bin_idx].reshape(oldshape) def interp(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor: """One-dimensional linear interpolation for monotonically increasing sample points. Returns the one-dimensional piecewise linear interpolant to a function with given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`. This is confirmed to be a correct implementation. See https://github.com/pytorch/pytorch/issues/1552#issuecomment-979998307 Args: x: the :math:`x`-coordinates at which to evaluate the interpolated values. xp: the :math:`x`-coordinates of the data points, must be increasing. fp: the :math:`y`-coordinates of the data points, same length as `xp`. Returns: the interpolated values, same size as `x`. """ i = torch.clip(torch.searchsorted(xp, x, right=True), 1, len(xp) - 1) return (fp[i - 1] * (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])