| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387 |
- # mypy: allow-untyped-defs
- import builtins
- import collections
- import contextlib
- import copy
- import functools
- import inspect
- import logging
- import math
- import os
- import warnings
- from collections.abc import Callable
- from itertools import chain
- from types import CodeType, FunctionType, ModuleType
- from typing import Any, get_args, NamedTuple, Optional, TypeAlias, Union
- import torch
- import torch.utils._pytree as pytree
- from torch._C import ScriptObject # type: ignore[attr-defined]
- from torch._library.fake_class_registry import FakeScriptObject
- from torch._library.opaque_object import is_opaque_reference_type, is_opaque_type
- from ._compatibility import compatibility
- from ._lazy_graph_module import _make_graph_module
- from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
- from .graph_module import GraphModule
- from .node import Argument, base_types, map_aggregate
- from .proxy import ParameterProxy, Proxy, Scope, ScopeContextManager, TracerBase
- log = logging.getLogger(__name__)
- HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
- # These need to run in global scope to handle nested calls correctly
- _orig_module_call: Callable = torch.nn.Module.__call__
- _orig_module_getattr: Callable = torch.nn.Module.__getattr__
- _proxyable_classes: dict[type, None] = {}
- _is_fx_tracing_flag = False
- _ConstantAttributeType: TypeAlias = Union[
- torch.Tensor, torch.ScriptObject, FakeScriptObject, pytree.TreeSpec
- ]
- _constant_attribute_types = get_args(_ConstantAttributeType)
- # We only want to print this once to avoid flooding logs
- @functools.lru_cache
- def is_fx_tracing_warning():
- log.warning(
- "is_fx_tracing will return true for both fx.symbolic_trace and "
- "torch.export. Please use "
- "is_fx_tracing_symbolic_tracing() for specifically fx.symbolic_trace "
- "or torch.compiler.is_compiling() for specifically torch.export/compile."
- )
- def is_fx_tracing():
- is_fx_tracing_warning()
- return _is_fx_tracing_flag
- def is_fx_symbolic_tracing():
- return _is_fx_tracing_flag and not torch.compiler.is_compiling()
- @compatibility(is_backward_compatible=True)
- class ProxyableClassMeta(type):
- """
- ProxyableClassMeta allows you to make construction of a given Python class
- symbolically traceable. For example::
- import torch
- import torch.fx
- class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
- def __init__(self, left, right):
- self.left, self.right = left, right
- def add(self, other):
- l = self.left + other.left
- r = self.right + other.right
- return TensorPair(l, r)
- def mul(self, other):
- l = self.left * other.left
- r = self.right * other.right
- return TensorPair(l, r)
- def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor):
- s = x.add(TensorPair(y, y))
- return s.mul(x)
- x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
- y = torch.randn(5, 3)
- ref_out = use_tensor_pair_ctor(x, y)
- traced = torch.fx.symbolic_trace(use_tensor_pair_ctor)
- print(traced.code)
- '''
- def forward(self, x : __main___TensorPair, y : torch.Tensor):
- tensor_pair = __main___TensorPair(y, y); y = None
- add = x.add(tensor_pair); tensor_pair = None
- mul = add.mul(x); add = x = None
- return mul
- '''
- From this example, we can see that construction of a class (``TensorPair``)
- defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic
- tracing.
- """
- def __init__(cls, name, bases, attrs):
- _proxyable_classes.setdefault(cls)
- super().__init__(name, bases, attrs)
- def __call__(cls, *args, **kwargs):
- instance = cls.__new__(cls) # type: ignore[call-overload]
- if not is_fx_tracing():
- cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
- return instance
- found_proxies = []
- def check_proxy(a):
- if isinstance(a, Proxy):
- found_proxies.append(a)
- map_aggregate(args, check_proxy)
- map_aggregate(kwargs, check_proxy)
- if len(found_proxies) != 0:
- tracer = found_proxies[0].tracer
- return tracer.create_proxy("call_function", cls, args, kwargs)
- else:
- cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
- return instance
- def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
- co = fn.__code__
- co_flags = co.co_flags & ~HAS_VARSTUFF
- co_args: tuple
- if hasattr(co, "co_qualname"):
- # Python-3.11+ code signature
- co_args = (
- nargs,
- 0,
- 0,
- co.co_nlocals,
- co.co_stacksize,
- co_flags,
- co.co_code,
- co.co_consts,
- co.co_names,
- co.co_varnames,
- co.co_filename,
- co.co_name,
- co.co_qualname, # type: ignore[attr-defined]
- co.co_firstlineno,
- co.co_linetable,
- co.co_exceptiontable, # type: ignore[attr-defined]
- co.co_freevars,
- co.co_cellvars,
- )
- elif hasattr(co, "co_posonlyargcount"):
- co_args = (
- nargs,
- 0,
- 0,
- co.co_nlocals,
- co.co_stacksize,
- co_flags,
- co.co_code,
- co.co_consts,
- co.co_names,
- co.co_varnames,
- co.co_filename,
- co.co_name,
- co.co_firstlineno,
- co.co_lnotab,
- co.co_freevars,
- co.co_cellvars,
- )
- else:
- co_args = (
- nargs,
- 0,
- co.co_nlocals,
- co.co_stacksize,
- co_flags,
- co.co_code,
- co.co_consts,
- co.co_names,
- co.co_varnames,
- co.co_filename,
- co.co_name,
- co.co_firstlineno,
- co.co_lnotab,
- co.co_freevars,
- co.co_cellvars,
- )
- new_code = CodeType(*co_args) # type: ignore[arg-type]
- return FunctionType(
- new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__
- )
- # we need to insert placeholder nodes for *args and **kwargs
- # we can't call this function normally, otherwise it would try to unpack them
- # instead, let's make python think that args and kwargs are normal variables
- @compatibility(is_backward_compatible=False)
- class PHBase:
- """
- Object representing an input placeholder to `concrete_args`
- """
- def __repr__(self):
- return "PH"
- PH = PHBase()
- @compatibility(is_backward_compatible=False)
- class PHWithMeta(PHBase):
- """
- Object representing an input placeholder to `concrete_args`
- """
- def __init__(self, ph_key: Optional[str] = None):
- super().__init__()
- # Provide a hey for user to identify placeholder node during analysis
- self.ph_key = ph_key
- def _transfer_attrs(fr, to):
- for attr_name in dir(fr):
- attr_val = getattr(fr, attr_name)
- if (
- not callable(attr_val)
- and not attr_name.startswith("__")
- and not hasattr(to, attr_name)
- ):
- setattr(to, attr_name, attr_val)
- @compatibility(is_backward_compatible=True)
- class Tracer(TracerBase):
- # Reference: https://github.com/pytorch/pytorch/issues/54354
- # The first line of this docstring overrides the one Sphinx generates for the
- # documentation. We need it so that Sphinx doesn't leak `math`s path from the
- # build environment (e.g. `<module 'math' from '/leaked/path').
- """Tracer(autowrap_modules=(math,), autowrap_functions=())
- ``Tracer`` is the class that implements the symbolic tracing functionality
- of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
- to ``Tracer().trace(m)``.
- Tracer can be subclassed to override various behaviors of the tracing
- process. The different behaviors that can be overridden are described
- in the docstrings of the methods on this class.
- """
- # Not checking BC on this API because the default value for `autowrap_modules`
- # includes the local filepath to the `math` module, which would jitter
- # across machines.
- @compatibility(is_backward_compatible=True)
- def __init__(
- self,
- autowrap_modules: tuple[ModuleType] = (math,),
- autowrap_functions: tuple[Callable, ...] = (),
- param_shapes_constant: bool = False,
- ) -> None:
- # This method's signature is overridden by the first line of this class'
- # docstring. If this method's signature is modified, the signature that
- # overrides it also should be modified accordingly.
- """
- Construct a Tracer object.
- Args:
- autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
- Python modules whose functions should be wrapped automatically
- without needing to use fx.wrap(). Backward-compatibility for
- this parameter is guaranteed.
- autowrap_functions (Tuple[Callable, ...]): defaults to `()`,
- Python functions that should be wrapped automatically without
- needing to use fx.wrap(). Backward compatibility for this
- parameter is guaranteed.
- param_shapes_constant (bool): When this flag is set, calls to shape,
- size and a few other shape like attributes of a module's parameter
- will be evaluated directly, rather than returning a new Proxy value
- for an attribute access. Backward compatibility for this parameter
- is guaranteed.
- """
- super().__init__()
- # Functions we will eagerly wrap when we see them while tracing
- # this captures both `math.sqrt()` and `from math import sqrt` automatically
- self._autowrap_function_ids: set[int] = {
- id(value)
- for name, value in chain.from_iterable(
- m.__dict__.items() for m in autowrap_modules
- )
- if not name.startswith("_") and callable(value)
- }
- self._autowrap_function_ids.update({id(f) for f in autowrap_functions})
- # Python modules to apply autowrap to at the start, in addition to
- # modules we see while tracing
- self._autowrap_search: list[ModuleType] = list(autowrap_modules)
- self.param_shapes_constant = param_shapes_constant
- self.submodule_paths: Optional[dict[torch.nn.Module, str]] = None
- self.root_module_name: str = ""
- # Maps the containing module's name to the operator name
- self.scope = Scope("", None)
- # Records the module call stack
- self.module_stack = collections.OrderedDict()
- self.num_calls: dict[str, int] = {}
- # Mapping of node name to module scope
- self.node_name_to_scope: dict[str, tuple[str, type]] = {}
- _qualname_counter: dict[str, int] = collections.defaultdict(int)
- @compatibility(is_backward_compatible=True)
- def get_fresh_qualname(self, prefix: str) -> str:
- """
- Gets a fresh name for a prefix and returns it. This function ensures
- that it will not clash with an existing attribute on the graph.
- """
- # The idea here is that if the module doesn't have this prefix at all we
- # should reset the counter to start from the beginning
- # It's a ... little bit hacky (doesn't cover all cases) but the precise
- # naming of the prefixes isn't a correctness issue, just a niceness
- # issue
- qualname = f"{prefix}0"
- if not hasattr(self.root, qualname):
- self._qualname_counter[prefix] = 0
- return qualname
- i = self._qualname_counter[prefix]
- while True:
- qualname = f"{prefix}{i}"
- i += 1
- if not hasattr(self.root, qualname):
- break
- self._qualname_counter[prefix] = i
- return qualname
- @compatibility(is_backward_compatible=True)
- def create_arg(self, a: Any) -> "Argument":
- """
- A method to specify the behavior of tracing when preparing values to
- be used as arguments to nodes in the ``Graph``.
- By default, the behavior includes:
- #. Iterate through collection types (e.g. tuple, list, dict) and recursively
- call ``create_args`` on the elements.
- #. Given a Proxy object, return a reference to the underlying IR ``Node``
- #. Given a non-Proxy Tensor object, emit IR for various cases:
- * For a Parameter, emit a ``get_attr`` node referring to that Parameter
- * For a non-Parameter Tensor, store the Tensor away in a special
- attribute referring to that attribute.
- This method can be overridden to support more types.
- Args:
- a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
- Returns:
- The value ``a`` converted into the appropriate ``Argument``
- """
- # The base tracer is used to construct Graphs when there is no associated
- # module hierarchy, so it can never create parameter references.
- # The default tracer adds the ability to refer to parameters when
- # tracing modules.
- if isinstance(a, torch.nn.Parameter):
- for n, p in self.root.named_parameters():
- if a is p:
- return self.create_node("get_attr", n, (), {})
- raise NameError("parameter is not a member of this module")
- elif isinstance(a, torch.Tensor):
- for n_, p_ in self.root.named_buffers():
- if a is p_:
- return self.create_node("get_attr", n_, (), {})
- elif isinstance(a, torch.nn.Module):
- for n_, p_ in self.root.named_modules():
- if a is p_:
- return self.create_node("get_attr", n_, (), {})
- # For NamedTuple instances that appear literally as args, we emit
- # a node to construct the NamedTuple and use that Node as the argument.
- if isinstance(a, tuple) and hasattr(a, "_fields"):
- args = tuple(self.create_arg(elem) for elem in a)
- return self.create_node("call_function", a.__class__, args, {})
- # Tensors do not have a reliable string repr() from which they can be
- # constructed (and we probably don't want to rely on that, either), so
- # for any constant Tensor values we encounter, first search for if they
- # are an attribute of some module in the module hierarchy. If so, emit
- # a get_attr to retrieve that tensor. Otherwise, we'll store away the
- # tensor value into a special attribute on the Module s.t. we can
- # retrieve it with a get_attr.
- if isinstance(a, _constant_attribute_types) or (
- is_opaque_reference_type(type(a))
- ):
- qualname: Optional[str] = self.tensor_attrs.get(a)
- # Tensor was not found in the Module hierarchy, stow it away in a
- # special attribute and set the qualname to refer to that
- if not qualname:
- if isinstance(a, torch.Tensor):
- base_name = "_tensor_constant"
- elif isinstance(a, (FakeScriptObject, ScriptObject)):
- base_name = "_torchbind_obj"
- elif isinstance(a, pytree.TreeSpec):
- base_name = "_tree_spec_constant"
- elif is_opaque_type(type(a)):
- base_name = "_opaque_obj"
- else:
- raise RuntimeError(
- f"cannot create constant arg for {a} of type {type(a)}."
- )
- qualname = self.get_fresh_qualname(base_name)
- if not isinstance(qualname, str):
- raise AssertionError(
- f"Expected qualname to be str, got {type(qualname)}"
- )
- self.tensor_attrs[a] = qualname
- setattr(self.root, qualname, a)
- return self.create_node("get_attr", qualname, (), {})
- if type(a) in _proxyable_classes:
- # This is an instance of a proxyable class for which we did not
- # witness its construction. Intern this as a constant attribute
- # TODO: binary search
- qualname = self.get_fresh_qualname(f"_{a.__class__.__name__}_constant_")
- if not isinstance(qualname, str):
- raise AssertionError(
- f"Expected qualname to be str, got {type(qualname)}"
- )
- setattr(self.root, qualname, a)
- return self.create_node("get_attr", qualname, (), {})
- return super().create_arg(a)
- @compatibility(is_backward_compatible=True)
- def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
- """
- A method to specify whether a given ``nn.Module`` is a "leaf" module.
- Leaf modules are the atomic units that appear in
- the IR, referenced by ``call_module`` calls. By default,
- Modules in the PyTorch standard library namespace (torch.nn)
- are leaf modules. All other modules are traced through and
- their constituent ops are recorded, unless specified otherwise
- via this parameter.
- Args:
- m (Module): The module being queried about
- module_qualified_name (str): The path to root of this module. For example,
- if you have a module hierarchy where submodule ``foo`` contains
- submodule ``bar``, which contains submodule ``baz``, that module will
- appear with the qualified name ``foo.bar.baz`` here.
- """
- return (
- m.__module__.startswith("torch.nn")
- or m.__module__.startswith("torch.ao.nn")
- ) and not isinstance(m, torch.nn.Sequential)
- @compatibility(is_backward_compatible=True)
- def path_of_module(self, mod: torch.nn.Module) -> str:
- """
- Helper method to find the qualified name of ``mod`` in the Module hierarchy
- of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
- a submodule named ``bar``, passing ``bar`` into this function will return
- the string "foo.bar".
- Args:
- mod (str): The ``Module`` to retrieve the qualified name for.
- """
- # Prefer the O(1) algorithm
- if self.submodule_paths:
- path = self.submodule_paths.get(mod)
- if path is None:
- raise NameError("module is not installed as a submodule")
- if not isinstance(path, str):
- raise AssertionError(f"Expected path to be str, got {type(path)}")
- return path
- # O(N^2) fallback in the case that we didn't store the submodule
- # paths.
- else:
- for n, p in self.root.named_modules():
- if mod is p:
- return n
- raise NameError("module is not installed as a submodule")
- @compatibility(is_backward_compatible=True)
- def call_module(
- self,
- m: torch.nn.Module,
- forward: Callable[..., Any],
- args: tuple[Any, ...],
- kwargs: dict[str, Any],
- ) -> Any:
- """
- Method that specifies the behavior of this ``Tracer`` when it encounters
- a call to an ``nn.Module`` instance.
- By default, the behavior is to check if the called module is a leaf module
- via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
- ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
- the operations in its ``forward`` function.
- This method can be overridden to--for example--create nested traced
- GraphModules, or any other behavior you would want while tracing across
- ``Module`` boundaries.
- Args:
- m (Module): The module for which a call is being emitted
- forward (Callable): The forward() method of the ``Module`` to be invoked
- args (Tuple): args of the module callsite
- kwargs (Dict): kwargs of the module callsite
- Return:
- The return value from the Module call. In the case that a ``call_module``
- node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
- value was returned from the ``Module`` invocation.
- """
- module_qualified_name = self.path_of_module(m)
- with ScopeContextManager(
- self.scope, Scope(module_qualified_name, type(m))
- ) as _scope:
- # module_stack is an ordered dict so writing then deleting the
- # entry is equivalent to push/pop on a list
- num_calls = self.num_calls.get(module_qualified_name, 0)
- module_key = (
- f"{_scope.module_path}@{num_calls}"
- if num_calls > 0
- else _scope.module_path
- )
- self.module_stack[module_key] = (module_qualified_name, _scope.module_type)
- self.num_calls[module_qualified_name] = num_calls + 1
- if not self.is_leaf_module(m, module_qualified_name):
- ret_val = forward(*args, **kwargs)
- else:
- ret_val = self.create_proxy(
- "call_module", module_qualified_name, args, kwargs
- )
- key, _ = self.module_stack.popitem(last=True)
- if key != module_key:
- raise AssertionError(f"Unexpected key {key}, expected {module_key}")
- return ret_val
- @compatibility(is_backward_compatible=False)
- def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]):
- """
- Method that specifies the behavior of this ``Tracer`` when we call getattr
- on a call to an ``nn.Module`` instance.
- By default, the behavior is to return a proxy value for the attribute. It
- also stores the proxy value in the ``parameter_proxy_cache``, so that future
- calls will reuse the proxy rather than creating a new one.
- This method can be overridden to --for example-- not return proxies when
- querying parameters.
- Args:
- attr (str): The name of the attribute being queried
- attr_val (Any): The value of the attribute
- parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies
- Return:
- The return value from the getattr call.
- """
- def maybe_get_proxy_for_attr(
- attr_val, collection_to_search, parameter_proxy_cache
- ):
- for n, p in collection_to_search:
- if attr_val is p:
- if n not in parameter_proxy_cache:
- kwargs = {}
- if (
- "proxy_factory_fn"
- in inspect.signature(self.create_proxy).parameters
- ):
- kwargs["proxy_factory_fn"] = (
- # pyrefly: ignore [unsupported-operation]
- None
- if not self.param_shapes_constant
- else lambda node: ParameterProxy(
- self, node, n, attr_val
- )
- )
- val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
- parameter_proxy_cache[n] = val_proxy
- return parameter_proxy_cache[n]
- return None
- if isinstance(attr_val, torch.nn.Parameter):
- maybe_parameter_proxy = maybe_get_proxy_for_attr(
- attr_val, self.root.named_parameters(), parameter_proxy_cache
- )
- if maybe_parameter_proxy is not None:
- return maybe_parameter_proxy
- if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
- maybe_buffer_proxy = maybe_get_proxy_for_attr(
- attr_val, self.root.named_buffers(), parameter_proxy_cache
- )
- if maybe_buffer_proxy is not None:
- return maybe_buffer_proxy
- return attr_val
- # This method will be refactored
- @compatibility(is_backward_compatible=False)
- def create_args_for_root(self, root_fn, is_module, concrete_args=None):
- """
- Create ``placeholder`` nodes corresponding to the signature of the ``root``
- Module. This method introspects root's signature and emits those
- nodes accordingly, also supporting ``*args`` and ``**kwargs``.
- """
- # In some cases, a function or method has been decorated with a wrapper
- # defined via ``functools.wraps``. In this case, the outer code object
- # will likely not contain the actual parameters we care about, so unwrap
- # the function to get to the innermost callable.
- fn_for_analysis = inspect.unwrap(root_fn)
- co = fn_for_analysis.__code__
- total_args = co.co_argcount + co.co_kwonlyargcount
- orig_args = list(co.co_varnames)
- names_iter = iter(co.co_varnames)
- args: list[Any] = []
- skip_arg_idx = 0
- if is_module:
- if total_args == 0:
- raise RuntimeError(
- "``self`` argument cannot be part of *args expansion!"
- )
- skip_arg_idx = 1
- next(names_iter) # skip self
- args.append(self.root)
- sig = inspect.signature(fn_for_analysis)
- # This covers the very specific case where we are passing in flat
- # concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
- # In this case, just take the concrete_args and pass them through.
- name_idx = 0
- if (
- isinstance(concrete_args, tuple)
- and len(concrete_args) > 0
- and (co.co_flags & HAS_VARSTUFF)
- and total_args == 1
- ):
- for concrete_arg in concrete_args:
- out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
- if isinstance(concrete_arg, PHBase):
- if concrete_arg != PH:
- # Transfer attrs in the case where you're using a placeholder other
- # than the singleton PH (PH has no attributes to transfer).
- # Proxies were created out of the placeholders.
- # Transfer any metadata (put on the placeholders in the form of
- # attributes set by the user) from the placeholder to the
- # underlying nodes (the proxy is unwrapped by the user, but
- # the metadata should hold).
- _transfer_attrs(fr=concrete_arg, to=out.node)
- args.append(out)
- name_idx += 1
- return root_fn, args
- arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
- if isinstance(concrete_args, tuple):
- if len(arg_names) != len(concrete_args):
- raise RuntimeError(
- f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
- )
- concrete_args = dict(zip(arg_names, concrete_args))
- def proxy_placeholder(name):
- return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis)
- args.extend(proxy_placeholder(names) for names in arg_names)
- if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
- # TODO: type annotations for *args and **kwargs
- if co.co_flags & inspect.CO_VARARGS:
- args.append(proxy_placeholder("*" + next(names_iter)))
- if co.co_flags & inspect.CO_VARKEYWORDS:
- args.append(proxy_placeholder("**" + next(names_iter)))
- root_fn = _patch_function(root_fn, len(args))
- flat_args, in_spec = pytree.tree_flatten(tuple(args))
- if not all(child.is_leaf() for child in in_spec.children()):
- # In the case that we have pytree-flattened inputs in
- # `concrete_args`, generate a flattening wrapper around the
- # original root function and return that.
- self.graph._codegen = _PyTreeCodeGen( # type: ignore[has-type]
- _PyTreeInfo(orig_args[:total_args], in_spec, None)
- )
- def flatten_fn(*args):
- tree_args = pytree.tree_unflatten(list(args), in_spec)
- tree_out = root_fn(*tree_args)
- out_args, out_spec = pytree.tree_flatten(tree_out)
- if not isinstance(self.graph._codegen, _PyTreeCodeGen): # type: ignore[has-type]
- raise AssertionError(
- f"Expected _codegen to be _PyTreeCodeGen, got "
- f"{type(self.graph._codegen)}"
- )
- self.graph._codegen.pytree_info = (
- self.graph._codegen.pytree_info._replace(out_spec=out_spec)
- )
- return out_args
- return flatten_fn, flat_args
- return root_fn, args
- @compatibility(is_backward_compatible=True)
- def trace(
- self,
- root: Union[torch.nn.Module, Callable[..., Any]],
- concrete_args: Optional[dict[str, Any]] = None,
- ) -> Graph:
- """
- Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
- can either be an ``nn.Module`` instance or a Python callable.
- Note that after this call, ``self.root`` may be different from the ``root`` passed
- in here. For example, when a free function is passed to ``trace()``, we will
- create an ``nn.Module`` instance to use as the root and add embedded constants
- to.
- Args:
- root (Union[Module, Callable]): Either a ``Module`` or a function to be
- traced through. Backwards-compatibility for this parameter is
- guaranteed.
- concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
- not be treated as Proxies. This parameter is experimental and
- its backwards-compatibility is *NOT* guaranteed.
- Returns:
- A ``Graph`` representing the semantics of the passed-in ``root``.
- """
- global _is_fx_tracing_flag
- old_is_fx_tracing_flag = _is_fx_tracing_flag
- _is_fx_tracing_flag = True
- try:
- if isinstance(root, torch.nn.Module):
- # do real recompilation for _LazyGraphModule before retracing since the trace
- # method can not trace the _lazy_forward method. Got error:
- # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
- # without this.
- from torch.fx._lazy_graph_module import _LazyGraphModule
- _LazyGraphModule.force_recompile(root)
- self.root = root
- if not hasattr(type(root), self.traced_func_name):
- raise AssertionError(
- f"traced_func_name={self.traced_func_name} doesn't exist in "
- f"{type(root).__name__}"
- )
- fn = getattr(type(root), self.traced_func_name)
- self.root_module_name = root._get_name()
- self.submodule_paths = {mod: name for name, mod in root.named_modules()}
- else:
- self.root = torch.nn.Module()
- fn = root
- tracer_cls: Optional[type[Tracer]] = getattr(self, "__class__", None)
- self.graph = Graph(tracer_cls=tracer_cls)
- if hasattr(fn, "__code__"):
- code = fn.__code__
- self.graph._co_fields = {
- "co_name": code.co_name,
- "co_filename": code.co_filename,
- "co_firstlineno": code.co_firstlineno,
- }
- # When we encounter a Tensor value that's not a parameter, we look if it
- # is some other attribute on the model. Construct a dict mapping Tensor
- # values to the qualified name here for efficiency. This is used downstream
- # in create_arg
- self.tensor_attrs: dict[
- _ConstantAttributeType,
- str,
- ] = {}
- def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: list[str]):
- for k, v in m.__dict__.items():
- if isinstance(v, _constant_attribute_types):
- self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
- for k, v in m.named_children():
- collect_tensor_attrs(v, prefix_atoms + [k])
- collect_tensor_attrs(self.root, [])
- if not isinstance(fn, FunctionType):
- raise AssertionError(f"Expected FunctionType, got {type(fn)}")
- fn_globals = fn.__globals__ # run before it gets patched
- fn, args = self.create_args_for_root(
- fn, isinstance(root, torch.nn.Module), concrete_args
- )
- parameter_proxy_cache: dict[
- str, Proxy
- ] = {} # Reduce number of get_attr calls
- # Method dispatch on parameters is not recorded unless it's directly used.
- # Thus, we need to insert a proxy when __getattr__ requests a parameter.
- @functools.wraps(_orig_module_getattr)
- def module_getattr_wrapper(mod, attr):
- attr_val = _orig_module_getattr(mod, attr)
- return self.getattr(attr, attr_val, parameter_proxy_cache)
- @functools.wraps(_orig_module_call)
- def module_call_wrapper(mod, *args, **kwargs):
- def forward(*args, **kwargs):
- return _orig_module_call(mod, *args, **kwargs)
- _autowrap_check(
- patcher, # type: ignore[has-type]
- getattr(getattr(mod, "forward", mod), "__globals__", {}),
- self._autowrap_function_ids,
- )
- return self.call_module(mod, forward, args, kwargs)
- with _new_patcher() as patcher:
- # allow duplicate patches to support the case of nested calls
- patcher.patch_method(
- torch.nn.Module,
- "__getattr__",
- module_getattr_wrapper,
- deduplicate=False,
- )
- patcher.patch_method(
- torch.nn.Module,
- "__call__",
- module_call_wrapper,
- deduplicate=False,
- )
- _patch_wrapped_functions(patcher)
- _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
- for module in self._autowrap_search:
- _autowrap_check(
- patcher, module.__dict__, self._autowrap_function_ids
- )
- ann = inspect.get_annotations(inspect.unwrap(fn))
- self.create_node(
- "output",
- "output",
- (self.create_arg(fn(*args)),),
- {},
- type_expr=ann.get("return", None),
- )
- self.submodule_paths = None
- except RuntimeError as e:
- if e.args and isinstance(e.args[0], str) and "data-dependent" in e.args[0]:
- partial_fx_graph = self.graph.python_code(
- root_module="self",
- verbose=True,
- ).src
- e.partial_fx_graph = partial_fx_graph # type: ignore[attr-defined]
- raise
- raise
- finally:
- _is_fx_tracing_flag = old_is_fx_tracing_flag
- return self.graph
- def __deepcopy__(self, memo):
- # _autowrap_search contains modules, which cannot be deepcopied.
- new_tracer = Tracer.__new__(Tracer)
- for k, v in self.__dict__.items():
- if k == "_autowrap_search":
- new_obj = copy.copy(v)
- else:
- new_obj = copy.deepcopy(v, memo)
- new_tracer.__dict__[k] = new_obj
- return new_tracer
- def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis):
- if concrete_args is not None and name in concrete_args:
- cnt = 0
- def replace_ph(x):
- nonlocal cnt
- cnt += 1
- param = sig.parameters[name]
- default: tuple[Any, ...] = (
- () if param.default is inspect.Parameter.empty else (param.default,)
- )
- out = self.create_proxy(
- "placeholder", f"{name}_{str(cnt)}", default, {}
- )
- if isinstance(x, PHBase):
- if x != PH:
- # Transfer attrs in the case where you're using a placeholder other
- # than the singleton PH (PH has no attributes to transfer).
- # Proxies were created out of the placeholders.
- # Transfer any metadata (put on the placeholders in the form of
- # attributes set by the user) from the placeholder to the
- # underlying nodes (the proxy is unwrapped by the user, but
- # the metadata should hold).
- _transfer_attrs(fr=x, to=out.node)
- return out
- # Union[int, bool] == bool in Python <= 3.6
- if (
- type(x) is bool
- or type(x) in base_types
- and type(x) is not torch.Tensor
- ):
- torch._assert(
- out == x,
- f"{name} has been specialized to have value {x} but got another value",
- )
- elif x is None:
- args = (
- out,
- f"{name} has been specialized to have value None but got another value",
- )
- self.create_proxy("call_function", _assert_is_none, args, {})
- else:
- warnings.warn(
- f"Was not able to add assertion to guarantee correct input {name} to "
- f"specialized function. It is up to the user to make sure that your inputs match the "
- f"inputs you specialized the function with."
- )
- return x
- return pytree.tree_map(replace_ph, concrete_args[name])
- if name[0] == "*":
- default: tuple[Any, ...] = ()
- else:
- param = sig.parameters[name]
- default = ( # type: ignore[assignment]
- () if param.default is inspect.Parameter.empty else (param.default,)
- )
- return self.create_proxy(
- "placeholder",
- name,
- default,
- {},
- type_expr=fn_for_analysis.__annotations__.get(name, None),
- )
- # Dictionary of (id(globals dict), function name) => globals_dict to patch for
- # the purposes of the wrap() API.
- # We key by the globals dict id and function name to ensure we're wrapping a given
- # function only once.
- _wrapped_fns_to_patch: dict[tuple[int, str], dict] = {}
- # List of methods on classes to wrap (class type, function name)
- # this currently only works for Tensor.* methods that aren't traced properly
- _wrapped_methods_to_patch: list[tuple[type, str]] = []
- if os.environ.get("FX_PATCH_GETITEM") == "1":
- # This change is needed to trace models like PositionalEmbedding from BERT:
- # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py
- # but causes issues in quantization documented here:
- # https://github.com/pytorch/pytorch/issues/50710
- # once that is fixed we can make this the default behavior.
- _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
- def _find_proxy(*objects_to_search):
- """
- Recursively search a data structure for a Proxy() and return it,
- return None if not found.
- """
- proxy = None
- def find_proxy(x):
- nonlocal proxy
- if isinstance(x, Proxy):
- proxy = x
- map_aggregate(objects_to_search, find_proxy)
- return proxy
- def _create_wrapped_func(orig_fn):
- @functools.wraps(orig_fn)
- def wrapped(*args, **kwargs):
- """
- Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
- a Proxy object. If there is one, emit a ``call_function`` node to preserve the
- call to this leaf function directly. Otherwise, just return the results of
- this function call, as this function is not being traced.
- """
- proxy = _find_proxy(args, kwargs)
- if proxy is not None:
- return_proxy = proxy.tracer.create_proxy(
- "call_function", orig_fn, args, kwargs
- )
- return_proxy.node.meta["is_wrapped"] = True
- return return_proxy
- return orig_fn(*args, **kwargs)
- return wrapped
- def _create_wrapped_method(cls, name):
- orig_fn = getattr(cls, name)
- @functools.wraps(orig_fn)
- def wrapped(*args, **kwargs):
- """
- Search the args and kwargs for a Proxy object. If there is one,
- emit a ``call_method`` node to preserve the call to this method
- directly. Otherwise, just return the results of this function
- call, as this function is not being traced.
- """
- proxy = _find_proxy(args, kwargs)
- if proxy is not None:
- return proxy.tracer.create_proxy("call_method", name, args, kwargs)
- return orig_fn(*args, **kwargs)
- return wrapped
- class _PatchedFn(NamedTuple):
- frame_dict: Any
- fn_name: str
- orig_fn: Any
- new_fn: Any
- def revert(self):
- raise NotImplementedError
- def patch(self):
- raise NotImplementedError
- class _PatchedFnSetItem(_PatchedFn):
- def revert(self):
- self.frame_dict[self.fn_name] = self.orig_fn
- def patch(self):
- self.frame_dict[self.fn_name] = self.new_fn
- class _PatchedFnDel(_PatchedFn):
- def revert(self):
- del self.frame_dict[self.fn_name]
- def patch(self):
- self.frame_dict[self.fn_name] = self.new_fn
- class _PatchedFnSetAttr(_PatchedFn):
- def revert(self):
- setattr(self.frame_dict, self.fn_name, self.orig_fn)
- def patch(self):
- setattr(self.frame_dict, self.fn_name, self.new_fn)
- class _Patcher:
- def __init__(self) -> None:
- super().__init__()
- self.patches_made: list[_PatchedFn] = []
- self.visited: set[int] = set()
- def patch(
- self,
- frame_dict: dict[str, Any],
- name: str,
- new_fn: Callable,
- deduplicate: bool = True,
- ):
- """
- Replace frame_dict[name] with new_fn until we exit the context manager.
- """
- new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
- if name not in frame_dict and hasattr(builtins, name):
- self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn))
- self.patches_made[-1].patch()
- elif getattr(frame_dict[name], "__fx_already_patched", False):
- return # already patched, no need to do it again
- else:
- self.patches_made.append(
- _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn)
- )
- self.patches_made[-1].patch()
- def patch_method(
- self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True
- ):
- """
- Replace object_or_dict.name with new_fn until we exit the context manager.
- """
- new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
- orig_fn = getattr(cls, name)
- if getattr(orig_fn, "__fx_already_patched", False):
- return # already patched, no need to do it again
- self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn))
- self.patches_made[-1].patch()
- def visit_once(self, thing: Any):
- """Return True on the first call to with thing, otherwise false"""
- idx = id(thing)
- if idx in self.visited:
- return False
- self.visited.add(idx)
- return True
- def revert_all_patches(self):
- """
- Remove all the stored patcheds. It doesn't modify patches_made.
- """
- for patch in self.patches_made:
- patch.revert()
- return self.patches_made
- def reapply_all_patches(self):
- """
- Patch all the stored patcheds. It doesn't modify patches_made.
- """
- for patch in self.patches_made:
- patch.patch()
- return self.patches_made
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- """
- Undo all the changes made via self.patch() and self.patch_method()
- """
- while self.patches_made:
- # unpatch in reverse order to handle duplicates correctly
- self.patches_made.pop().revert()
- self.visited.clear()
- CURRENT_PATCHER: Optional[_Patcher] = None
- @contextlib.contextmanager
- def _new_patcher():
- global CURRENT_PATCHER
- prior_patcher = CURRENT_PATCHER
- try:
- CURRENT_PATCHER = _Patcher()
- yield CURRENT_PATCHER
- finally:
- # Clear all the patches made by when using current patcher.
- if CURRENT_PATCHER is None:
- raise AssertionError("CURRENT_PATCHER is None in finally block")
- CURRENT_PATCHER.revert_all_patches()
- CURRENT_PATCHER = prior_patcher
- @contextlib.contextmanager
- def _maybe_revert_all_patches():
- current_patcher = CURRENT_PATCHER
- patches_made = None
- patches_removed = None
- try:
- if current_patcher is not None:
- patches_removed = current_patcher.revert_all_patches()
- yield
- finally:
- if current_patcher is not None:
- patches_made = current_patcher.reapply_all_patches()
- if patches_made != patches_removed:
- raise AssertionError(
- "CURRENT_PATCHER was changed during a revert_all_patches"
- )
- def _patch_wrapped_functions(patcher: _Patcher):
- """
- Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
- the listed global functions in the `_create_wrapped_func` wrapper.
- """
- for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items():
- if name not in frame_dict and hasattr(builtins, name):
- orig_fn = getattr(builtins, name)
- else:
- orig_fn = frame_dict[name]
- patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
- for cls, name in _wrapped_methods_to_patch:
- patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
- def _autowrap_check(
- patcher: _Patcher, frame_dict: dict[str, Any], function_ids: set[int]
- ):
- """
- Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
- This method searches a scope for them and patches them if found.
- """
- if patcher.visit_once(frame_dict):
- for name, value in frame_dict.items():
- if (
- not name.startswith("_")
- and callable(value)
- and id(value) in function_ids
- ):
- patcher.patch(frame_dict, name, _create_wrapped_func(value))
- @compatibility(is_backward_compatible=True)
- def wrap(fn_or_name: Union[str, Callable]):
- """
- This function can be called at module-level scope to register fn_or_name as a "leaf function".
- A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being
- traced through::
- # foo/bar/baz.py
- def my_custom_function(x, y):
- return x * x + y * y
- torch.fx.wrap("my_custom_function")
- def fn_to_be_traced(x, y):
- # When symbolic tracing, the below call to my_custom_function will be inserted into
- # the graph rather than tracing it.
- return my_custom_function(x, y)
- This function can also equivalently be used as a decorator::
- # foo/bar/baz.py
- @torch.fx.wrap
- def my_custom_function(x, y):
- return x * x + y * y
- A wrapped function can be thought of a "leaf function", analogous to the concept of
- "leaf modules", that is, they are functions that are left as calls in the FX trace
- rather than traced through.
- Args:
- fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
- graph when it's called
- """
- if not callable(fn_or_name) and not isinstance(fn_or_name, str):
- raise RuntimeError(
- "Unsupported type for global function! Must be either a callable or "
- "string name"
- )
- if callable(fn_or_name):
- if isinstance(fn_or_name, str): # to make mypy happy
- raise AssertionError("Unexpected: fn_or_name is both callable and str")
- fn_name = fn_or_name.__name__
- else:
- if not isinstance(fn_or_name, str):
- raise AssertionError(
- f"fn_or_name must be a global function or string name, got "
- f"{type(fn_or_name)}"
- )
- fn_name = fn_or_name
- currentframe = inspect.currentframe()
- if currentframe is None:
- raise AssertionError("inspect.currentframe() returned None")
- f = currentframe.f_back
- if f is None:
- raise AssertionError("currentframe.f_back is None")
- if f.f_code.co_name != "<module>":
- raise NotImplementedError("wrap must be called at the top level of a module")
- # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
- # semantics would be slightly different, but would add support `from x import wrapped_function`
- _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals
- return fn_or_name
- @compatibility(is_backward_compatible=True)
- def symbolic_trace(
- root: Union[torch.nn.Module, Callable[..., Any]],
- concrete_args: Optional[dict[str, Any]] = None,
- ) -> GraphModule:
- """
- Symbolic tracing API
- Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
- constructed by recording operations seen while tracing through ``root``.
- ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures.
- For example::
- def f(a, b):
- if b == True:
- return a
- else:
- return a * 2
- FX can typically not trace through this due to the presence of control
- flow. However, we can use `concrete_args` to specialize on the value of
- `b` to trace through this::
- f = fx.symbolic_trace(f, concrete_args={"b": False})
- assert f(3, False) == 6
- Note that although you can still pass in different values of `b`, they will be ignored.
- We can also use `concrete_args` to eliminate data-structure handling from
- our function. This will use pytrees to flatten your input. To avoid
- overspecializing, pass in `fx.PH` for values that shouldn't be
- specialized. For example::
- def f(x):
- out = 0
- for v in x.values():
- out += v
- return out
- f = fx.symbolic_trace(
- f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}
- )
- assert f({"a": 1, "b": 2, "c": 4}) == 7
- Args:
- root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
- into a Graph representation.
- concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized
- Returns:
- GraphModule: a Module created from the recorded operations from ``root``.
- """
- tracer = Tracer()
- graph = tracer.trace(root, concrete_args)
- name = (
- root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
- )
- return _make_graph_module(tracer.root, graph, name)
- @wrap
- def _assert_is_none(value, msg):
- if value is not None:
- raise AssertionError(msg)
|