pose_utils.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from __future__ import annotations
  2. from typing import Any
  3. import numpy as np
  4. from PIL import Image
  5. from tqdm.auto import tqdm
  6. from ultralytics.engine.results import Results
  7. from ultralytics.models.yolo.pose import PosePredictor
  8. from ultralytics.utils.plotting import Annotator
  9. import wandb
  10. from wandb.integration.ultralytics.bbox_utils import (
  11. get_boxes,
  12. get_ground_truth_bbox_annotations,
  13. )
  14. def annotate_keypoint_results(result: Results, visualize_skeleton: bool):
  15. annotator = Annotator(np.ascontiguousarray(result.orig_img[:, :, ::-1]))
  16. key_points = result.keypoints.data.numpy()
  17. for idx in range(key_points.shape[0]):
  18. annotator.kpts(key_points[idx], kpt_line=visualize_skeleton)
  19. return annotator.im
  20. def annotate_keypoint_batch(image_path: str, keypoints: Any, visualize_skeleton: bool):
  21. with Image.open(image_path) as original_image:
  22. original_image = np.ascontiguousarray(original_image)
  23. annotator = Annotator(original_image)
  24. annotator.kpts(keypoints.numpy(), kpt_line=visualize_skeleton)
  25. return annotator.im
  26. def plot_pose_predictions(
  27. result: Results,
  28. model_name: str,
  29. visualize_skeleton: bool,
  30. table: wandb.Table | None = None,
  31. ):
  32. result = result.to("cpu")
  33. boxes, mean_confidence_map = get_boxes(result)
  34. annotated_image = annotate_keypoint_results(result, visualize_skeleton)
  35. prediction_image = wandb.Image(annotated_image, boxes=boxes)
  36. table_row = [
  37. model_name,
  38. prediction_image,
  39. len(boxes["predictions"]["box_data"]),
  40. mean_confidence_map,
  41. result.speed,
  42. ]
  43. if table is not None:
  44. table.add_data(*table_row)
  45. return table
  46. return table_row
  47. def plot_pose_validation_results(
  48. dataloader,
  49. class_label_map,
  50. model_name: str,
  51. predictor: PosePredictor,
  52. visualize_skeleton: bool,
  53. table: wandb.Table,
  54. max_validation_batches: int,
  55. epoch: int | None = None,
  56. ) -> wandb.Table:
  57. data_idx = 0
  58. num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size
  59. max_validation_batches = min(max_validation_batches, num_dataloader_batches)
  60. for batch_idx, batch in enumerate(dataloader):
  61. prediction_results = predictor(batch["im_file"])
  62. progress_bar_result_iterable = tqdm(
  63. enumerate(prediction_results),
  64. total=len(prediction_results),
  65. desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}",
  66. )
  67. for img_idx, prediction_result in progress_bar_result_iterable:
  68. prediction_result = prediction_result.to("cpu")
  69. table_row = plot_pose_predictions(
  70. prediction_result, model_name, visualize_skeleton
  71. )
  72. ground_truth_image = wandb.Image(
  73. annotate_keypoint_batch(
  74. batch["im_file"][img_idx],
  75. batch["keypoints"][img_idx],
  76. visualize_skeleton,
  77. ),
  78. boxes={
  79. "ground-truth": {
  80. "box_data": get_ground_truth_bbox_annotations(
  81. img_idx, batch["im_file"][img_idx], batch, class_label_map
  82. ),
  83. "class_labels": class_label_map,
  84. },
  85. },
  86. )
  87. table_row = [data_idx, batch_idx, ground_truth_image] + table_row[1:]
  88. table_row = [epoch] + table_row if epoch is not None else table_row
  89. table_row = [model_name] + table_row
  90. table.add_data(*table_row)
  91. data_idx += 1
  92. if batch_idx + 1 == max_validation_batches:
  93. break
  94. return table