| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- from __future__ import annotations
- from typing import Any, Literal
- import tensorflow as tf # type: ignore
- from tensorflow.keras import callbacks
- import wandb
- from wandb.integration.keras.keras import patch_tf_keras
- from wandb.sdk.lib import telemetry
- LogStrategy = Literal["epoch", "batch"]
- patch_tf_keras()
- class WandbMetricsLogger(callbacks.Callback):
- """Logger that sends system metrics to W&B.
- `WandbMetricsLogger` automatically logs the `logs` dictionary that callback methods
- take as argument to wandb.
- This callback automatically logs the following to a W&B run page:
- * system (CPU/GPU/TPU) metrics,
- * train and validation metrics defined in `model.compile`,
- * learning rate (both for a fixed value or a learning rate scheduler)
- Notes:
- If you resume training by passing `initial_epoch` to `model.fit` and you are using a
- learning rate scheduler, make sure to pass `initial_global_step` to
- `WandbMetricsLogger`. The `initial_global_step` is `step_size * initial_step`, where
- `step_size` is number of training steps per epoch. `step_size` can be calculated as
- the product of the cardinality of the training dataset and the batch size.
- Args:
- log_freq: ("epoch", "batch", or int) if "epoch", logs metrics
- at the end of each epoch. If "batch", logs metrics at the end
- of each batch. If an integer, logs metrics at the end of that
- many batches. Defaults to "epoch".
- initial_global_step: (int) Use this argument to correctly log the
- learning rate when you resume training from some `initial_epoch`,
- and a learning rate scheduler is used. This can be computed as
- `step_size * initial_step`. Defaults to 0.
- """
- def __init__(
- self,
- log_freq: LogStrategy | int = "epoch",
- initial_global_step: int = 0,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- super().__init__(*args, **kwargs)
- if wandb.run is None:
- raise wandb.Error(
- "You must call `wandb.init()` before WandbMetricsLogger()"
- )
- with telemetry.context(run=wandb.run) as tel:
- tel.feature.keras_metrics_logger = True
- if log_freq == "batch":
- log_freq = 1
- self.logging_batch_wise = isinstance(log_freq, int)
- self.log_freq: Any = log_freq if self.logging_batch_wise else None
- self.global_batch = 0
- self.global_step = initial_global_step
- if self.logging_batch_wise:
- # define custom x-axis for batch logging.
- wandb.define_metric("batch/batch_step")
- # set all batch metrics to be logged against batch_step.
- wandb.define_metric("batch/*", step_metric="batch/batch_step")
- else:
- # define custom x-axis for epoch-wise logging.
- wandb.define_metric("epoch/epoch")
- # set all epoch-wise metrics to be logged against epoch.
- wandb.define_metric("epoch/*", step_metric="epoch/epoch")
- def _get_lr(self) -> float | None:
- if isinstance(
- self.model.optimizer.learning_rate,
- (tf.Variable, tf.Tensor),
- ) or (
- hasattr(self.model.optimizer.learning_rate, "shape")
- and self.model.optimizer.learning_rate.shape == ()
- ):
- return float(self.model.optimizer.learning_rate.numpy().item())
- try:
- return float(
- self.model.optimizer.learning_rate(step=self.global_step).numpy().item()
- )
- except Exception as e:
- wandb.termerror(f"Unable to log learning rate: {e}", repeat=False)
- return None
- def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
- """Called at the end of an epoch."""
- logs = dict() if logs is None else {f"epoch/{k}": v for k, v in logs.items()}
- logs["epoch/epoch"] = epoch
- lr = self._get_lr()
- if lr is not None:
- logs["epoch/learning_rate"] = lr
- wandb.log(logs)
- def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
- self.global_step += 1
- """An alias for `on_train_batch_end` for backwards compatibility."""
- if self.logging_batch_wise and batch % self.log_freq == 0:
- logs = {f"batch/{k}": v for k, v in logs.items()} if logs else {}
- logs["batch/batch_step"] = self.global_batch
- lr = self._get_lr()
- if lr is not None:
- logs["batch/learning_rate"] = lr
- wandb.log(logs)
- self.global_batch += self.log_freq
- def on_train_batch_end(
- self, batch: int, logs: dict[str, Any] | None = None
- ) -> None:
- """Called at the end of a training batch in `fit` methods."""
- self.on_batch_end(batch, logs if logs else {})
|