| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- from __future__ import annotations
- from typing import Callable
- import wandb
- def full_path_exists(full_func: str) -> bool:
- """Return True if every component in a dotted path exists as a module attribute.
- Args:
- full_func: A dotted path such as `kfp.dsl.component_factory.create_component_from_func`.
- Returns:
- True if all intermediate modules and the final attribute exist.
- """
- components = full_func.split(".")
- for i in range(1, len(components)):
- parent = ".".join(components[:i])
- child = components[i]
- module = wandb.util.get_module(parent)
- if not module or not hasattr(module, child) or getattr(module, child) is None:
- return False
- return True
- def patch(module_name: str, func: Callable) -> bool:
- """Monkey-patch `func` onto `module_name`, keeping a backup for `unpatch`.
- Args:
- module_name: Dotted module path (e.g. `kfp.dsl.component_factory`).
- func: Replacement function. Its `__name__` must match the target
- attribute on the module.
- Returns:
- True if the patch was applied successfully.
- """
- module = wandb.util.get_module(module_name)
- success = False
- full_func = f"{module_name}.{func.__name__}"
- if not full_path_exists(full_func):
- wandb.termerror(
- f"Failed to patch {module_name}.{func.__name__}! "
- "Please check if this package/module is installed!"
- )
- else:
- wandb.patched.setdefault(module.__name__, [])
- if [module, func.__name__] not in wandb.patched[module.__name__]:
- setattr(module, f"orig_{func.__name__}", getattr(module, func.__name__))
- setattr(module, func.__name__, func)
- wandb.patched[module.__name__].append([module, func.__name__])
- success = True
- return success
- def unpatch(module_name: str) -> None:
- """Restore original functions previously replaced by `patch`.
- Args:
- module_name: Dotted module path that was previously patched.
- """
- if module_name in wandb.patched:
- for module, func in wandb.patched[module_name]:
- setattr(module, func, getattr(module, f"orig_{func}"))
- wandb.patched[module_name] = []
|