_async.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # mypy: allow-untyped-defs
  2. """Async API.
  3. This module contains the API for parallelism in TorchScript, notably:
  4. * torch.jit.fork
  5. * torch.jit.wait
  6. This is not intended to be imported directly; please use the exposed
  7. functionalities in `torch.jit`.
  8. """
  9. import warnings
  10. import torch
  11. from torch._jit_internal import Future
  12. from torch.jit._builtins import _register_builtin
  13. from torch.utils import set_module
  14. set_module(Future, "torch.jit")
  15. def fork(func, *args, **kwargs):
  16. r"""
  17. Create an asynchronous task executing `func` and a reference to the value of the result of this execution.
  18. .. deprecated:: 2.5
  19. TorchScript is deprecated, please use ``torch.compile`` instead.
  20. `fork` will return immediately, so the return value of `func` may not have been computed yet. To force completion
  21. of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked
  22. with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily
  23. nested, and may be invoked with positional and keyword arguments.
  24. Asynchronous execution will only occur when run in TorchScript. If run in pure python,
  25. `fork` will not execute in parallel. `fork` will also not execute in parallel when invoked
  26. while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph.
  27. .. warning::
  28. `fork` tasks will execute non-deterministically. We recommend only spawning
  29. parallel fork tasks for pure functions that do not modify their inputs,
  30. module attributes, or global state.
  31. Args:
  32. func (callable or torch.nn.Module): A Python function or `torch.nn.Module`
  33. that will be invoked. If executed in TorchScript, it will execute asynchronously,
  34. otherwise it will not. Traced invocations of fork will be captured in the IR.
  35. ``*args``, ``**kwargs``: arguments to invoke `func` with.
  36. Returns:
  37. `torch.jit.Future[T]`: a reference to the execution of `func`. The value `T`
  38. can only be accessed by forcing completion of `func` through `torch.jit.wait`.
  39. Example (fork a free function):
  40. .. code-block:: python
  41. import torch
  42. from torch import Tensor
  43. def foo(a: Tensor, b: int) -> Tensor:
  44. return a + b
  45. def bar(a):
  46. fut: torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
  47. return torch.jit.wait(fut)
  48. script_bar = torch.jit.script(bar)
  49. input = torch.tensor(2)
  50. # only the scripted version executes asynchronously
  51. assert script_bar(input) == bar(input)
  52. # trace is not run asynchronously, but fork is captured in IR
  53. graph = torch.jit.trace(bar, (input,)).graph
  54. assert "fork" in str(graph)
  55. Example (fork a module method):
  56. .. code-block:: python
  57. import torch
  58. from torch import Tensor
  59. class AddMod(torch.nn.Module):
  60. def forward(self, a: Tensor, b: int):
  61. return a + b
  62. class Mod(torch.nn.Module):
  63. def __init__(self) -> None:
  64. super(self).__init__()
  65. self.mod = AddMod()
  66. def forward(self, input):
  67. fut = torch.jit.fork(self.mod, a, b=2)
  68. return torch.jit.wait(fut)
  69. input = torch.tensor(2)
  70. mod = Mod()
  71. assert mod(input) == torch.jit.script(mod).forward(input)
  72. """
  73. warnings.warn(
  74. "`torch.jit.fork` is deprecated. Please use `torch.compile` instead.",
  75. DeprecationWarning,
  76. )
  77. return torch._C.fork(func, *args, **kwargs)
  78. def wait(future):
  79. r"""
  80. Force completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task.
  81. .. deprecated:: 2.5
  82. TorchScript is deprecated, please use ``torch.compile`` instead.
  83. See :func:`~fork` for docs and examples.
  84. Args:
  85. future (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork`
  86. Returns:
  87. `T`: the return value of the completed task
  88. """
  89. warnings.warn(
  90. "`torch.jit.wait` is deprecated. Please use `torch.compile` instead.",
  91. DeprecationWarning,
  92. )
  93. return torch._C.wait(future)
  94. _register_builtin(wait, "aten::wait")