wandb_log_v1.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. def wandb_log( # noqa: C901
  2. func=None,
  3. # /, # py38 only
  4. log_component_file=True,
  5. ):
  6. """Wrap a kfp v1 python functional component and log to W&B.
  7. Requires kfp<2.0.0. Deprecated -- please upgrade to kfp>=2.0.0.
  8. """
  9. import json
  10. import os
  11. from functools import wraps
  12. from inspect import Parameter, signature
  13. from kfp import components
  14. from kfp.components import (
  15. InputArtifact,
  16. InputBinaryFile,
  17. InputPath,
  18. InputTextFile,
  19. OutputArtifact,
  20. OutputBinaryFile,
  21. OutputPath,
  22. OutputTextFile,
  23. )
  24. import wandb
  25. from wandb.proto.wandb_telemetry_pb2 import Deprecated
  26. from wandb.sdk.lib import telemetry as wb_telemetry
  27. from wandb.sdk.lib.deprecation import warn_and_record_deprecation
  28. output_types = (OutputArtifact, OutputBinaryFile, OutputPath, OutputTextFile)
  29. input_types = (InputArtifact, InputBinaryFile, InputPath, InputTextFile)
  30. def isinstance_namedtuple(x):
  31. t = type(x)
  32. b = t.__bases__
  33. if len(b) != 1 or b[0] is not tuple:
  34. return False
  35. f = getattr(t, "_fields", None)
  36. if not isinstance(f, tuple):
  37. return False
  38. return all(isinstance(n, str) for n in f)
  39. def get_iframe_html(run):
  40. return f'<iframe src="{run.url}?kfp=true" style="border:none;width:100%;height:100%;min-width:900px;min-height:600px;"></iframe>'
  41. def get_link_back_to_kubeflow():
  42. wandb_kubeflow_url = os.getenv("WANDB_KUBEFLOW_URL")
  43. return f"{wandb_kubeflow_url}/#/runs/details/{{workflow.uid}}"
  44. def log_input_scalar(name, data, run=None):
  45. run.config[name] = data
  46. wandb.termlog(f"Setting config: {name} to {data}")
  47. def log_input_artifact(name, data, type, run=None):
  48. artifact = wandb.Artifact(name, type=type)
  49. artifact.add_file(data)
  50. run.use_artifact(artifact)
  51. wandb.termlog(f"Using artifact: {name}")
  52. def log_output_scalar(name, data, run=None):
  53. if isinstance_namedtuple(data):
  54. for k, v in zip(data._fields, data):
  55. run.log({f"{func.__name__}.{k}": v})
  56. else:
  57. run.log({name: data})
  58. def log_output_artifact(name, data, type, run=None):
  59. artifact = wandb.Artifact(name, type=type)
  60. artifact.add_file(data)
  61. run.log_artifact(artifact)
  62. wandb.termlog(f"Logging artifact: {name}")
  63. def _log_component_file(func, run=None):
  64. name = func.__name__
  65. output_component_file = f"{name}.yml"
  66. components._python_op.func_to_component_file(func, output_component_file)
  67. artifact = wandb.Artifact(name, type="kubeflow_component_file")
  68. artifact.add_file(output_component_file)
  69. run.log_artifact(artifact)
  70. wandb.termlog(f"Logging component file: {output_component_file}")
  71. # Add `mlpipeline_ui_metadata_path` to signature to show W&B run in "ML Visualizations tab"
  72. sig = signature(func)
  73. no_default = []
  74. has_default = []
  75. for param in sig.parameters.values():
  76. if param.default is param.empty:
  77. no_default.append(param)
  78. else:
  79. has_default.append(param)
  80. new_params = tuple(
  81. (
  82. *no_default,
  83. Parameter(
  84. "mlpipeline_ui_metadata_path",
  85. annotation=OutputPath(),
  86. kind=Parameter.POSITIONAL_OR_KEYWORD,
  87. ),
  88. *has_default,
  89. )
  90. )
  91. new_sig = sig.replace(parameters=new_params)
  92. new_anns = {param.name: param.annotation for param in new_params}
  93. if "return" in func.__annotations__:
  94. new_anns["return"] = func.__annotations__["return"]
  95. def decorator(func):
  96. input_scalars = {}
  97. input_artifacts = {}
  98. output_scalars = {}
  99. output_artifacts = {}
  100. for name, ann in func.__annotations__.items():
  101. if name == "return":
  102. output_scalars[name] = ann
  103. elif isinstance(ann, output_types):
  104. output_artifacts[name] = ann
  105. elif isinstance(ann, input_types):
  106. input_artifacts[name] = ann
  107. else:
  108. input_scalars[name] = ann
  109. @wraps(func)
  110. def wrapper(*args, **kwargs):
  111. bound = new_sig.bind(*args, **kwargs)
  112. bound.apply_defaults()
  113. mlpipeline_ui_metadata_path = bound.arguments["mlpipeline_ui_metadata_path"]
  114. del bound.arguments["mlpipeline_ui_metadata_path"]
  115. with wandb.init(
  116. job_type=func.__name__,
  117. group="{{workflow.annotations.pipelines.kubeflow.org/run_name}}",
  118. ) as run:
  119. warn_and_record_deprecation(
  120. feature=Deprecated(kfp_v1_wandb_log=True),
  121. message=(
  122. "KFP v1 (kfp<2.0.0) support for @wandb_log is deprecated "
  123. "and will be removed in a future release. "
  124. "Please upgrade to kfp>=2.0.0."
  125. ),
  126. run=run,
  127. )
  128. kubeflow_url = get_link_back_to_kubeflow()
  129. run.notes = kubeflow_url
  130. run.config["LINK_TO_KUBEFLOW_RUN"] = kubeflow_url
  131. iframe_html = get_iframe_html(run)
  132. metadata = {
  133. "outputs": [
  134. {
  135. "type": "markdown",
  136. "storage": "inline",
  137. "source": iframe_html,
  138. }
  139. ]
  140. }
  141. with open(mlpipeline_ui_metadata_path, "w") as metadata_file:
  142. json.dump(metadata, metadata_file)
  143. if log_component_file:
  144. _log_component_file(func, run=run)
  145. for name, _ in input_scalars.items():
  146. log_input_scalar(name, kwargs[name], run)
  147. for name, ann in input_artifacts.items():
  148. log_input_artifact(name, kwargs[name], ann.type, run)
  149. with wb_telemetry.context(run=run) as tel:
  150. tel.feature.kfp_wandb_log = True
  151. result = func(*bound.args, **bound.kwargs)
  152. for name, _ in output_scalars.items():
  153. log_output_scalar(name, result, run)
  154. for name, ann in output_artifacts.items():
  155. log_output_artifact(name, kwargs[name], ann.type, run)
  156. return result
  157. wrapper.__signature__ = new_sig
  158. wrapper.__annotations__ = new_anns
  159. return wrapper
  160. if func is None:
  161. return decorator
  162. else:
  163. return decorator(func)