_patch_utils.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from __future__ import annotations
  2. from typing import Callable
  3. import wandb
  4. def full_path_exists(full_func: str) -> bool:
  5. """Return True if every component in a dotted path exists as a module attribute.
  6. Args:
  7. full_func: A dotted path such as `kfp.dsl.component_factory.create_component_from_func`.
  8. Returns:
  9. True if all intermediate modules and the final attribute exist.
  10. """
  11. components = full_func.split(".")
  12. for i in range(1, len(components)):
  13. parent = ".".join(components[:i])
  14. child = components[i]
  15. module = wandb.util.get_module(parent)
  16. if not module or not hasattr(module, child) or getattr(module, child) is None:
  17. return False
  18. return True
  19. def patch(module_name: str, func: Callable) -> bool:
  20. """Monkey-patch `func` onto `module_name`, keeping a backup for `unpatch`.
  21. Args:
  22. module_name: Dotted module path (e.g. `kfp.dsl.component_factory`).
  23. func: Replacement function. Its `__name__` must match the target
  24. attribute on the module.
  25. Returns:
  26. True if the patch was applied successfully.
  27. """
  28. module = wandb.util.get_module(module_name)
  29. success = False
  30. full_func = f"{module_name}.{func.__name__}"
  31. if not full_path_exists(full_func):
  32. wandb.termerror(
  33. f"Failed to patch {module_name}.{func.__name__}! "
  34. "Please check if this package/module is installed!"
  35. )
  36. else:
  37. wandb.patched.setdefault(module.__name__, [])
  38. if [module, func.__name__] not in wandb.patched[module.__name__]:
  39. setattr(module, f"orig_{func.__name__}", getattr(module, func.__name__))
  40. setattr(module, func.__name__, func)
  41. wandb.patched[module.__name__].append([module, func.__name__])
  42. success = True
  43. return success
  44. def unpatch(module_name: str) -> None:
  45. """Restore original functions previously replaced by `patch`.
  46. Args:
  47. module_name: Dotted module path that was previously patched.
  48. """
  49. if module_name in wandb.patched:
  50. for module, func in wandb.patched[module_name]:
  51. setattr(module, func, getattr(module, f"orig_{func}"))
  52. wandb.patched[module_name] = []