boxmot_tracker.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import datetime
  19. import logging
  20. import os
  21. from pathlib import Path
  22. from typing import Any, Optional, Union
  23. from kornia.config import kornia_config
  24. from kornia.core import Tensor, tensor
  25. from kornia.core.external import boxmot
  26. from kornia.core.external import numpy as np
  27. from kornia.io import write_image
  28. from kornia.models.detection.base import ObjectDetector
  29. from kornia.models.detection.rtdetr import RTDETRDetectorBuilder
  30. from kornia.utils.image import tensor_to_image
  31. __all__ = ["BoxMotTracker"]
  32. logger = logging.getLogger(__name__)
  33. class BoxMotTracker:
  34. """BoxMotTracker is a module that wraps a detector and a tracker model.
  35. This module uses BoxMot library for tracking.
  36. Args:
  37. detector: ObjectDetector: The detector model.
  38. tracker_model_name: The name of the tracker model. Valid options are:
  39. - "BoTSORT"
  40. - "DeepOCSORT"
  41. - "OCSORT"
  42. - "HybridSORT"
  43. - "ByteTrack"
  44. - "StrongSORT"
  45. - "ImprAssoc"
  46. tracker_model_weights: Path to the model weights for ReID (Re-Identification).
  47. device: Device on which to run the model (e.g., 'cpu' or 'cuda').
  48. fp16: Whether to use half-precision (fp16) for faster inference on compatible devices.
  49. per_class: Whether to perform per-class tracking
  50. track_high_thresh: High threshold for detection confidence.
  51. Detections above this threshold are used in the first association round.
  52. track_low_thresh: Low threshold for detection confidence.
  53. Detections below this threshold are ignored.
  54. new_track_thresh: Threshold for creating a new track.
  55. Detections above this threshold will be considered as potential new tracks.
  56. track_buffer: Number of frames to keep a track alive after it was last detected.
  57. match_thresh: Threshold for the matching step in data association.
  58. proximity_thresh: Threshold for IoU (Intersection over Union) distance in first-round association.
  59. appearance_thresh: Threshold for appearance embedding distance in the ReID module.
  60. cmc_method: Method for correcting camera motion. Options include "sof" (simple optical flow).
  61. frame_rate: Frame rate of the video being processed. Used to scale the track buffer size.
  62. fuse_first_associate: Whether to fuse appearance and motion information during the first association step.
  63. with_reid: Whether to use ReID (Re-Identification) features for association.
  64. .. code-block:: python
  65. import kornia
  66. image = kornia.utils.sample.get_sample_images()[0][None]
  67. model = BoxMotTracker()
  68. for i in range(4): # At least 4 frames are needed to initialize the tracking position
  69. model.update(image)
  70. model.save(image)
  71. .. note::
  72. At least 4 frames are needed to initialize the tracking position.
  73. """
  74. name: str = "boxmot_tracker"
  75. def __init__(
  76. self,
  77. detector: Union[ObjectDetector, str] = "rtdetr_r18vd",
  78. tracker_model_name: str = "DeepOCSORT",
  79. tracker_model_weights: str = "osnet_x0_25_msmt17.pt",
  80. device: str = "cpu",
  81. fp16: bool = False,
  82. **kwargs: Any,
  83. ) -> None:
  84. super().__init__()
  85. if isinstance(detector, str):
  86. if detector.startswith("rtdetr"):
  87. detector = RTDETRDetectorBuilder.build(model_name=detector)
  88. else:
  89. raise ValueError(
  90. f"Detector `{detector}` not available. You may pass an ObjectDetector instance instead."
  91. )
  92. self.detector = detector
  93. os.makedirs(f"{kornia_config.hub_models_dir}/boxmot", exist_ok=True)
  94. self.tracker = getattr(boxmot, tracker_model_name)(
  95. model_weights=Path(os.path.join(f"{kornia_config.hub_models_dir}/boxmot", tracker_model_weights)),
  96. device=device,
  97. fp16=fp16,
  98. **kwargs,
  99. )
  100. def update(self, image: Tensor) -> None:
  101. """Update the tracker with a new image.
  102. Args:
  103. image: The input image.
  104. """
  105. if not (image.ndim == 4 and image.shape[0] == 1) and not image.ndim == 3:
  106. raise ValueError(f"Input tensor must be of shape (1, 3, H, W) or (3, H, W). Got {image.shape}")
  107. if image.ndim == 3:
  108. image = image.unsqueeze(0)
  109. detections_raw: Union[Tensor, list[Tensor]] = self.detector(image)
  110. detections = detections_raw[0].cpu().numpy() # Batch size is 1
  111. detections = np.array( # type: ignore
  112. [
  113. detections[:, 2],
  114. detections[:, 3],
  115. detections[:, 2] + detections[:, 4],
  116. detections[:, 3] + detections[:, 5],
  117. detections[:, 1],
  118. detections[:, 0],
  119. ]
  120. ).T
  121. if detections.shape[0] == 0:
  122. # empty N X (x, y, x, y, conf, cls)
  123. detections = np.empty((0, 6)) # type: ignore
  124. frame_raw = (tensor_to_image(image) * 255).astype(np.uint8)
  125. # --> M X (x, y, x, y, id, conf, cls, ind)
  126. return self.tracker.update(detections, frame_raw)
  127. def visualize(self, image: Tensor, show_trajectories: bool = True) -> Tensor:
  128. """Visualize the results of the tracker.
  129. Args:
  130. image: The input image.
  131. show_trajectories: Whether to show the trajectories.
  132. Returns:
  133. The image with the results of the tracker.
  134. """
  135. frame_raw = (tensor_to_image(image) * 255).astype(np.uint8)
  136. self.tracker.plot_results(frame_raw, show_trajectories=show_trajectories)
  137. return tensor(frame_raw).permute(2, 0, 1)
  138. def save(self, image: Tensor, show_trajectories: bool = True, directory: Optional[str] = None) -> None:
  139. """Save the model to ONNX format.
  140. Args:
  141. image: The input image.
  142. show_trajectories: Whether to visualize trajectories.
  143. directory: Where to save the file(s).
  144. """
  145. if directory is None:
  146. name = f"{self.name}_{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}"
  147. directory = os.path.join("kornia_outputs", name)
  148. output = self.visualize(image, show_trajectories=show_trajectories)
  149. os.makedirs(directory, exist_ok=True)
  150. write_image(
  151. os.path.join(directory, f"{str(0).zfill(6)}.jpg"),
  152. output.byte(),
  153. )
  154. logger.info(f"Outputs are saved in {directory}")