classification_utils.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from __future__ import annotations
  2. from typing import Any
  3. import numpy as np
  4. from tqdm.auto import tqdm
  5. from ultralytics.engine.results import Results
  6. from ultralytics.models.yolo.classify import ClassificationPredictor
  7. import wandb
  8. def plot_classification_predictions(
  9. result: Results,
  10. model_name: str,
  11. table: wandb.Table | None = None,
  12. original_image: np.array | None = None,
  13. ):
  14. """Plot classification prediction results to a `wandb.Table` if the table is passed otherwise return the data."""
  15. result = result.to("cpu")
  16. probabilities = result.probs
  17. probabilities_list = probabilities.data.numpy().tolist()
  18. class_id_to_label = {int(k): str(v) for k, v in result.names.items()}
  19. original_image = (
  20. wandb.Image(original_image)
  21. if original_image is not None
  22. else wandb.Image(result.orig_img)
  23. )
  24. table_row = [
  25. model_name,
  26. original_image,
  27. class_id_to_label[int(probabilities.top1)],
  28. probabilities.top1conf,
  29. [class_id_to_label[int(class_idx)] for class_idx in list(probabilities.top5)],
  30. [probabilities_list[int(class_idx)] for class_idx in list(probabilities.top5)],
  31. {
  32. class_id_to_label[int(class_idx)]: probability
  33. for class_idx, probability in enumerate(probabilities_list)
  34. },
  35. result.speed,
  36. ]
  37. if table is not None:
  38. table.add_data(*table_row)
  39. return table
  40. return class_id_to_label, table_row
  41. def plot_classification_validation_results(
  42. dataloader: Any,
  43. model_name: str,
  44. predictor: ClassificationPredictor,
  45. table: wandb.Table,
  46. max_validation_batches: int,
  47. epoch: int | None = None,
  48. ) -> wandb.Table:
  49. """Plot classification results to a `wandb.Table`."""
  50. data_idx = 0
  51. num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size
  52. max_validation_batches = min(max_validation_batches, num_dataloader_batches)
  53. for batch_idx, batch in enumerate(dataloader):
  54. image_batch = [
  55. image for image in np.transpose(batch["img"].numpy(), (0, 2, 3, 1))
  56. ]
  57. ground_truth = batch["cls"].numpy().tolist()
  58. progress_bar_result_iterable = tqdm(
  59. range(max_validation_batches),
  60. desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}",
  61. )
  62. for img_idx in progress_bar_result_iterable:
  63. try:
  64. prediction_result = predictor(image_batch[img_idx])[0]
  65. class_id_to_label, table_row = plot_classification_predictions(
  66. prediction_result, model_name, original_image=image_batch[img_idx]
  67. )
  68. table_row = [data_idx, batch_idx] + table_row[1:]
  69. table_row.insert(3, class_id_to_label[ground_truth[img_idx]])
  70. table_row = [epoch] + table_row if epoch is not None else table_row
  71. table_row = [model_name] + table_row
  72. table.add_data(*table_row)
  73. data_idx += 1
  74. except Exception:
  75. pass
  76. if batch_idx + 1 == max_validation_batches:
  77. break
  78. return table