common.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """
  2. This module provides common utilities and base classes for TorchDynamo backends.
  3. Key components:
  4. - AotAutograd: Base class for implementing AOT (Ahead-of-Time) autograd backends
  5. - Backend utilities for handling:
  6. - Fake tensor conversion
  7. - Device/dtype detection from inputs
  8. - Memory efficient fusion
  9. - Graph flattening
  10. - Common compiler configurations
  11. The utilities here are used by various backend implementations to handle
  12. common operations and provide consistent behavior across different backends.
  13. AOT autograd functionality is particularly important as it enables ahead-of-time
  14. optimization of both forward and backward passes.
  15. """
  16. import contextlib
  17. import functools
  18. import logging
  19. from collections.abc import Callable, Iterable
  20. from typing import Any
  21. from typing_extensions import ParamSpec, TypeVar
  22. from unittest.mock import patch
  23. import torch
  24. from torch._dynamo import disable
  25. from torch._dynamo.exc import TensorifyScalarRestartAnalysis
  26. from torch._dynamo.utils import counters, defake, flatten_graph_inputs
  27. from torch._functorch.aot_autograd import (
  28. aot_module_simplified,
  29. SerializableAOTDispatchCompiler,
  30. )
  31. from torch.utils._python_dispatch import _disable_current_modes
  32. log = logging.getLogger(__name__)
  33. P = ParamSpec("P")
  34. R = TypeVar("R")
  35. class AotAutograd:
  36. def __init__(self, **kwargs: Any) -> None:
  37. self.__name__ = "compiler_fn"
  38. self.kwargs = kwargs
  39. def __call__(
  40. self, gm: torch.fx.GraphModule, example_inputs: Iterable[Any], **kwargs: Any
  41. ) -> Callable[..., Any]:
  42. if kwargs:
  43. log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs)
  44. if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
  45. return flatten_graph_inputs(
  46. gm,
  47. example_inputs,
  48. self,
  49. )
  50. # Hack to get around circular import problems with aot_eager_decomp_partition
  51. if callable(self.kwargs.get("decompositions")):
  52. self.kwargs["decompositions"] = self.kwargs["decompositions"]()
  53. # NB: dont delete counter increment
  54. counters["aot_autograd"]["total"] += 1
  55. use_fallback = False
  56. if use_fallback:
  57. log.debug("Unable to use AOT Autograd because graph has mutation")
  58. counters["aot_autograd"]["not_ok"] += 1
  59. # pyrefly: ignore [bad-return]
  60. return gm
  61. def wrap_bw_compiler(bw_compiler_fn: Callable[P, R]) -> Callable[..., R]:
  62. def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R:
  63. # Note [Wrapping bw_compiler in disable]
  64. # The two disables here:
  65. # - stop TorchDynamo from trying to compile the bw_compiler function itself
  66. # - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces
  67. return disable(
  68. disable(
  69. bw_compiler_fn, reason="do not trace backward compiler function"
  70. )(*args, **kwargs), # type: ignore[misc]
  71. reason="do not trace generated backwards pass",
  72. )
  73. _wrapped_bw_compiler._is_wrapped_bw_compiler = ( # pyrefly: ignore [missing-attribute]
  74. True
  75. )
  76. return _wrapped_bw_compiler
  77. bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
  78. if isinstance(bw_compiler, SerializableAOTDispatchCompiler):
  79. bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn)
  80. elif getattr(bw_compiler, "_is_wrapped_bw_compiler", False):
  81. bw_compiler.compiler_fn = bw_compiler
  82. else:
  83. bw_compiler = wrap_bw_compiler(bw_compiler)
  84. self.kwargs["bw_compiler"] = bw_compiler
  85. self.kwargs["inference_compiler"] = (
  86. self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
  87. )
  88. from functorch.compile import nop
  89. from torch._inductor.debug import enable_aot_logging
  90. # debug asserts slow down compile time noticeably,
  91. # So only default them on when the aot_eager backend is used.
  92. if self.kwargs.get("fw_compiler", None) is nop:
  93. patch_config: contextlib.AbstractContextManager[Any] = patch(
  94. "functorch.compile.config.debug_assert", True
  95. )
  96. else:
  97. patch_config = contextlib.nullcontext()
  98. try:
  99. # NB: NOT cloned!
  100. with enable_aot_logging(), patch_config:
  101. cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  102. counters["aot_autograd"]["ok"] += 1
  103. return disable(cg, reason="do not trace AOT-compiled graph")
  104. except TensorifyScalarRestartAnalysis:
  105. raise
  106. except Exception:
  107. counters["aot_autograd"]["not_ok"] += 1
  108. raise
  109. def aot_autograd(**kwargs: Any) -> AotAutograd:
  110. return AotAutograd(**kwargs)
  111. def mem_efficient_fusion_kwargs(use_decomps: bool) -> dict[str, Any]:
  112. from functorch.compile import (
  113. default_decompositions,
  114. min_cut_rematerialization_partition,
  115. ts_compile,
  116. )
  117. kwargs = {
  118. # these are taken from memory_efficient_fusion()
  119. "fw_compiler": ts_compile,
  120. "bw_compiler": ts_compile,
  121. "partition_fn": min_cut_rematerialization_partition,
  122. }
  123. if use_decomps:
  124. # pyrefly: ignore [bad-typed-dict-key, unsupported-operation]
  125. kwargs["decompositions"] = default_decompositions
  126. return kwargs
  127. def fake_tensor_unsupported(fn: Callable[[Any, list[Any], Any], R]) -> Any:
  128. """
  129. Decorator for backends that need real inputs. We swap out fake
  130. tensors for zero tensors.
  131. """
  132. @functools.wraps(fn)
  133. def wrapper(model: Any, inputs: Any, **kwargs: Any) -> Any:
  134. with _disable_current_modes():
  135. inputs = list(map(defake, inputs))
  136. return fn(model, inputs, **kwargs) # type: ignore[call-arg]
  137. return wrapper
  138. def device_from_inputs(example_inputs: Iterable[Any]) -> torch.device:
  139. for x in example_inputs:
  140. if hasattr(x, "device"):
  141. return x.device
  142. return torch.device("cpu") # Default fallback
  143. def dtype_from_inputs(example_inputs: Iterable[Any]) -> torch.dtype:
  144. for x in example_inputs:
  145. if hasattr(x, "dtype"):
  146. return x.dtype
  147. return torch.float32 # Default fallback