metaflow.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. from __future__ import annotations
  2. import inspect
  3. import pickle
  4. from functools import wraps
  5. from pathlib import Path
  6. import wandb
  7. from wandb.sdk.lib import telemetry as wb_telemetry
  8. from . import errors
  9. try:
  10. from metaflow import current
  11. except ImportError as e:
  12. raise Exception(
  13. "Error: `metaflow` not installed >> This integration requires metaflow!"
  14. " To fix, please `pip install -Uqq metaflow`"
  15. ) from e
  16. try:
  17. from . import data_pandas
  18. except errors.MissingDependencyError as e:
  19. e.warn()
  20. data_pandas = None
  21. try:
  22. from . import data_pytorch
  23. except errors.MissingDependencyError as e:
  24. e.warn()
  25. data_pytorch = None
  26. try:
  27. from . import data_sklearn
  28. except errors.MissingDependencyError as e:
  29. e.warn()
  30. data_sklearn = None
  31. class ArtifactProxy:
  32. def __init__(self, flow):
  33. # do this to avoid recursion problem with __setattr__
  34. self.__dict__.update(
  35. {
  36. "flow": flow,
  37. "inputs": {},
  38. "outputs": {},
  39. "base": set(dir(flow)),
  40. "params": {p: getattr(flow, p) for p in current.parameter_names},
  41. }
  42. )
  43. def __setattr__(self, key, val):
  44. self.outputs[key] = val
  45. return setattr(self.flow, key, val)
  46. def __getattr__(self, key):
  47. if key not in self.base and key not in self.outputs:
  48. self.inputs[key] = getattr(self.flow, key)
  49. return getattr(self.flow, key)
  50. def _track_scalar(
  51. name: str,
  52. data: dict | list | set | str | int | float | bool,
  53. run,
  54. testing: bool = False,
  55. ) -> str | None:
  56. if testing:
  57. return "scalar"
  58. run.log({name: data})
  59. return None
  60. def _track_path(
  61. name: str,
  62. data: Path,
  63. run,
  64. testing: bool = False,
  65. ) -> str | None:
  66. if testing:
  67. return "Path"
  68. artifact = wandb.Artifact(name, type="dataset")
  69. if data.is_dir():
  70. artifact.add_dir(data)
  71. elif data.is_file():
  72. artifact.add_file(data)
  73. run.log_artifact(artifact)
  74. wandb.termlog(f"Logging artifact: {name} ({type(data)})")
  75. return None
  76. def _track_generic(
  77. name: str,
  78. data,
  79. run,
  80. testing: bool = False,
  81. ) -> str | None:
  82. if testing:
  83. return "generic"
  84. artifact = wandb.Artifact(name, type="other")
  85. with artifact.new_file(f"{name}.pkl", "wb") as f:
  86. pickle.dump(data, f)
  87. run.log_artifact(artifact)
  88. wandb.termlog(f"Logging artifact: {name} ({type(data)})")
  89. return None
  90. def wandb_track(
  91. name: str,
  92. data,
  93. datasets: bool = False,
  94. models: bool = False,
  95. others: bool = False,
  96. run: wandb.Run | None = None,
  97. testing: bool = False,
  98. ) -> str | None:
  99. """Track data as wandb artifacts based on type and flags."""
  100. # Check for pandas DataFrame
  101. if data_pandas and data_pandas.is_dataframe(data) and datasets:
  102. return data_pandas.track_dataframe(name, data, run, testing)
  103. # Check for PyTorch Module
  104. if data_pytorch and data_pytorch.is_nn_module(data) and models:
  105. return data_pytorch.track_nn_module(name, data, run, testing)
  106. # Check for scikit-learn BaseEstimator
  107. if data_sklearn and data_sklearn.is_estimator(data) and models:
  108. return data_sklearn.track_estimator(name, data, run, testing)
  109. # Check for Path objects
  110. if isinstance(data, Path) and datasets:
  111. return _track_path(name, data, run, testing)
  112. # Check for scalar types
  113. if isinstance(data, (dict, list, set, str, int, float, bool)):
  114. return _track_scalar(name, data, run, testing)
  115. # Generic fallback
  116. if others:
  117. return _track_generic(name, data, run, testing)
  118. # No action taken
  119. return None
  120. def wandb_use(
  121. name: str,
  122. data,
  123. datasets: bool = False,
  124. models: bool = False,
  125. others: bool = False,
  126. run=None,
  127. testing: bool = False,
  128. ) -> str | None:
  129. """Use wandb artifacts based on data type and flags."""
  130. # Skip scalar types - nothing to use
  131. if isinstance(data, (dict, list, set, str, int, float, bool)):
  132. return None
  133. try:
  134. # Check for pandas DataFrame
  135. if data_pandas and data_pandas.is_dataframe(data) and datasets:
  136. return data_pandas.use_dataframe(name, run, testing)
  137. # Check for PyTorch Module
  138. elif data_pytorch and data_pytorch.is_nn_module(data) and models:
  139. return data_pytorch.use_nn_module(name, run, testing)
  140. # Check for scikit-learn BaseEstimator
  141. elif data_sklearn and data_sklearn.is_estimator(data) and models:
  142. return data_sklearn.use_estimator(name, run, testing)
  143. # Check for Path objects
  144. elif isinstance(data, Path) and datasets:
  145. return _use_path(name, data, run, testing)
  146. # Generic fallback
  147. elif others:
  148. return _use_generic(name, data, run, testing)
  149. else:
  150. return None
  151. except wandb.CommError:
  152. wandb.termwarn(
  153. f"This artifact ({name}, {type(data)}) does not exist in the wandb datastore!"
  154. " If you created an instance inline (e.g. sklearn.ensemble.RandomForestClassifier),"
  155. " then you can safely ignore this. Otherwise you may want to check your internet connection!"
  156. )
  157. return None
  158. def _use_path(
  159. name: str,
  160. data: Path,
  161. run,
  162. testing: bool = False,
  163. ) -> str | None:
  164. if testing:
  165. return "datasets"
  166. run.use_artifact(f"{name}:latest")
  167. wandb.termlog(f"Using artifact: {name} ({type(data)})")
  168. return None
  169. def _use_generic(
  170. name: str,
  171. data,
  172. run,
  173. testing: bool = False,
  174. ) -> str | None:
  175. if testing:
  176. return "others"
  177. run.use_artifact(f"{name}:latest")
  178. wandb.termlog(f"Using artifact: {name} ({type(data)})")
  179. return None
  180. def coalesce(*arg):
  181. return next((a for a in arg if a is not None), None)
  182. def wandb_log(
  183. func=None,
  184. /,
  185. datasets: bool = False,
  186. models: bool = False,
  187. others: bool = False,
  188. settings: wandb.Settings | None = None,
  189. ):
  190. """Automatically log parameters and artifacts to W&B.
  191. This decorator can be applied to a flow, step, or both:
  192. - Decorating a step enables or disables logging within that step
  193. - Decorating a flow is equivalent to decorating all steps
  194. - Decorating a step after decorating its flow overwrites the flow decoration
  195. Args:
  196. func: The step method or flow class to decorate.
  197. datasets: Whether to log `pd.DataFrame` and `pathlib.Path`
  198. types. Defaults to False.
  199. models: Whether to log `nn.Module` and `sklearn.base.BaseEstimator`
  200. types. Defaults to False.
  201. others: If `True`, log anything pickle-able. Defaults to False.
  202. settings: Custom settings to pass to `wandb.init`.
  203. If `run_group` is `None`, it is set to `{flow_name}/{run_id}`.
  204. If `run_job_type` is `None`, it is set to `{run_job_type}/{step_name}`.
  205. """
  206. @wraps(func)
  207. def decorator(func):
  208. # If you decorate a class, apply the decoration to all methods in that class
  209. if inspect.isclass(func):
  210. cls = func
  211. for attr in cls.__dict__:
  212. if callable(getattr(cls, attr)) and not hasattr(attr, "_base_func"):
  213. setattr(cls, attr, decorator(getattr(cls, attr)))
  214. return cls
  215. # prefer the earliest decoration (i.e. method decoration overrides class decoration)
  216. if hasattr(func, "_base_func"):
  217. return func
  218. @wraps(func)
  219. def wrapper(self, *args, settings=settings, **kwargs):
  220. if not isinstance(settings, wandb.sdk.wandb_settings.Settings):
  221. settings = wandb.Settings()
  222. settings.update_from_dict(
  223. {
  224. "run_group": coalesce(
  225. settings.run_group, f"{current.flow_name}/{current.run_id}"
  226. ),
  227. "run_job_type": coalesce(settings.run_job_type, current.step_name),
  228. }
  229. )
  230. with wandb.init(settings=settings) as run:
  231. with wb_telemetry.context(run=run) as tel:
  232. tel.feature.metaflow = True
  233. proxy = ArtifactProxy(self)
  234. run.config.update(proxy.params)
  235. func(proxy, *args, **kwargs)
  236. for name, data in proxy.inputs.items():
  237. wandb_use(
  238. name,
  239. data,
  240. datasets=datasets,
  241. models=models,
  242. others=others,
  243. run=run,
  244. )
  245. for name, data in proxy.outputs.items():
  246. wandb_track(
  247. name,
  248. data,
  249. datasets=datasets,
  250. models=models,
  251. others=others,
  252. run=run,
  253. )
  254. wrapper._base_func = func
  255. # Add for testing visibility
  256. wrapper._kwargs = {
  257. "datasets": datasets,
  258. "models": models,
  259. "others": others,
  260. "settings": settings,
  261. }
  262. return wrapper
  263. if func is None:
  264. return decorator
  265. else:
  266. return decorator(func)