base.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, List, Optional, Tuple, Union
  18. import torch
  19. from kornia.core import Tensor
  20. from kornia.core.external import PILImage as Image
  21. from kornia.core.external import onnx
  22. from kornia.models.base import ModelBase
  23. __all__ = ["SuperResolution"]
  24. # TODO: support patching -> SR -> unpatching pipeline
  25. class SuperResolution(ModelBase):
  26. """SuperResolution is a module that wraps an super resolution model."""
  27. name: str = "super_resolution"
  28. input_image_size: Optional[int]
  29. output_image_size: Optional[int]
  30. pseudo_image_size: Optional[int]
  31. @torch.inference_mode()
  32. def forward(self, images: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]]:
  33. """Forward pass of the super resolution model.
  34. Args:
  35. images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
  36. If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.
  37. Returns:
  38. output tensor.
  39. """
  40. output = self.pre_processor(images)
  41. if isinstance(
  42. output,
  43. (
  44. list,
  45. tuple,
  46. ),
  47. ):
  48. images = output[0]
  49. else:
  50. images = output
  51. if isinstance(images, list):
  52. out_images = [self.model(image[None])[0] for image in images]
  53. else:
  54. out_images = self.model(images)
  55. return self.post_processor(out_images)
  56. def visualize(
  57. self,
  58. images: Union[Tensor, List[Tensor]],
  59. edge_maps: Optional[Union[Tensor, List[Tensor]]] = None,
  60. output_type: str = "torch",
  61. ) -> Union[Tensor, List[Tensor], List["Image.Image"]]: # type: ignore
  62. """Draw the super resolution results.
  63. Args:
  64. images: input tensor.
  65. edge_maps: detected edges.
  66. output_type: type of the output.
  67. Returns:
  68. output tensor.
  69. """
  70. if edge_maps is None:
  71. edge_maps = self.forward(images)
  72. output = []
  73. for edge_map in edge_maps:
  74. output.append(edge_map)
  75. return self._tensor_to_type(output, output_type, is_batch=isinstance(images, Tensor))
  76. def save(
  77. self,
  78. images: Union[Tensor, List[Tensor]],
  79. edge_maps: Optional[Union[Tensor, List[Tensor]]] = None,
  80. directory: Optional[str] = None,
  81. output_type: str = "torch",
  82. ) -> None:
  83. """Save the super resolution results.
  84. Args:
  85. images: input tensor.
  86. edge_maps: detected edges.
  87. output_type: type of the output.
  88. directory: where to save outputs.
  89. output_type: backend used to generate outputs.
  90. Returns:
  91. output tensor.
  92. """
  93. outputs = self.visualize(images, edge_maps, output_type)
  94. self._save_outputs(images, directory, suffix="_src")
  95. self._save_outputs(outputs, directory, suffix="_sr")
  96. def to_onnx( # type: ignore[override]
  97. self,
  98. onnx_name: Optional[str] = None,
  99. include_pre_and_post_processor: bool = True,
  100. save: bool = True,
  101. additional_metadata: Optional[List[Tuple[str, str]]] = None,
  102. **kwargs: Any,
  103. ) -> "onnx.ModelProto": # type: ignore
  104. """Export the current super resolution model to an ONNX model file.
  105. Args:
  106. onnx_name:
  107. The name of the output ONNX file. If not provided, a default name in the
  108. format "Kornia-<ClassName>.onnx" will be used.
  109. image_size:
  110. The size to which input images will be resized during preprocessing.
  111. If None, image_size will be dynamic. For DexiNed, recommended scale is 352.
  112. include_pre_and_post_processor:
  113. Whether to include the pre-processor and post-processor in the exported model.
  114. save:
  115. If to save the model or load it.
  116. additional_metadata:
  117. Additional metadata to add to the ONNX model.
  118. kwargs: Additional arguments for converting to onnx.
  119. """
  120. if onnx_name is None:
  121. onnx_name = f"kornia_{self.name}.onnx"
  122. return super().to_onnx(
  123. onnx_name,
  124. input_shape=[-1, 3, self.input_image_size or -1, self.input_image_size or -1],
  125. output_shape=[-1, 3, self.output_image_size or -1, self.output_image_size or -1],
  126. pseudo_shape=[1, 3, self.pseudo_image_size or 352, self.pseudo_image_size or 352],
  127. model=self if include_pre_and_post_processor else self.model,
  128. save=save,
  129. additional_metadata=additional_metadata,
  130. **kwargs,
  131. )