registry.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """
  2. This module implements TorchDynamo's backend registry system for managing compiler backends.
  3. The registry provides a centralized way to register, discover and manage different compiler
  4. backends that can be used with torch.compile(). It handles:
  5. - Backend registration and discovery through decorators and entry points
  6. - Lazy loading of backend implementations
  7. - Lookup and validation of backend names
  8. - Categorization of backends using tags (debug, experimental, etc.)
  9. Key components:
  10. - CompilerFn: Type for backend compiler functions that transform FX graphs
  11. - _BACKENDS: Registry mapping backend names to entry points
  12. - _COMPILER_FNS: Registry mapping backend names to loaded compiler functions
  13. Example usage:
  14. @register_backend
  15. def my_compiler(fx_graph, example_inputs):
  16. # Transform FX graph into optimized implementation
  17. return compiled_fn
  18. # Use registered backend
  19. torch.compile(model, backend="my_compiler")
  20. The registry also supports discovering backends through setuptools entry points
  21. in the "torch_dynamo_backends" group. Example:
  22. ```
  23. setup.py
  24. ---
  25. from setuptools import setup
  26. setup(
  27. name='my_torch_backend',
  28. version='0.1',
  29. packages=['my_torch_backend'],
  30. entry_points={
  31. 'torch_dynamo_backends': [
  32. # name = path to entry point of backend implementation
  33. 'my_compiler = my_torch_backend.compiler:my_compiler_function',
  34. ],
  35. },
  36. )
  37. ```
  38. ```
  39. my_torch_backend/compiler.py
  40. ---
  41. def my_compiler_function(fx_graph, example_inputs):
  42. # Transform FX graph into optimized implementation
  43. return compiled_fn
  44. ```
  45. Using `my_compiler` backend:
  46. ```
  47. import torch
  48. model = ... # Your PyTorch model
  49. optimized_model = torch.compile(model, backend="my_compiler")
  50. ```
  51. """
  52. import functools
  53. import logging
  54. from collections.abc import Callable, Sequence
  55. from importlib.metadata import EntryPoint
  56. from typing import Any, Optional, Protocol, Union
  57. import torch
  58. from torch import fx
  59. log = logging.getLogger(__name__)
  60. class CompiledFn(Protocol):
  61. def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ...
  62. CompilerFn = Callable[[fx.GraphModule, list[torch.Tensor]], CompiledFn]
  63. _BACKENDS: dict[str, Optional[EntryPoint]] = {}
  64. _COMPILER_FNS: dict[str, CompilerFn] = {}
  65. def register_backend(
  66. compiler_fn: Optional[CompilerFn] = None,
  67. name: Optional[str] = None,
  68. tags: Sequence[str] = (),
  69. ) -> Callable[..., Any]:
  70. """
  71. Decorator to add a given compiler to the registry to allow calling
  72. `torch.compile` with string shorthand. Note: for projects not
  73. imported by default, it might be easier to pass a function directly
  74. as a backend and not use a string.
  75. Args:
  76. compiler_fn: Callable taking a FX graph and fake tensor inputs
  77. name: Optional name, defaults to `compiler_fn.__name__`
  78. tags: Optional set of string tags to categorize backend with
  79. """
  80. if compiler_fn is None:
  81. # @register_backend(name="") syntax
  82. return functools.partial(register_backend, name=name, tags=tags) # type: ignore[return-value]
  83. assert callable(compiler_fn)
  84. name = name or compiler_fn.__name__
  85. assert name not in _COMPILER_FNS, f"duplicate name: {name}"
  86. if compiler_fn not in _BACKENDS:
  87. _BACKENDS[name] = None
  88. _COMPILER_FNS[name] = compiler_fn
  89. compiler_fn._tags = tuple(tags) # type: ignore[attr-defined]
  90. return compiler_fn
  91. register_debug_backend = functools.partial(register_backend, tags=("debug",))
  92. register_experimental_backend = functools.partial(
  93. register_backend, tags=("experimental",)
  94. )
  95. def lookup_backend(compiler_fn: Union[str, CompilerFn]) -> CompilerFn:
  96. """Expand backend strings to functions"""
  97. if isinstance(compiler_fn, str):
  98. if compiler_fn not in _BACKENDS:
  99. _lazy_import()
  100. if compiler_fn not in _BACKENDS:
  101. from ..exc import InvalidBackend
  102. raise InvalidBackend(name=compiler_fn)
  103. if compiler_fn not in _COMPILER_FNS:
  104. entry_point = _BACKENDS[compiler_fn]
  105. if entry_point is not None:
  106. register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
  107. compiler_fn = _COMPILER_FNS[compiler_fn]
  108. return compiler_fn
  109. # NOTE: can't type this due to public api mismatch; follow up with dev team
  110. def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: # type: ignore[no-untyped-def]
  111. """
  112. Return valid strings that can be passed to:
  113. torch.compile(..., backend="name")
  114. """
  115. _lazy_import()
  116. exclude_tags_set = set(exclude_tags or ())
  117. backends = [
  118. name
  119. for name in _BACKENDS
  120. if name not in _COMPILER_FNS
  121. or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined]
  122. ]
  123. return sorted(backends)
  124. @functools.cache
  125. def _lazy_import() -> None:
  126. from .. import backends
  127. from ..utils import import_submodule
  128. import_submodule(backends)
  129. from ..repro.after_dynamo import dynamo_minifier_backend
  130. assert dynamo_minifier_backend is not None
  131. _discover_entrypoint_backends()
  132. @functools.cache
  133. def _discover_entrypoint_backends() -> None:
  134. # importing here so it will pick up the mocked version in test_backends.py
  135. from importlib.metadata import entry_points
  136. group_name = "torch_dynamo_backends"
  137. eps = entry_points(group=group_name)
  138. # pyrefly: ignore [bad-index]
  139. eps_dict = {name: eps[name] for name in eps.names}
  140. for backend_name in eps_dict:
  141. _BACKENDS[backend_name] = eps_dict[backend_name]
  142. def _is_registered_backend(compiler_fn: CompilerFn) -> bool:
  143. """
  144. Check if the given compiler function is a registered backend.
  145. Custom backends (user-provided callables not in the registry) return False.
  146. """
  147. # Ensure backends are loaded
  148. _lazy_import()
  149. # Check if it's directly a registered backend function
  150. if compiler_fn in _COMPILER_FNS.values():
  151. return True
  152. # Check for _TorchCompileInductorWrapper or _TorchCompileWrapper
  153. # These have a compiler_name attribute that identifies the backend
  154. if hasattr(compiler_fn, "compiler_name"):
  155. compiler_name = compiler_fn.compiler_name
  156. if compiler_name in _BACKENDS or compiler_name in _COMPILER_FNS:
  157. return True
  158. # Check if the wrapper has a compiler_fn attribute (e.g., _TorchCompileWrapper)
  159. if hasattr(compiler_fn, "compiler_fn"):
  160. return compiler_fn.compiler_fn in _COMPILER_FNS.values()
  161. return False