| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- """Hooks that add fast.ai v1 Learners to Weights & Biases through a callback.
- Requested logged data can be configured through the callback constructor.
- Examples:
- WandbCallback can be used when initializing the Learner::
- ```
- from wandb.fastai import WandbCallback
- [...]
- learn = Learner(data, ..., callback_fns=WandbCallback)
- learn.fit(epochs)
- ```
- Custom parameters can be given using functools.partial::
- ```
- from wandb.fastai import WandbCallback
- from functools import partial
- [...]
- learn = Learner(data, ..., callback_fns=partial(WandbCallback, ...))
- learn.fit(epochs)
- ```
- Finally, it is possible to use WandbCallback only when starting
- training. In this case it must be instantiated::
- ```
- learn.fit(..., callbacks=WandbCallback(learn))
- ```
- or, with custom parameters::
- ```
- learn.fit(..., callbacks=WandbCallback(learn, ...))
- ```
- """
- from __future__ import annotations
- import random
- from pathlib import Path
- from typing import Any, Literal
- import fastai
- from fastai.callbacks import TrackerCallback
- import wandb
- from wandb.sdk.lib import ipython
- try:
- import matplotlib
- if not ipython.in_jupyter():
- matplotlib.use("Agg") # non-interactive backend (avoid tkinter issues)
- import matplotlib.pyplot as plt
- except ImportError:
- wandb.termwarn("matplotlib required if logging sample image predictions")
- class WandbCallback(TrackerCallback):
- """Callback for saving model topology, losses & metrics.
- Optionally logs weights, gradients, sample predictions and best trained model.
- Args:
- learn (fastai.basic_train.Learner): the fast.ai learner to hook.
- log (str): "gradients", "parameters", "all", or None. Losses & metrics are always logged.
- save_model (bool): save model at the end of each epoch. It will also load best model at the end of training.
- monitor (str): metric to monitor for saving best model. None uses default TrackerCallback monitor value.
- mode (str): "auto", "min" or "max" to compare "monitor" values and define best model.
- input_type (str): "images" or None. Used to display sample predictions.
- validation_data (list): data used for sample predictions if input_type is set.
- predictions (int): number of predictions to make if input_type is set and validation_data is None.
- seed (int): initialize random generator for sample predictions if input_type is set and validation_data is None.
- """
- # Record if watch has been called previously (even in another instance)
- _watch_called = False
- def __init__(
- self,
- learn: fastai.basic_train.Learner,
- log: Literal["gradients", "parameters", "all"] | None = "gradients",
- save_model: bool = True,
- monitor: str | None = None,
- mode: Literal["auto", "min", "max"] = "auto",
- input_type: Literal["images"] | None = None,
- validation_data: list | None = None,
- predictions: int = 36,
- seed: int = 12345,
- ) -> None:
- # Check if wandb.init has been called
- if wandb.run is None:
- raise ValueError("You must call wandb.init() before WandbCallback()")
- # Adapted from fast.ai "SaveModelCallback"
- if monitor is None:
- # use default TrackerCallback monitor value
- super().__init__(learn, mode=mode)
- else:
- super().__init__(learn, monitor=monitor, mode=mode)
- self.save_model = save_model
- self.model_path = Path(wandb.run.dir) / "bestmodel.pth"
- self.log = log
- self.input_type = input_type
- self.best = None
- # Select items for sample predictions to see evolution along training
- self.validation_data = validation_data
- if input_type and not self.validation_data:
- wandb_random = random.Random(seed) # For repeatability
- predictions = min(predictions, len(learn.data.valid_ds))
- indices = wandb_random.sample(range(len(learn.data.valid_ds)), predictions)
- self.validation_data = [learn.data.valid_ds[i] for i in indices]
- def on_train_begin(self, **kwargs: Any) -> None:
- """Call watch method to log model topology, gradients & weights."""
- # Set self.best, method inherited from "TrackerCallback" by "SaveModelCallback"
- super().on_train_begin()
- # Ensure we don't call "watch" multiple times
- if not WandbCallback._watch_called:
- WandbCallback._watch_called = True
- # Logs model topology and optionally gradients and weights
- wandb.watch(self.learn.model, log=self.log)
- def on_epoch_end(
- self, epoch: int, smooth_loss: float, last_metrics: list, **kwargs: Any
- ) -> None:
- """Log training loss, validation loss and custom metrics & log prediction samples & save model."""
- if self.save_model:
- # Adapted from fast.ai "SaveModelCallback"
- current = self.get_monitor_value()
- if current is not None and self.operator(current, self.best):
- wandb.termlog(
- f"Better model found at epoch {epoch} with {self.monitor} value: {current}."
- )
- self.best = current
- # Save within wandb folder
- with self.model_path.open("wb") as model_file:
- self.learn.save(model_file)
- # Log sample predictions if learn.predict is available
- if self.validation_data:
- try:
- self._wandb_log_predictions()
- except FastaiError as e:
- wandb.termwarn(e.message)
- self.validation_data = None # prevent from trying again on next loop
- except Exception as e:
- wandb.termwarn(f"Unable to log prediction samples.\n{e}")
- self.validation_data = None # prevent from trying again on next loop
- # Log losses & metrics
- # Adapted from fast.ai "CSVLogger"
- logs = {
- name: stat
- for name, stat in list(
- zip(self.learn.recorder.names, [epoch, smooth_loss] + last_metrics)
- )
- }
- wandb.log(logs)
- def on_train_end(self, **kwargs: Any) -> None:
- """Load the best model."""
- if self.save_model and self.model_path.is_file():
- # Adapted from fast.ai "SaveModelCallback"
- with self.model_path.open("rb") as model_file:
- self.learn.load(model_file, purge=False)
- wandb.termlog(f"Loaded best saved model from {self.model_path}")
- def _wandb_log_predictions(self) -> None:
- """Log prediction samples."""
- pred_log = []
- if self.validation_data is None:
- return
- for x, y in self.validation_data:
- try:
- pred = self.learn.predict(x)
- except Exception:
- raise FastaiError(
- 'Unable to run "predict" method from Learner to log prediction samples.'
- )
- # scalar -> likely to be a category
- # tensor of dim 1 -> likely to be multicategory
- if not pred[1].shape or pred[1].dim() == 1:
- pred_log.append(
- wandb.Image(
- x.data,
- caption=f"Ground Truth: {y}\nPrediction: {pred[0]}",
- )
- )
- # most vision datasets have a "show" function we can use
- elif hasattr(x, "show"):
- # log input data
- pred_log.append(wandb.Image(x.data, caption="Input data", grouping=3))
- # log label and prediction
- for im, capt in ((pred[0], "Prediction"), (y, "Ground Truth")):
- # Resize plot to image resolution
- # from https://stackoverflow.com/a/13714915
- my_dpi = 100
- fig = plt.figure(frameon=False, dpi=my_dpi)
- h, w = x.size
- fig.set_size_inches(w / my_dpi, h / my_dpi)
- ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
- ax.set_axis_off()
- fig.add_axes(ax)
- # Superpose label or prediction to input image
- x.show(ax=ax, y=im)
- pred_log.append(wandb.Image(fig, caption=capt))
- plt.close(fig)
- # likely to be an image
- elif hasattr(y, "shape") and (
- (len(y.shape) == 2) or (len(y.shape) == 3 and y.shape[0] in [1, 3, 4])
- ):
- pred_log.extend(
- [
- wandb.Image(x.data, caption="Input data", grouping=3),
- wandb.Image(pred[0].data, caption="Prediction"),
- wandb.Image(y.data, caption="Ground Truth"),
- ]
- )
- # we just log input data
- else:
- pred_log.append(wandb.Image(x.data, caption="Input data"))
- wandb.log({"Prediction Samples": pred_log}, commit=False)
- class FastaiError(wandb.Error):
- pass
|