compiler.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from abc import ABCMeta, abstractmethod
  2. from dataclasses import dataclass
  3. from enum import Enum
  4. from typing import Dict, Union
  5. from types import ModuleType
  6. @dataclass(frozen=True)
  7. class GPUTarget(object):
  8. # Target backend, e.g., cuda, hip
  9. backend: str
  10. # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip)
  11. arch: Union[int, str]
  12. warp_size: int
  13. class Language(Enum):
  14. """The input language being compiled by the backend."""
  15. TRITON = 0
  16. GLUON = 1
  17. class BaseBackend(metaclass=ABCMeta):
  18. supports_native_tensor_specialization = True
  19. def __init__(self, target: GPUTarget) -> None:
  20. self.target = target
  21. assert self.supports_target(target)
  22. @staticmethod
  23. @abstractmethod
  24. def supports_target(target: GPUTarget):
  25. raise NotImplementedError
  26. @abstractmethod
  27. def hash(self) -> str:
  28. """Returns a unique identifier for this backend"""
  29. raise NotImplementedError
  30. @abstractmethod
  31. def parse_options(self, options: dict) -> object:
  32. """
  33. Converts an `options` dictionary into an arbitrary object and returns it.
  34. This function may contain target-specific heuristics and check the legality of the provided options
  35. """
  36. raise NotImplementedError
  37. @abstractmethod
  38. def add_stages(self, stages: dict, options: object) -> None:
  39. """
  40. Populates `stages` dictionary with entries of the form:
  41. ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes]
  42. The value of each entry may populate a `metadata` dictionary.
  43. Stages will be run sequentially (in inseriton order) and can communicate using `metadata`.
  44. All stages are expected to return a `str` object, except for the last stage which returns
  45. a `bytes` object for execution by the launcher.
  46. """
  47. raise NotImplementedError
  48. @abstractmethod
  49. def load_dialects(self, context):
  50. """
  51. Load additional MLIR dialects into the provided `context`
  52. """
  53. raise NotImplementedError
  54. @abstractmethod
  55. def get_module_map(self) -> Dict[str, ModuleType]:
  56. """
  57. Return a map of interface modules to their device-specific implementations
  58. """
  59. raise NotImplementedError
  60. @staticmethod
  61. def parse_attr(desc):
  62. assert isinstance(desc, str)
  63. ret = []
  64. if "D" in desc:
  65. ret += [["tt.divisibility", 16]]
  66. return ret
  67. @staticmethod
  68. def get_int_specialization(arg, **kwargs):
  69. if arg % 16 == 0 and kwargs.get("align", False):
  70. return "D"
  71. return ""
  72. @staticmethod
  73. def get_tensor_specialization(arg, **kwargs):
  74. if arg.data_ptr() % 16 == 0 and kwargs.get("align", False):
  75. return "D"
  76. return ""