wandb_log_v2.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. from __future__ import annotations
  2. import os
  3. from functools import wraps
  4. from inspect import signature
  5. from typing import Any, Callable
  6. import kfp.dsl
  7. from kfp.dsl.types.type_annotations import (
  8. InputPath,
  9. OutputPath,
  10. is_artifact_wrapped_in_Input,
  11. is_artifact_wrapped_in_Output,
  12. )
  13. import wandb
  14. from wandb.sdk.lib import telemetry as wb_telemetry
  15. def _is_namedtuple(x: Any) -> bool:
  16. """Return True if `x` is an instance of a NamedTuple.
  17. Python does not provide a common base class for named tuples created
  18. via `collections.namedtuple` or `typing.NamedTuple`, so there is
  19. no way to use `isinstance`. Instead we check that the type is a
  20. `tuple` subclass whose `_fields` attribute is a tuple of strings,
  21. following the documented NamedTuple API:
  22. https://docs.python.org/3/library/collections.html#collections.somenamedtuple._fields
  23. KFP uses NamedTuples for multi-output components. The decorator sees
  24. the actual return value at runtime and unpacks its fields for logging.
  25. KFP's own executor processes type annotations separately for
  26. serialization, so runtime value detection is the correct approach here.
  27. Args:
  28. x: The value to check.
  29. Returns:
  30. True if `x` is a NamedTuple instance.
  31. """
  32. t = type(x)
  33. if not issubclass(t, tuple):
  34. return False
  35. fields = getattr(t, "_fields", None)
  36. if not isinstance(fields, tuple):
  37. return False
  38. return all(isinstance(n, str) for n in fields)
  39. def _is_output_annotation(ann: Any) -> bool:
  40. """Return True if `ann` is a KFP Output or OutputPath annotation."""
  41. return is_artifact_wrapped_in_Output(ann) or isinstance(ann, OutputPath)
  42. def _is_input_annotation(ann: Any) -> bool:
  43. """Return True if `ann` is a KFP Input or InputPath annotation."""
  44. return is_artifact_wrapped_in_Input(ann) or isinstance(ann, InputPath)
  45. def _get_artifact_path(value: Any) -> str | None:
  46. """Return the local file path for a KFP artifact value, or None.
  47. Args:
  48. value: A KFP artifact instance or a string file path.
  49. Returns:
  50. The local path if the artifact/file exists on disk, otherwise None.
  51. """
  52. if isinstance(value, kfp.dsl.Artifact):
  53. return value.path if os.path.exists(value.path) else None
  54. if isinstance(value, str) and os.path.exists(value):
  55. return value
  56. return None
  57. def _log_artifact(
  58. run: wandb.Run,
  59. name: str,
  60. value: Any,
  61. *,
  62. use: bool = False,
  63. ) -> bool:
  64. """Log or use a single artifact.
  65. Args:
  66. run: The active W&B run.
  67. name: Artifact name.
  68. value: A KFP artifact or string path.
  69. use: If True, call `run.use_artifact` (for inputs); otherwise
  70. call `run.log_artifact` (for outputs).
  71. Returns:
  72. True on success, False if the artifact path is missing.
  73. """
  74. path = _get_artifact_path(value)
  75. if path is None:
  76. return False
  77. artifact = wandb.Artifact(name, type="kfp_artifact")
  78. artifact.add_file(path)
  79. if use:
  80. run.use_artifact(artifact)
  81. wandb.termlog(f"Using artifact: {name}")
  82. else:
  83. run.log_artifact(artifact)
  84. wandb.termlog(f"Logging artifact: {name}")
  85. return True
  86. class _KfpWandbLogger:
  87. """Classifies a KFP component's annotations and logs I/O to W&B.
  88. Inspects the function's type annotations at decoration time to
  89. partition parameters into scalar inputs, artifact inputs, and
  90. artifact outputs. Only parameter names are stored (annotation
  91. values are not needed after classification).
  92. Args:
  93. func: The KFP component function to classify.
  94. """
  95. def __init__(self, func: Callable) -> None:
  96. self._scalars_in: set[str] = set()
  97. self._artifacts_in: set[str] = set()
  98. self._artifacts_out: set[str] = set()
  99. for name, ann in func.__annotations__.items():
  100. if name == "return":
  101. continue
  102. elif _is_output_annotation(ann):
  103. self._artifacts_out.add(name)
  104. elif _is_input_annotation(ann):
  105. self._artifacts_in.add(name)
  106. else:
  107. self._scalars_in.add(name)
  108. def log_inputs(self, run: wandb.Run, bound_args: dict[str, Any]) -> None:
  109. """Log scalar configs and input artifacts for a component invocation.
  110. Args:
  111. run: The active W&B run.
  112. bound_args: Bound arguments from `inspect.Signature.bind`.
  113. """
  114. for name in self._scalars_in:
  115. if name in bound_args:
  116. value = bound_args[name]
  117. run.config[name] = value
  118. wandb.termlog(f"Setting config: {name} to {value}")
  119. for name in self._artifacts_in:
  120. if name in bound_args:
  121. try:
  122. _log_artifact(run, name, bound_args[name], use=True)
  123. except Exception as e:
  124. wandb.termwarn(f"Failed to log input artifact '{name}': {e}")
  125. def log_outputs(
  126. self,
  127. run: wandb.Run,
  128. func_name: str,
  129. result: Any,
  130. bound_args: dict[str, Any],
  131. ) -> None:
  132. """Log scalar results and output artifacts for a component invocation.
  133. Args:
  134. run: The active W&B run.
  135. func_name: The component function's name (used as log key prefix).
  136. result: The return value of the component function.
  137. bound_args: Bound arguments from `inspect.Signature.bind`.
  138. """
  139. if result is not None and not run._is_finished:
  140. if _is_namedtuple(result):
  141. run.log({f"{func_name}.{k}": v for k, v in zip(result._fields, result)})
  142. else:
  143. run.log({func_name: result})
  144. for name in self._artifacts_out:
  145. if name in bound_args:
  146. try:
  147. _log_artifact(run, name, bound_args[name], use=False)
  148. except Exception as e:
  149. wandb.termwarn(f"Failed to log output artifact '{name}': {e}")
  150. def wandb_log(
  151. func: Callable | None = None,
  152. ) -> Callable:
  153. """Wrap a KFP v2 component function and log to W&B.
  154. Compatible with `kfp>=2.0.0`. Automatically logs input parameters
  155. to `wandb.config` and output scalars via `wandb.log`. Artifacts
  156. annotated with KFP's `Input` / `Output` types are logged as W&B
  157. Artifacts.
  158. Example:
  159. ```python
  160. from kfp import dsl
  161. from wandb.integration.kfp import wandb_log
  162. @dsl.component
  163. @wandb_log
  164. def add(a: float, b: float) -> float:
  165. return a + b
  166. ```
  167. """
  168. def decorator(func: Callable) -> Callable:
  169. logger = _KfpWandbLogger(func)
  170. func_sig = signature(func)
  171. @wraps(func)
  172. def wrapper(*args: Any, **kwargs: Any) -> Any:
  173. bound = func_sig.bind(*args, **kwargs)
  174. bound.apply_defaults()
  175. # WANDB_RUN_GROUP: standard W&B env var for grouping runs.
  176. # KFP_RUN_NAME: set by the KFP orchestrator at container runtime.
  177. # ARGO_WORKFLOW_NAME: set by Argo Workflows (KFP's execution backend).
  178. wandb_group = (
  179. os.getenv("WANDB_RUN_GROUP")
  180. or os.getenv("KFP_RUN_NAME")
  181. or os.getenv("ARGO_WORKFLOW_NAME")
  182. )
  183. with wandb.init(
  184. job_type=func.__name__,
  185. group=wandb_group,
  186. ) as run:
  187. kubeflow_url = os.getenv("WANDB_KUBEFLOW_URL")
  188. if kubeflow_url:
  189. run.config["LINK_TO_KUBEFLOW"] = kubeflow_url
  190. logger.log_inputs(run, bound.arguments)
  191. with wb_telemetry.context(run=run) as tel:
  192. tel.feature.kfp_wandb_log = True
  193. result = func(*bound.args, **bound.kwargs)
  194. logger.log_outputs(run, func.__name__, result, bound.arguments)
  195. return result
  196. # Checked by kfp_patch.py to detect decorated functions for wandb
  197. # package injection and decorator source serialization.
  198. wrapper._wandb_logged = True
  199. # KFP's executor calls inspect.getfullargspec() to discover component
  200. # parameters. Without this, the executor sees (*args, **kwargs) from
  201. # the wrapper instead of the real function signature.
  202. wrapper.__signature__ = func_sig
  203. return wrapper
  204. if func is None:
  205. return decorator
  206. else:
  207. return decorator(func)