utils.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import contextlib
  2. from collections.abc import Generator
  3. from typing import Any, Union
  4. import torch
  5. from torch._C._functorch import (
  6. get_single_level_autograd_function_allowed,
  7. set_single_level_autograd_function_allowed,
  8. unwrap_if_dead,
  9. )
  10. from torch.utils._exposed_in import exposed_in
  11. __all__ = [
  12. "exposed_in",
  13. "argnums_t",
  14. "enable_single_level_autograd_function",
  15. "unwrap_dead_wrappers",
  16. ]
  17. @contextlib.contextmanager
  18. def enable_single_level_autograd_function() -> Generator[None, None, None]:
  19. try:
  20. prev_state = get_single_level_autograd_function_allowed()
  21. set_single_level_autograd_function_allowed(True)
  22. yield
  23. finally:
  24. set_single_level_autograd_function_allowed(prev_state)
  25. def unwrap_dead_wrappers(args: tuple[Any, ...]) -> tuple[Any, ...]:
  26. # NB: doesn't use tree_map_only for performance reasons
  27. result = tuple(
  28. unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
  29. )
  30. return result
  31. argnums_t = Union[int, tuple[int, ...]]