data_pandas.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. """Support for Pandas datatypes.
  2. May raise MissingDependencyError on import.
  3. """
  4. from __future__ import annotations
  5. from typing_extensions import Any, TypeIs
  6. import wandb
  7. from . import errors
  8. try:
  9. import pandas as pd
  10. except ImportError as e:
  11. warning = (
  12. "`pandas` not installed >>"
  13. " @wandb_log(datasets=True) may not auto log your dataset!"
  14. )
  15. raise errors.MissingDependencyError(warning=warning) from e
  16. def is_dataframe(data: Any) -> TypeIs[pd.DataFrame]:
  17. """Returns whether the data is a Pandas DataFrame."""
  18. return isinstance(data, pd.DataFrame)
  19. def use_dataframe(
  20. name: str,
  21. run: wandb.Run | None,
  22. testing: bool = False,
  23. ) -> str | None:
  24. """Log a dependency on a DataFrame input.
  25. Args:
  26. name: Name of the input.
  27. run: The run to update.
  28. testing: True in unit tests.
  29. """
  30. if testing:
  31. return "datasets"
  32. assert run
  33. wandb.termlog(f"Using artifact: {name} (Pandas DataFrame)")
  34. run.use_artifact(f"{name}:latest")
  35. return None
  36. def track_dataframe(
  37. name: str,
  38. data: pd.DataFrame,
  39. run: wandb.Run | None,
  40. testing: bool = False,
  41. ) -> str | None:
  42. """Log a DataFrame output as an artifact.
  43. Args:
  44. name: The output's name.
  45. data: The output's value.
  46. run: The run to update.
  47. testing: True in unit tests.
  48. """
  49. if testing:
  50. return "pd.DataFrame"
  51. assert run
  52. artifact = wandb.Artifact(name, type="dataset")
  53. with artifact.new_file(f"{name}.parquet", "wb") as f:
  54. data.to_parquet(f, engine="pyarrow")
  55. wandb.termlog(f"Logging artifact: {name} (Pandas DataFrame)")
  56. run.log_artifact(artifact)
  57. return None