kfp_patch.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. from __future__ import annotations
  2. import inspect
  3. import itertools
  4. import textwrap
  5. from collections.abc import Mapping
  6. from typing import Callable
  7. import wandb
  8. from ._patch_utils import patch, unpatch
  9. try:
  10. from kfp import __version__ as kfp_version
  11. from packaging.version import parse
  12. _KFP_V2 = parse(kfp_version) >= parse("2.0.0")
  13. except (ImportError, ValueError):
  14. _KFP_V2 = False
  15. # Build _wandb_logging_extras: the decorator source injected into KFP
  16. # container scripts at compile time. Both v1 and v2 follow the same
  17. # pattern: an import preamble + the serialized decorator code.
  18. _log_module = None
  19. _import_preamble = ""
  20. _component_factory = None
  21. if _KFP_V2:
  22. try:
  23. from kfp.dsl import component_factory as _component_factory
  24. except ImportError:
  25. wandb.termerror(
  26. "kfp>=2.0.0 detected but failed to import kfp internals. "
  27. "Please ensure kfp is installed correctly."
  28. )
  29. else:
  30. from . import wandb_log_v2 as _log_module
  31. _import_preamble = """\
  32. import os
  33. import typing
  34. from typing import Any, NamedTuple
  35. import wandb"""
  36. else:
  37. try:
  38. from kfp import __version__ as kfp_version
  39. from kfp.components import structures
  40. from kfp.components._components import _create_task_factory_from_component_spec
  41. from kfp.components._python_op import _func_to_component_spec
  42. from packaging.version import parse
  43. MIN_KFP_VERSION = "1.6.1"
  44. if parse(kfp_version) < parse(MIN_KFP_VERSION):
  45. wandb.termwarn(
  46. f"Your version of kfp {kfp_version} may not work. "
  47. f"This integration requires kfp>={MIN_KFP_VERSION}"
  48. )
  49. except ImportError:
  50. wandb.termerror("kfp not found! Please `pip install kfp`")
  51. from . import wandb_log_v1 as _log_module
  52. _import_preamble = """\
  53. import typing
  54. from typing import NamedTuple
  55. import collections
  56. from collections import namedtuple
  57. import kfp
  58. from kfp import components
  59. from kfp.components import InputPath, OutputPath
  60. import wandb"""
  61. if _log_module:
  62. _decorator_code = inspect.getsource(_log_module.wandb_log)
  63. _wandb_logging_extras = f"{_import_preamble}\n\n{_decorator_code}\n"
  64. else:
  65. _wandb_logging_extras = ""
  66. # ---------------------------------------------------------------------------
  67. # v1 patch functions
  68. # ---------------------------------------------------------------------------
  69. def _unpatch_kfp_v1() -> None:
  70. """Remove v1 monkey-patches from kfp.components."""
  71. unpatch("kfp.components")
  72. unpatch("kfp.components._python_op")
  73. unpatch("wandb.integration.kfp")
  74. def _patch_kfp_v1() -> None:
  75. """Apply v1 monkey-patches to kfp.components."""
  76. to_patch = [
  77. ("kfp.components", _v1_create_component_from_func),
  78. ("kfp.components._python_op", _v1_create_component_from_func),
  79. ("kfp.components._python_op", _v1_get_function_source_definition),
  80. ("kfp.components._python_op", _v1_strip_type_hints),
  81. ]
  82. successes = []
  83. for module_name, func in to_patch:
  84. success = patch(module_name, func)
  85. successes.append(success)
  86. if not all(successes):
  87. wandb.termerror(
  88. "Failed to patch one or more kfp functions. "
  89. "Patching @wandb_log decorator to no-op."
  90. )
  91. patch("wandb.integration.kfp", _v1_wandb_log_noop)
  92. def _v1_wandb_log_noop(
  93. func: Callable | None = None,
  94. log_component_file: bool = True,
  95. ) -> Callable:
  96. """No-op fallback decorator used when v1 patching fails."""
  97. from functools import wraps
  98. def decorator(func: Callable) -> Callable:
  99. @wraps(func)
  100. def wrapper(*args, **kwargs):
  101. return func(*args, **kwargs)
  102. return wrapper
  103. if func is None:
  104. return decorator
  105. else:
  106. return decorator(func)
  107. def _v1_get_function_source_definition(func: Callable) -> str:
  108. """Get the source code of a function, preserving `@wandb_log`.
  109. Modified from KFP v1. Original source:
  110. https://github.com/kubeflow/pipelines/blob/b6406b02f45cdb195c7b99e2f6d22bf85b12268b/sdk/python/kfp/components/_python_op.py#L300-L319
  111. Args:
  112. func: The function whose source to extract.
  113. Returns:
  114. The dedented source code starting from `@wandb_log` or `def`.
  115. Raises:
  116. ValueError: If the source cannot be cleaned up.
  117. """
  118. func_code = inspect.getsource(func)
  119. func_code = textwrap.dedent(func_code)
  120. func_code_lines = func_code.split("\n")
  121. func_code_lines = itertools.dropwhile(
  122. lambda x: not (x.startswith(("def", "@wandb_log"))),
  123. func_code_lines,
  124. )
  125. if not func_code_lines:
  126. raise ValueError(
  127. f'Failed to dedent and clean up the source of function "{func.__name__}". '
  128. "It is probably not properly indented."
  129. )
  130. return "\n".join(func_code_lines)
  131. def _v1_create_component_from_func(
  132. func: Callable,
  133. output_component_file: str | None = None,
  134. base_image: str | None = None,
  135. packages_to_install: list[str] | None = None,
  136. annotations: Mapping[str, str] | None = None,
  137. ) -> Callable:
  138. """Convert a Python function to a KFP v1 component task factory.
  139. Modified from KFP v1. Original source:
  140. https://github.com/kubeflow/pipelines/blob/b6406b02f45cdb195c7b99e2f6d22bf85b12268b/sdk/python/kfp/components/_python_op.py#L998-L1110
  141. Args:
  142. func: The python function to convert.
  143. output_component_file: Write a component definition to a local file.
  144. base_image: Custom Docker container image for the component.
  145. packages_to_install: Python packages to pip install before execution.
  146. annotations: Arbitrary key-value data for the component specification.
  147. Returns:
  148. A factory function with a strongly-typed signature taken from the
  149. python function.
  150. """
  151. core_packages = ["wandb", "kfp"]
  152. if not packages_to_install:
  153. packages_to_install = core_packages
  154. else:
  155. packages_to_install += core_packages
  156. component_spec = _func_to_component_spec(
  157. func=func,
  158. extra_code=_wandb_logging_extras,
  159. base_image=base_image,
  160. packages_to_install=packages_to_install,
  161. )
  162. if annotations:
  163. component_spec.metadata = structures.MetadataSpec(
  164. annotations=annotations,
  165. )
  166. if output_component_file:
  167. component_spec.save(output_component_file)
  168. return _create_task_factory_from_component_spec(component_spec)
  169. def _v1_strip_type_hints(source_code: str) -> str:
  170. """No-op replacement that preserves type hints in component source.
  171. Modified from KFP v1. Original source:
  172. https://github.com/kubeflow/pipelines/blob/b6406b02f45cdb195c7b99e2f6d22bf85b12268b/sdk/python/kfp/components/_python_op.py#L237-L248
  173. Args:
  174. source_code: The source code string.
  175. Returns:
  176. The source code unchanged.
  177. """
  178. return source_code
  179. _v1_get_function_source_definition.__name__ = "_get_function_source_definition"
  180. _v1_create_component_from_func.__name__ = "create_component_from_func"
  181. _v1_strip_type_hints.__name__ = "strip_type_hints"
  182. # ---------------------------------------------------------------------------
  183. # v2 patch functions (delegated to _kfp_v2_patch module)
  184. # ---------------------------------------------------------------------------
  185. def _unpatch_kfp_v2() -> None:
  186. """Remove v2 monkey-patches from kfp.dsl.component_factory."""
  187. unpatch("kfp.dsl.component_factory")
  188. def _patch_kfp_v2() -> None:
  189. """Apply v2 monkey-patches to kfp.dsl.component_factory."""
  190. if _component_factory is None:
  191. return
  192. from . import _kfp_v2_patch
  193. _kfp_v2_patch._orig_create = _component_factory.create_component_from_func
  194. _kfp_v2_patch._orig_get_cmd = (
  195. _component_factory._get_command_and_args_for_lightweight_component
  196. )
  197. _kfp_v2_patch._wandb_logging_extras = _wandb_logging_extras
  198. to_patch = [
  199. ("kfp.dsl.component_factory", _kfp_v2_patch.get_function_source_definition),
  200. ("kfp.dsl.component_factory", _kfp_v2_patch.create_component_from_func),
  201. (
  202. "kfp.dsl.component_factory",
  203. _kfp_v2_patch.get_command_and_args_for_lightweight_component,
  204. ),
  205. ]
  206. successes = []
  207. for module_name, func in to_patch:
  208. success = patch(module_name, func)
  209. successes.append(success)
  210. if not all(successes):
  211. wandb.termerror(
  212. "Failed to patch one or more kfp v2 functions. "
  213. "@wandb_log may not work correctly with @dsl.component."
  214. )
  215. # ---------------------------------------------------------------------------
  216. # Public API
  217. # ---------------------------------------------------------------------------
  218. def unpatch_kfp() -> None:
  219. """Undo all KFP monkey-patches applied by `patch_kfp`."""
  220. if _KFP_V2:
  221. _unpatch_kfp_v2()
  222. else:
  223. _unpatch_kfp_v1()
  224. def patch_kfp() -> None:
  225. """Apply KFP monkey-patches for the detected KFP version."""
  226. if _KFP_V2:
  227. _patch_kfp_v2()
  228. else:
  229. _patch_kfp_v1()