metrics_logger.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from __future__ import annotations
  2. from typing import Any, Literal
  3. import tensorflow as tf # type: ignore
  4. from tensorflow.keras import callbacks
  5. import wandb
  6. from wandb.integration.keras.keras import patch_tf_keras
  7. from wandb.sdk.lib import telemetry
  8. LogStrategy = Literal["epoch", "batch"]
  9. patch_tf_keras()
  10. class WandbMetricsLogger(callbacks.Callback):
  11. """Logger that sends system metrics to W&B.
  12. `WandbMetricsLogger` automatically logs the `logs` dictionary that callback methods
  13. take as argument to wandb.
  14. This callback automatically logs the following to a W&B run page:
  15. * system (CPU/GPU/TPU) metrics,
  16. * train and validation metrics defined in `model.compile`,
  17. * learning rate (both for a fixed value or a learning rate scheduler)
  18. Notes:
  19. If you resume training by passing `initial_epoch` to `model.fit` and you are using a
  20. learning rate scheduler, make sure to pass `initial_global_step` to
  21. `WandbMetricsLogger`. The `initial_global_step` is `step_size * initial_step`, where
  22. `step_size` is number of training steps per epoch. `step_size` can be calculated as
  23. the product of the cardinality of the training dataset and the batch size.
  24. Args:
  25. log_freq: ("epoch", "batch", or int) if "epoch", logs metrics
  26. at the end of each epoch. If "batch", logs metrics at the end
  27. of each batch. If an integer, logs metrics at the end of that
  28. many batches. Defaults to "epoch".
  29. initial_global_step: (int) Use this argument to correctly log the
  30. learning rate when you resume training from some `initial_epoch`,
  31. and a learning rate scheduler is used. This can be computed as
  32. `step_size * initial_step`. Defaults to 0.
  33. """
  34. def __init__(
  35. self,
  36. log_freq: LogStrategy | int = "epoch",
  37. initial_global_step: int = 0,
  38. *args: Any,
  39. **kwargs: Any,
  40. ) -> None:
  41. super().__init__(*args, **kwargs)
  42. if wandb.run is None:
  43. raise wandb.Error(
  44. "You must call `wandb.init()` before WandbMetricsLogger()"
  45. )
  46. with telemetry.context(run=wandb.run) as tel:
  47. tel.feature.keras_metrics_logger = True
  48. if log_freq == "batch":
  49. log_freq = 1
  50. self.logging_batch_wise = isinstance(log_freq, int)
  51. self.log_freq: Any = log_freq if self.logging_batch_wise else None
  52. self.global_batch = 0
  53. self.global_step = initial_global_step
  54. if self.logging_batch_wise:
  55. # define custom x-axis for batch logging.
  56. wandb.define_metric("batch/batch_step")
  57. # set all batch metrics to be logged against batch_step.
  58. wandb.define_metric("batch/*", step_metric="batch/batch_step")
  59. else:
  60. # define custom x-axis for epoch-wise logging.
  61. wandb.define_metric("epoch/epoch")
  62. # set all epoch-wise metrics to be logged against epoch.
  63. wandb.define_metric("epoch/*", step_metric="epoch/epoch")
  64. def _get_lr(self) -> float | None:
  65. if isinstance(
  66. self.model.optimizer.learning_rate,
  67. (tf.Variable, tf.Tensor),
  68. ) or (
  69. hasattr(self.model.optimizer.learning_rate, "shape")
  70. and self.model.optimizer.learning_rate.shape == ()
  71. ):
  72. return float(self.model.optimizer.learning_rate.numpy().item())
  73. try:
  74. return float(
  75. self.model.optimizer.learning_rate(step=self.global_step).numpy().item()
  76. )
  77. except Exception as e:
  78. wandb.termerror(f"Unable to log learning rate: {e}", repeat=False)
  79. return None
  80. def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
  81. """Called at the end of an epoch."""
  82. logs = dict() if logs is None else {f"epoch/{k}": v for k, v in logs.items()}
  83. logs["epoch/epoch"] = epoch
  84. lr = self._get_lr()
  85. if lr is not None:
  86. logs["epoch/learning_rate"] = lr
  87. wandb.log(logs)
  88. def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
  89. self.global_step += 1
  90. """An alias for `on_train_batch_end` for backwards compatibility."""
  91. if self.logging_batch_wise and batch % self.log_freq == 0:
  92. logs = {f"batch/{k}": v for k, v in logs.items()} if logs else {}
  93. logs["batch/batch_step"] = self.global_batch
  94. lr = self._get_lr()
  95. if lr is not None:
  96. logs["batch/learning_rate"] = lr
  97. wandb.log(logs)
  98. self.global_batch += self.log_freq
  99. def on_train_batch_end(
  100. self, batch: int, logs: dict[str, Any] | None = None
  101. ) -> None:
  102. """Called at the end of a training batch in `fit` methods."""
  103. self.on_batch_end(batch, logs if logs else {})