| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- from abc import ABCMeta, abstractmethod
- from dataclasses import dataclass
- from enum import Enum
- from typing import Dict, Union
- from types import ModuleType
- @dataclass(frozen=True)
- class GPUTarget(object):
- # Target backend, e.g., cuda, hip
- backend: str
- # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip)
- arch: Union[int, str]
- warp_size: int
- class Language(Enum):
- """The input language being compiled by the backend."""
- TRITON = 0
- GLUON = 1
- class BaseBackend(metaclass=ABCMeta):
- supports_native_tensor_specialization = True
- def __init__(self, target: GPUTarget) -> None:
- self.target = target
- assert self.supports_target(target)
- @staticmethod
- @abstractmethod
- def supports_target(target: GPUTarget):
- raise NotImplementedError
- @abstractmethod
- def hash(self) -> str:
- """Returns a unique identifier for this backend"""
- raise NotImplementedError
- @abstractmethod
- def parse_options(self, options: dict) -> object:
- """
- Converts an `options` dictionary into an arbitrary object and returns it.
- This function may contain target-specific heuristics and check the legality of the provided options
- """
- raise NotImplementedError
- @abstractmethod
- def add_stages(self, stages: dict, options: object) -> None:
- """
- Populates `stages` dictionary with entries of the form:
- ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes]
- The value of each entry may populate a `metadata` dictionary.
- Stages will be run sequentially (in inseriton order) and can communicate using `metadata`.
- All stages are expected to return a `str` object, except for the last stage which returns
- a `bytes` object for execution by the launcher.
- """
- raise NotImplementedError
- @abstractmethod
- def load_dialects(self, context):
- """
- Load additional MLIR dialects into the provided `context`
- """
- raise NotImplementedError
- @abstractmethod
- def get_module_map(self) -> Dict[str, ModuleType]:
- """
- Return a map of interface modules to their device-specific implementations
- """
- raise NotImplementedError
- @staticmethod
- def parse_attr(desc):
- assert isinstance(desc, str)
- ret = []
- if "D" in desc:
- ret += [["tt.divisibility", 16]]
- return ret
- @staticmethod
- def get_int_specialization(arg, **kwargs):
- if arg % 16 == 0 and kwargs.get("align", False):
- return "D"
- return ""
- @staticmethod
- def get_tensor_specialization(arg, **kwargs):
- if arg.data_ptr() % 16 == 0 and kwargs.get("align", False):
- return "D"
- return ""
|