| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- def wandb_log( # noqa: C901
- func=None,
- # /, # py38 only
- log_component_file=True,
- ):
- """Wrap a kfp v1 python functional component and log to W&B.
- Requires kfp<2.0.0. Deprecated -- please upgrade to kfp>=2.0.0.
- """
- import json
- import os
- from functools import wraps
- from inspect import Parameter, signature
- from kfp import components
- from kfp.components import (
- InputArtifact,
- InputBinaryFile,
- InputPath,
- InputTextFile,
- OutputArtifact,
- OutputBinaryFile,
- OutputPath,
- OutputTextFile,
- )
- import wandb
- from wandb.proto.wandb_telemetry_pb2 import Deprecated
- from wandb.sdk.lib import telemetry as wb_telemetry
- from wandb.sdk.lib.deprecation import warn_and_record_deprecation
- output_types = (OutputArtifact, OutputBinaryFile, OutputPath, OutputTextFile)
- input_types = (InputArtifact, InputBinaryFile, InputPath, InputTextFile)
- def isinstance_namedtuple(x):
- t = type(x)
- b = t.__bases__
- if len(b) != 1 or b[0] is not tuple:
- return False
- f = getattr(t, "_fields", None)
- if not isinstance(f, tuple):
- return False
- return all(isinstance(n, str) for n in f)
- def get_iframe_html(run):
- return f'<iframe src="{run.url}?kfp=true" style="border:none;width:100%;height:100%;min-width:900px;min-height:600px;"></iframe>'
- def get_link_back_to_kubeflow():
- wandb_kubeflow_url = os.getenv("WANDB_KUBEFLOW_URL")
- return f"{wandb_kubeflow_url}/#/runs/details/{{workflow.uid}}"
- def log_input_scalar(name, data, run=None):
- run.config[name] = data
- wandb.termlog(f"Setting config: {name} to {data}")
- def log_input_artifact(name, data, type, run=None):
- artifact = wandb.Artifact(name, type=type)
- artifact.add_file(data)
- run.use_artifact(artifact)
- wandb.termlog(f"Using artifact: {name}")
- def log_output_scalar(name, data, run=None):
- if isinstance_namedtuple(data):
- for k, v in zip(data._fields, data):
- run.log({f"{func.__name__}.{k}": v})
- else:
- run.log({name: data})
- def log_output_artifact(name, data, type, run=None):
- artifact = wandb.Artifact(name, type=type)
- artifact.add_file(data)
- run.log_artifact(artifact)
- wandb.termlog(f"Logging artifact: {name}")
- def _log_component_file(func, run=None):
- name = func.__name__
- output_component_file = f"{name}.yml"
- components._python_op.func_to_component_file(func, output_component_file)
- artifact = wandb.Artifact(name, type="kubeflow_component_file")
- artifact.add_file(output_component_file)
- run.log_artifact(artifact)
- wandb.termlog(f"Logging component file: {output_component_file}")
- # Add `mlpipeline_ui_metadata_path` to signature to show W&B run in "ML Visualizations tab"
- sig = signature(func)
- no_default = []
- has_default = []
- for param in sig.parameters.values():
- if param.default is param.empty:
- no_default.append(param)
- else:
- has_default.append(param)
- new_params = tuple(
- (
- *no_default,
- Parameter(
- "mlpipeline_ui_metadata_path",
- annotation=OutputPath(),
- kind=Parameter.POSITIONAL_OR_KEYWORD,
- ),
- *has_default,
- )
- )
- new_sig = sig.replace(parameters=new_params)
- new_anns = {param.name: param.annotation for param in new_params}
- if "return" in func.__annotations__:
- new_anns["return"] = func.__annotations__["return"]
- def decorator(func):
- input_scalars = {}
- input_artifacts = {}
- output_scalars = {}
- output_artifacts = {}
- for name, ann in func.__annotations__.items():
- if name == "return":
- output_scalars[name] = ann
- elif isinstance(ann, output_types):
- output_artifacts[name] = ann
- elif isinstance(ann, input_types):
- input_artifacts[name] = ann
- else:
- input_scalars[name] = ann
- @wraps(func)
- def wrapper(*args, **kwargs):
- bound = new_sig.bind(*args, **kwargs)
- bound.apply_defaults()
- mlpipeline_ui_metadata_path = bound.arguments["mlpipeline_ui_metadata_path"]
- del bound.arguments["mlpipeline_ui_metadata_path"]
- with wandb.init(
- job_type=func.__name__,
- group="{{workflow.annotations.pipelines.kubeflow.org/run_name}}",
- ) as run:
- warn_and_record_deprecation(
- feature=Deprecated(kfp_v1_wandb_log=True),
- message=(
- "KFP v1 (kfp<2.0.0) support for @wandb_log is deprecated "
- "and will be removed in a future release. "
- "Please upgrade to kfp>=2.0.0."
- ),
- run=run,
- )
- kubeflow_url = get_link_back_to_kubeflow()
- run.notes = kubeflow_url
- run.config["LINK_TO_KUBEFLOW_RUN"] = kubeflow_url
- iframe_html = get_iframe_html(run)
- metadata = {
- "outputs": [
- {
- "type": "markdown",
- "storage": "inline",
- "source": iframe_html,
- }
- ]
- }
- with open(mlpipeline_ui_metadata_path, "w") as metadata_file:
- json.dump(metadata, metadata_file)
- if log_component_file:
- _log_component_file(func, run=run)
- for name, _ in input_scalars.items():
- log_input_scalar(name, kwargs[name], run)
- for name, ann in input_artifacts.items():
- log_input_artifact(name, kwargs[name], ann.type, run)
- with wb_telemetry.context(run=run) as tel:
- tel.feature.kfp_wandb_log = True
- result = func(*bound.args, **bound.kwargs)
- for name, _ in output_scalars.items():
- log_output_scalar(name, result, run)
- for name, ann in output_artifacts.items():
- log_output_artifact(name, kwargs[name], ann.type, run)
- return result
- wrapper.__signature__ = new_sig
- wrapper.__annotations__ = new_anns
- return wrapper
- if func is None:
- return decorator
- else:
- return decorator(func)
|