model_checkpoint.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. from __future__ import annotations
  2. import os
  3. import string
  4. from typing import Any, Literal
  5. import tensorflow as tf # type: ignore
  6. from tensorflow.keras import callbacks # type: ignore
  7. import wandb
  8. from wandb.sdk.lib import telemetry
  9. from wandb.sdk.lib.paths import StrPath
  10. from ..keras import patch_tf_keras
  11. Mode = Literal["auto", "min", "max"]
  12. SaveStrategy = Literal["epoch"]
  13. patch_tf_keras()
  14. class WandbModelCheckpoint(callbacks.ModelCheckpoint):
  15. """A checkpoint that periodically saves a Keras model or model weights.
  16. Saved weights are uploaded to W&B as a `wandb.Artifact`.
  17. Since this callback is subclassed from `tf.keras.callbacks.ModelCheckpoint`, the
  18. checkpointing logic is taken care of by the parent callback. You can learn more
  19. here: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint
  20. This callback is to be used in conjunction with training using `model.fit()` to save
  21. a model or weights (in a checkpoint file) at some interval. The model checkpoints
  22. will be logged as W&B Artifacts. You can learn more here:
  23. https://docs.wandb.ai/models/artifacts
  24. This callback provides the following features:
  25. - Save the model that has achieved "best performance" based on "monitor".
  26. - Save the model at the end of every epoch regardless of the performance.
  27. - Save the model at the end of epoch or after a fixed number of training batches.
  28. - Save only model weights, or save the whole model.
  29. - Save the model either in SavedModel format or in `.h5` format.
  30. Args:
  31. filepath: (Union[str, os.PathLike]) path to save the model file. `filepath`
  32. can contain named formatting options, which will be filled by the value
  33. of `epoch` and keys in `logs` (passed in `on_epoch_end`). For example:
  34. if `filepath` is `model-{epoch:02d}-{val_loss:.2f}`, then the
  35. model checkpoints will be saved with the epoch number and the
  36. validation loss in the filename.
  37. monitor: (str) The metric name to monitor. Default to "val_loss".
  38. verbose: (int) Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1
  39. displays messages when the callback takes an action.
  40. save_best_only: (bool) if `save_best_only=True`, it only saves when the model
  41. is considered the "best" and the latest best model according to the
  42. quantity monitored will not be overwritten. If `filepath` doesn't contain
  43. formatting options like `{epoch}` then `filepath` will be overwritten by
  44. each new better model locally. The model logged as an artifact will still be
  45. associated with the correct `monitor`. Artifacts will be uploaded
  46. continuously and versioned separately as a new best model is found.
  47. save_weights_only: (bool) if True, then only the model's weights will be saved.
  48. mode: (Mode) one of {'auto', 'min', 'max'}. For `val_acc`, this should be `max`,
  49. for `val_loss` this should be `min`, etc.
  50. save_freq: (Union[SaveStrategy, int]) `epoch` or integer. When using `'epoch'`,
  51. the callback saves the model after each epoch. When using an integer, the
  52. callback saves the model at end of this many batches.
  53. Note that when monitoring validation metrics such as `val_acc` or `val_loss`,
  54. save_freq must be set to "epoch" as those metrics are only available at the
  55. end of an epoch.
  56. initial_value_threshold: (Optional[float]) Floating point initial "best" value of the metric
  57. to be monitored.
  58. """
  59. def __init__(
  60. self,
  61. filepath: StrPath,
  62. monitor: str = "val_loss",
  63. verbose: int = 0,
  64. save_best_only: bool = False,
  65. save_weights_only: bool = False,
  66. mode: Mode = "auto",
  67. save_freq: SaveStrategy | int = "epoch",
  68. initial_value_threshold: float | None = None,
  69. **kwargs: Any,
  70. ) -> None:
  71. super().__init__(
  72. filepath=filepath,
  73. monitor=monitor,
  74. verbose=verbose,
  75. save_best_only=save_best_only,
  76. save_weights_only=save_weights_only,
  77. mode=mode,
  78. save_freq=save_freq,
  79. initial_value_threshold=initial_value_threshold,
  80. **kwargs,
  81. )
  82. if wandb.run is None:
  83. raise wandb.Error(
  84. "You must call `wandb.init()` before `WandbModelCheckpoint()`"
  85. )
  86. with telemetry.context(run=wandb.run) as tel:
  87. tel.feature.keras_model_checkpoint = True
  88. self.save_weights_only = save_weights_only
  89. # User-friendly warning when trying to save the best model.
  90. if self.save_best_only:
  91. self._check_filepath()
  92. self._is_old_tf_keras_version: bool | None = None
  93. def on_train_batch_end(
  94. self, batch: int, logs: dict[str, float] | None = None
  95. ) -> None:
  96. if self._should_save_on_batch(batch):
  97. if self.is_old_tf_keras_version:
  98. # Save the model and get filepath
  99. self._save_model(epoch=self._current_epoch, logs=logs)
  100. filepath = self._get_file_path(epoch=self._current_epoch, logs=logs)
  101. else:
  102. # Save the model and get filepath
  103. self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)
  104. filepath = self._get_file_path(
  105. epoch=self._current_epoch, batch=batch, logs=logs
  106. )
  107. # Log the model as artifact
  108. aliases = ["latest", f"epoch_{self._current_epoch}_batch_{batch}"]
  109. self._log_ckpt_as_artifact(filepath, aliases=aliases)
  110. def on_epoch_end(self, epoch: int, logs: dict[str, float] | None = None) -> None:
  111. super().on_epoch_end(epoch, logs)
  112. # Check if model checkpoint is created at the end of epoch.
  113. if self.save_freq == "epoch":
  114. # Get filepath where the model checkpoint is saved.
  115. if self.is_old_tf_keras_version:
  116. filepath = self._get_file_path(epoch=epoch, logs=logs)
  117. else:
  118. filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs)
  119. # Log the model as artifact
  120. aliases = ["latest", f"epoch_{epoch}"]
  121. self._log_ckpt_as_artifact(filepath, aliases=aliases)
  122. def _log_ckpt_as_artifact(
  123. self, filepath: str, aliases: list[str] | None = None
  124. ) -> None:
  125. """Log model checkpoint as W&B Artifact."""
  126. try:
  127. assert wandb.run is not None
  128. model_checkpoint_artifact = wandb.Artifact(
  129. f"run_{wandb.run.id}_model", type="model"
  130. )
  131. if os.path.isfile(filepath):
  132. model_checkpoint_artifact.add_file(filepath)
  133. elif os.path.isdir(filepath):
  134. model_checkpoint_artifact.add_dir(filepath)
  135. else:
  136. raise FileNotFoundError(f"No such file or directory {filepath}")
  137. wandb.log_artifact(model_checkpoint_artifact, aliases=aliases or [])
  138. except ValueError:
  139. # This error occurs when `save_best_only=True` and the model
  140. # checkpoint is not saved for that epoch/batch. Since TF/Keras
  141. # is giving friendly log, we can avoid clustering the stdout.
  142. pass
  143. def _check_filepath(self) -> None:
  144. placeholders = []
  145. for tup in string.Formatter().parse(self.filepath):
  146. if tup[1] is not None:
  147. placeholders.append(tup[1])
  148. if len(placeholders) == 0:
  149. wandb.termwarn(
  150. "When using `save_best_only`, ensure that the `filepath` argument "
  151. "contains formatting placeholders like `{epoch:02d}` or `{batch:02d}`. "
  152. "This ensures correct interpretation of the logged artifacts.",
  153. repeat=False,
  154. )
  155. @property
  156. def is_old_tf_keras_version(self) -> bool | None:
  157. if self._is_old_tf_keras_version is None:
  158. from packaging.version import parse
  159. try:
  160. if parse(tf.keras.__version__) < parse("2.6.0"):
  161. self._is_old_tf_keras_version = True
  162. else:
  163. self._is_old_tf_keras_version = False
  164. except AttributeError:
  165. self._is_old_tf_keras_version = False
  166. return self._is_old_tf_keras_version