| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- """
- This module provides common utilities and base classes for TorchDynamo backends.
- Key components:
- - AotAutograd: Base class for implementing AOT (Ahead-of-Time) autograd backends
- - Backend utilities for handling:
- - Fake tensor conversion
- - Device/dtype detection from inputs
- - Memory efficient fusion
- - Graph flattening
- - Common compiler configurations
- The utilities here are used by various backend implementations to handle
- common operations and provide consistent behavior across different backends.
- AOT autograd functionality is particularly important as it enables ahead-of-time
- optimization of both forward and backward passes.
- """
- import contextlib
- import functools
- import logging
- from collections.abc import Callable, Iterable
- from typing import Any
- from typing_extensions import ParamSpec, TypeVar
- from unittest.mock import patch
- import torch
- from torch._dynamo import disable
- from torch._dynamo.exc import TensorifyScalarRestartAnalysis
- from torch._dynamo.utils import counters, defake, flatten_graph_inputs
- from torch._functorch.aot_autograd import (
- aot_module_simplified,
- SerializableAOTDispatchCompiler,
- )
- from torch.utils._python_dispatch import _disable_current_modes
- log = logging.getLogger(__name__)
- P = ParamSpec("P")
- R = TypeVar("R")
- class AotAutograd:
- def __init__(self, **kwargs: Any) -> None:
- self.__name__ = "compiler_fn"
- self.kwargs = kwargs
- def __call__(
- self, gm: torch.fx.GraphModule, example_inputs: Iterable[Any], **kwargs: Any
- ) -> Callable[..., Any]:
- if kwargs:
- log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs)
- if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
- return flatten_graph_inputs(
- gm,
- example_inputs,
- self,
- )
- # Hack to get around circular import problems with aot_eager_decomp_partition
- if callable(self.kwargs.get("decompositions")):
- self.kwargs["decompositions"] = self.kwargs["decompositions"]()
- # NB: dont delete counter increment
- counters["aot_autograd"]["total"] += 1
- use_fallback = False
- if use_fallback:
- log.debug("Unable to use AOT Autograd because graph has mutation")
- counters["aot_autograd"]["not_ok"] += 1
- # pyrefly: ignore [bad-return]
- return gm
- def wrap_bw_compiler(bw_compiler_fn: Callable[P, R]) -> Callable[..., R]:
- def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R:
- # Note [Wrapping bw_compiler in disable]
- # The two disables here:
- # - stop TorchDynamo from trying to compile the bw_compiler function itself
- # - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces
- return disable(
- disable(
- bw_compiler_fn, reason="do not trace backward compiler function"
- )(*args, **kwargs), # type: ignore[misc]
- reason="do not trace generated backwards pass",
- )
- _wrapped_bw_compiler._is_wrapped_bw_compiler = ( # pyrefly: ignore [missing-attribute]
- True
- )
- return _wrapped_bw_compiler
- bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
- if isinstance(bw_compiler, SerializableAOTDispatchCompiler):
- bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn)
- elif getattr(bw_compiler, "_is_wrapped_bw_compiler", False):
- bw_compiler.compiler_fn = bw_compiler
- else:
- bw_compiler = wrap_bw_compiler(bw_compiler)
- self.kwargs["bw_compiler"] = bw_compiler
- self.kwargs["inference_compiler"] = (
- self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
- )
- from functorch.compile import nop
- from torch._inductor.debug import enable_aot_logging
- # debug asserts slow down compile time noticeably,
- # So only default them on when the aot_eager backend is used.
- if self.kwargs.get("fw_compiler", None) is nop:
- patch_config: contextlib.AbstractContextManager[Any] = patch(
- "functorch.compile.config.debug_assert", True
- )
- else:
- patch_config = contextlib.nullcontext()
- try:
- # NB: NOT cloned!
- with enable_aot_logging(), patch_config:
- cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
- counters["aot_autograd"]["ok"] += 1
- return disable(cg, reason="do not trace AOT-compiled graph")
- except TensorifyScalarRestartAnalysis:
- raise
- except Exception:
- counters["aot_autograd"]["not_ok"] += 1
- raise
- def aot_autograd(**kwargs: Any) -> AotAutograd:
- return AotAutograd(**kwargs)
- def mem_efficient_fusion_kwargs(use_decomps: bool) -> dict[str, Any]:
- from functorch.compile import (
- default_decompositions,
- min_cut_rematerialization_partition,
- ts_compile,
- )
- kwargs = {
- # these are taken from memory_efficient_fusion()
- "fw_compiler": ts_compile,
- "bw_compiler": ts_compile,
- "partition_fn": min_cut_rematerialization_partition,
- }
- if use_decomps:
- # pyrefly: ignore [bad-typed-dict-key, unsupported-operation]
- kwargs["decompositions"] = default_decompositions
- return kwargs
- def fake_tensor_unsupported(fn: Callable[[Any, list[Any], Any], R]) -> Any:
- """
- Decorator for backends that need real inputs. We swap out fake
- tensors for zero tensors.
- """
- @functools.wraps(fn)
- def wrapper(model: Any, inputs: Any, **kwargs: Any) -> Any:
- with _disable_current_modes():
- inputs = list(map(defake, inputs))
- return fn(model, inputs, **kwargs) # type: ignore[call-arg]
- return wrapper
- def device_from_inputs(example_inputs: Iterable[Any]) -> torch.device:
- for x in example_inputs:
- if hasattr(x, "device"):
- return x.device
- return torch.device("cpu") # Default fallback
- def dtype_from_inputs(example_inputs: Iterable[Any]) -> torch.dtype:
- for x in example_inputs:
- if hasattr(x, "dtype"):
- return x.dtype
- return torch.float32 # Default fallback
|