dag_operation_future.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Dict, Generic, TypeVar
  3. from ray.experimental.channel.accelerator_context import AcceleratorContext
  4. from ray.util.annotations import DeveloperAPI
  5. T = TypeVar("T")
  6. @DeveloperAPI
  7. class DAGOperationFuture(ABC, Generic[T]):
  8. """
  9. A future representing the result of a DAG operation.
  10. This is an abstraction that is internal to each actor,
  11. and is not exposed to the DAG caller.
  12. """
  13. @abstractmethod
  14. def wait(self):
  15. """
  16. Wait for the future and return the result of the operation.
  17. """
  18. raise NotImplementedError
  19. @DeveloperAPI
  20. class ResolvedFuture(DAGOperationFuture):
  21. """
  22. A future that is already resolved. Calling `wait()` on this will
  23. immediately return the result without blocking.
  24. """
  25. def __init__(self, result):
  26. """
  27. Initialize a resolved future.
  28. Args:
  29. result: The result of the future.
  30. """
  31. self._result = result
  32. def wait(self):
  33. """
  34. Wait and immediately return the result. This operation will not block.
  35. """
  36. return self._result
  37. @DeveloperAPI
  38. class GPUFuture(DAGOperationFuture[Any]):
  39. """
  40. A future for a GPU event on a CUDA stream.
  41. This future wraps a buffer, and records an event on the given stream
  42. when it is created. When the future is waited on, it makes the current
  43. CUDA stream wait on the event, then returns the buffer.
  44. The buffer must be a GPU tensor produced by an earlier operation launched
  45. on the given stream, or it could be CPU data. Then the future guarantees
  46. that when the wait() returns, the buffer is ready on the current stream.
  47. The `wait()` does not block CPU.
  48. """
  49. # Caching GPU futures ensures CUDA events associated with futures are properly
  50. # destroyed instead of relying on garbage collection. The CUDA event contained
  51. # in a GPU future is destroyed right before removing the future from the cache.
  52. # The dictionary key is the future ID, which is the task idx of the dag operation
  53. # that produced the future. When a future is created, it is immediately added to
  54. # the cache. When a future has been waited on, it is removed from the cache.
  55. # When adding a future, if its ID is already a key in the cache, the old future
  56. # is removed. This can happen when an exception is thrown in a previous execution
  57. # of the dag, in which case the old future is never waited on.
  58. # Upon dag teardown, all pending futures produced by the dag are removed.
  59. gpu_futures: Dict[int, "GPUFuture"] = {}
  60. @staticmethod
  61. def add_gpu_future(fut_id: int, fut: "GPUFuture") -> None:
  62. """
  63. Cache the GPU future.
  64. Args:
  65. fut_id: GPU future ID.
  66. fut: GPU future to be cached.
  67. """
  68. if fut_id in GPUFuture.gpu_futures:
  69. # The old future was not waited on because of an execution exception.
  70. GPUFuture.gpu_futures.pop(fut_id).destroy_event()
  71. GPUFuture.gpu_futures[fut_id] = fut
  72. @staticmethod
  73. def remove_gpu_future(fut_id: int) -> None:
  74. """
  75. Remove the cached GPU future and destroy its CUDA event.
  76. Args:
  77. fut_id: GPU future ID.
  78. """
  79. if fut_id in GPUFuture.gpu_futures:
  80. GPUFuture.gpu_futures.pop(fut_id).destroy_event()
  81. def __init__(self, buf: Any, fut_id: int, stream: Any = None):
  82. """
  83. Initialize a GPU future on the given stream.
  84. Args:
  85. buf: The buffer to return when the future is resolved.
  86. fut_id: The future ID to cache the future.
  87. stream: The torch stream to record the event on, this event is waited
  88. on when the future is resolved. If None, the current stream is used.
  89. """
  90. if stream is None:
  91. stream = AcceleratorContext.get().current_stream()
  92. self._buf = buf
  93. self._event = AcceleratorContext.get().create_event()
  94. self._event.record(stream)
  95. self._fut_id = fut_id
  96. self._waited: bool = False
  97. # Cache the GPU future such that its CUDA event is properly destroyed.
  98. GPUFuture.add_gpu_future(fut_id, self)
  99. def wait(self) -> Any:
  100. """
  101. Wait for the future on the current CUDA stream and return the result from
  102. the GPU operation. This operation does not block CPU.
  103. """
  104. current_stream = AcceleratorContext.get().current_stream()
  105. if not self._waited:
  106. self._waited = True
  107. current_stream.wait_event(self._event)
  108. # Destroy the CUDA event after it is waited on.
  109. GPUFuture.remove_gpu_future(self._fut_id)
  110. return self._buf
  111. def destroy_event(self) -> None:
  112. """
  113. Destroy the CUDA event associated with this future.
  114. """
  115. if self._event is None:
  116. return
  117. self._event = None