wandb_require_helpers.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import os
  2. from functools import wraps
  3. from typing import Any, Callable, TypeVar, cast
  4. FuncT = TypeVar("FuncT", bound=Callable[..., Any])
  5. requirement_env_var_mapping: dict[str, str] = {
  6. "report-editing:v0": "WANDB_REQUIRE_REPORT_EDITING_V0"
  7. }
  8. def requires(requirement: str) -> FuncT: # type: ignore
  9. """Decorate functions to gate features with wandb.require."""
  10. env_var = requirement_env_var_mapping[requirement]
  11. def deco(func: FuncT) -> FuncT:
  12. @wraps(func)
  13. def wrapper(*args: Any, **kwargs: Any) -> Any:
  14. if not os.getenv(env_var):
  15. raise Exception(
  16. f"You need to enable this feature with `wandb.require({requirement!r})`"
  17. )
  18. return func(*args, **kwargs)
  19. return cast(FuncT, wrapper)
  20. return cast(FuncT, deco)
  21. class RequiresMixin:
  22. requirement = ""
  23. def __init__(self) -> None:
  24. self._check_if_requirements_met()
  25. def __post_init__(self) -> None:
  26. self._check_if_requirements_met()
  27. def _check_if_requirements_met(self) -> None:
  28. env_var = requirement_env_var_mapping[self.requirement]
  29. if not os.getenv(env_var):
  30. raise Exception(
  31. f'You must explicitly enable this feature with `wandb.require("{self.requirement})"'
  32. )