__init__.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. """Hooks that add fast.ai v1 Learners to Weights & Biases through a callback.
  2. Requested logged data can be configured through the callback constructor.
  3. Examples:
  4. WandbCallback can be used when initializing the Learner::
  5. ```
  6. from wandb.fastai import WandbCallback
  7. [...]
  8. learn = Learner(data, ..., callback_fns=WandbCallback)
  9. learn.fit(epochs)
  10. ```
  11. Custom parameters can be given using functools.partial::
  12. ```
  13. from wandb.fastai import WandbCallback
  14. from functools import partial
  15. [...]
  16. learn = Learner(data, ..., callback_fns=partial(WandbCallback, ...))
  17. learn.fit(epochs)
  18. ```
  19. Finally, it is possible to use WandbCallback only when starting
  20. training. In this case it must be instantiated::
  21. ```
  22. learn.fit(..., callbacks=WandbCallback(learn))
  23. ```
  24. or, with custom parameters::
  25. ```
  26. learn.fit(..., callbacks=WandbCallback(learn, ...))
  27. ```
  28. """
  29. from __future__ import annotations
  30. import random
  31. from pathlib import Path
  32. from typing import Any, Literal
  33. import fastai
  34. from fastai.callbacks import TrackerCallback
  35. import wandb
  36. from wandb.sdk.lib import ipython
  37. try:
  38. import matplotlib
  39. if not ipython.in_jupyter():
  40. matplotlib.use("Agg") # non-interactive backend (avoid tkinter issues)
  41. import matplotlib.pyplot as plt
  42. except ImportError:
  43. wandb.termwarn("matplotlib required if logging sample image predictions")
  44. class WandbCallback(TrackerCallback):
  45. """Callback for saving model topology, losses & metrics.
  46. Optionally logs weights, gradients, sample predictions and best trained model.
  47. Args:
  48. learn (fastai.basic_train.Learner): the fast.ai learner to hook.
  49. log (str): "gradients", "parameters", "all", or None. Losses & metrics are always logged.
  50. save_model (bool): save model at the end of each epoch. It will also load best model at the end of training.
  51. monitor (str): metric to monitor for saving best model. None uses default TrackerCallback monitor value.
  52. mode (str): "auto", "min" or "max" to compare "monitor" values and define best model.
  53. input_type (str): "images" or None. Used to display sample predictions.
  54. validation_data (list): data used for sample predictions if input_type is set.
  55. predictions (int): number of predictions to make if input_type is set and validation_data is None.
  56. seed (int): initialize random generator for sample predictions if input_type is set and validation_data is None.
  57. """
  58. # Record if watch has been called previously (even in another instance)
  59. _watch_called = False
  60. def __init__(
  61. self,
  62. learn: fastai.basic_train.Learner,
  63. log: Literal["gradients", "parameters", "all"] | None = "gradients",
  64. save_model: bool = True,
  65. monitor: str | None = None,
  66. mode: Literal["auto", "min", "max"] = "auto",
  67. input_type: Literal["images"] | None = None,
  68. validation_data: list | None = None,
  69. predictions: int = 36,
  70. seed: int = 12345,
  71. ) -> None:
  72. # Check if wandb.init has been called
  73. if wandb.run is None:
  74. raise ValueError("You must call wandb.init() before WandbCallback()")
  75. # Adapted from fast.ai "SaveModelCallback"
  76. if monitor is None:
  77. # use default TrackerCallback monitor value
  78. super().__init__(learn, mode=mode)
  79. else:
  80. super().__init__(learn, monitor=monitor, mode=mode)
  81. self.save_model = save_model
  82. self.model_path = Path(wandb.run.dir) / "bestmodel.pth"
  83. self.log = log
  84. self.input_type = input_type
  85. self.best = None
  86. # Select items for sample predictions to see evolution along training
  87. self.validation_data = validation_data
  88. if input_type and not self.validation_data:
  89. wandb_random = random.Random(seed) # For repeatability
  90. predictions = min(predictions, len(learn.data.valid_ds))
  91. indices = wandb_random.sample(range(len(learn.data.valid_ds)), predictions)
  92. self.validation_data = [learn.data.valid_ds[i] for i in indices]
  93. def on_train_begin(self, **kwargs: Any) -> None:
  94. """Call watch method to log model topology, gradients & weights."""
  95. # Set self.best, method inherited from "TrackerCallback" by "SaveModelCallback"
  96. super().on_train_begin()
  97. # Ensure we don't call "watch" multiple times
  98. if not WandbCallback._watch_called:
  99. WandbCallback._watch_called = True
  100. # Logs model topology and optionally gradients and weights
  101. wandb.watch(self.learn.model, log=self.log)
  102. def on_epoch_end(
  103. self, epoch: int, smooth_loss: float, last_metrics: list, **kwargs: Any
  104. ) -> None:
  105. """Log training loss, validation loss and custom metrics & log prediction samples & save model."""
  106. if self.save_model:
  107. # Adapted from fast.ai "SaveModelCallback"
  108. current = self.get_monitor_value()
  109. if current is not None and self.operator(current, self.best):
  110. wandb.termlog(
  111. f"Better model found at epoch {epoch} with {self.monitor} value: {current}."
  112. )
  113. self.best = current
  114. # Save within wandb folder
  115. with self.model_path.open("wb") as model_file:
  116. self.learn.save(model_file)
  117. # Log sample predictions if learn.predict is available
  118. if self.validation_data:
  119. try:
  120. self._wandb_log_predictions()
  121. except FastaiError as e:
  122. wandb.termwarn(e.message)
  123. self.validation_data = None # prevent from trying again on next loop
  124. except Exception as e:
  125. wandb.termwarn(f"Unable to log prediction samples.\n{e}")
  126. self.validation_data = None # prevent from trying again on next loop
  127. # Log losses & metrics
  128. # Adapted from fast.ai "CSVLogger"
  129. logs = {
  130. name: stat
  131. for name, stat in list(
  132. zip(self.learn.recorder.names, [epoch, smooth_loss] + last_metrics)
  133. )
  134. }
  135. wandb.log(logs)
  136. def on_train_end(self, **kwargs: Any) -> None:
  137. """Load the best model."""
  138. if self.save_model and self.model_path.is_file():
  139. # Adapted from fast.ai "SaveModelCallback"
  140. with self.model_path.open("rb") as model_file:
  141. self.learn.load(model_file, purge=False)
  142. wandb.termlog(f"Loaded best saved model from {self.model_path}")
  143. def _wandb_log_predictions(self) -> None:
  144. """Log prediction samples."""
  145. pred_log = []
  146. if self.validation_data is None:
  147. return
  148. for x, y in self.validation_data:
  149. try:
  150. pred = self.learn.predict(x)
  151. except Exception:
  152. raise FastaiError(
  153. 'Unable to run "predict" method from Learner to log prediction samples.'
  154. )
  155. # scalar -> likely to be a category
  156. # tensor of dim 1 -> likely to be multicategory
  157. if not pred[1].shape or pred[1].dim() == 1:
  158. pred_log.append(
  159. wandb.Image(
  160. x.data,
  161. caption=f"Ground Truth: {y}\nPrediction: {pred[0]}",
  162. )
  163. )
  164. # most vision datasets have a "show" function we can use
  165. elif hasattr(x, "show"):
  166. # log input data
  167. pred_log.append(wandb.Image(x.data, caption="Input data", grouping=3))
  168. # log label and prediction
  169. for im, capt in ((pred[0], "Prediction"), (y, "Ground Truth")):
  170. # Resize plot to image resolution
  171. # from https://stackoverflow.com/a/13714915
  172. my_dpi = 100
  173. fig = plt.figure(frameon=False, dpi=my_dpi)
  174. h, w = x.size
  175. fig.set_size_inches(w / my_dpi, h / my_dpi)
  176. ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
  177. ax.set_axis_off()
  178. fig.add_axes(ax)
  179. # Superpose label or prediction to input image
  180. x.show(ax=ax, y=im)
  181. pred_log.append(wandb.Image(fig, caption=capt))
  182. plt.close(fig)
  183. # likely to be an image
  184. elif hasattr(y, "shape") and (
  185. (len(y.shape) == 2) or (len(y.shape) == 3 and y.shape[0] in [1, 3, 4])
  186. ):
  187. pred_log.extend(
  188. [
  189. wandb.Image(x.data, caption="Input data", grouping=3),
  190. wandb.Image(pred[0].data, caption="Prediction"),
  191. wandb.Image(y.data, caption="Ground Truth"),
  192. ]
  193. )
  194. # we just log input data
  195. else:
  196. pred_log.append(wandb.Image(x.data, caption="Input data"))
  197. wandb.log({"Prediction Samples": pred_log}, commit=False)
  198. class FastaiError(wandb.Error):
  199. pass