green_contexts.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import torch
  2. _GreenContext = object
  3. SUPPORTED = False
  4. if hasattr(torch._C, "_CUDAGreenContext"):
  5. _GreenContext = torch._C._CUDAGreenContext # type: ignore[misc]
  6. SUPPORTED = True
  7. # Python shim helps Sphinx process docstrings more reliably.
  8. # pyrefly: ignore [invalid-inheritance]
  9. class GreenContext(_GreenContext):
  10. r"""Wrapper around a CUDA green context.
  11. .. warning::
  12. This API is in beta and may change in future releases.
  13. """
  14. @staticmethod
  15. def create(num_sms: int, device_id: int = 0) -> _GreenContext:
  16. r"""Create a CUDA green context.
  17. Arguments:
  18. num_sms (int): The number of SMs to use in the green context.
  19. device_id (int, optional): The device index of green context.
  20. """
  21. if not SUPPORTED:
  22. raise RuntimeError("PyTorch was not built with Green Context support!")
  23. return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined]
  24. # Note that these functions are bypassed by we define them here
  25. # for Sphinx documentation purposes
  26. def set_context(self) -> None: # pylint: disable=useless-parent-delegation
  27. r"""Make the green context the current context."""
  28. return super().set_context() # type: ignore[misc]
  29. def pop_context(self) -> None: # pylint: disable=useless-parent-delegation
  30. r"""Assuming the green context is the current context, pop it from the
  31. context stack and restore the previous context.
  32. """
  33. return super().pop_context() # type: ignore[misc]
  34. def Stream(self) -> torch.Stream:
  35. r"""Return the CUDA Stream used by the green context."""
  36. return super().Stream()