_compile.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. """
  2. APIs related to torch.compile which lazily import torch._dynamo to avoid
  3. circular dependencies.
  4. """
  5. import functools
  6. from collections.abc import Callable
  7. from typing import overload, TypeVar
  8. from typing_extensions import ParamSpec
  9. _T = TypeVar("_T")
  10. _P = ParamSpec("_P")
  11. @overload
  12. def _disable_dynamo(
  13. fn: Callable[_P, _T], recursive: bool = True
  14. ) -> Callable[_P, _T]: ...
  15. @overload
  16. def _disable_dynamo(
  17. fn: None = None, recursive: bool = True
  18. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ...
  19. def _disable_dynamo(
  20. fn: Callable[_P, _T] | None = None, recursive: bool = True
  21. ) -> Callable[_P, _T] | Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  22. """
  23. This API should be only used inside torch, external users should still use
  24. torch._dynamo.disable. The main goal of this API is to avoid circular
  25. imports issues that is common while using _dynamo.disable inside torch
  26. itself.
  27. This API avoids it by lazily importing torch._dynamo from the import time to
  28. the invocation of the decorated function.
  29. """
  30. if fn is not None:
  31. @functools.wraps(fn)
  32. def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  33. # cache this on the first invocation to avoid adding too much overhead.
  34. disable_fn = getattr(fn, "__dynamo_disable", None)
  35. if disable_fn is None:
  36. import torch._dynamo
  37. # We can safely turn off functools.wraps here because the inner
  38. # already wraps fn in the outer scope.
  39. disable_fn = torch._dynamo.disable(fn, recursive, wrapping=False)
  40. fn.__dynamo_disable = disable_fn # type: ignore[attr-defined]
  41. return disable_fn(*args, **kwargs)
  42. return inner
  43. else:
  44. # decorator usage like @_disable_dynamo(recursive=False). The resulting
  45. # object expects the original decorated function as the arg.
  46. return functools.partial(_disable_dynamo, recursive=recursive)