| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327 |
- from __future__ import annotations
- import inspect
- import pickle
- from functools import wraps
- from pathlib import Path
- import wandb
- from wandb.sdk.lib import telemetry as wb_telemetry
- from . import errors
- try:
- from metaflow import current
- except ImportError as e:
- raise Exception(
- "Error: `metaflow` not installed >> This integration requires metaflow!"
- " To fix, please `pip install -Uqq metaflow`"
- ) from e
- try:
- from . import data_pandas
- except errors.MissingDependencyError as e:
- e.warn()
- data_pandas = None
- try:
- from . import data_pytorch
- except errors.MissingDependencyError as e:
- e.warn()
- data_pytorch = None
- try:
- from . import data_sklearn
- except errors.MissingDependencyError as e:
- e.warn()
- data_sklearn = None
- class ArtifactProxy:
- def __init__(self, flow):
- # do this to avoid recursion problem with __setattr__
- self.__dict__.update(
- {
- "flow": flow,
- "inputs": {},
- "outputs": {},
- "base": set(dir(flow)),
- "params": {p: getattr(flow, p) for p in current.parameter_names},
- }
- )
- def __setattr__(self, key, val):
- self.outputs[key] = val
- return setattr(self.flow, key, val)
- def __getattr__(self, key):
- if key not in self.base and key not in self.outputs:
- self.inputs[key] = getattr(self.flow, key)
- return getattr(self.flow, key)
- def _track_scalar(
- name: str,
- data: dict | list | set | str | int | float | bool,
- run,
- testing: bool = False,
- ) -> str | None:
- if testing:
- return "scalar"
- run.log({name: data})
- return None
- def _track_path(
- name: str,
- data: Path,
- run,
- testing: bool = False,
- ) -> str | None:
- if testing:
- return "Path"
- artifact = wandb.Artifact(name, type="dataset")
- if data.is_dir():
- artifact.add_dir(data)
- elif data.is_file():
- artifact.add_file(data)
- run.log_artifact(artifact)
- wandb.termlog(f"Logging artifact: {name} ({type(data)})")
- return None
- def _track_generic(
- name: str,
- data,
- run,
- testing: bool = False,
- ) -> str | None:
- if testing:
- return "generic"
- artifact = wandb.Artifact(name, type="other")
- with artifact.new_file(f"{name}.pkl", "wb") as f:
- pickle.dump(data, f)
- run.log_artifact(artifact)
- wandb.termlog(f"Logging artifact: {name} ({type(data)})")
- return None
- def wandb_track(
- name: str,
- data,
- datasets: bool = False,
- models: bool = False,
- others: bool = False,
- run: wandb.Run | None = None,
- testing: bool = False,
- ) -> str | None:
- """Track data as wandb artifacts based on type and flags."""
- # Check for pandas DataFrame
- if data_pandas and data_pandas.is_dataframe(data) and datasets:
- return data_pandas.track_dataframe(name, data, run, testing)
- # Check for PyTorch Module
- if data_pytorch and data_pytorch.is_nn_module(data) and models:
- return data_pytorch.track_nn_module(name, data, run, testing)
- # Check for scikit-learn BaseEstimator
- if data_sklearn and data_sklearn.is_estimator(data) and models:
- return data_sklearn.track_estimator(name, data, run, testing)
- # Check for Path objects
- if isinstance(data, Path) and datasets:
- return _track_path(name, data, run, testing)
- # Check for scalar types
- if isinstance(data, (dict, list, set, str, int, float, bool)):
- return _track_scalar(name, data, run, testing)
- # Generic fallback
- if others:
- return _track_generic(name, data, run, testing)
- # No action taken
- return None
- def wandb_use(
- name: str,
- data,
- datasets: bool = False,
- models: bool = False,
- others: bool = False,
- run=None,
- testing: bool = False,
- ) -> str | None:
- """Use wandb artifacts based on data type and flags."""
- # Skip scalar types - nothing to use
- if isinstance(data, (dict, list, set, str, int, float, bool)):
- return None
- try:
- # Check for pandas DataFrame
- if data_pandas and data_pandas.is_dataframe(data) and datasets:
- return data_pandas.use_dataframe(name, run, testing)
- # Check for PyTorch Module
- elif data_pytorch and data_pytorch.is_nn_module(data) and models:
- return data_pytorch.use_nn_module(name, run, testing)
- # Check for scikit-learn BaseEstimator
- elif data_sklearn and data_sklearn.is_estimator(data) and models:
- return data_sklearn.use_estimator(name, run, testing)
- # Check for Path objects
- elif isinstance(data, Path) and datasets:
- return _use_path(name, data, run, testing)
- # Generic fallback
- elif others:
- return _use_generic(name, data, run, testing)
- else:
- return None
- except wandb.CommError:
- wandb.termwarn(
- f"This artifact ({name}, {type(data)}) does not exist in the wandb datastore!"
- " If you created an instance inline (e.g. sklearn.ensemble.RandomForestClassifier),"
- " then you can safely ignore this. Otherwise you may want to check your internet connection!"
- )
- return None
- def _use_path(
- name: str,
- data: Path,
- run,
- testing: bool = False,
- ) -> str | None:
- if testing:
- return "datasets"
- run.use_artifact(f"{name}:latest")
- wandb.termlog(f"Using artifact: {name} ({type(data)})")
- return None
- def _use_generic(
- name: str,
- data,
- run,
- testing: bool = False,
- ) -> str | None:
- if testing:
- return "others"
- run.use_artifact(f"{name}:latest")
- wandb.termlog(f"Using artifact: {name} ({type(data)})")
- return None
- def coalesce(*arg):
- return next((a for a in arg if a is not None), None)
- def wandb_log(
- func=None,
- /,
- datasets: bool = False,
- models: bool = False,
- others: bool = False,
- settings: wandb.Settings | None = None,
- ):
- """Automatically log parameters and artifacts to W&B.
- This decorator can be applied to a flow, step, or both:
- - Decorating a step enables or disables logging within that step
- - Decorating a flow is equivalent to decorating all steps
- - Decorating a step after decorating its flow overwrites the flow decoration
- Args:
- func: The step method or flow class to decorate.
- datasets: Whether to log `pd.DataFrame` and `pathlib.Path`
- types. Defaults to False.
- models: Whether to log `nn.Module` and `sklearn.base.BaseEstimator`
- types. Defaults to False.
- others: If `True`, log anything pickle-able. Defaults to False.
- settings: Custom settings to pass to `wandb.init`.
- If `run_group` is `None`, it is set to `{flow_name}/{run_id}`.
- If `run_job_type` is `None`, it is set to `{run_job_type}/{step_name}`.
- """
- @wraps(func)
- def decorator(func):
- # If you decorate a class, apply the decoration to all methods in that class
- if inspect.isclass(func):
- cls = func
- for attr in cls.__dict__:
- if callable(getattr(cls, attr)) and not hasattr(attr, "_base_func"):
- setattr(cls, attr, decorator(getattr(cls, attr)))
- return cls
- # prefer the earliest decoration (i.e. method decoration overrides class decoration)
- if hasattr(func, "_base_func"):
- return func
- @wraps(func)
- def wrapper(self, *args, settings=settings, **kwargs):
- if not isinstance(settings, wandb.sdk.wandb_settings.Settings):
- settings = wandb.Settings()
- settings.update_from_dict(
- {
- "run_group": coalesce(
- settings.run_group, f"{current.flow_name}/{current.run_id}"
- ),
- "run_job_type": coalesce(settings.run_job_type, current.step_name),
- }
- )
- with wandb.init(settings=settings) as run:
- with wb_telemetry.context(run=run) as tel:
- tel.feature.metaflow = True
- proxy = ArtifactProxy(self)
- run.config.update(proxy.params)
- func(proxy, *args, **kwargs)
- for name, data in proxy.inputs.items():
- wandb_use(
- name,
- data,
- datasets=datasets,
- models=models,
- others=others,
- run=run,
- )
- for name, data in proxy.outputs.items():
- wandb_track(
- name,
- data,
- datasets=datasets,
- models=models,
- others=others,
- run=run,
- )
- wrapper._base_func = func
- # Add for testing visibility
- wrapper._kwargs = {
- "datasets": datasets,
- "models": models,
- "others": others,
- "settings": settings,
- }
- return wrapper
- if func is None:
- return decorator
- else:
- return decorator(func)
|