wandb_watch.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. """watch."""
  2. from __future__ import annotations
  3. import logging
  4. from collections.abc import Sequence
  5. from typing import TYPE_CHECKING
  6. try:
  7. from typing import Literal
  8. except ImportError:
  9. from typing_extensions import Literal # type: ignore
  10. import wandb
  11. from .lib import telemetry
  12. if TYPE_CHECKING:
  13. import torch # type: ignore [import-not-found]
  14. logger = logging.getLogger("wandb")
  15. _global_watch_idx = 0
  16. def _watch(
  17. run: wandb.Run,
  18. models: torch.nn.Module | Sequence[torch.nn.Module],
  19. criterion: torch.F | None = None,
  20. log: Literal["gradients", "parameters", "all"] | None = "gradients",
  21. log_freq: int = 1000,
  22. idx: int | None = None,
  23. log_graph: bool = False,
  24. ):
  25. """Hooks into the given PyTorch model(s) to monitor gradients and the model's computational graph.
  26. This function can track parameters, gradients, or both during training. It should be
  27. extended to support arbitrary machine learning models in the future.
  28. Args:
  29. run (wandb.Run): The run object to log to.
  30. models (Union[torch.nn.Module, Sequence[torch.nn.Module]]):
  31. A single model or a sequence of models to be monitored.
  32. criterion (Optional[torch.F]):
  33. The loss function being optimized (optional).
  34. log (Optional[Literal["gradients", "parameters", "all"]]):
  35. Specifies whether to log "gradients", "parameters", or "all".
  36. Set to None to disable logging. (default="gradients")
  37. log_freq (int):
  38. Frequency (in batches) to log gradients and parameters. (default=1000)
  39. idx (Optional[int]):
  40. Index used when tracking multiple models with `wandb.watch`. (default=None)
  41. log_graph (bool):
  42. Whether to log the model's computational graph. (default=False)
  43. Returns:
  44. wandb.Graph:
  45. The graph object, which will be populated after the first backward pass.
  46. Raises:
  47. ValueError: If `wandb.init` has not been called.
  48. TypeError: If any of the models are not instances of `torch.nn.Module`.
  49. """
  50. global _global_watch_idx
  51. with telemetry.context() as tel:
  52. tel.feature.watch = True
  53. logger.info("Watching")
  54. if log not in {"gradients", "parameters", "all", None}:
  55. raise ValueError("log must be one of 'gradients', 'parameters', 'all', or None")
  56. log_parameters = log in {"parameters", "all"}
  57. log_gradients = log in {"gradients", "all"}
  58. if not isinstance(models, (tuple, list)):
  59. models = (models,)
  60. torch = wandb.util.get_module(
  61. "torch", required="wandb.watch only works with pytorch, couldn't import torch."
  62. )
  63. for model in models:
  64. if not isinstance(model, torch.nn.Module):
  65. raise TypeError(
  66. f"Expected a pytorch model (torch.nn.Module). Received {type(model)}"
  67. )
  68. graphs = []
  69. prefix = ""
  70. if idx is None:
  71. idx = _global_watch_idx
  72. for local_idx, model in enumerate(models):
  73. global_idx = idx + local_idx
  74. _global_watch_idx += 1
  75. if global_idx > 0:
  76. # TODO: this makes ugly chart names like gradients/graph_1conv1d.bias
  77. prefix = f"graph_{global_idx}"
  78. if log_parameters:
  79. run._torch.add_log_parameters_hook(
  80. model,
  81. prefix=prefix,
  82. log_freq=log_freq,
  83. )
  84. if log_gradients:
  85. run._torch.add_log_gradients_hook(
  86. model,
  87. prefix=prefix,
  88. log_freq=log_freq,
  89. )
  90. if log_graph:
  91. graph = run._torch.hook_torch(model, criterion, graph_idx=global_idx)
  92. graphs.append(graph)
  93. # NOTE: the graph is set in run.summary by hook_torch on the backward pass
  94. return graphs
  95. def _unwatch(
  96. run: wandb.Run, models: torch.nn.Module | Sequence[torch.nn.Module] | None = None
  97. ) -> None:
  98. """Remove pytorch model topology, gradient and parameter hooks.
  99. Args:
  100. run (wandb.Run):
  101. The run object to log to.
  102. models (torch.nn.Module | Sequence[torch.nn.Module]):
  103. Optional list of pytorch models that have had watch called on them
  104. """
  105. if models:
  106. if not isinstance(models, (tuple, list)):
  107. models = (models,)
  108. for model in models:
  109. if not hasattr(model, "_wandb_hook_names"):
  110. wandb.termwarn(f"{model} model has not been watched")
  111. else:
  112. for name in model._wandb_hook_names:
  113. run._torch.unhook(name)
  114. delattr(model, "_wandb_hook_names")
  115. # TODO: we should also remove recursively model._wandb_watch_called
  116. else:
  117. run._torch.unhook_all()