| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- from abc import ABC, abstractmethod
- from typing import Any, Dict, Generic, TypeVar
- from ray.experimental.channel.accelerator_context import AcceleratorContext
- from ray.util.annotations import DeveloperAPI
- T = TypeVar("T")
- @DeveloperAPI
- class DAGOperationFuture(ABC, Generic[T]):
- """
- A future representing the result of a DAG operation.
- This is an abstraction that is internal to each actor,
- and is not exposed to the DAG caller.
- """
- @abstractmethod
- def wait(self):
- """
- Wait for the future and return the result of the operation.
- """
- raise NotImplementedError
- @DeveloperAPI
- class ResolvedFuture(DAGOperationFuture):
- """
- A future that is already resolved. Calling `wait()` on this will
- immediately return the result without blocking.
- """
- def __init__(self, result):
- """
- Initialize a resolved future.
- Args:
- result: The result of the future.
- """
- self._result = result
- def wait(self):
- """
- Wait and immediately return the result. This operation will not block.
- """
- return self._result
- @DeveloperAPI
- class GPUFuture(DAGOperationFuture[Any]):
- """
- A future for a GPU event on a CUDA stream.
- This future wraps a buffer, and records an event on the given stream
- when it is created. When the future is waited on, it makes the current
- CUDA stream wait on the event, then returns the buffer.
- The buffer must be a GPU tensor produced by an earlier operation launched
- on the given stream, or it could be CPU data. Then the future guarantees
- that when the wait() returns, the buffer is ready on the current stream.
- The `wait()` does not block CPU.
- """
- # Caching GPU futures ensures CUDA events associated with futures are properly
- # destroyed instead of relying on garbage collection. The CUDA event contained
- # in a GPU future is destroyed right before removing the future from the cache.
- # The dictionary key is the future ID, which is the task idx of the dag operation
- # that produced the future. When a future is created, it is immediately added to
- # the cache. When a future has been waited on, it is removed from the cache.
- # When adding a future, if its ID is already a key in the cache, the old future
- # is removed. This can happen when an exception is thrown in a previous execution
- # of the dag, in which case the old future is never waited on.
- # Upon dag teardown, all pending futures produced by the dag are removed.
- gpu_futures: Dict[int, "GPUFuture"] = {}
- @staticmethod
- def add_gpu_future(fut_id: int, fut: "GPUFuture") -> None:
- """
- Cache the GPU future.
- Args:
- fut_id: GPU future ID.
- fut: GPU future to be cached.
- """
- if fut_id in GPUFuture.gpu_futures:
- # The old future was not waited on because of an execution exception.
- GPUFuture.gpu_futures.pop(fut_id).destroy_event()
- GPUFuture.gpu_futures[fut_id] = fut
- @staticmethod
- def remove_gpu_future(fut_id: int) -> None:
- """
- Remove the cached GPU future and destroy its CUDA event.
- Args:
- fut_id: GPU future ID.
- """
- if fut_id in GPUFuture.gpu_futures:
- GPUFuture.gpu_futures.pop(fut_id).destroy_event()
- def __init__(self, buf: Any, fut_id: int, stream: Any = None):
- """
- Initialize a GPU future on the given stream.
- Args:
- buf: The buffer to return when the future is resolved.
- fut_id: The future ID to cache the future.
- stream: The torch stream to record the event on, this event is waited
- on when the future is resolved. If None, the current stream is used.
- """
- if stream is None:
- stream = AcceleratorContext.get().current_stream()
- self._buf = buf
- self._event = AcceleratorContext.get().create_event()
- self._event.record(stream)
- self._fut_id = fut_id
- self._waited: bool = False
- # Cache the GPU future such that its CUDA event is properly destroyed.
- GPUFuture.add_gpu_future(fut_id, self)
- def wait(self) -> Any:
- """
- Wait for the future on the current CUDA stream and return the result from
- the GPU operation. This operation does not block CPU.
- """
- current_stream = AcceleratorContext.get().current_stream()
- if not self._waited:
- self._waited = True
- current_stream.wait_event(self._event)
- # Destroy the CUDA event after it is waited on.
- GPUFuture.remove_gpu_future(self._fut_id)
- return self._buf
- def destroy_event(self) -> None:
- """
- Destroy the CUDA event associated with this future.
- """
- if self._event is None:
- return
- self._event = None
|