bbox_utils.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. from __future__ import annotations
  2. from typing import Any
  3. import torch
  4. from tqdm.auto import tqdm
  5. from ultralytics.engine.results import Results
  6. from ultralytics.models.yolo.detect import DetectionPredictor
  7. from ultralytics.utils import ops
  8. import wandb
  9. def scale_bounding_box_to_original_image_shape(
  10. box: torch.Tensor,
  11. resized_image_shape: tuple,
  12. original_image_shape: tuple,
  13. ratio_pad: bool,
  14. ) -> list[int]:
  15. """YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
  16. This function rescales the bounding box labels to the original
  17. image shape.
  18. Reference: https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/utils/callbacks/comet.py#L105
  19. """
  20. resized_image_height, resized_image_width = resized_image_shape
  21. # Convert normalized xywh format predictions to xyxy in resized scale format
  22. box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
  23. # Scale box predictions from resized image scale back to original image scale
  24. box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
  25. # # Convert bounding box format from xyxy to xywh for Comet logging
  26. box = ops.xyxy2xywh(box)
  27. return box.tolist()
  28. def get_ground_truth_bbox_annotations(
  29. img_idx: int, image_path: str, batch: dict, class_name_map: dict = None
  30. ) -> list[dict[str, Any]]:
  31. """Get ground truth bounding box annotation data in the form required for `wandb.Image` overlay system."""
  32. indices = batch["batch_idx"] == img_idx
  33. bboxes = batch["bboxes"][indices]
  34. if len(batch["cls"][indices]):
  35. cls_labels = batch["cls"][indices].squeeze(1).tolist()
  36. else:
  37. cls_labels = []
  38. class_name_map_reverse = {v: k for k, v in class_name_map.items()}
  39. if len(bboxes) == 0:
  40. wandb.termwarn(
  41. f"Image: {image_path} has no bounding boxes labels", repeat=False
  42. )
  43. return None
  44. if len(batch["cls"][indices]):
  45. cls_labels = batch["cls"][indices].squeeze(1).tolist()
  46. else:
  47. cls_labels = []
  48. if class_name_map:
  49. cls_labels = [str(class_name_map[label]) for label in cls_labels]
  50. original_image_shape = batch["ori_shape"][img_idx]
  51. resized_image_shape = batch["resized_shape"][img_idx]
  52. ratio_pad = batch["ratio_pad"][img_idx]
  53. data = []
  54. for box, label in zip(bboxes, cls_labels):
  55. box = scale_bounding_box_to_original_image_shape(
  56. box, resized_image_shape, original_image_shape, ratio_pad
  57. )
  58. data.append(
  59. {
  60. "position": {
  61. "middle": [int(box[0]), int(box[1])],
  62. "width": int(box[2]),
  63. "height": int(box[3]),
  64. },
  65. "domain": "pixel",
  66. "class_id": class_name_map_reverse[label],
  67. "box_caption": label,
  68. }
  69. )
  70. return data
  71. def get_mean_confidence_map(
  72. classes: list, confidence: list, class_id_to_label: dict
  73. ) -> dict[str, float]:
  74. """Get Mean-confidence map from the predictions to be logged into a `wandb.Table`."""
  75. confidence_map = {v: [] for _, v in class_id_to_label.items()}
  76. for class_idx, confidence_value in zip(classes, confidence):
  77. confidence_map[class_id_to_label[class_idx]].append(confidence_value)
  78. updated_confidence_map = {}
  79. for label, confidence_list in confidence_map.items():
  80. if len(confidence_list) > 0:
  81. updated_confidence_map[label] = sum(confidence_list) / len(confidence_list)
  82. else:
  83. updated_confidence_map[label] = 0
  84. return updated_confidence_map
  85. def get_boxes(result: Results) -> tuple[dict, dict]:
  86. """Convert an ultralytics prediction result into metadata for the `wandb.Image` overlay system."""
  87. boxes = result.boxes.xywh.long().numpy()
  88. classes = result.boxes.cls.long().numpy()
  89. confidence = result.boxes.conf.numpy()
  90. class_id_to_label = {int(k): str(v) for k, v in result.names.items()}
  91. mean_confidence_map = get_mean_confidence_map(
  92. classes, confidence, class_id_to_label
  93. )
  94. box_data = []
  95. for idx in range(len(boxes)):
  96. box_data.append(
  97. {
  98. "position": {
  99. "middle": [int(boxes[idx][0]), int(boxes[idx][1])],
  100. "width": int(boxes[idx][2]),
  101. "height": int(boxes[idx][3]),
  102. },
  103. "domain": "pixel",
  104. "class_id": int(classes[idx]),
  105. "box_caption": class_id_to_label[int(classes[idx])],
  106. "scores": {"confidence": float(confidence[idx])},
  107. }
  108. )
  109. boxes = {
  110. "predictions": {
  111. "box_data": box_data,
  112. "class_labels": class_id_to_label,
  113. },
  114. }
  115. return boxes, mean_confidence_map
  116. def plot_bbox_predictions(
  117. result: Results, model_name: str, table: wandb.Table | None = None
  118. ) -> wandb.Table | tuple[wandb.Image, dict, dict]:
  119. """Plot the images with the W&B overlay system.
  120. The `wandb.Image` is either added to a `wandb.Table` or returned.
  121. """
  122. result = result.to("cpu")
  123. boxes, mean_confidence_map = get_boxes(result)
  124. image = wandb.Image(result.orig_img[:, :, ::-1], boxes=boxes)
  125. if table is not None:
  126. table.add_data(
  127. model_name,
  128. image,
  129. len(boxes["predictions"]["box_data"]),
  130. mean_confidence_map,
  131. result.speed,
  132. )
  133. return table
  134. return image, boxes["predictions"], mean_confidence_map
  135. def plot_detection_validation_results(
  136. dataloader: Any,
  137. class_label_map: dict,
  138. model_name: str,
  139. predictor: DetectionPredictor,
  140. table: wandb.Table,
  141. max_validation_batches: int,
  142. epoch: int | None = None,
  143. ) -> wandb.Table:
  144. """Plot validation results in a table."""
  145. data_idx = 0
  146. num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size
  147. max_validation_batches = min(max_validation_batches, num_dataloader_batches)
  148. for batch_idx, batch in enumerate(dataloader):
  149. prediction_results = predictor(batch["im_file"])
  150. progress_bar_result_iterable = tqdm(
  151. enumerate(prediction_results),
  152. total=len(prediction_results),
  153. desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}",
  154. )
  155. for img_idx, prediction_result in progress_bar_result_iterable:
  156. prediction_result = prediction_result.to("cpu")
  157. _, prediction_box_data, mean_confidence_map = plot_bbox_predictions(
  158. prediction_result, model_name
  159. )
  160. try:
  161. ground_truth_data = get_ground_truth_bbox_annotations(
  162. img_idx, batch["im_file"][img_idx], batch, class_label_map
  163. )
  164. wandb_image = wandb.Image(
  165. batch["im_file"][img_idx],
  166. boxes={
  167. "ground-truth": {
  168. "box_data": ground_truth_data,
  169. "class_labels": class_label_map,
  170. },
  171. "predictions": {
  172. "box_data": prediction_box_data["box_data"],
  173. "class_labels": class_label_map,
  174. },
  175. },
  176. )
  177. table_rows = [
  178. data_idx,
  179. batch_idx,
  180. wandb_image,
  181. mean_confidence_map,
  182. prediction_result.speed,
  183. ]
  184. table_rows = [epoch] + table_rows if epoch is not None else table_rows
  185. table_rows = [model_name] + table_rows
  186. table.add_data(*table_rows)
  187. data_idx += 1
  188. except TypeError:
  189. pass
  190. if batch_idx + 1 == max_validation_batches:
  191. break
  192. return table