executor.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from collections.abc import Callable
  2. from typing import Any, TypeVar
  3. from typing_extensions import ParamSpec, TypeVarTuple, Unpack
  4. from torch._prims.context import TorchRefsMode
  5. from torch.fx import GraphModule
  6. from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
  7. T = TypeVar("T")
  8. P = ParamSpec("P")
  9. Ts = TypeVarTuple("Ts")
  10. def execute(
  11. gm: GraphModule,
  12. *args: Unpack[Ts],
  13. executor: str = "aten",
  14. executor_parameters: dict | None = None,
  15. ) -> Any:
  16. """
  17. Prototype ATen executor.
  18. Just executes the context's graph.
  19. """
  20. if executor == "aten":
  21. return gm.forward(*args)
  22. msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten."
  23. raise ValueError(msg)
  24. def make_traced(fn: Callable[P, T]) -> Callable[P, T]:
  25. """
  26. Returns a function that, when called, will
  27. trace its torch operations to prims and then
  28. execute those prims on the requested trace executor
  29. (possibly lowering them to that trace executor first).
  30. Only supports the torch operations defined in _torch_to_reference_map
  31. in context.py and operations with positional args. All args must
  32. be tensors.
  33. In the near future all these restrictions will be lifted.
  34. Example usage:
  35. def foo(a, b):
  36. return torch.add(a, b)
  37. traced_foo = make_traced(foo)
  38. a = torch.randn((1, 2, 3, 4, 5), device='cuda')
  39. b = torch.randn((1, 2, 3, 4, 5), device='cuda')
  40. result = traced_foo(a, b, executor='aten')
  41. """
  42. def _traced(*args: P.args, **kwargs: P.kwargs) -> T:
  43. executor = str(kwargs.pop("executor", "aten"))
  44. # TODO: caching
  45. wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)
  46. with TorchRefsMode():
  47. gm = make_fx(wrapped)(all_args)
  48. return execute(gm, all_args, executor=executor)
  49. return _traced # type: ignore[return-value]