base.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Any, Optional, Union
  19. import torch
  20. from kornia.color.gray import grayscale_to_rgb
  21. from kornia.core import Tensor
  22. from kornia.core.external import PILImage as Image
  23. from kornia.core.external import onnx
  24. from kornia.models.base import ModelBase
  25. __all__ = ["EdgeDetector"]
  26. class EdgeDetector(ModelBase):
  27. """EdgeDetector is a module that wraps an edge detection model."""
  28. name: str = "edge_detection"
  29. @torch.inference_mode()
  30. def forward(self, images: Union[Tensor, list[Tensor]]) -> Union[Tensor, list[Tensor]]:
  31. """Forward pass of the edge detection model.
  32. Args:
  33. images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
  34. If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.
  35. Returns:
  36. output tensor.
  37. """
  38. images, image_sizes = self.pre_processor(images)
  39. out_images = self.model(images)
  40. return self.post_processor(out_images, image_sizes)
  41. def visualize(
  42. self,
  43. images: Union[Tensor, list[Tensor]],
  44. edge_maps: Optional[Union[Tensor, list[Tensor]]] = None,
  45. output_type: str = "torch",
  46. ) -> Union[Tensor, list[Tensor], list[Image.Image]]: # type: ignore
  47. """Draw the edge detection results.
  48. Args:
  49. images: input tensor.
  50. edge_maps: detected edges.
  51. output_type: type of the output.
  52. Returns:
  53. output tensor.
  54. """
  55. if edge_maps is None:
  56. edge_maps = self.forward(images)
  57. output = []
  58. for edge_map in edge_maps:
  59. output.append(grayscale_to_rgb(edge_map)[0])
  60. return self._tensor_to_type(output, output_type, is_batch=isinstance(images, Tensor))
  61. def save(
  62. self,
  63. images: Union[Tensor, list[Tensor]],
  64. edge_maps: Optional[Union[Tensor, list[Tensor]]] = None,
  65. directory: Optional[str] = None,
  66. output_type: str = "torch",
  67. ) -> None:
  68. """Save the edge detection results.
  69. Args:
  70. images: input tensor.
  71. edge_maps: detected edges.
  72. output_type: type of the output.
  73. directory: where to save outputs.
  74. Returns:
  75. output tensor.
  76. """
  77. outputs = self.visualize(images, edge_maps, output_type)
  78. self._save_outputs(images, directory, suffix="_src")
  79. self._save_outputs(outputs, directory, suffix="_edge")
  80. def to_onnx( # type: ignore[override]
  81. self,
  82. onnx_name: Optional[str] = None,
  83. image_size: Optional[int] = 352,
  84. include_pre_and_post_processor: bool = True,
  85. save: bool = True,
  86. additional_metadata: Optional[list[tuple[str, str]]] = None,
  87. **kwargs: Any,
  88. ) -> onnx.ModelProto: # type: ignore
  89. """Export the current edge detection model to an ONNX model file.
  90. Args:
  91. onnx_name:
  92. The name of the output ONNX file. If not provided, a default name in the
  93. format "Kornia-<ClassName>.onnx" will be used.
  94. image_size:
  95. The size to which input images will be resized during preprocessing.
  96. If None, image_size will be dynamic. For DexiNed, recommended scale is 352.
  97. include_pre_and_post_processor:
  98. Whether to include the pre-processor and post-processor in the exported model.
  99. save:
  100. If to save the model or load it.
  101. additional_metadata:
  102. Additional metadata to add to the ONNX model.
  103. kwargs: Additional arguments to convert to onnx.
  104. """
  105. if onnx_name is None:
  106. onnx_name = f"kornia_{self.name}_{image_size}.onnx"
  107. return super().to_onnx(
  108. onnx_name,
  109. input_shape=[-1, 3, image_size or -1, image_size or -1],
  110. output_shape=[-1, 1, image_size or -1, image_size or -1],
  111. pseudo_shape=[1, 3, image_size or 352, image_size or 352],
  112. model=self if include_pre_and_post_processor else self.model,
  113. save=save,
  114. additional_metadata=additional_metadata,
  115. **kwargs,
  116. )