| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- from __future__ import annotations
- from typing import Any
- import numpy as np
- from PIL import Image
- from tqdm.auto import tqdm
- from ultralytics.engine.results import Results
- from ultralytics.models.yolo.pose import PosePredictor
- from ultralytics.utils.plotting import Annotator
- import wandb
- from wandb.integration.ultralytics.bbox_utils import (
- get_boxes,
- get_ground_truth_bbox_annotations,
- )
- def annotate_keypoint_results(result: Results, visualize_skeleton: bool):
- annotator = Annotator(np.ascontiguousarray(result.orig_img[:, :, ::-1]))
- key_points = result.keypoints.data.numpy()
- for idx in range(key_points.shape[0]):
- annotator.kpts(key_points[idx], kpt_line=visualize_skeleton)
- return annotator.im
- def annotate_keypoint_batch(image_path: str, keypoints: Any, visualize_skeleton: bool):
- with Image.open(image_path) as original_image:
- original_image = np.ascontiguousarray(original_image)
- annotator = Annotator(original_image)
- annotator.kpts(keypoints.numpy(), kpt_line=visualize_skeleton)
- return annotator.im
- def plot_pose_predictions(
- result: Results,
- model_name: str,
- visualize_skeleton: bool,
- table: wandb.Table | None = None,
- ):
- result = result.to("cpu")
- boxes, mean_confidence_map = get_boxes(result)
- annotated_image = annotate_keypoint_results(result, visualize_skeleton)
- prediction_image = wandb.Image(annotated_image, boxes=boxes)
- table_row = [
- model_name,
- prediction_image,
- len(boxes["predictions"]["box_data"]),
- mean_confidence_map,
- result.speed,
- ]
- if table is not None:
- table.add_data(*table_row)
- return table
- return table_row
- def plot_pose_validation_results(
- dataloader,
- class_label_map,
- model_name: str,
- predictor: PosePredictor,
- visualize_skeleton: bool,
- table: wandb.Table,
- max_validation_batches: int,
- epoch: int | None = None,
- ) -> wandb.Table:
- data_idx = 0
- num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size
- max_validation_batches = min(max_validation_batches, num_dataloader_batches)
- for batch_idx, batch in enumerate(dataloader):
- prediction_results = predictor(batch["im_file"])
- progress_bar_result_iterable = tqdm(
- enumerate(prediction_results),
- total=len(prediction_results),
- desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}",
- )
- for img_idx, prediction_result in progress_bar_result_iterable:
- prediction_result = prediction_result.to("cpu")
- table_row = plot_pose_predictions(
- prediction_result, model_name, visualize_skeleton
- )
- ground_truth_image = wandb.Image(
- annotate_keypoint_batch(
- batch["im_file"][img_idx],
- batch["keypoints"][img_idx],
- visualize_skeleton,
- ),
- boxes={
- "ground-truth": {
- "box_data": get_ground_truth_bbox_annotations(
- img_idx, batch["im_file"][img_idx], batch, class_label_map
- ),
- "class_labels": class_label_map,
- },
- },
- )
- table_row = [data_idx, batch_idx, ground_truth_image] + table_row[1:]
- table_row = [epoch] + table_row if epoch is not None else table_row
- table_row = [model_name] + table_row
- table.add_data(*table_row)
- data_idx += 1
- if batch_idx + 1 == max_validation_batches:
- break
- return table
|