| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- import datetime
- import logging
- import os
- from pathlib import Path
- from typing import Any, Optional, Union
- from kornia.config import kornia_config
- from kornia.core import Tensor, tensor
- from kornia.core.external import boxmot
- from kornia.core.external import numpy as np
- from kornia.io import write_image
- from kornia.models.detection.base import ObjectDetector
- from kornia.models.detection.rtdetr import RTDETRDetectorBuilder
- from kornia.utils.image import tensor_to_image
- __all__ = ["BoxMotTracker"]
- logger = logging.getLogger(__name__)
- class BoxMotTracker:
- """BoxMotTracker is a module that wraps a detector and a tracker model.
- This module uses BoxMot library for tracking.
- Args:
- detector: ObjectDetector: The detector model.
- tracker_model_name: The name of the tracker model. Valid options are:
- - "BoTSORT"
- - "DeepOCSORT"
- - "OCSORT"
- - "HybridSORT"
- - "ByteTrack"
- - "StrongSORT"
- - "ImprAssoc"
- tracker_model_weights: Path to the model weights for ReID (Re-Identification).
- device: Device on which to run the model (e.g., 'cpu' or 'cuda').
- fp16: Whether to use half-precision (fp16) for faster inference on compatible devices.
- per_class: Whether to perform per-class tracking
- track_high_thresh: High threshold for detection confidence.
- Detections above this threshold are used in the first association round.
- track_low_thresh: Low threshold for detection confidence.
- Detections below this threshold are ignored.
- new_track_thresh: Threshold for creating a new track.
- Detections above this threshold will be considered as potential new tracks.
- track_buffer: Number of frames to keep a track alive after it was last detected.
- match_thresh: Threshold for the matching step in data association.
- proximity_thresh: Threshold for IoU (Intersection over Union) distance in first-round association.
- appearance_thresh: Threshold for appearance embedding distance in the ReID module.
- cmc_method: Method for correcting camera motion. Options include "sof" (simple optical flow).
- frame_rate: Frame rate of the video being processed. Used to scale the track buffer size.
- fuse_first_associate: Whether to fuse appearance and motion information during the first association step.
- with_reid: Whether to use ReID (Re-Identification) features for association.
- .. code-block:: python
- import kornia
- image = kornia.utils.sample.get_sample_images()[0][None]
- model = BoxMotTracker()
- for i in range(4): # At least 4 frames are needed to initialize the tracking position
- model.update(image)
- model.save(image)
- .. note::
- At least 4 frames are needed to initialize the tracking position.
- """
- name: str = "boxmot_tracker"
- def __init__(
- self,
- detector: Union[ObjectDetector, str] = "rtdetr_r18vd",
- tracker_model_name: str = "DeepOCSORT",
- tracker_model_weights: str = "osnet_x0_25_msmt17.pt",
- device: str = "cpu",
- fp16: bool = False,
- **kwargs: Any,
- ) -> None:
- super().__init__()
- if isinstance(detector, str):
- if detector.startswith("rtdetr"):
- detector = RTDETRDetectorBuilder.build(model_name=detector)
- else:
- raise ValueError(
- f"Detector `{detector}` not available. You may pass an ObjectDetector instance instead."
- )
- self.detector = detector
- os.makedirs(f"{kornia_config.hub_models_dir}/boxmot", exist_ok=True)
- self.tracker = getattr(boxmot, tracker_model_name)(
- model_weights=Path(os.path.join(f"{kornia_config.hub_models_dir}/boxmot", tracker_model_weights)),
- device=device,
- fp16=fp16,
- **kwargs,
- )
- def update(self, image: Tensor) -> None:
- """Update the tracker with a new image.
- Args:
- image: The input image.
- """
- if not (image.ndim == 4 and image.shape[0] == 1) and not image.ndim == 3:
- raise ValueError(f"Input tensor must be of shape (1, 3, H, W) or (3, H, W). Got {image.shape}")
- if image.ndim == 3:
- image = image.unsqueeze(0)
- detections_raw: Union[Tensor, list[Tensor]] = self.detector(image)
- detections = detections_raw[0].cpu().numpy() # Batch size is 1
- detections = np.array( # type: ignore
- [
- detections[:, 2],
- detections[:, 3],
- detections[:, 2] + detections[:, 4],
- detections[:, 3] + detections[:, 5],
- detections[:, 1],
- detections[:, 0],
- ]
- ).T
- if detections.shape[0] == 0:
- # empty N X (x, y, x, y, conf, cls)
- detections = np.empty((0, 6)) # type: ignore
- frame_raw = (tensor_to_image(image) * 255).astype(np.uint8)
- # --> M X (x, y, x, y, id, conf, cls, ind)
- return self.tracker.update(detections, frame_raw)
- def visualize(self, image: Tensor, show_trajectories: bool = True) -> Tensor:
- """Visualize the results of the tracker.
- Args:
- image: The input image.
- show_trajectories: Whether to show the trajectories.
- Returns:
- The image with the results of the tracker.
- """
- frame_raw = (tensor_to_image(image) * 255).astype(np.uint8)
- self.tracker.plot_results(frame_raw, show_trajectories=show_trajectories)
- return tensor(frame_raw).permute(2, 0, 1)
- def save(self, image: Tensor, show_trajectories: bool = True, directory: Optional[str] = None) -> None:
- """Save the model to ONNX format.
- Args:
- image: The input image.
- show_trajectories: Whether to visualize trajectories.
- directory: Where to save the file(s).
- """
- if directory is None:
- name = f"{self.name}_{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}"
- directory = os.path.join("kornia_outputs", name)
- output = self.visualize(image, show_trajectories=show_trajectories)
- os.makedirs(directory, exist_ok=True)
- write_image(
- os.path.join(directory, f"{str(0).zfill(6)}.jpg"),
- output.byte(),
- )
- logger.info(f"Outputs are saved in {directory}")
|