| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- from __future__ import annotations
- import os
- from functools import wraps
- from inspect import signature
- from typing import Any, Callable
- import kfp.dsl
- from kfp.dsl.types.type_annotations import (
- InputPath,
- OutputPath,
- is_artifact_wrapped_in_Input,
- is_artifact_wrapped_in_Output,
- )
- import wandb
- from wandb.sdk.lib import telemetry as wb_telemetry
- def _is_namedtuple(x: Any) -> bool:
- """Return True if `x` is an instance of a NamedTuple.
- Python does not provide a common base class for named tuples created
- via `collections.namedtuple` or `typing.NamedTuple`, so there is
- no way to use `isinstance`. Instead we check that the type is a
- `tuple` subclass whose `_fields` attribute is a tuple of strings,
- following the documented NamedTuple API:
- https://docs.python.org/3/library/collections.html#collections.somenamedtuple._fields
- KFP uses NamedTuples for multi-output components. The decorator sees
- the actual return value at runtime and unpacks its fields for logging.
- KFP's own executor processes type annotations separately for
- serialization, so runtime value detection is the correct approach here.
- Args:
- x: The value to check.
- Returns:
- True if `x` is a NamedTuple instance.
- """
- t = type(x)
- if not issubclass(t, tuple):
- return False
- fields = getattr(t, "_fields", None)
- if not isinstance(fields, tuple):
- return False
- return all(isinstance(n, str) for n in fields)
- def _is_output_annotation(ann: Any) -> bool:
- """Return True if `ann` is a KFP Output or OutputPath annotation."""
- return is_artifact_wrapped_in_Output(ann) or isinstance(ann, OutputPath)
- def _is_input_annotation(ann: Any) -> bool:
- """Return True if `ann` is a KFP Input or InputPath annotation."""
- return is_artifact_wrapped_in_Input(ann) or isinstance(ann, InputPath)
- def _get_artifact_path(value: Any) -> str | None:
- """Return the local file path for a KFP artifact value, or None.
- Args:
- value: A KFP artifact instance or a string file path.
- Returns:
- The local path if the artifact/file exists on disk, otherwise None.
- """
- if isinstance(value, kfp.dsl.Artifact):
- return value.path if os.path.exists(value.path) else None
- if isinstance(value, str) and os.path.exists(value):
- return value
- return None
- def _log_artifact(
- run: wandb.Run,
- name: str,
- value: Any,
- *,
- use: bool = False,
- ) -> bool:
- """Log or use a single artifact.
- Args:
- run: The active W&B run.
- name: Artifact name.
- value: A KFP artifact or string path.
- use: If True, call `run.use_artifact` (for inputs); otherwise
- call `run.log_artifact` (for outputs).
- Returns:
- True on success, False if the artifact path is missing.
- """
- path = _get_artifact_path(value)
- if path is None:
- return False
- artifact = wandb.Artifact(name, type="kfp_artifact")
- artifact.add_file(path)
- if use:
- run.use_artifact(artifact)
- wandb.termlog(f"Using artifact: {name}")
- else:
- run.log_artifact(artifact)
- wandb.termlog(f"Logging artifact: {name}")
- return True
- class _KfpWandbLogger:
- """Classifies a KFP component's annotations and logs I/O to W&B.
- Inspects the function's type annotations at decoration time to
- partition parameters into scalar inputs, artifact inputs, and
- artifact outputs. Only parameter names are stored (annotation
- values are not needed after classification).
- Args:
- func: The KFP component function to classify.
- """
- def __init__(self, func: Callable) -> None:
- self._scalars_in: set[str] = set()
- self._artifacts_in: set[str] = set()
- self._artifacts_out: set[str] = set()
- for name, ann in func.__annotations__.items():
- if name == "return":
- continue
- elif _is_output_annotation(ann):
- self._artifacts_out.add(name)
- elif _is_input_annotation(ann):
- self._artifacts_in.add(name)
- else:
- self._scalars_in.add(name)
- def log_inputs(self, run: wandb.Run, bound_args: dict[str, Any]) -> None:
- """Log scalar configs and input artifacts for a component invocation.
- Args:
- run: The active W&B run.
- bound_args: Bound arguments from `inspect.Signature.bind`.
- """
- for name in self._scalars_in:
- if name in bound_args:
- value = bound_args[name]
- run.config[name] = value
- wandb.termlog(f"Setting config: {name} to {value}")
- for name in self._artifacts_in:
- if name in bound_args:
- try:
- _log_artifact(run, name, bound_args[name], use=True)
- except Exception as e:
- wandb.termwarn(f"Failed to log input artifact '{name}': {e}")
- def log_outputs(
- self,
- run: wandb.Run,
- func_name: str,
- result: Any,
- bound_args: dict[str, Any],
- ) -> None:
- """Log scalar results and output artifacts for a component invocation.
- Args:
- run: The active W&B run.
- func_name: The component function's name (used as log key prefix).
- result: The return value of the component function.
- bound_args: Bound arguments from `inspect.Signature.bind`.
- """
- if result is not None and not run._is_finished:
- if _is_namedtuple(result):
- run.log({f"{func_name}.{k}": v for k, v in zip(result._fields, result)})
- else:
- run.log({func_name: result})
- for name in self._artifacts_out:
- if name in bound_args:
- try:
- _log_artifact(run, name, bound_args[name], use=False)
- except Exception as e:
- wandb.termwarn(f"Failed to log output artifact '{name}': {e}")
- def wandb_log(
- func: Callable | None = None,
- ) -> Callable:
- """Wrap a KFP v2 component function and log to W&B.
- Compatible with `kfp>=2.0.0`. Automatically logs input parameters
- to `wandb.config` and output scalars via `wandb.log`. Artifacts
- annotated with KFP's `Input` / `Output` types are logged as W&B
- Artifacts.
- Example:
- ```python
- from kfp import dsl
- from wandb.integration.kfp import wandb_log
- @dsl.component
- @wandb_log
- def add(a: float, b: float) -> float:
- return a + b
- ```
- """
- def decorator(func: Callable) -> Callable:
- logger = _KfpWandbLogger(func)
- func_sig = signature(func)
- @wraps(func)
- def wrapper(*args: Any, **kwargs: Any) -> Any:
- bound = func_sig.bind(*args, **kwargs)
- bound.apply_defaults()
- # WANDB_RUN_GROUP: standard W&B env var for grouping runs.
- # KFP_RUN_NAME: set by the KFP orchestrator at container runtime.
- # ARGO_WORKFLOW_NAME: set by Argo Workflows (KFP's execution backend).
- wandb_group = (
- os.getenv("WANDB_RUN_GROUP")
- or os.getenv("KFP_RUN_NAME")
- or os.getenv("ARGO_WORKFLOW_NAME")
- )
- with wandb.init(
- job_type=func.__name__,
- group=wandb_group,
- ) as run:
- kubeflow_url = os.getenv("WANDB_KUBEFLOW_URL")
- if kubeflow_url:
- run.config["LINK_TO_KUBEFLOW"] = kubeflow_url
- logger.log_inputs(run, bound.arguments)
- with wb_telemetry.context(run=run) as tel:
- tel.feature.kfp_wandb_log = True
- result = func(*bound.args, **bound.kwargs)
- logger.log_outputs(run, func.__name__, result, bound.arguments)
- return result
- # Checked by kfp_patch.py to detect decorated functions for wandb
- # package injection and decorator source serialization.
- wrapper._wandb_logged = True
- # KFP's executor calls inspect.getfullargspec() to discover component
- # parameters. Without this, the executor sees (*args, **kwargs) from
- # the wrapper instead of the real function signature.
- wrapper.__signature__ = func_sig
- return wrapper
- if func is None:
- return decorator
- else:
- return decorator(func)
|