driver.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from abc import ABCMeta, abstractmethod
  2. from typing import Callable, List, Protocol, Sequence
  3. class Benchmarker(Protocol):
  4. def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]:
  5. pass
  6. class DriverBase(metaclass=ABCMeta):
  7. @classmethod
  8. @abstractmethod
  9. def is_active(self):
  10. pass
  11. @abstractmethod
  12. def map_python_to_cpp_type(self, ty: str) -> str:
  13. """
  14. Converts a Triton type string to its corresponding C++ type string for this backend.
  15. Args:
  16. ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'.
  17. Returns:
  18. str: The C++ type string.
  19. """
  20. pass
  21. @abstractmethod
  22. def get_current_target(self):
  23. pass
  24. @abstractmethod
  25. def get_active_torch_device(self):
  26. pass
  27. @abstractmethod
  28. def get_benchmarker(self) -> Benchmarker:
  29. """
  30. Return the benchmarking function that this backend should use by default.
  31. """
  32. raise NotImplementedError
  33. def __init__(self) -> None:
  34. pass
  35. class GPUDriver(DriverBase):
  36. def __init__(self):
  37. # TODO: support other frameworks than torch
  38. import torch
  39. self.get_device_capability = torch.cuda.get_device_capability
  40. try:
  41. from torch._C import _cuda_getCurrentRawStream
  42. self.get_current_stream = _cuda_getCurrentRawStream
  43. except ImportError:
  44. self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
  45. self.get_current_device = torch.cuda.current_device
  46. self.set_current_device = torch.cuda.set_device
  47. # TODO: remove once TMA is cleaned up
  48. def assemble_tensormap_to_arg(self, tensormaps_info, args):
  49. return args