data_pytorch.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. """Support for PyTorch datatypes.
  2. May raise MissingDependencyError on import.
  3. """
  4. from __future__ import annotations
  5. from typing_extensions import Any, TypeIs
  6. import wandb
  7. from . import errors
  8. try:
  9. import torch
  10. import torch.nn as nn
  11. except ImportError as e:
  12. warning = (
  13. "`torch` (PyTorch) not installed >>"
  14. " @wandb_log(models=True) may not auto log your model!"
  15. )
  16. raise errors.MissingDependencyError(warning=warning) from e
  17. def is_nn_module(data: Any) -> TypeIs[nn.Module]:
  18. """Returns whether the data is a PyTorch nn.Module."""
  19. return isinstance(data, nn.Module)
  20. def use_nn_module(
  21. name: str,
  22. run: wandb.Run | None,
  23. testing: bool = False,
  24. ) -> str | None:
  25. """Log a dependency on a PyTorch model input.
  26. Args:
  27. name: Name of the input.
  28. run: The run to update.
  29. testing: True in unit tests.
  30. """
  31. if testing:
  32. return "models"
  33. assert run
  34. wandb.termlog(f"Using artifact: {name} (PyTorch nn.Module)")
  35. run.use_artifact(f"{name}:latest")
  36. return None
  37. def track_nn_module(
  38. name: str,
  39. data: nn.Module,
  40. run: wandb.Run | None,
  41. testing: bool = False,
  42. ) -> str | None:
  43. """Log a PyTorch model output as an artifact.
  44. Args:
  45. name: The output's name.
  46. data: The output's value.
  47. run: The run to update.
  48. testing: True in unit tests.
  49. """
  50. if testing:
  51. return "nn.Module"
  52. assert run
  53. artifact = wandb.Artifact(name, type="model")
  54. with artifact.new_file(f"{name}.pkl", "wb") as f:
  55. torch.save(data, f)
  56. wandb.termlog(f"Logging artifact: {name} (PyTorch nn.Module)")
  57. run.log_artifact(artifact)
  58. return None