__init__.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import importlib
  2. import os
  3. import inspect
  4. import sys
  5. from dataclasses import dataclass
  6. from typing import Type, TypeVar, Union
  7. from types import ModuleType
  8. from .driver import DriverBase
  9. from .compiler import BaseBackend
  10. if sys.version_info >= (3, 10):
  11. from importlib.metadata import entry_points
  12. else:
  13. from importlib_metadata import entry_points
  14. T = TypeVar("T", bound=Union[BaseBackend, DriverBase])
  15. def _find_concrete_subclasses(module: ModuleType, base_class: Type[T]) -> Type[T]:
  16. ret: list[Type[T]] = []
  17. for attr_name in dir(module):
  18. attr = getattr(module, attr_name)
  19. if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr):
  20. ret.append(attr)
  21. if len(ret) == 0:
  22. raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}")
  23. if len(ret) > 1:
  24. raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}")
  25. return ret[0]
  26. @dataclass(frozen=True)
  27. class Backend:
  28. compiler: Type[BaseBackend]
  29. driver: Type[DriverBase]
  30. def _discover_backends() -> dict[str, Backend]:
  31. backends = dict()
  32. # Fast path: optionally skip entry point discovery (which can be slow) and
  33. # discover only in-tree backends under the `triton.backends` namespace.
  34. skip_entrypoints_env = os.environ.get("TRITON_BACKENDS_IN_TREE", "")
  35. if skip_entrypoints_env == "1":
  36. root = os.path.dirname(__file__)
  37. for name in os.listdir(root):
  38. if not os.path.isdir(os.path.join(root, name)):
  39. continue
  40. if name.startswith('__'):
  41. continue
  42. compiler = importlib.import_module(f"triton.backends.{name}.compiler")
  43. driver = importlib.import_module(f"triton.backends.{name}.driver")
  44. backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend),
  45. _find_concrete_subclasses(driver, DriverBase))
  46. return backends
  47. # Default path: discover via entry points for out-of-tree/downstream plugins.
  48. for ep in entry_points().select(group="triton.backends"):
  49. compiler = importlib.import_module(f"{ep.value}.compiler")
  50. driver = importlib.import_module(f"{ep.value}.driver")
  51. backends[ep.name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), # type: ignore
  52. _find_concrete_subclasses(driver, DriverBase)) # type: ignore
  53. return backends
  54. backends: dict[str, Backend] = _discover_backends()