| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- import importlib
- import os
- import inspect
- import sys
- from dataclasses import dataclass
- from typing import Type, TypeVar, Union
- from types import ModuleType
- from .driver import DriverBase
- from .compiler import BaseBackend
- if sys.version_info >= (3, 10):
- from importlib.metadata import entry_points
- else:
- from importlib_metadata import entry_points
- T = TypeVar("T", bound=Union[BaseBackend, DriverBase])
- def _find_concrete_subclasses(module: ModuleType, base_class: Type[T]) -> Type[T]:
- ret: list[Type[T]] = []
- for attr_name in dir(module):
- attr = getattr(module, attr_name)
- if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr):
- ret.append(attr)
- if len(ret) == 0:
- raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}")
- if len(ret) > 1:
- raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}")
- return ret[0]
- @dataclass(frozen=True)
- class Backend:
- compiler: Type[BaseBackend]
- driver: Type[DriverBase]
- def _discover_backends() -> dict[str, Backend]:
- backends = dict()
- # Fast path: optionally skip entry point discovery (which can be slow) and
- # discover only in-tree backends under the `triton.backends` namespace.
- skip_entrypoints_env = os.environ.get("TRITON_BACKENDS_IN_TREE", "")
- if skip_entrypoints_env == "1":
- root = os.path.dirname(__file__)
- for name in os.listdir(root):
- if not os.path.isdir(os.path.join(root, name)):
- continue
- if name.startswith('__'):
- continue
- compiler = importlib.import_module(f"triton.backends.{name}.compiler")
- driver = importlib.import_module(f"triton.backends.{name}.driver")
- backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend),
- _find_concrete_subclasses(driver, DriverBase))
- return backends
- # Default path: discover via entry points for out-of-tree/downstream plugins.
- for ep in entry_points().select(group="triton.backends"):
- compiler = importlib.import_module(f"{ep.value}.compiler")
- driver = importlib.import_module(f"{ep.value}.driver")
- backends[ep.name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), # type: ignore
- _find_concrete_subclasses(driver, DriverBase)) # type: ignore
- return backends
- backends: dict[str, Backend] = _discover_backends()
|