data_sklearn.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. """Support for sklearn datatypes.
  2. May raise MissingDependencyError on import.
  3. """
  4. from __future__ import annotations
  5. import pickle
  6. from typing_extensions import Any, TypeIs
  7. import wandb
  8. from . import errors
  9. try:
  10. from sklearn.base import BaseEstimator
  11. except ImportError as e:
  12. warning = (
  13. "`sklearn` not installed >>"
  14. " @wandb_log(models=True) may not auto log your model!"
  15. )
  16. raise errors.MissingDependencyError(warning=warning) from e
  17. def is_estimator(data: Any) -> TypeIs[BaseEstimator]:
  18. """Returns whether the data is an sklearn BaseEstimator."""
  19. return isinstance(data, BaseEstimator)
  20. def use_estimator(
  21. name: str,
  22. run: wandb.Run | None,
  23. testing: bool = False,
  24. ) -> str | None:
  25. """Log a dependency on an sklearn estimator.
  26. Args:
  27. name: Name of the input.
  28. run: The run to update.
  29. testing: True in unit tests.
  30. """
  31. if testing:
  32. return "models"
  33. assert run
  34. wandb.termlog(f"Using artifact: {name} (sklearn BaseEstimator)")
  35. run.use_artifact(f"{name}:latest")
  36. return None
  37. def track_estimator(
  38. name: str,
  39. data: BaseEstimator,
  40. run: wandb.Run | None,
  41. testing: bool = False,
  42. ) -> str | None:
  43. """Log an sklearn estimator output as an artifact.
  44. Args:
  45. name: The output's name.
  46. data: The output's value.
  47. run: The run to update.
  48. testing: True in unit tests.
  49. """
  50. if testing:
  51. return "BaseEstimator"
  52. assert run
  53. artifact = wandb.Artifact(name, type="model")
  54. with artifact.new_file(f"{name}.pkl", "wb") as f:
  55. pickle.dump(data, f)
  56. wandb.termlog(f"Logging artifact: {name} (sklearn BaseEstimator)")
  57. run.log_artifact(artifact)
  58. return None