utils.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import warnings
  18. from typing import List, Tuple, Union
  19. import torch
  20. from torch import Tensor
  21. from kornia.core import Module, concatenate
  22. from kornia.geometry.transform import resize
  23. __all__ = ["OutputRangePostProcessor", "ResizePostProcessor", "ResizePreProcessor"]
  24. class ResizePreProcessor(Module):
  25. """Resize a list of image tensors to the given size.
  26. Additionally, also returns the original image sizes for further post-processing.
  27. """
  28. def __init__(self, height: int, width: int, interpolation_mode: str = "bilinear") -> None:
  29. """Construct ResizePreprocessor module.
  30. Args:
  31. height: height of the resized image.
  32. width: width of the resized image.
  33. interpolation_mode: interpolation mode for image resizing. Supported values: ``nearest``, ``bilinear``,
  34. ``bicubic``, ``area``, and ``nearest-exact``.
  35. """
  36. super().__init__()
  37. self.size = (height, width)
  38. self.interpolation_mode = interpolation_mode
  39. def forward(self, imgs: Union[Tensor, List[Tensor]]) -> Tuple[Tensor, Tensor]:
  40. """Run forward.
  41. Returns:
  42. resized_imgs: resized images in a batch.
  43. original_sizes: the original image sizes of (height, width).
  44. """
  45. # TODO: support other input formats e.g. file path, numpy
  46. resized_imgs: list[Tensor] = []
  47. iters = len(imgs) if isinstance(imgs, list) else imgs.shape[0]
  48. original_sizes = imgs[0].new_zeros((iters, 2))
  49. for i in range(iters):
  50. img = imgs[i]
  51. original_sizes[i, 0] = img.shape[-2] # Height
  52. original_sizes[i, 1] = img.shape[-1] # Width
  53. resized_imgs.append(resize(img[None], size=self.size, interpolation=self.interpolation_mode))
  54. return concatenate(resized_imgs), original_sizes
  55. class ResizePostProcessor(Module):
  56. def __init__(self, interpolation_mode: str = "bilinear") -> None:
  57. super().__init__()
  58. self.interpolation_mode = interpolation_mode
  59. def forward(self, imgs: Union[Tensor, List[Tensor]], original_sizes: Tensor) -> Union[Tensor, List[Tensor]]:
  60. """Run forward.
  61. Returns:
  62. resized_imgs: resized images in a batch.
  63. original_sizes: the original image sizes of (height, width).
  64. """
  65. # TODO: support other input formats e.g. file path, numpy
  66. resized_imgs: list[Tensor] = []
  67. if torch.onnx.is_in_onnx_export():
  68. warnings.warn(
  69. "ResizePostProcessor is not supported in ONNX export. "
  70. "The output will not be resized back to the original size.",
  71. stacklevel=1,
  72. )
  73. return imgs
  74. iters = len(imgs) if isinstance(imgs, list) else imgs.shape[0]
  75. for i in range(iters):
  76. img = imgs[i]
  77. size = original_sizes[i]
  78. resized_imgs.append(
  79. resize(img[None], size=size.cpu().long().numpy().tolist(), interpolation=self.interpolation_mode)
  80. )
  81. return resized_imgs
  82. class OutputRangePostProcessor(Module):
  83. def __init__(self, min_val: float = 0.0, max_val: float = 1.0) -> None:
  84. super().__init__()
  85. self.min_val = min_val
  86. self.max_val = max_val
  87. def forward(self, imgs: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]]:
  88. if isinstance(imgs, Tensor):
  89. return torch.clamp(imgs, self.min_val, self.max_val)
  90. return [img.clamp_(self.min_val, self.max_val) for img in imgs]