# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from dataclasses import dataclass from enum import Enum from typing import Any, Optional, Union import torch from kornia.core import Tensor from kornia.core.check import KORNIA_CHECK_SHAPE from kornia.core.external import PILImage as Image from kornia.core.external import onnx from kornia.models.base import ModelBase from kornia.utils.draw import draw_rectangle __all__ = [ "BoundingBox", "BoundingBoxDataFormat", "ObjectDetector", "ObjectDetectorResult", "results_from_detections", ] class BoundingBoxDataFormat(Enum): """Enum class that maps bounding box data format.""" XYWH = 0 XYXY = 1 CXCYWH = 2 CENTER_XYWH = 2 # NOTE: probably we should use a more generic name like BoundingBox2D # and add a BoundingBox3D class for 3D bounding boxes. Also for serialization # we should have an explicit class for each format to make it more production ready # specially to serialize to protobuf and not saturate at a high rates. @dataclass(frozen=True) class BoundingBox: """Bounding box data class. Useful for representing bounding boxes in different formats for object detection. Args: data: tuple of bounding box data. The length of the tuple depends on the data format. data_format: bounding box data format. """ data: tuple[float, float, float, float] data_format: BoundingBoxDataFormat @dataclass(frozen=True) class ObjectDetectorResult: """Object detection result. Args: class_id: class id of the detected object. confidence: confidence score of the detected object. bbox: bounding box of the detected object in xywh format. """ class_id: int confidence: float bbox: BoundingBox def results_from_detections(detections: Tensor, format: str | BoundingBoxDataFormat) -> list[ObjectDetectorResult]: """Convert a detection tensor to a list of :py:class:`ObjectDetectorResult`. Args: detections: tensor with shape :math:`(D, 6)`, where :math:`D` is the number of detections in the given image, :math:`6` represents class id, score, and `xywh` bounding box. format: detection format. Returns: list of :py:class:`ObjectDetectorResult`. """ KORNIA_CHECK_SHAPE(detections, ["D", "6"]) if isinstance(format, str): format = BoundingBoxDataFormat[format.upper()] results: list[ObjectDetectorResult] = [] for det in detections: det = det.squeeze().tolist() if len(det) != 6: continue results.append( ObjectDetectorResult( class_id=int(det[0]), confidence=det[1], bbox=BoundingBox(data=(det[2], det[3], det[4], det[5]), data_format=format), ) ) return results class ObjectDetector(ModelBase): """Wrap an object detection model and perform pre-processing and post-processing.""" name: str = "detection" @torch.inference_mode() def forward(self, images: Union[Tensor, list[Tensor]]) -> Union[Tensor, list[Tensor]]: """Detect objects in a given list of images. Args: images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`. If Tensor, a Tensor with shape :math:`(B, 3, H, W)`. Returns: list of detections found in each image. For item in a batch, shape is :math:`(D, 6)`, where :math:`D` is the number of detections in the given image, :math:`6` represents class id, score, and `xywh` bounding box. """ images, images_sizes = self.pre_processor(images) logits, boxes = self.model(images) detections = self.post_processor(logits, boxes, images_sizes) return detections def visualize( self, images: Union[Tensor, list[Tensor]], detections: Optional[Tensor] = None, output_type: str = "torch" ) -> Union[Tensor, list[Tensor], list[Image.Image]]: # type: ignore """Very simple drawing. Needs to be more fancy later. """ dets = detections or self.forward(images) output = [] for image, detection in zip(images, dets): out_img = image[None].clone() for out in detection: out_img = draw_rectangle( out_img, torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]]), ) output.append(out_img[0]) return self._tensor_to_type(output, output_type, is_batch=isinstance(images, Tensor)) def save( self, images: Union[Tensor, list[Tensor]], detections: Optional[Tensor] = None, directory: Optional[str] = None ) -> None: """Save the output image(s) to a directory. Args: images: input tensor. detections: detection tensor. directory: directory to save the images. """ outputs = self.visualize(images, detections) self._save_outputs(outputs, directory) def to_onnx( # type: ignore[override] self, onnx_name: Optional[str] = None, image_size: Optional[int] = 640, include_pre_and_post_processor: bool = True, save: bool = True, additional_metadata: Optional[list[tuple[str, str]]] = None, **kwargs: Any, ) -> onnx.ModelProto: # type: ignore """Export an RT-DETR object detection model to ONNX format. Either `model_name` or `config` must be provided. If neither is provided, a default pretrained model (`rtdetr_r18vd`) will be built. Args: onnx_name: The name of the output ONNX file. If not provided, a default name in the format "Kornia-.onnx" will be used. image_size: The size to which input images will be resized during preprocessing. If None, image_size will be dynamic. For RTDETR, recommended scales include [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]. include_pre_and_post_processor: Whether to include the pre-processor and post-processor in the exported model. save: If to save the model or load it. additional_metadata: Additional metadata to add to the ONNX model. kwargs: Additional arguments to convert to onnx. """ if onnx_name is None: onnx_name = f"kornia_{self.name}_{image_size}.onnx" return super().to_onnx( onnx_name, input_shape=[-1, 3, image_size or -1, image_size or -1], output_shape=[-1, -1, 6], pseudo_shape=[1, 3, image_size or 352, image_size or 352], model=self if include_pre_and_post_processor else self.model, save=save, additional_metadata=additional_metadata, **kwargs, ) def compile( self, *, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: Optional[str] = None, options: Optional[dict[str, str | int | bool]] = None, disable: bool = False, ) -> None: """Compile the internal object detection model with :py:func:`torch.compile()`.""" self.model = torch.compile( # type: ignore self.model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode, options=options, disable=disable, )