context.py 954 B

12345678910111213141516171819202122232425262728293031
  1. import functools
  2. from collections.abc import Callable
  3. from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI
  4. from torchgen.context import native_function_manager
  5. from torchgen.utils import T
  6. # Like tools.api.context.with_native_function, but for
  7. # NativeFunctionWithDifferentiabilityInfo.
  8. def with_native_function_with_differentiability_info(
  9. func: Callable[[NFWDI], T],
  10. ) -> Callable[[NFWDI], T]:
  11. @functools.wraps(func)
  12. def wrapper(f: NFWDI) -> T:
  13. with native_function_manager(f.func):
  14. return func(f)
  15. return wrapper
  16. # Like the above but with an additional dispatch key string argument
  17. def with_native_function_with_differentiability_info_and_key(
  18. func: Callable[[NFWDI, str], T],
  19. ) -> Callable[[NFWDI, str], T]:
  20. @functools.wraps(func)
  21. def wrapper(f: NFWDI, key: str) -> T:
  22. with native_function_manager(f.func):
  23. return func(f, key)
  24. return wrapper