| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import datetime
- import logging
- import os
- from typing import List, Optional, Union
- from kornia.core import Module, Tensor, stack
- from kornia.core.external import PILImage as Image
- from kornia.core.external import numpy as np
- from kornia.core.mixin.onnx import ONNXExportMixin
- from kornia.io import write_image
- from kornia.utils.image import tensor_to_image
- logger = logging.getLogger(__name__)
- class ModelBaseMixin:
- name: str = "model"
- def _tensor_to_type(
- self, output: Union[Tensor, List[Tensor]], output_type: str, is_batch: bool = False
- ) -> Union[Tensor, List[Tensor], List["Image.Image"]]: # type: ignore
- """Convert the output tensor to the desired type.
- Args:
- output: The output tensor or list of tensors.
- output_type: The desired output type. Accepted values are "torch" and "pil".
- is_batch: If True, the output is expected to be a batch of tensors.
- Returns:
- The converted output tensor or list of tensors.
- Raises:
- RuntimeError: If the output type is not supported.
- """
- if output_type == "torch":
- if is_batch and not isinstance(output, Tensor):
- return stack(output)
- elif is_batch and isinstance(output, Tensor):
- return output
- elif not is_batch and isinstance(output, Tensor):
- return list(output)
- elif not is_batch and not isinstance(output, Tensor):
- return output
- return output
- elif output_type == "pil":
- out = [Image.fromarray((tensor_to_image(out_img) * 255).astype(np.uint8)) for out_img in output] # type: ignore
- return list(out)
- raise RuntimeError(f"Unsupported output type `{output_type}`.")
- def _save_outputs(
- self, outputs: Union[Tensor, List[Tensor]], directory: Optional[str] = None, suffix: str = ""
- ) -> None:
- """Save the output image(s) to a directory.
- Args:
- outputs: output tensor.
- directory: directory to save the images.
- suffix: filename suffix.
- """
- if directory is None:
- name = f"{self.name}_{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}"
- directory = os.path.join("kornia_outputs", name)
- os.makedirs(directory, exist_ok=True)
- for i, out_image in enumerate(outputs):
- write_image(
- os.path.join(directory, f"{str(i).zfill(6)}{suffix}.jpg"),
- out_image.mul(255.0).byte(),
- )
- logger.info(f"Outputs are saved in {directory}")
- class ModelBase(Module, ONNXExportMixin, ModelBaseMixin):
- """Wrap a model and perform pre-processing and post-processing."""
- def __init__(
- self, model: Module, pre_processor: Module, post_processor: Module, name: Optional[str] = None
- ) -> None:
- """Construct an Object Detector object.
- Args:
- model: an object detection model.
- pre_processor: a pre-processing module
- post_processor: a post-processing module.
- name: name of a model.
- """
- super().__init__()
- self.model = model.eval()
- self.pre_processor = pre_processor.eval()
- self.post_processor = post_processor.eval()
- if name is not None:
- self.name = name
|