yolov8.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. from __future__ import annotations
  2. from typing import Any, Callable
  3. from ultralytics.yolo.engine.model import YOLO
  4. from ultralytics.yolo.engine.trainer import BaseTrainer
  5. try:
  6. from ultralytics.yolo.utils import RANK
  7. from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
  8. except ModuleNotFoundError:
  9. from ultralytics.utils import RANK
  10. from ultralytics.utils.torch_utils import get_flops, get_num_params
  11. from ultralytics.yolo.v8.classify.train import ClassificationTrainer
  12. import wandb
  13. from wandb.sdk.lib import telemetry
  14. class WandbCallback:
  15. """An internal YOLO model wrapper that tracks metrics, and logs models to Weights & Biases.
  16. Usage:
  17. ```python
  18. from wandb.integration.yolov8.yolov8 import WandbCallback
  19. model = YOLO("yolov8n.pt")
  20. wandb_logger = WandbCallback(
  21. model,
  22. )
  23. for event, callback_fn in wandb_logger.callbacks.items():
  24. model.add_callback(event, callback_fn)
  25. ```
  26. """
  27. def __init__(
  28. self,
  29. yolo: YOLO,
  30. run_name: str | None = None,
  31. project: str | None = None,
  32. tags: list[str] | None = None,
  33. resume: str | None = None,
  34. **kwargs: Any | None,
  35. ) -> None:
  36. """A utility class to manage wandb run and various callbacks for the ultralytics YOLOv8 framework.
  37. Args:
  38. yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO`
  39. run_name, str: The name of the Weights & Biases run, defaults to an auto generated run_name if `trainer.args.name` is not defined.
  40. project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined.
  41. tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`.
  42. resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`.
  43. **kwargs: Additional arguments to be passed to `wandb.init()`.
  44. """
  45. self.yolo = yolo
  46. self.run_name = run_name
  47. self.project = project
  48. self.tags = tags
  49. self.resume = resume
  50. self.kwargs = kwargs
  51. def on_pretrain_routine_start(self, trainer: BaseTrainer) -> None:
  52. """Starts a new wandb run to track the training process and log to Weights & Biases.
  53. Args:
  54. trainer: A task trainer that's inherited from `:class:ultralytics.yolo.engine.trainer.BaseTrainer`
  55. that contains the model training and optimization routine.
  56. """
  57. if wandb.run is None:
  58. self.run = wandb.init(
  59. name=self.run_name if self.run_name else trainer.args.name,
  60. project=self.project
  61. if self.project
  62. else trainer.args.project or "YOLOv8",
  63. tags=self.tags if self.tags else ["YOLOv8"],
  64. config=vars(trainer.args),
  65. resume=self.resume if self.resume else None,
  66. **self.kwargs,
  67. )
  68. else:
  69. self.run = wandb.run
  70. assert self.run is not None
  71. self.run.define_metric("epoch", hidden=True)
  72. self.run.define_metric(
  73. "train/*", step_metric="epoch", step_sync=True, summary="min"
  74. )
  75. self.run.define_metric(
  76. "val/*", step_metric="epoch", step_sync=True, summary="min"
  77. )
  78. self.run.define_metric(
  79. "metrics/*", step_metric="epoch", step_sync=True, summary="max"
  80. )
  81. self.run.define_metric(
  82. "lr/*", step_metric="epoch", step_sync=True, summary="last"
  83. )
  84. with telemetry.context(run=wandb.run) as tel:
  85. tel.feature.ultralytics_yolov8 = True
  86. def on_pretrain_routine_end(self, trainer: BaseTrainer) -> None:
  87. assert self.run is not None
  88. self.run.summary.update(
  89. {
  90. "model/parameters": get_num_params(trainer.model),
  91. "model/GFLOPs": round(get_flops(trainer.model), 3),
  92. }
  93. )
  94. def on_train_epoch_start(self, trainer: BaseTrainer) -> None:
  95. """On train epoch start we only log epoch number to the Weights & Biases run."""
  96. # We log the epoch number here to commit the previous step,
  97. assert self.run is not None
  98. self.run.log({"epoch": trainer.epoch + 1})
  99. def on_train_epoch_end(self, trainer: BaseTrainer) -> None:
  100. """On train epoch end we log all the metrics to the Weights & Biases run."""
  101. assert self.run is not None
  102. self.run.log(
  103. {
  104. **trainer.metrics,
  105. **trainer.label_loss_items(trainer.tloss, prefix="train"),
  106. **trainer.lr,
  107. },
  108. )
  109. # Currently only the detection and segmentation trainers save images to the save_dir
  110. if not isinstance(trainer, ClassificationTrainer):
  111. self.run.log(
  112. {
  113. "train_batch_images": [
  114. wandb.Image(str(image_path), caption=image_path.stem)
  115. for image_path in trainer.save_dir.glob("train_batch*.jpg")
  116. ]
  117. }
  118. )
  119. def on_fit_epoch_end(self, trainer: BaseTrainer) -> None:
  120. """On fit epoch end we log all the best metrics and model detail to Weights & Biases run summary."""
  121. assert self.run is not None
  122. if trainer.epoch == 0:
  123. speeds = [
  124. trainer.validator.speed.get(
  125. key,
  126. )
  127. for key in (1, "inference")
  128. ]
  129. speed = speeds[0] if speeds[0] else speeds[1]
  130. if speed:
  131. self.run.summary.update(
  132. {
  133. "model/speed(ms/img)": round(speed, 3),
  134. }
  135. )
  136. if trainer.best_fitness == trainer.fitness:
  137. self.run.summary.update(
  138. {
  139. "best/epoch": trainer.epoch + 1,
  140. **{f"best/{key}": val for key, val in trainer.metrics.items()},
  141. }
  142. )
  143. def on_train_end(self, trainer: BaseTrainer) -> None:
  144. """On train end we log all the media, including plots, images and best model artifact to Weights & Biases."""
  145. # Currently only the detection and segmentation trainers save images to the save_dir
  146. assert self.run is not None
  147. if not isinstance(trainer, ClassificationTrainer):
  148. assert self.run is not None
  149. self.run.log(
  150. {
  151. "plots": [
  152. wandb.Image(str(image_path), caption=image_path.stem)
  153. for image_path in trainer.save_dir.glob("*.png")
  154. ],
  155. "val_images": [
  156. wandb.Image(str(image_path), caption=image_path.stem)
  157. for image_path in trainer.validator.save_dir.glob("val*.jpg")
  158. ],
  159. },
  160. )
  161. if trainer.best.exists():
  162. assert self.run is not None
  163. self.run.log_artifact(
  164. str(trainer.best),
  165. type="model",
  166. name=f"{self.run.name}_{trainer.args.task}.pt",
  167. aliases=["best", f"epoch_{trainer.epoch + 1}"],
  168. )
  169. def on_model_save(self, trainer: BaseTrainer) -> None:
  170. """On model save we log the model as an artifact to Weights & Biases."""
  171. assert self.run is not None
  172. self.run.log_artifact(
  173. str(trainer.last),
  174. type="model",
  175. name=f"{self.run.name}_{trainer.args.task}.pt",
  176. aliases=["last", f"epoch_{trainer.epoch + 1}"],
  177. )
  178. def teardown(self, _trainer: BaseTrainer) -> None:
  179. """On teardown, we finish the Weights & Biases run and set it to None."""
  180. assert self.run is not None
  181. self.run.finish()
  182. self.run = None
  183. @property
  184. def callbacks(
  185. self,
  186. ) -> dict[str, Callable]:
  187. """Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging."""
  188. return {
  189. "on_pretrain_routine_start": self.on_pretrain_routine_start,
  190. "on_pretrain_routine_end": self.on_pretrain_routine_end,
  191. "on_train_epoch_start": self.on_train_epoch_start,
  192. "on_train_epoch_end": self.on_train_epoch_end,
  193. "on_fit_epoch_end": self.on_fit_epoch_end,
  194. "on_train_end": self.on_train_end,
  195. "on_model_save": self.on_model_save,
  196. "teardown": self.teardown,
  197. }
  198. def add_callbacks(
  199. yolo: YOLO,
  200. run_name: str | None = None,
  201. project: str | None = None,
  202. tags: list[str] | None = None,
  203. resume: str | None = None,
  204. **kwargs: Any | None,
  205. ) -> YOLO:
  206. """A YOLO model wrapper that tracks metrics, and logs models to Weights & Biases.
  207. Args:
  208. yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO`
  209. run_name, str: The name of the Weights & Biases run, defaults to an auto generated name if `trainer.args.name` is not defined.
  210. project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined.
  211. tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`.
  212. resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`.
  213. **kwargs: Additional arguments to be passed to `wandb.init()`.
  214. Usage:
  215. ```python
  216. from wandb.integration.yolov8 import add_callbacks as add_wandb_callbacks
  217. model = YOLO("yolov8n.pt")
  218. add_wandb_callbacks(
  219. model,
  220. )
  221. model.train(
  222. data="coco128.yaml",
  223. epochs=3,
  224. imgsz=640,
  225. )
  226. ```
  227. """
  228. wandb.termwarn(
  229. """The wandb callback is currently in beta and is subject to change based on updates to `ultralytics yolov8`.
  230. The callback is tested and supported for ultralytics v8.0.43 and above.
  231. Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.
  232. """,
  233. repeat=False,
  234. )
  235. wandb.termwarn(
  236. """This wandb callback is no longer functional and would be deprecated in the near future.
  237. We recommend you to use the updated callback using `from wandb.integration.ultralytics import add_wandb_callback`.
  238. The updated callback is tested and supported for ultralytics 8.0.167 and above.
  239. You can refer to https://docs.wandb.ai/models/integrations/ultralytics for the updated documentation.
  240. Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.
  241. """,
  242. repeat=False,
  243. )
  244. if RANK in [-1, 0]:
  245. wandb_logger = WandbCallback(
  246. yolo, run_name=run_name, project=project, tags=tags, resume=resume, **kwargs
  247. )
  248. for event, callback_fn in wandb_logger.callbacks.items():
  249. yolo.add_callback(event, callback_fn)
  250. return yolo
  251. else:
  252. wandb.termerror(
  253. "The RANK of the process to add the callbacks was neither 0 or -1."
  254. "No Weights & Biases callbacks were added to this instance of the YOLO model."
  255. )
  256. return yolo