base.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import datetime
  18. import logging
  19. import os
  20. from typing import List, Optional, Union
  21. from kornia.core import Module, Tensor, stack
  22. from kornia.core.external import PILImage as Image
  23. from kornia.core.external import numpy as np
  24. from kornia.core.mixin.onnx import ONNXExportMixin
  25. from kornia.io import write_image
  26. from kornia.utils.image import tensor_to_image
  27. logger = logging.getLogger(__name__)
  28. class ModelBaseMixin:
  29. name: str = "model"
  30. def _tensor_to_type(
  31. self, output: Union[Tensor, List[Tensor]], output_type: str, is_batch: bool = False
  32. ) -> Union[Tensor, List[Tensor], List["Image.Image"]]: # type: ignore
  33. """Convert the output tensor to the desired type.
  34. Args:
  35. output: The output tensor or list of tensors.
  36. output_type: The desired output type. Accepted values are "torch" and "pil".
  37. is_batch: If True, the output is expected to be a batch of tensors.
  38. Returns:
  39. The converted output tensor or list of tensors.
  40. Raises:
  41. RuntimeError: If the output type is not supported.
  42. """
  43. if output_type == "torch":
  44. if is_batch and not isinstance(output, Tensor):
  45. return stack(output)
  46. elif is_batch and isinstance(output, Tensor):
  47. return output
  48. elif not is_batch and isinstance(output, Tensor):
  49. return list(output)
  50. elif not is_batch and not isinstance(output, Tensor):
  51. return output
  52. return output
  53. elif output_type == "pil":
  54. out = [Image.fromarray((tensor_to_image(out_img) * 255).astype(np.uint8)) for out_img in output] # type: ignore
  55. return list(out)
  56. raise RuntimeError(f"Unsupported output type `{output_type}`.")
  57. def _save_outputs(
  58. self, outputs: Union[Tensor, List[Tensor]], directory: Optional[str] = None, suffix: str = ""
  59. ) -> None:
  60. """Save the output image(s) to a directory.
  61. Args:
  62. outputs: output tensor.
  63. directory: directory to save the images.
  64. suffix: filename suffix.
  65. """
  66. if directory is None:
  67. name = f"{self.name}_{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}"
  68. directory = os.path.join("kornia_outputs", name)
  69. os.makedirs(directory, exist_ok=True)
  70. for i, out_image in enumerate(outputs):
  71. write_image(
  72. os.path.join(directory, f"{str(i).zfill(6)}{suffix}.jpg"),
  73. out_image.mul(255.0).byte(),
  74. )
  75. logger.info(f"Outputs are saved in {directory}")
  76. class ModelBase(Module, ONNXExportMixin, ModelBaseMixin):
  77. """Wrap a model and perform pre-processing and post-processing."""
  78. def __init__(
  79. self, model: Module, pre_processor: Module, post_processor: Module, name: Optional[str] = None
  80. ) -> None:
  81. """Construct an Object Detector object.
  82. Args:
  83. model: an object detection model.
  84. pre_processor: a pre-processing module
  85. post_processor: a post-processing module.
  86. name: name of a model.
  87. """
  88. super().__init__()
  89. self.model = model.eval()
  90. self.pre_processor = pre_processor.eval()
  91. self.post_processor = post_processor.eval()
  92. if name is not None:
  93. self.name = name