callback.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. from __future__ import annotations
  2. import copy
  3. from datetime import datetime
  4. from typing import Callable, Union
  5. from packaging import version
  6. try:
  7. import dill as pickle
  8. except ImportError:
  9. import pickle
  10. import wandb
  11. from wandb.sdk.lib import telemetry
  12. try:
  13. import torch
  14. import ultralytics
  15. from tqdm.auto import tqdm
  16. if version.parse(ultralytics.__version__) > version.parse("8.0.238"):
  17. wandb.termwarn(
  18. """This integration is tested and supported for ultralytics v8.0.238 and below.
  19. Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.""",
  20. repeat=False,
  21. )
  22. from ultralytics.models import YOLO
  23. from ultralytics.models.sam.predict import Predictor as SAMPredictor
  24. from ultralytics.models.yolo.classify import (
  25. ClassificationPredictor,
  26. ClassificationTrainer,
  27. ClassificationValidator,
  28. )
  29. from ultralytics.models.yolo.detect import (
  30. DetectionPredictor,
  31. DetectionTrainer,
  32. DetectionValidator,
  33. )
  34. from ultralytics.models.yolo.pose import PosePredictor, PoseTrainer, PoseValidator
  35. from ultralytics.models.yolo.segment import (
  36. SegmentationPredictor,
  37. SegmentationTrainer,
  38. SegmentationValidator,
  39. )
  40. from ultralytics.utils.torch_utils import de_parallel
  41. try:
  42. from ultralytics.yolo.utils import RANK, __version__
  43. except ModuleNotFoundError:
  44. from ultralytics.utils import RANK, __version__
  45. from wandb.integration.ultralytics.bbox_utils import (
  46. plot_bbox_predictions,
  47. plot_detection_validation_results,
  48. )
  49. from wandb.integration.ultralytics.classification_utils import (
  50. plot_classification_predictions,
  51. plot_classification_validation_results,
  52. )
  53. from wandb.integration.ultralytics.mask_utils import (
  54. plot_mask_predictions,
  55. plot_sam_predictions,
  56. plot_segmentation_validation_results,
  57. )
  58. from wandb.integration.ultralytics.pose_utils import (
  59. plot_pose_predictions,
  60. plot_pose_validation_results,
  61. )
  62. except Exception as e:
  63. wandb.Error(e)
  64. TRAINER_TYPE = Union[
  65. ClassificationTrainer, DetectionTrainer, SegmentationTrainer, PoseTrainer
  66. ]
  67. VALIDATOR_TYPE = Union[
  68. ClassificationValidator, DetectionValidator, SegmentationValidator, PoseValidator
  69. ]
  70. PREDICTOR_TYPE = Union[
  71. ClassificationPredictor,
  72. DetectionPredictor,
  73. SegmentationPredictor,
  74. PosePredictor,
  75. SAMPredictor,
  76. ]
  77. class WandBUltralyticsCallback:
  78. """Stateful callback for logging to W&B.
  79. In particular, it will log model checkpoints, predictions, and
  80. ground-truth annotations with interactive overlays for bounding boxes
  81. to Weights & Biases Tables during training, validation and prediction
  82. for a `ultratytics` workflow.
  83. Example:
  84. ```python
  85. from ultralytics.yolo.engine.model import YOLO
  86. from wandb.yolov8 import add_wandb_callback
  87. # initialize YOLO model
  88. model = YOLO("yolov8n.pt")
  89. # add wandb callback
  90. add_wandb_callback(
  91. model, max_validation_batches=2, enable_model_checkpointing=True
  92. )
  93. # train
  94. model.train(data="coco128.yaml", epochs=5, imgsz=640)
  95. # validate
  96. model.val()
  97. # perform inference
  98. model(["img1.jpeg", "img2.jpeg"])
  99. ```
  100. Args:
  101. model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type
  102. `ultralytics.yolo.engine.model.YOLO`.
  103. epoch_logging_interval: (int) interval to log the prediction visualizations
  104. during training.
  105. max_validation_batches: (int) maximum number of validation batches to log to
  106. a table per epoch.
  107. enable_model_checkpointing: (bool) enable logging model checkpoints as
  108. artifacts at the end of eveny epoch if set to `True`.
  109. visualize_skeleton: (bool) visualize pose skeleton by drawing lines connecting
  110. keypoints for human pose.
  111. """
  112. def __init__(
  113. self,
  114. model: YOLO,
  115. epoch_logging_interval: int = 1,
  116. max_validation_batches: int = 1,
  117. enable_model_checkpointing: bool = False,
  118. visualize_skeleton: bool = False,
  119. ) -> None:
  120. self.epoch_logging_interval = epoch_logging_interval
  121. self.max_validation_batches = max_validation_batches
  122. self.enable_model_checkpointing = enable_model_checkpointing
  123. self.visualize_skeleton = visualize_skeleton
  124. self.task = model.task
  125. self.task_map = model.task_map
  126. self.model_name = (
  127. str(model.overrides["model"]).split(".")[0]
  128. if "model" in model.overrides
  129. else None
  130. )
  131. self._make_tables()
  132. self._make_predictor(model)
  133. self.supported_tasks = ["detect", "segment", "pose", "classify"]
  134. self.prompts = None
  135. self.run_id = None
  136. self.train_epoch = None
  137. def _make_tables(self):
  138. if self.task in ["detect", "segment"]:
  139. validation_columns = [
  140. "Data-Index",
  141. "Batch-Index",
  142. "Image",
  143. "Mean-Confidence",
  144. "Speed",
  145. ]
  146. train_columns = ["Epoch"] + validation_columns
  147. self.train_validation_table = wandb.Table(
  148. columns=["Model-Name"] + train_columns
  149. )
  150. self.validation_table = wandb.Table(
  151. columns=["Model-Name"] + validation_columns
  152. )
  153. self.prediction_table = wandb.Table(
  154. columns=[
  155. "Model-Name",
  156. "Image",
  157. "Num-Objects",
  158. "Mean-Confidence",
  159. "Speed",
  160. ]
  161. )
  162. elif self.task == "classify":
  163. classification_columns = [
  164. "Image",
  165. "Predicted-Category",
  166. "Prediction-Confidence",
  167. "Top-5-Prediction-Categories",
  168. "Top-5-Prediction-Confindence",
  169. "Probabilities",
  170. "Speed",
  171. ]
  172. validation_columns = ["Data-Index", "Batch-Index"] + classification_columns
  173. validation_columns.insert(3, "Ground-Truth-Category")
  174. self.train_validation_table = wandb.Table(
  175. columns=["Model-Name", "Epoch"] + validation_columns
  176. )
  177. self.validation_table = wandb.Table(
  178. columns=["Model-Name"] + validation_columns
  179. )
  180. self.prediction_table = wandb.Table(
  181. columns=["Model-Name"] + classification_columns
  182. )
  183. elif self.task == "pose":
  184. validation_columns = [
  185. "Data-Index",
  186. "Batch-Index",
  187. "Image-Ground-Truth",
  188. "Image-Prediction",
  189. "Num-Instances",
  190. "Mean-Confidence",
  191. "Speed",
  192. ]
  193. train_columns = ["Epoch"] + validation_columns
  194. self.train_validation_table = wandb.Table(
  195. columns=["Model-Name"] + train_columns
  196. )
  197. self.validation_table = wandb.Table(
  198. columns=["Model-Name"] + validation_columns
  199. )
  200. self.prediction_table = wandb.Table(
  201. columns=[
  202. "Model-Name",
  203. "Image-Prediction",
  204. "Num-Instances",
  205. "Mean-Confidence",
  206. "Speed",
  207. ]
  208. )
  209. def _make_predictor(self, model: YOLO):
  210. overrides = copy.deepcopy(model.overrides)
  211. overrides["conf"] = 0.1
  212. self.predictor = self.task_map[self.task]["predictor"](overrides=overrides)
  213. self.predictor.callbacks = {}
  214. self.predictor.args.save = False
  215. self.predictor.args.save_txt = False
  216. self.predictor.args.save_crop = False
  217. self.predictor.args.verbose = None
  218. def _save_model(self, trainer: TRAINER_TYPE):
  219. model_checkpoint_artifact = wandb.Artifact(f"run_{wandb.run.id}_model", "model")
  220. checkpoint_dict = {
  221. "epoch": trainer.epoch,
  222. "best_fitness": trainer.best_fitness,
  223. "model": copy.deepcopy(de_parallel(self.model)).half(),
  224. "ema": copy.deepcopy(trainer.ema.ema).half(),
  225. "updates": trainer.ema.updates,
  226. "optimizer": trainer.optimizer.state_dict(),
  227. "train_args": vars(trainer.args),
  228. "date": datetime.now().isoformat(),
  229. "version": __version__,
  230. }
  231. checkpoint_path = trainer.wdir / f"epoch{trainer.epoch}.pt"
  232. torch.save(checkpoint_dict, checkpoint_path, pickle_module=pickle)
  233. model_checkpoint_artifact.add_file(checkpoint_path)
  234. wandb.log_artifact(
  235. model_checkpoint_artifact, aliases=[f"epoch_{trainer.epoch}"]
  236. )
  237. def on_train_start(self, trainer: TRAINER_TYPE):
  238. with telemetry.context(run=wandb.run) as tel:
  239. tel.feature.ultralytics_yolov8 = True
  240. wandb.config.train = vars(trainer.args)
  241. self.run_id = wandb.run.id
  242. @torch.no_grad()
  243. def on_fit_epoch_end(self, trainer: DetectionTrainer):
  244. if self.task in self.supported_tasks and self.train_epoch != trainer.epoch:
  245. self.train_epoch = trainer.epoch
  246. if (self.train_epoch + 1) % self.epoch_logging_interval == 0:
  247. validator = trainer.validator
  248. dataloader = validator.dataloader
  249. class_label_map = validator.names
  250. self.device = next(trainer.model.parameters()).device
  251. if isinstance(trainer.model, torch.nn.parallel.DistributedDataParallel):
  252. model = trainer.model.module
  253. else:
  254. model = trainer.model
  255. self.model = copy.deepcopy(model).eval().to(self.device)
  256. self.predictor.setup_model(model=self.model, verbose=False)
  257. if self.task == "pose":
  258. self.train_validation_table = plot_pose_validation_results(
  259. dataloader=dataloader,
  260. class_label_map=class_label_map,
  261. model_name=self.model_name,
  262. predictor=self.predictor,
  263. visualize_skeleton=self.visualize_skeleton,
  264. table=self.train_validation_table,
  265. max_validation_batches=self.max_validation_batches,
  266. epoch=trainer.epoch,
  267. )
  268. elif self.task == "segment":
  269. self.train_validation_table = plot_segmentation_validation_results(
  270. dataloader=dataloader,
  271. class_label_map=class_label_map,
  272. model_name=self.model_name,
  273. predictor=self.predictor,
  274. table=self.train_validation_table,
  275. max_validation_batches=self.max_validation_batches,
  276. epoch=trainer.epoch,
  277. )
  278. elif self.task == "detect":
  279. self.train_validation_table = plot_detection_validation_results(
  280. dataloader=dataloader,
  281. class_label_map=class_label_map,
  282. model_name=self.model_name,
  283. predictor=self.predictor,
  284. table=self.train_validation_table,
  285. max_validation_batches=self.max_validation_batches,
  286. epoch=trainer.epoch,
  287. )
  288. elif self.task == "classify":
  289. self.train_validation_table = (
  290. plot_classification_validation_results(
  291. dataloader=dataloader,
  292. model_name=self.model_name,
  293. predictor=self.predictor,
  294. table=self.train_validation_table,
  295. max_validation_batches=self.max_validation_batches,
  296. epoch=trainer.epoch,
  297. )
  298. )
  299. if self.enable_model_checkpointing:
  300. self._save_model(trainer)
  301. trainer.model.to(self.device)
  302. def on_train_end(self, trainer: TRAINER_TYPE):
  303. if self.task in self.supported_tasks:
  304. wandb.log({"Train-Table": self.train_validation_table}, commit=False)
  305. def on_val_start(self, validator: VALIDATOR_TYPE):
  306. wandb.run or wandb.init(
  307. project=validator.args.project or "YOLOv8",
  308. job_type="validation_" + validator.args.task,
  309. )
  310. @torch.no_grad()
  311. def on_val_end(self, trainer: VALIDATOR_TYPE):
  312. if self.task in self.supported_tasks:
  313. validator = trainer
  314. dataloader = validator.dataloader
  315. class_label_map = validator.names
  316. if self.task == "pose":
  317. self.validation_table = plot_pose_validation_results(
  318. dataloader=dataloader,
  319. class_label_map=class_label_map,
  320. model_name=self.model_name,
  321. predictor=self.predictor,
  322. visualize_skeleton=self.visualize_skeleton,
  323. table=self.validation_table,
  324. max_validation_batches=self.max_validation_batches,
  325. )
  326. elif self.task == "segment":
  327. self.validation_table = plot_segmentation_validation_results(
  328. dataloader=dataloader,
  329. class_label_map=class_label_map,
  330. model_name=self.model_name,
  331. predictor=self.predictor,
  332. table=self.validation_table,
  333. max_validation_batches=self.max_validation_batches,
  334. )
  335. elif self.task == "detect":
  336. self.validation_table = plot_detection_validation_results(
  337. dataloader=dataloader,
  338. class_label_map=class_label_map,
  339. model_name=self.model_name,
  340. predictor=self.predictor,
  341. table=self.validation_table,
  342. max_validation_batches=self.max_validation_batches,
  343. )
  344. elif self.task == "classify":
  345. self.validation_table = plot_classification_validation_results(
  346. dataloader=dataloader,
  347. model_name=self.model_name,
  348. predictor=self.predictor,
  349. table=self.validation_table,
  350. max_validation_batches=self.max_validation_batches,
  351. )
  352. wandb.log({"Validation-Table": self.validation_table}, commit=False)
  353. def on_predict_start(self, predictor: PREDICTOR_TYPE):
  354. wandb.run or wandb.init(
  355. project=predictor.args.project or "YOLOv8",
  356. config=vars(predictor.args),
  357. job_type="prediction_" + predictor.args.task,
  358. )
  359. if isinstance(predictor, SAMPredictor):
  360. self.prompts = copy.deepcopy(predictor.prompts)
  361. self.prediction_table = wandb.Table(columns=["Image"])
  362. def on_predict_end(self, predictor: PREDICTOR_TYPE):
  363. wandb.config.prediction_configs = vars(predictor.args)
  364. if self.task in self.supported_tasks:
  365. for result in tqdm(predictor.results):
  366. if self.task == "pose":
  367. self.prediction_table = plot_pose_predictions(
  368. result,
  369. self.model_name,
  370. self.visualize_skeleton,
  371. self.prediction_table,
  372. )
  373. elif self.task == "segment":
  374. if isinstance(predictor, SegmentationPredictor):
  375. self.prediction_table = plot_mask_predictions(
  376. result, self.model_name, self.prediction_table
  377. )
  378. elif isinstance(predictor, SAMPredictor):
  379. self.prediction_table = plot_sam_predictions(
  380. result, self.prompts, self.prediction_table
  381. )
  382. elif self.task == "detect":
  383. self.prediction_table = plot_bbox_predictions(
  384. result, self.model_name, self.prediction_table
  385. )
  386. elif self.task == "classify":
  387. self.prediction_table = plot_classification_predictions(
  388. result, self.model_name, self.prediction_table
  389. )
  390. wandb.log({"Prediction-Table": self.prediction_table}, commit=False)
  391. @property
  392. def callbacks(self) -> dict[str, Callable]:
  393. """Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging."""
  394. return {
  395. "on_train_start": self.on_train_start,
  396. "on_fit_epoch_end": self.on_fit_epoch_end,
  397. "on_train_end": self.on_train_end,
  398. "on_val_start": self.on_val_start,
  399. "on_val_end": self.on_val_end,
  400. "on_predict_start": self.on_predict_start,
  401. "on_predict_end": self.on_predict_end,
  402. }
  403. # TODO: Add epoch interval
  404. def add_wandb_callback(
  405. model: YOLO,
  406. epoch_logging_interval: int = 1,
  407. enable_model_checkpointing: bool = False,
  408. enable_train_validation_logging: bool = True,
  409. enable_validation_logging: bool = True,
  410. enable_prediction_logging: bool = True,
  411. max_validation_batches: int | None = 1,
  412. visualize_skeleton: bool | None = True,
  413. ):
  414. """Function to add the `WandBUltralyticsCallback` callback to the `YOLO` model.
  415. Example:
  416. ```python
  417. from ultralytics.yolo.engine.model import YOLO
  418. from wandb.yolov8 import add_wandb_callback
  419. # initialize YOLO model
  420. model = YOLO("yolov8n.pt")
  421. # add wandb callback
  422. add_wandb_callback(
  423. model, max_validation_batches=2, enable_model_checkpointing=True
  424. )
  425. # train
  426. model.train(data="coco128.yaml", epochs=5, imgsz=640)
  427. # validate
  428. model.val()
  429. # perform inference
  430. model(["img1.jpeg", "img2.jpeg"])
  431. ```
  432. Args:
  433. model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type
  434. `ultralytics.yolo.engine.model.YOLO`.
  435. epoch_logging_interval: (int) interval to log the prediction visualizations
  436. during training.
  437. enable_model_checkpointing: (bool) enable logging model checkpoints as
  438. artifacts at the end of eveny epoch if set to `True`.
  439. enable_train_validation_logging: (bool) enable logging the predictions and
  440. ground-truths as interactive image overlays on the images from
  441. the validation dataloader to a `wandb.Table` along with
  442. mean-confidence of the predictions per-class at the end of each
  443. training epoch.
  444. enable_validation_logging: (bool) enable logging the predictions and
  445. ground-truths as interactive image overlays on the images from the
  446. validation dataloader to a `wandb.Table` along with
  447. mean-confidence of the predictions per-class at the end of
  448. validation.
  449. enable_prediction_logging: (bool) enable logging the predictions and
  450. ground-truths as interactive image overlays on the images from the
  451. validation dataloader to a `wandb.Table` along with mean-confidence
  452. of the predictions per-class at the end of each prediction.
  453. max_validation_batches: (Optional[int]) maximum number of validation batches to log to
  454. a table per epoch.
  455. visualize_skeleton: (Optional[bool]) visualize pose skeleton by drawing lines connecting
  456. keypoints for human pose.
  457. Returns:
  458. An instance of `ultralytics.yolo.engine.model.YOLO` with the `WandBUltralyticsCallback`.
  459. """
  460. if RANK in [-1, 0]:
  461. wandb_callback = WandBUltralyticsCallback(
  462. copy.deepcopy(model),
  463. epoch_logging_interval,
  464. max_validation_batches,
  465. enable_model_checkpointing,
  466. visualize_skeleton,
  467. )
  468. callbacks = wandb_callback.callbacks
  469. if not enable_train_validation_logging:
  470. _ = callbacks.pop("on_fit_epoch_end")
  471. _ = callbacks.pop("on_train_end")
  472. if not enable_validation_logging:
  473. _ = callbacks.pop("on_val_start")
  474. _ = callbacks.pop("on_val_end")
  475. if not enable_prediction_logging:
  476. _ = callbacks.pop("on_predict_start")
  477. _ = callbacks.pop("on_predict_end")
  478. for event, callback_fn in callbacks.items():
  479. model.add_callback(event, callback_fn)
  480. else:
  481. wandb.termerror(
  482. "The RANK of the process to add the callbacks was neither 0 or "
  483. "-1. No Weights & Biases callbacks were added to this instance "
  484. "of the YOLO model."
  485. )
  486. return model