__init__.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from __future__ import annotations
  2. __all__ = ["wandb_log", "unpatch_kfp"]
  3. from typing import TYPE_CHECKING, Any, Callable
  4. from .kfp_patch import patch_kfp, unpatch_kfp
  5. if TYPE_CHECKING:
  6. from typing import ParamSpec, TypeVar, overload
  7. _P = ParamSpec("_P")
  8. _T = TypeVar("_T")
  9. @overload
  10. def wandb_log(func: Callable[_P, _T]) -> Callable[_P, _T]: ...
  11. @overload
  12. def wandb_log(**kwargs: Any) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ...
  13. try:
  14. from kfp import __version__ as _kfp_version
  15. from packaging.version import parse
  16. _KFP_V2 = parse(_kfp_version) >= parse("2.0.0")
  17. except (ImportError, ValueError):
  18. _KFP_V2 = False
  19. def wandb_log(
  20. func: Callable | None = None,
  21. **kwargs: Any,
  22. ) -> Callable:
  23. """Decorator that wraps a KFP component function and logs to W&B.
  24. Automatically detects the installed KFP version and delegates to the
  25. appropriate implementation:
  26. - kfp >= 2.0.0: logs input parameters to `wandb.config`, output
  27. scalars via `wandb.log`, and Input/Output artifacts as W&B
  28. Artifacts.
  29. - kfp < 2.0.0 (deprecated): legacy v1 logging behaviour.
  30. Example:
  31. ```python
  32. from kfp import dsl
  33. from wandb.integration.kfp import wandb_log
  34. @dsl.component
  35. @wandb_log
  36. def add(a: float, b: float) -> float:
  37. return a + b
  38. ```
  39. """
  40. if _KFP_V2:
  41. from .wandb_log_v2 import wandb_log
  42. else:
  43. from .wandb_log_v1 import wandb_log
  44. return wandb_log(func, **kwargs)
  45. patch_kfp()