| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import warnings
- from typing import List, Tuple, Union
- import torch
- from torch import Tensor
- from kornia.core import Module, concatenate
- from kornia.geometry.transform import resize
- __all__ = ["OutputRangePostProcessor", "ResizePostProcessor", "ResizePreProcessor"]
- class ResizePreProcessor(Module):
- """Resize a list of image tensors to the given size.
- Additionally, also returns the original image sizes for further post-processing.
- """
- def __init__(self, height: int, width: int, interpolation_mode: str = "bilinear") -> None:
- """Construct ResizePreprocessor module.
- Args:
- height: height of the resized image.
- width: width of the resized image.
- interpolation_mode: interpolation mode for image resizing. Supported values: ``nearest``, ``bilinear``,
- ``bicubic``, ``area``, and ``nearest-exact``.
- """
- super().__init__()
- self.size = (height, width)
- self.interpolation_mode = interpolation_mode
- def forward(self, imgs: Union[Tensor, List[Tensor]]) -> Tuple[Tensor, Tensor]:
- """Run forward.
- Returns:
- resized_imgs: resized images in a batch.
- original_sizes: the original image sizes of (height, width).
- """
- # TODO: support other input formats e.g. file path, numpy
- resized_imgs: list[Tensor] = []
- iters = len(imgs) if isinstance(imgs, list) else imgs.shape[0]
- original_sizes = imgs[0].new_zeros((iters, 2))
- for i in range(iters):
- img = imgs[i]
- original_sizes[i, 0] = img.shape[-2] # Height
- original_sizes[i, 1] = img.shape[-1] # Width
- resized_imgs.append(resize(img[None], size=self.size, interpolation=self.interpolation_mode))
- return concatenate(resized_imgs), original_sizes
- class ResizePostProcessor(Module):
- def __init__(self, interpolation_mode: str = "bilinear") -> None:
- super().__init__()
- self.interpolation_mode = interpolation_mode
- def forward(self, imgs: Union[Tensor, List[Tensor]], original_sizes: Tensor) -> Union[Tensor, List[Tensor]]:
- """Run forward.
- Returns:
- resized_imgs: resized images in a batch.
- original_sizes: the original image sizes of (height, width).
- """
- # TODO: support other input formats e.g. file path, numpy
- resized_imgs: list[Tensor] = []
- if torch.onnx.is_in_onnx_export():
- warnings.warn(
- "ResizePostProcessor is not supported in ONNX export. "
- "The output will not be resized back to the original size.",
- stacklevel=1,
- )
- return imgs
- iters = len(imgs) if isinstance(imgs, list) else imgs.shape[0]
- for i in range(iters):
- img = imgs[i]
- size = original_sizes[i]
- resized_imgs.append(
- resize(img[None], size=size.cpu().long().numpy().tolist(), interpolation=self.interpolation_mode)
- )
- return resized_imgs
- class OutputRangePostProcessor(Module):
- def __init__(self, min_val: float = 0.0, max_val: float = 1.0) -> None:
- super().__init__()
- self.min_val = min_val
- self.max_val = max_val
- def forward(self, imgs: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]]:
- if isinstance(imgs, Tensor):
- return torch.clamp(imgs, self.min_val, self.max_val)
- return [img.clamp_(self.min_val, self.max_val) for img in imgs]
|