base.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from dataclasses import dataclass
  19. from enum import Enum
  20. from typing import Any, Optional, Union
  21. import torch
  22. from kornia.core import Tensor
  23. from kornia.core.check import KORNIA_CHECK_SHAPE
  24. from kornia.core.external import PILImage as Image
  25. from kornia.core.external import onnx
  26. from kornia.models.base import ModelBase
  27. from kornia.utils.draw import draw_rectangle
  28. __all__ = [
  29. "BoundingBox",
  30. "BoundingBoxDataFormat",
  31. "ObjectDetector",
  32. "ObjectDetectorResult",
  33. "results_from_detections",
  34. ]
  35. class BoundingBoxDataFormat(Enum):
  36. """Enum class that maps bounding box data format."""
  37. XYWH = 0
  38. XYXY = 1
  39. CXCYWH = 2
  40. CENTER_XYWH = 2
  41. # NOTE: probably we should use a more generic name like BoundingBox2D
  42. # and add a BoundingBox3D class for 3D bounding boxes. Also for serialization
  43. # we should have an explicit class for each format to make it more production ready
  44. # specially to serialize to protobuf and not saturate at a high rates.
  45. @dataclass(frozen=True)
  46. class BoundingBox:
  47. """Bounding box data class.
  48. Useful for representing bounding boxes in different formats for object detection.
  49. Args:
  50. data: tuple of bounding box data. The length of the tuple depends on the data format.
  51. data_format: bounding box data format.
  52. """
  53. data: tuple[float, float, float, float]
  54. data_format: BoundingBoxDataFormat
  55. @dataclass(frozen=True)
  56. class ObjectDetectorResult:
  57. """Object detection result.
  58. Args:
  59. class_id: class id of the detected object.
  60. confidence: confidence score of the detected object.
  61. bbox: bounding box of the detected object in xywh format.
  62. """
  63. class_id: int
  64. confidence: float
  65. bbox: BoundingBox
  66. def results_from_detections(detections: Tensor, format: str | BoundingBoxDataFormat) -> list[ObjectDetectorResult]:
  67. """Convert a detection tensor to a list of :py:class:`ObjectDetectorResult`.
  68. Args:
  69. detections: tensor with shape :math:`(D, 6)`, where :math:`D` is the number of detections in the given image,
  70. :math:`6` represents class id, score, and `xywh` bounding box.
  71. format: detection format.
  72. Returns:
  73. list of :py:class:`ObjectDetectorResult`.
  74. """
  75. KORNIA_CHECK_SHAPE(detections, ["D", "6"])
  76. if isinstance(format, str):
  77. format = BoundingBoxDataFormat[format.upper()]
  78. results: list[ObjectDetectorResult] = []
  79. for det in detections:
  80. det = det.squeeze().tolist()
  81. if len(det) != 6:
  82. continue
  83. results.append(
  84. ObjectDetectorResult(
  85. class_id=int(det[0]),
  86. confidence=det[1],
  87. bbox=BoundingBox(data=(det[2], det[3], det[4], det[5]), data_format=format),
  88. )
  89. )
  90. return results
  91. class ObjectDetector(ModelBase):
  92. """Wrap an object detection model and perform pre-processing and post-processing."""
  93. name: str = "detection"
  94. @torch.inference_mode()
  95. def forward(self, images: Union[Tensor, list[Tensor]]) -> Union[Tensor, list[Tensor]]:
  96. """Detect objects in a given list of images.
  97. Args:
  98. images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
  99. If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.
  100. Returns:
  101. list of detections found in each image. For item in a batch, shape is :math:`(D, 6)`, where :math:`D` is the
  102. number of detections in the given image, :math:`6` represents class id, score, and `xywh` bounding box.
  103. """
  104. images, images_sizes = self.pre_processor(images)
  105. logits, boxes = self.model(images)
  106. detections = self.post_processor(logits, boxes, images_sizes)
  107. return detections
  108. def visualize(
  109. self, images: Union[Tensor, list[Tensor]], detections: Optional[Tensor] = None, output_type: str = "torch"
  110. ) -> Union[Tensor, list[Tensor], list[Image.Image]]: # type: ignore
  111. """Very simple drawing.
  112. Needs to be more fancy later.
  113. """
  114. dets = detections or self.forward(images)
  115. output = []
  116. for image, detection in zip(images, dets):
  117. out_img = image[None].clone()
  118. for out in detection:
  119. out_img = draw_rectangle(
  120. out_img,
  121. torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]]),
  122. )
  123. output.append(out_img[0])
  124. return self._tensor_to_type(output, output_type, is_batch=isinstance(images, Tensor))
  125. def save(
  126. self, images: Union[Tensor, list[Tensor]], detections: Optional[Tensor] = None, directory: Optional[str] = None
  127. ) -> None:
  128. """Save the output image(s) to a directory.
  129. Args:
  130. images: input tensor.
  131. detections: detection tensor.
  132. directory: directory to save the images.
  133. """
  134. outputs = self.visualize(images, detections)
  135. self._save_outputs(outputs, directory)
  136. def to_onnx( # type: ignore[override]
  137. self,
  138. onnx_name: Optional[str] = None,
  139. image_size: Optional[int] = 640,
  140. include_pre_and_post_processor: bool = True,
  141. save: bool = True,
  142. additional_metadata: Optional[list[tuple[str, str]]] = None,
  143. **kwargs: Any,
  144. ) -> onnx.ModelProto: # type: ignore
  145. """Export an RT-DETR object detection model to ONNX format.
  146. Either `model_name` or `config` must be provided. If neither is provided,
  147. a default pretrained model (`rtdetr_r18vd`) will be built.
  148. Args:
  149. onnx_name:
  150. The name of the output ONNX file. If not provided, a default name in the
  151. format "Kornia-<ClassName>.onnx" will be used.
  152. image_size:
  153. The size to which input images will be resized during preprocessing.
  154. If None, image_size will be dynamic.
  155. For RTDETR, recommended scales include [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800].
  156. include_pre_and_post_processor:
  157. Whether to include the pre-processor and post-processor in the exported model.
  158. save:
  159. If to save the model or load it.
  160. additional_metadata:
  161. Additional metadata to add to the ONNX model.
  162. kwargs: Additional arguments to convert to onnx.
  163. """
  164. if onnx_name is None:
  165. onnx_name = f"kornia_{self.name}_{image_size}.onnx"
  166. return super().to_onnx(
  167. onnx_name,
  168. input_shape=[-1, 3, image_size or -1, image_size or -1],
  169. output_shape=[-1, -1, 6],
  170. pseudo_shape=[1, 3, image_size or 352, image_size or 352],
  171. model=self if include_pre_and_post_processor else self.model,
  172. save=save,
  173. additional_metadata=additional_metadata,
  174. **kwargs,
  175. )
  176. def compile(
  177. self,
  178. *,
  179. fullgraph: bool = False,
  180. dynamic: bool = False,
  181. backend: str = "inductor",
  182. mode: Optional[str] = None,
  183. options: Optional[dict[str, str | int | bool]] = None,
  184. disable: bool = False,
  185. ) -> None:
  186. """Compile the internal object detection model with :py:func:`torch.compile()`."""
  187. self.model = torch.compile( # type: ignore
  188. self.model,
  189. fullgraph=fullgraph,
  190. dynamic=dynamic,
  191. backend=backend,
  192. mode=mode,
  193. options=options,
  194. disable=disable,
  195. )