| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- from __future__ import annotations
- import os
- import string
- from typing import Any, Literal
- import tensorflow as tf # type: ignore
- from tensorflow.keras import callbacks # type: ignore
- import wandb
- from wandb.sdk.lib import telemetry
- from wandb.sdk.lib.paths import StrPath
- from ..keras import patch_tf_keras
- Mode = Literal["auto", "min", "max"]
- SaveStrategy = Literal["epoch"]
- patch_tf_keras()
- class WandbModelCheckpoint(callbacks.ModelCheckpoint):
- """A checkpoint that periodically saves a Keras model or model weights.
- Saved weights are uploaded to W&B as a `wandb.Artifact`.
- Since this callback is subclassed from `tf.keras.callbacks.ModelCheckpoint`, the
- checkpointing logic is taken care of by the parent callback. You can learn more
- here: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint
- This callback is to be used in conjunction with training using `model.fit()` to save
- a model or weights (in a checkpoint file) at some interval. The model checkpoints
- will be logged as W&B Artifacts. You can learn more here:
- https://docs.wandb.ai/models/artifacts
- This callback provides the following features:
- - Save the model that has achieved "best performance" based on "monitor".
- - Save the model at the end of every epoch regardless of the performance.
- - Save the model at the end of epoch or after a fixed number of training batches.
- - Save only model weights, or save the whole model.
- - Save the model either in SavedModel format or in `.h5` format.
- Args:
- filepath: (Union[str, os.PathLike]) path to save the model file. `filepath`
- can contain named formatting options, which will be filled by the value
- of `epoch` and keys in `logs` (passed in `on_epoch_end`). For example:
- if `filepath` is `model-{epoch:02d}-{val_loss:.2f}`, then the
- model checkpoints will be saved with the epoch number and the
- validation loss in the filename.
- monitor: (str) The metric name to monitor. Default to "val_loss".
- verbose: (int) Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1
- displays messages when the callback takes an action.
- save_best_only: (bool) if `save_best_only=True`, it only saves when the model
- is considered the "best" and the latest best model according to the
- quantity monitored will not be overwritten. If `filepath` doesn't contain
- formatting options like `{epoch}` then `filepath` will be overwritten by
- each new better model locally. The model logged as an artifact will still be
- associated with the correct `monitor`. Artifacts will be uploaded
- continuously and versioned separately as a new best model is found.
- save_weights_only: (bool) if True, then only the model's weights will be saved.
- mode: (Mode) one of {'auto', 'min', 'max'}. For `val_acc`, this should be `max`,
- for `val_loss` this should be `min`, etc.
- save_freq: (Union[SaveStrategy, int]) `epoch` or integer. When using `'epoch'`,
- the callback saves the model after each epoch. When using an integer, the
- callback saves the model at end of this many batches.
- Note that when monitoring validation metrics such as `val_acc` or `val_loss`,
- save_freq must be set to "epoch" as those metrics are only available at the
- end of an epoch.
- initial_value_threshold: (Optional[float]) Floating point initial "best" value of the metric
- to be monitored.
- """
- def __init__(
- self,
- filepath: StrPath,
- monitor: str = "val_loss",
- verbose: int = 0,
- save_best_only: bool = False,
- save_weights_only: bool = False,
- mode: Mode = "auto",
- save_freq: SaveStrategy | int = "epoch",
- initial_value_threshold: float | None = None,
- **kwargs: Any,
- ) -> None:
- super().__init__(
- filepath=filepath,
- monitor=monitor,
- verbose=verbose,
- save_best_only=save_best_only,
- save_weights_only=save_weights_only,
- mode=mode,
- save_freq=save_freq,
- initial_value_threshold=initial_value_threshold,
- **kwargs,
- )
- if wandb.run is None:
- raise wandb.Error(
- "You must call `wandb.init()` before `WandbModelCheckpoint()`"
- )
- with telemetry.context(run=wandb.run) as tel:
- tel.feature.keras_model_checkpoint = True
- self.save_weights_only = save_weights_only
- # User-friendly warning when trying to save the best model.
- if self.save_best_only:
- self._check_filepath()
- self._is_old_tf_keras_version: bool | None = None
- def on_train_batch_end(
- self, batch: int, logs: dict[str, float] | None = None
- ) -> None:
- if self._should_save_on_batch(batch):
- if self.is_old_tf_keras_version:
- # Save the model and get filepath
- self._save_model(epoch=self._current_epoch, logs=logs)
- filepath = self._get_file_path(epoch=self._current_epoch, logs=logs)
- else:
- # Save the model and get filepath
- self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)
- filepath = self._get_file_path(
- epoch=self._current_epoch, batch=batch, logs=logs
- )
- # Log the model as artifact
- aliases = ["latest", f"epoch_{self._current_epoch}_batch_{batch}"]
- self._log_ckpt_as_artifact(filepath, aliases=aliases)
- def on_epoch_end(self, epoch: int, logs: dict[str, float] | None = None) -> None:
- super().on_epoch_end(epoch, logs)
- # Check if model checkpoint is created at the end of epoch.
- if self.save_freq == "epoch":
- # Get filepath where the model checkpoint is saved.
- if self.is_old_tf_keras_version:
- filepath = self._get_file_path(epoch=epoch, logs=logs)
- else:
- filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs)
- # Log the model as artifact
- aliases = ["latest", f"epoch_{epoch}"]
- self._log_ckpt_as_artifact(filepath, aliases=aliases)
- def _log_ckpt_as_artifact(
- self, filepath: str, aliases: list[str] | None = None
- ) -> None:
- """Log model checkpoint as W&B Artifact."""
- try:
- assert wandb.run is not None
- model_checkpoint_artifact = wandb.Artifact(
- f"run_{wandb.run.id}_model", type="model"
- )
- if os.path.isfile(filepath):
- model_checkpoint_artifact.add_file(filepath)
- elif os.path.isdir(filepath):
- model_checkpoint_artifact.add_dir(filepath)
- else:
- raise FileNotFoundError(f"No such file or directory {filepath}")
- wandb.log_artifact(model_checkpoint_artifact, aliases=aliases or [])
- except ValueError:
- # This error occurs when `save_best_only=True` and the model
- # checkpoint is not saved for that epoch/batch. Since TF/Keras
- # is giving friendly log, we can avoid clustering the stdout.
- pass
- def _check_filepath(self) -> None:
- placeholders = []
- for tup in string.Formatter().parse(self.filepath):
- if tup[1] is not None:
- placeholders.append(tup[1])
- if len(placeholders) == 0:
- wandb.termwarn(
- "When using `save_best_only`, ensure that the `filepath` argument "
- "contains formatting placeholders like `{epoch:02d}` or `{batch:02d}`. "
- "This ensures correct interpretation of the logged artifacts.",
- repeat=False,
- )
- @property
- def is_old_tf_keras_version(self) -> bool | None:
- if self._is_old_tf_keras_version is None:
- from packaging.version import parse
- try:
- if parse(tf.keras.__version__) < parse("2.6.0"):
- self._is_old_tf_keras_version = True
- else:
- self._is_old_tf_keras_version = False
- except AttributeError:
- self._is_old_tf_keras_version = False
- return self._is_old_tf_keras_version
|