| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import torch
- def histogram_matching(source: torch.Tensor, template: torch.Tensor) -> torch.Tensor:
- """Adjust the pixel values of an image to match its histogram towards a target image.
- `Histogram matching <https://en.wikipedia.org/wiki/Histogram_matching>`_ is the transformation
- of an image so that its histogram matches a specified histogram. In this implementation, the
- histogram is computed over the flattened image array. Code referred to
- `here <https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x>`_.
- Args:
- source: Image to transform.
- template: Template image. It can have different dimensions to source.
- Returns:
- The transformed output image as the same shape as the source image.
- Note:
- This function does not matches histograms element-wisely if input a batched tensor.
- """
- oldshape = source.shape
- source = source.ravel()
- template = template.ravel()
- # get the set of unique pixel values and their corresponding indices and counts.
- _, bin_idx, s_counts = torch.unique(source, return_inverse=True, return_counts=True)
- t_values, t_counts = torch.unique(template, return_counts=True)
- # take the cumsum of the counts and normalize by the number of pixels to
- # get the empirical cumulative distribution functions for the source and
- # template images (maps pixel value --> quantile)
- s_quantiles = torch.cumsum(s_counts, dim=0, dtype=source.dtype)
- s_quantiles = s_quantiles / s_quantiles[-1]
- t_quantiles = torch.cumsum(t_counts, dim=0, dtype=source.dtype)
- t_quantiles = t_quantiles / t_quantiles[-1]
- # interpolate linearly to find the pixel values in the template image
- # that correspond most closely to the quantiles in the source image
- interp_t_values = interp(s_quantiles, t_quantiles, t_values)
- return interp_t_values[bin_idx].reshape(oldshape)
- def interp(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor:
- """One-dimensional linear interpolation for monotonically increasing sample points.
- Returns the one-dimensional piecewise linear interpolant to a function with
- given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.
- This is confirmed to be a correct implementation.
- See https://github.com/pytorch/pytorch/issues/1552#issuecomment-979998307
- Args:
- x: the :math:`x`-coordinates at which to evaluate the interpolated
- values.
- xp: the :math:`x`-coordinates of the data points, must be increasing.
- fp: the :math:`y`-coordinates of the data points, same length as `xp`.
- Returns:
- the interpolated values, same size as `x`.
- """
- i = torch.clip(torch.searchsorted(xp, x, right=True), 1, len(xp) - 1)
- return (fp[i - 1] * (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])
|