_symbolic_trace.py 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387
  1. # mypy: allow-untyped-defs
  2. import builtins
  3. import collections
  4. import contextlib
  5. import copy
  6. import functools
  7. import inspect
  8. import logging
  9. import math
  10. import os
  11. import warnings
  12. from collections.abc import Callable
  13. from itertools import chain
  14. from types import CodeType, FunctionType, ModuleType
  15. from typing import Any, get_args, NamedTuple, Optional, TypeAlias, Union
  16. import torch
  17. import torch.utils._pytree as pytree
  18. from torch._C import ScriptObject # type: ignore[attr-defined]
  19. from torch._library.fake_class_registry import FakeScriptObject
  20. from torch._library.opaque_object import is_opaque_reference_type, is_opaque_type
  21. from ._compatibility import compatibility
  22. from ._lazy_graph_module import _make_graph_module
  23. from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
  24. from .graph_module import GraphModule
  25. from .node import Argument, base_types, map_aggregate
  26. from .proxy import ParameterProxy, Proxy, Scope, ScopeContextManager, TracerBase
  27. log = logging.getLogger(__name__)
  28. HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
  29. # These need to run in global scope to handle nested calls correctly
  30. _orig_module_call: Callable = torch.nn.Module.__call__
  31. _orig_module_getattr: Callable = torch.nn.Module.__getattr__
  32. _proxyable_classes: dict[type, None] = {}
  33. _is_fx_tracing_flag = False
  34. _ConstantAttributeType: TypeAlias = Union[
  35. torch.Tensor, torch.ScriptObject, FakeScriptObject, pytree.TreeSpec
  36. ]
  37. _constant_attribute_types = get_args(_ConstantAttributeType)
  38. # We only want to print this once to avoid flooding logs
  39. @functools.lru_cache
  40. def is_fx_tracing_warning():
  41. log.warning(
  42. "is_fx_tracing will return true for both fx.symbolic_trace and "
  43. "torch.export. Please use "
  44. "is_fx_tracing_symbolic_tracing() for specifically fx.symbolic_trace "
  45. "or torch.compiler.is_compiling() for specifically torch.export/compile."
  46. )
  47. def is_fx_tracing():
  48. is_fx_tracing_warning()
  49. return _is_fx_tracing_flag
  50. def is_fx_symbolic_tracing():
  51. return _is_fx_tracing_flag and not torch.compiler.is_compiling()
  52. @compatibility(is_backward_compatible=True)
  53. class ProxyableClassMeta(type):
  54. """
  55. ProxyableClassMeta allows you to make construction of a given Python class
  56. symbolically traceable. For example::
  57. import torch
  58. import torch.fx
  59. class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
  60. def __init__(self, left, right):
  61. self.left, self.right = left, right
  62. def add(self, other):
  63. l = self.left + other.left
  64. r = self.right + other.right
  65. return TensorPair(l, r)
  66. def mul(self, other):
  67. l = self.left * other.left
  68. r = self.right * other.right
  69. return TensorPair(l, r)
  70. def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor):
  71. s = x.add(TensorPair(y, y))
  72. return s.mul(x)
  73. x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
  74. y = torch.randn(5, 3)
  75. ref_out = use_tensor_pair_ctor(x, y)
  76. traced = torch.fx.symbolic_trace(use_tensor_pair_ctor)
  77. print(traced.code)
  78. '''
  79. def forward(self, x : __main___TensorPair, y : torch.Tensor):
  80. tensor_pair = __main___TensorPair(y, y); y = None
  81. add = x.add(tensor_pair); tensor_pair = None
  82. mul = add.mul(x); add = x = None
  83. return mul
  84. '''
  85. From this example, we can see that construction of a class (``TensorPair``)
  86. defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic
  87. tracing.
  88. """
  89. def __init__(cls, name, bases, attrs):
  90. _proxyable_classes.setdefault(cls)
  91. super().__init__(name, bases, attrs)
  92. def __call__(cls, *args, **kwargs):
  93. instance = cls.__new__(cls) # type: ignore[call-overload]
  94. if not is_fx_tracing():
  95. cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
  96. return instance
  97. found_proxies = []
  98. def check_proxy(a):
  99. if isinstance(a, Proxy):
  100. found_proxies.append(a)
  101. map_aggregate(args, check_proxy)
  102. map_aggregate(kwargs, check_proxy)
  103. if len(found_proxies) != 0:
  104. tracer = found_proxies[0].tracer
  105. return tracer.create_proxy("call_function", cls, args, kwargs)
  106. else:
  107. cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
  108. return instance
  109. def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
  110. co = fn.__code__
  111. co_flags = co.co_flags & ~HAS_VARSTUFF
  112. co_args: tuple
  113. if hasattr(co, "co_qualname"):
  114. # Python-3.11+ code signature
  115. co_args = (
  116. nargs,
  117. 0,
  118. 0,
  119. co.co_nlocals,
  120. co.co_stacksize,
  121. co_flags,
  122. co.co_code,
  123. co.co_consts,
  124. co.co_names,
  125. co.co_varnames,
  126. co.co_filename,
  127. co.co_name,
  128. co.co_qualname, # type: ignore[attr-defined]
  129. co.co_firstlineno,
  130. co.co_linetable,
  131. co.co_exceptiontable, # type: ignore[attr-defined]
  132. co.co_freevars,
  133. co.co_cellvars,
  134. )
  135. elif hasattr(co, "co_posonlyargcount"):
  136. co_args = (
  137. nargs,
  138. 0,
  139. 0,
  140. co.co_nlocals,
  141. co.co_stacksize,
  142. co_flags,
  143. co.co_code,
  144. co.co_consts,
  145. co.co_names,
  146. co.co_varnames,
  147. co.co_filename,
  148. co.co_name,
  149. co.co_firstlineno,
  150. co.co_lnotab,
  151. co.co_freevars,
  152. co.co_cellvars,
  153. )
  154. else:
  155. co_args = (
  156. nargs,
  157. 0,
  158. co.co_nlocals,
  159. co.co_stacksize,
  160. co_flags,
  161. co.co_code,
  162. co.co_consts,
  163. co.co_names,
  164. co.co_varnames,
  165. co.co_filename,
  166. co.co_name,
  167. co.co_firstlineno,
  168. co.co_lnotab,
  169. co.co_freevars,
  170. co.co_cellvars,
  171. )
  172. new_code = CodeType(*co_args) # type: ignore[arg-type]
  173. return FunctionType(
  174. new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__
  175. )
  176. # we need to insert placeholder nodes for *args and **kwargs
  177. # we can't call this function normally, otherwise it would try to unpack them
  178. # instead, let's make python think that args and kwargs are normal variables
  179. @compatibility(is_backward_compatible=False)
  180. class PHBase:
  181. """
  182. Object representing an input placeholder to `concrete_args`
  183. """
  184. def __repr__(self):
  185. return "PH"
  186. PH = PHBase()
  187. @compatibility(is_backward_compatible=False)
  188. class PHWithMeta(PHBase):
  189. """
  190. Object representing an input placeholder to `concrete_args`
  191. """
  192. def __init__(self, ph_key: Optional[str] = None):
  193. super().__init__()
  194. # Provide a hey for user to identify placeholder node during analysis
  195. self.ph_key = ph_key
  196. def _transfer_attrs(fr, to):
  197. for attr_name in dir(fr):
  198. attr_val = getattr(fr, attr_name)
  199. if (
  200. not callable(attr_val)
  201. and not attr_name.startswith("__")
  202. and not hasattr(to, attr_name)
  203. ):
  204. setattr(to, attr_name, attr_val)
  205. @compatibility(is_backward_compatible=True)
  206. class Tracer(TracerBase):
  207. # Reference: https://github.com/pytorch/pytorch/issues/54354
  208. # The first line of this docstring overrides the one Sphinx generates for the
  209. # documentation. We need it so that Sphinx doesn't leak `math`s path from the
  210. # build environment (e.g. `<module 'math' from '/leaked/path').
  211. """Tracer(autowrap_modules=(math,), autowrap_functions=())
  212. ``Tracer`` is the class that implements the symbolic tracing functionality
  213. of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
  214. to ``Tracer().trace(m)``.
  215. Tracer can be subclassed to override various behaviors of the tracing
  216. process. The different behaviors that can be overridden are described
  217. in the docstrings of the methods on this class.
  218. """
  219. # Not checking BC on this API because the default value for `autowrap_modules`
  220. # includes the local filepath to the `math` module, which would jitter
  221. # across machines.
  222. @compatibility(is_backward_compatible=True)
  223. def __init__(
  224. self,
  225. autowrap_modules: tuple[ModuleType] = (math,),
  226. autowrap_functions: tuple[Callable, ...] = (),
  227. param_shapes_constant: bool = False,
  228. ) -> None:
  229. # This method's signature is overridden by the first line of this class'
  230. # docstring. If this method's signature is modified, the signature that
  231. # overrides it also should be modified accordingly.
  232. """
  233. Construct a Tracer object.
  234. Args:
  235. autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
  236. Python modules whose functions should be wrapped automatically
  237. without needing to use fx.wrap(). Backward-compatibility for
  238. this parameter is guaranteed.
  239. autowrap_functions (Tuple[Callable, ...]): defaults to `()`,
  240. Python functions that should be wrapped automatically without
  241. needing to use fx.wrap(). Backward compatibility for this
  242. parameter is guaranteed.
  243. param_shapes_constant (bool): When this flag is set, calls to shape,
  244. size and a few other shape like attributes of a module's parameter
  245. will be evaluated directly, rather than returning a new Proxy value
  246. for an attribute access. Backward compatibility for this parameter
  247. is guaranteed.
  248. """
  249. super().__init__()
  250. # Functions we will eagerly wrap when we see them while tracing
  251. # this captures both `math.sqrt()` and `from math import sqrt` automatically
  252. self._autowrap_function_ids: set[int] = {
  253. id(value)
  254. for name, value in chain.from_iterable(
  255. m.__dict__.items() for m in autowrap_modules
  256. )
  257. if not name.startswith("_") and callable(value)
  258. }
  259. self._autowrap_function_ids.update({id(f) for f in autowrap_functions})
  260. # Python modules to apply autowrap to at the start, in addition to
  261. # modules we see while tracing
  262. self._autowrap_search: list[ModuleType] = list(autowrap_modules)
  263. self.param_shapes_constant = param_shapes_constant
  264. self.submodule_paths: Optional[dict[torch.nn.Module, str]] = None
  265. self.root_module_name: str = ""
  266. # Maps the containing module's name to the operator name
  267. self.scope = Scope("", None)
  268. # Records the module call stack
  269. self.module_stack = collections.OrderedDict()
  270. self.num_calls: dict[str, int] = {}
  271. # Mapping of node name to module scope
  272. self.node_name_to_scope: dict[str, tuple[str, type]] = {}
  273. _qualname_counter: dict[str, int] = collections.defaultdict(int)
  274. @compatibility(is_backward_compatible=True)
  275. def get_fresh_qualname(self, prefix: str) -> str:
  276. """
  277. Gets a fresh name for a prefix and returns it. This function ensures
  278. that it will not clash with an existing attribute on the graph.
  279. """
  280. # The idea here is that if the module doesn't have this prefix at all we
  281. # should reset the counter to start from the beginning
  282. # It's a ... little bit hacky (doesn't cover all cases) but the precise
  283. # naming of the prefixes isn't a correctness issue, just a niceness
  284. # issue
  285. qualname = f"{prefix}0"
  286. if not hasattr(self.root, qualname):
  287. self._qualname_counter[prefix] = 0
  288. return qualname
  289. i = self._qualname_counter[prefix]
  290. while True:
  291. qualname = f"{prefix}{i}"
  292. i += 1
  293. if not hasattr(self.root, qualname):
  294. break
  295. self._qualname_counter[prefix] = i
  296. return qualname
  297. @compatibility(is_backward_compatible=True)
  298. def create_arg(self, a: Any) -> "Argument":
  299. """
  300. A method to specify the behavior of tracing when preparing values to
  301. be used as arguments to nodes in the ``Graph``.
  302. By default, the behavior includes:
  303. #. Iterate through collection types (e.g. tuple, list, dict) and recursively
  304. call ``create_args`` on the elements.
  305. #. Given a Proxy object, return a reference to the underlying IR ``Node``
  306. #. Given a non-Proxy Tensor object, emit IR for various cases:
  307. * For a Parameter, emit a ``get_attr`` node referring to that Parameter
  308. * For a non-Parameter Tensor, store the Tensor away in a special
  309. attribute referring to that attribute.
  310. This method can be overridden to support more types.
  311. Args:
  312. a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
  313. Returns:
  314. The value ``a`` converted into the appropriate ``Argument``
  315. """
  316. # The base tracer is used to construct Graphs when there is no associated
  317. # module hierarchy, so it can never create parameter references.
  318. # The default tracer adds the ability to refer to parameters when
  319. # tracing modules.
  320. if isinstance(a, torch.nn.Parameter):
  321. for n, p in self.root.named_parameters():
  322. if a is p:
  323. return self.create_node("get_attr", n, (), {})
  324. raise NameError("parameter is not a member of this module")
  325. elif isinstance(a, torch.Tensor):
  326. for n_, p_ in self.root.named_buffers():
  327. if a is p_:
  328. return self.create_node("get_attr", n_, (), {})
  329. elif isinstance(a, torch.nn.Module):
  330. for n_, p_ in self.root.named_modules():
  331. if a is p_:
  332. return self.create_node("get_attr", n_, (), {})
  333. # For NamedTuple instances that appear literally as args, we emit
  334. # a node to construct the NamedTuple and use that Node as the argument.
  335. if isinstance(a, tuple) and hasattr(a, "_fields"):
  336. args = tuple(self.create_arg(elem) for elem in a)
  337. return self.create_node("call_function", a.__class__, args, {})
  338. # Tensors do not have a reliable string repr() from which they can be
  339. # constructed (and we probably don't want to rely on that, either), so
  340. # for any constant Tensor values we encounter, first search for if they
  341. # are an attribute of some module in the module hierarchy. If so, emit
  342. # a get_attr to retrieve that tensor. Otherwise, we'll store away the
  343. # tensor value into a special attribute on the Module s.t. we can
  344. # retrieve it with a get_attr.
  345. if isinstance(a, _constant_attribute_types) or (
  346. is_opaque_reference_type(type(a))
  347. ):
  348. qualname: Optional[str] = self.tensor_attrs.get(a)
  349. # Tensor was not found in the Module hierarchy, stow it away in a
  350. # special attribute and set the qualname to refer to that
  351. if not qualname:
  352. if isinstance(a, torch.Tensor):
  353. base_name = "_tensor_constant"
  354. elif isinstance(a, (FakeScriptObject, ScriptObject)):
  355. base_name = "_torchbind_obj"
  356. elif isinstance(a, pytree.TreeSpec):
  357. base_name = "_tree_spec_constant"
  358. elif is_opaque_type(type(a)):
  359. base_name = "_opaque_obj"
  360. else:
  361. raise RuntimeError(
  362. f"cannot create constant arg for {a} of type {type(a)}."
  363. )
  364. qualname = self.get_fresh_qualname(base_name)
  365. if not isinstance(qualname, str):
  366. raise AssertionError(
  367. f"Expected qualname to be str, got {type(qualname)}"
  368. )
  369. self.tensor_attrs[a] = qualname
  370. setattr(self.root, qualname, a)
  371. return self.create_node("get_attr", qualname, (), {})
  372. if type(a) in _proxyable_classes:
  373. # This is an instance of a proxyable class for which we did not
  374. # witness its construction. Intern this as a constant attribute
  375. # TODO: binary search
  376. qualname = self.get_fresh_qualname(f"_{a.__class__.__name__}_constant_")
  377. if not isinstance(qualname, str):
  378. raise AssertionError(
  379. f"Expected qualname to be str, got {type(qualname)}"
  380. )
  381. setattr(self.root, qualname, a)
  382. return self.create_node("get_attr", qualname, (), {})
  383. return super().create_arg(a)
  384. @compatibility(is_backward_compatible=True)
  385. def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
  386. """
  387. A method to specify whether a given ``nn.Module`` is a "leaf" module.
  388. Leaf modules are the atomic units that appear in
  389. the IR, referenced by ``call_module`` calls. By default,
  390. Modules in the PyTorch standard library namespace (torch.nn)
  391. are leaf modules. All other modules are traced through and
  392. their constituent ops are recorded, unless specified otherwise
  393. via this parameter.
  394. Args:
  395. m (Module): The module being queried about
  396. module_qualified_name (str): The path to root of this module. For example,
  397. if you have a module hierarchy where submodule ``foo`` contains
  398. submodule ``bar``, which contains submodule ``baz``, that module will
  399. appear with the qualified name ``foo.bar.baz`` here.
  400. """
  401. return (
  402. m.__module__.startswith("torch.nn")
  403. or m.__module__.startswith("torch.ao.nn")
  404. ) and not isinstance(m, torch.nn.Sequential)
  405. @compatibility(is_backward_compatible=True)
  406. def path_of_module(self, mod: torch.nn.Module) -> str:
  407. """
  408. Helper method to find the qualified name of ``mod`` in the Module hierarchy
  409. of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
  410. a submodule named ``bar``, passing ``bar`` into this function will return
  411. the string "foo.bar".
  412. Args:
  413. mod (str): The ``Module`` to retrieve the qualified name for.
  414. """
  415. # Prefer the O(1) algorithm
  416. if self.submodule_paths:
  417. path = self.submodule_paths.get(mod)
  418. if path is None:
  419. raise NameError("module is not installed as a submodule")
  420. if not isinstance(path, str):
  421. raise AssertionError(f"Expected path to be str, got {type(path)}")
  422. return path
  423. # O(N^2) fallback in the case that we didn't store the submodule
  424. # paths.
  425. else:
  426. for n, p in self.root.named_modules():
  427. if mod is p:
  428. return n
  429. raise NameError("module is not installed as a submodule")
  430. @compatibility(is_backward_compatible=True)
  431. def call_module(
  432. self,
  433. m: torch.nn.Module,
  434. forward: Callable[..., Any],
  435. args: tuple[Any, ...],
  436. kwargs: dict[str, Any],
  437. ) -> Any:
  438. """
  439. Method that specifies the behavior of this ``Tracer`` when it encounters
  440. a call to an ``nn.Module`` instance.
  441. By default, the behavior is to check if the called module is a leaf module
  442. via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
  443. ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
  444. the operations in its ``forward`` function.
  445. This method can be overridden to--for example--create nested traced
  446. GraphModules, or any other behavior you would want while tracing across
  447. ``Module`` boundaries.
  448. Args:
  449. m (Module): The module for which a call is being emitted
  450. forward (Callable): The forward() method of the ``Module`` to be invoked
  451. args (Tuple): args of the module callsite
  452. kwargs (Dict): kwargs of the module callsite
  453. Return:
  454. The return value from the Module call. In the case that a ``call_module``
  455. node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
  456. value was returned from the ``Module`` invocation.
  457. """
  458. module_qualified_name = self.path_of_module(m)
  459. with ScopeContextManager(
  460. self.scope, Scope(module_qualified_name, type(m))
  461. ) as _scope:
  462. # module_stack is an ordered dict so writing then deleting the
  463. # entry is equivalent to push/pop on a list
  464. num_calls = self.num_calls.get(module_qualified_name, 0)
  465. module_key = (
  466. f"{_scope.module_path}@{num_calls}"
  467. if num_calls > 0
  468. else _scope.module_path
  469. )
  470. self.module_stack[module_key] = (module_qualified_name, _scope.module_type)
  471. self.num_calls[module_qualified_name] = num_calls + 1
  472. if not self.is_leaf_module(m, module_qualified_name):
  473. ret_val = forward(*args, **kwargs)
  474. else:
  475. ret_val = self.create_proxy(
  476. "call_module", module_qualified_name, args, kwargs
  477. )
  478. key, _ = self.module_stack.popitem(last=True)
  479. if key != module_key:
  480. raise AssertionError(f"Unexpected key {key}, expected {module_key}")
  481. return ret_val
  482. @compatibility(is_backward_compatible=False)
  483. def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]):
  484. """
  485. Method that specifies the behavior of this ``Tracer`` when we call getattr
  486. on a call to an ``nn.Module`` instance.
  487. By default, the behavior is to return a proxy value for the attribute. It
  488. also stores the proxy value in the ``parameter_proxy_cache``, so that future
  489. calls will reuse the proxy rather than creating a new one.
  490. This method can be overridden to --for example-- not return proxies when
  491. querying parameters.
  492. Args:
  493. attr (str): The name of the attribute being queried
  494. attr_val (Any): The value of the attribute
  495. parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies
  496. Return:
  497. The return value from the getattr call.
  498. """
  499. def maybe_get_proxy_for_attr(
  500. attr_val, collection_to_search, parameter_proxy_cache
  501. ):
  502. for n, p in collection_to_search:
  503. if attr_val is p:
  504. if n not in parameter_proxy_cache:
  505. kwargs = {}
  506. if (
  507. "proxy_factory_fn"
  508. in inspect.signature(self.create_proxy).parameters
  509. ):
  510. kwargs["proxy_factory_fn"] = (
  511. # pyrefly: ignore [unsupported-operation]
  512. None
  513. if not self.param_shapes_constant
  514. else lambda node: ParameterProxy(
  515. self, node, n, attr_val
  516. )
  517. )
  518. val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
  519. parameter_proxy_cache[n] = val_proxy
  520. return parameter_proxy_cache[n]
  521. return None
  522. if isinstance(attr_val, torch.nn.Parameter):
  523. maybe_parameter_proxy = maybe_get_proxy_for_attr(
  524. attr_val, self.root.named_parameters(), parameter_proxy_cache
  525. )
  526. if maybe_parameter_proxy is not None:
  527. return maybe_parameter_proxy
  528. if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
  529. maybe_buffer_proxy = maybe_get_proxy_for_attr(
  530. attr_val, self.root.named_buffers(), parameter_proxy_cache
  531. )
  532. if maybe_buffer_proxy is not None:
  533. return maybe_buffer_proxy
  534. return attr_val
  535. # This method will be refactored
  536. @compatibility(is_backward_compatible=False)
  537. def create_args_for_root(self, root_fn, is_module, concrete_args=None):
  538. """
  539. Create ``placeholder`` nodes corresponding to the signature of the ``root``
  540. Module. This method introspects root's signature and emits those
  541. nodes accordingly, also supporting ``*args`` and ``**kwargs``.
  542. """
  543. # In some cases, a function or method has been decorated with a wrapper
  544. # defined via ``functools.wraps``. In this case, the outer code object
  545. # will likely not contain the actual parameters we care about, so unwrap
  546. # the function to get to the innermost callable.
  547. fn_for_analysis = inspect.unwrap(root_fn)
  548. co = fn_for_analysis.__code__
  549. total_args = co.co_argcount + co.co_kwonlyargcount
  550. orig_args = list(co.co_varnames)
  551. names_iter = iter(co.co_varnames)
  552. args: list[Any] = []
  553. skip_arg_idx = 0
  554. if is_module:
  555. if total_args == 0:
  556. raise RuntimeError(
  557. "``self`` argument cannot be part of *args expansion!"
  558. )
  559. skip_arg_idx = 1
  560. next(names_iter) # skip self
  561. args.append(self.root)
  562. sig = inspect.signature(fn_for_analysis)
  563. # This covers the very specific case where we are passing in flat
  564. # concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
  565. # In this case, just take the concrete_args and pass them through.
  566. name_idx = 0
  567. if (
  568. isinstance(concrete_args, tuple)
  569. and len(concrete_args) > 0
  570. and (co.co_flags & HAS_VARSTUFF)
  571. and total_args == 1
  572. ):
  573. for concrete_arg in concrete_args:
  574. out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
  575. if isinstance(concrete_arg, PHBase):
  576. if concrete_arg != PH:
  577. # Transfer attrs in the case where you're using a placeholder other
  578. # than the singleton PH (PH has no attributes to transfer).
  579. # Proxies were created out of the placeholders.
  580. # Transfer any metadata (put on the placeholders in the form of
  581. # attributes set by the user) from the placeholder to the
  582. # underlying nodes (the proxy is unwrapped by the user, but
  583. # the metadata should hold).
  584. _transfer_attrs(fr=concrete_arg, to=out.node)
  585. args.append(out)
  586. name_idx += 1
  587. return root_fn, args
  588. arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
  589. if isinstance(concrete_args, tuple):
  590. if len(arg_names) != len(concrete_args):
  591. raise RuntimeError(
  592. f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
  593. )
  594. concrete_args = dict(zip(arg_names, concrete_args))
  595. def proxy_placeholder(name):
  596. return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis)
  597. args.extend(proxy_placeholder(names) for names in arg_names)
  598. if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
  599. # TODO: type annotations for *args and **kwargs
  600. if co.co_flags & inspect.CO_VARARGS:
  601. args.append(proxy_placeholder("*" + next(names_iter)))
  602. if co.co_flags & inspect.CO_VARKEYWORDS:
  603. args.append(proxy_placeholder("**" + next(names_iter)))
  604. root_fn = _patch_function(root_fn, len(args))
  605. flat_args, in_spec = pytree.tree_flatten(tuple(args))
  606. if not all(child.is_leaf() for child in in_spec.children()):
  607. # In the case that we have pytree-flattened inputs in
  608. # `concrete_args`, generate a flattening wrapper around the
  609. # original root function and return that.
  610. self.graph._codegen = _PyTreeCodeGen( # type: ignore[has-type]
  611. _PyTreeInfo(orig_args[:total_args], in_spec, None)
  612. )
  613. def flatten_fn(*args):
  614. tree_args = pytree.tree_unflatten(list(args), in_spec)
  615. tree_out = root_fn(*tree_args)
  616. out_args, out_spec = pytree.tree_flatten(tree_out)
  617. if not isinstance(self.graph._codegen, _PyTreeCodeGen): # type: ignore[has-type]
  618. raise AssertionError(
  619. f"Expected _codegen to be _PyTreeCodeGen, got "
  620. f"{type(self.graph._codegen)}"
  621. )
  622. self.graph._codegen.pytree_info = (
  623. self.graph._codegen.pytree_info._replace(out_spec=out_spec)
  624. )
  625. return out_args
  626. return flatten_fn, flat_args
  627. return root_fn, args
  628. @compatibility(is_backward_compatible=True)
  629. def trace(
  630. self,
  631. root: Union[torch.nn.Module, Callable[..., Any]],
  632. concrete_args: Optional[dict[str, Any]] = None,
  633. ) -> Graph:
  634. """
  635. Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
  636. can either be an ``nn.Module`` instance or a Python callable.
  637. Note that after this call, ``self.root`` may be different from the ``root`` passed
  638. in here. For example, when a free function is passed to ``trace()``, we will
  639. create an ``nn.Module`` instance to use as the root and add embedded constants
  640. to.
  641. Args:
  642. root (Union[Module, Callable]): Either a ``Module`` or a function to be
  643. traced through. Backwards-compatibility for this parameter is
  644. guaranteed.
  645. concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
  646. not be treated as Proxies. This parameter is experimental and
  647. its backwards-compatibility is *NOT* guaranteed.
  648. Returns:
  649. A ``Graph`` representing the semantics of the passed-in ``root``.
  650. """
  651. global _is_fx_tracing_flag
  652. old_is_fx_tracing_flag = _is_fx_tracing_flag
  653. _is_fx_tracing_flag = True
  654. try:
  655. if isinstance(root, torch.nn.Module):
  656. # do real recompilation for _LazyGraphModule before retracing since the trace
  657. # method can not trace the _lazy_forward method. Got error:
  658. # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
  659. # without this.
  660. from torch.fx._lazy_graph_module import _LazyGraphModule
  661. _LazyGraphModule.force_recompile(root)
  662. self.root = root
  663. if not hasattr(type(root), self.traced_func_name):
  664. raise AssertionError(
  665. f"traced_func_name={self.traced_func_name} doesn't exist in "
  666. f"{type(root).__name__}"
  667. )
  668. fn = getattr(type(root), self.traced_func_name)
  669. self.root_module_name = root._get_name()
  670. self.submodule_paths = {mod: name for name, mod in root.named_modules()}
  671. else:
  672. self.root = torch.nn.Module()
  673. fn = root
  674. tracer_cls: Optional[type[Tracer]] = getattr(self, "__class__", None)
  675. self.graph = Graph(tracer_cls=tracer_cls)
  676. if hasattr(fn, "__code__"):
  677. code = fn.__code__
  678. self.graph._co_fields = {
  679. "co_name": code.co_name,
  680. "co_filename": code.co_filename,
  681. "co_firstlineno": code.co_firstlineno,
  682. }
  683. # When we encounter a Tensor value that's not a parameter, we look if it
  684. # is some other attribute on the model. Construct a dict mapping Tensor
  685. # values to the qualified name here for efficiency. This is used downstream
  686. # in create_arg
  687. self.tensor_attrs: dict[
  688. _ConstantAttributeType,
  689. str,
  690. ] = {}
  691. def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: list[str]):
  692. for k, v in m.__dict__.items():
  693. if isinstance(v, _constant_attribute_types):
  694. self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
  695. for k, v in m.named_children():
  696. collect_tensor_attrs(v, prefix_atoms + [k])
  697. collect_tensor_attrs(self.root, [])
  698. if not isinstance(fn, FunctionType):
  699. raise AssertionError(f"Expected FunctionType, got {type(fn)}")
  700. fn_globals = fn.__globals__ # run before it gets patched
  701. fn, args = self.create_args_for_root(
  702. fn, isinstance(root, torch.nn.Module), concrete_args
  703. )
  704. parameter_proxy_cache: dict[
  705. str, Proxy
  706. ] = {} # Reduce number of get_attr calls
  707. # Method dispatch on parameters is not recorded unless it's directly used.
  708. # Thus, we need to insert a proxy when __getattr__ requests a parameter.
  709. @functools.wraps(_orig_module_getattr)
  710. def module_getattr_wrapper(mod, attr):
  711. attr_val = _orig_module_getattr(mod, attr)
  712. return self.getattr(attr, attr_val, parameter_proxy_cache)
  713. @functools.wraps(_orig_module_call)
  714. def module_call_wrapper(mod, *args, **kwargs):
  715. def forward(*args, **kwargs):
  716. return _orig_module_call(mod, *args, **kwargs)
  717. _autowrap_check(
  718. patcher, # type: ignore[has-type]
  719. getattr(getattr(mod, "forward", mod), "__globals__", {}),
  720. self._autowrap_function_ids,
  721. )
  722. return self.call_module(mod, forward, args, kwargs)
  723. with _new_patcher() as patcher:
  724. # allow duplicate patches to support the case of nested calls
  725. patcher.patch_method(
  726. torch.nn.Module,
  727. "__getattr__",
  728. module_getattr_wrapper,
  729. deduplicate=False,
  730. )
  731. patcher.patch_method(
  732. torch.nn.Module,
  733. "__call__",
  734. module_call_wrapper,
  735. deduplicate=False,
  736. )
  737. _patch_wrapped_functions(patcher)
  738. _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
  739. for module in self._autowrap_search:
  740. _autowrap_check(
  741. patcher, module.__dict__, self._autowrap_function_ids
  742. )
  743. ann = inspect.get_annotations(inspect.unwrap(fn))
  744. self.create_node(
  745. "output",
  746. "output",
  747. (self.create_arg(fn(*args)),),
  748. {},
  749. type_expr=ann.get("return", None),
  750. )
  751. self.submodule_paths = None
  752. except RuntimeError as e:
  753. if e.args and isinstance(e.args[0], str) and "data-dependent" in e.args[0]:
  754. partial_fx_graph = self.graph.python_code(
  755. root_module="self",
  756. verbose=True,
  757. ).src
  758. e.partial_fx_graph = partial_fx_graph # type: ignore[attr-defined]
  759. raise
  760. raise
  761. finally:
  762. _is_fx_tracing_flag = old_is_fx_tracing_flag
  763. return self.graph
  764. def __deepcopy__(self, memo):
  765. # _autowrap_search contains modules, which cannot be deepcopied.
  766. new_tracer = Tracer.__new__(Tracer)
  767. for k, v in self.__dict__.items():
  768. if k == "_autowrap_search":
  769. new_obj = copy.copy(v)
  770. else:
  771. new_obj = copy.deepcopy(v, memo)
  772. new_tracer.__dict__[k] = new_obj
  773. return new_tracer
  774. def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis):
  775. if concrete_args is not None and name in concrete_args:
  776. cnt = 0
  777. def replace_ph(x):
  778. nonlocal cnt
  779. cnt += 1
  780. param = sig.parameters[name]
  781. default: tuple[Any, ...] = (
  782. () if param.default is inspect.Parameter.empty else (param.default,)
  783. )
  784. out = self.create_proxy(
  785. "placeholder", f"{name}_{str(cnt)}", default, {}
  786. )
  787. if isinstance(x, PHBase):
  788. if x != PH:
  789. # Transfer attrs in the case where you're using a placeholder other
  790. # than the singleton PH (PH has no attributes to transfer).
  791. # Proxies were created out of the placeholders.
  792. # Transfer any metadata (put on the placeholders in the form of
  793. # attributes set by the user) from the placeholder to the
  794. # underlying nodes (the proxy is unwrapped by the user, but
  795. # the metadata should hold).
  796. _transfer_attrs(fr=x, to=out.node)
  797. return out
  798. # Union[int, bool] == bool in Python <= 3.6
  799. if (
  800. type(x) is bool
  801. or type(x) in base_types
  802. and type(x) is not torch.Tensor
  803. ):
  804. torch._assert(
  805. out == x,
  806. f"{name} has been specialized to have value {x} but got another value",
  807. )
  808. elif x is None:
  809. args = (
  810. out,
  811. f"{name} has been specialized to have value None but got another value",
  812. )
  813. self.create_proxy("call_function", _assert_is_none, args, {})
  814. else:
  815. warnings.warn(
  816. f"Was not able to add assertion to guarantee correct input {name} to "
  817. f"specialized function. It is up to the user to make sure that your inputs match the "
  818. f"inputs you specialized the function with."
  819. )
  820. return x
  821. return pytree.tree_map(replace_ph, concrete_args[name])
  822. if name[0] == "*":
  823. default: tuple[Any, ...] = ()
  824. else:
  825. param = sig.parameters[name]
  826. default = ( # type: ignore[assignment]
  827. () if param.default is inspect.Parameter.empty else (param.default,)
  828. )
  829. return self.create_proxy(
  830. "placeholder",
  831. name,
  832. default,
  833. {},
  834. type_expr=fn_for_analysis.__annotations__.get(name, None),
  835. )
  836. # Dictionary of (id(globals dict), function name) => globals_dict to patch for
  837. # the purposes of the wrap() API.
  838. # We key by the globals dict id and function name to ensure we're wrapping a given
  839. # function only once.
  840. _wrapped_fns_to_patch: dict[tuple[int, str], dict] = {}
  841. # List of methods on classes to wrap (class type, function name)
  842. # this currently only works for Tensor.* methods that aren't traced properly
  843. _wrapped_methods_to_patch: list[tuple[type, str]] = []
  844. if os.environ.get("FX_PATCH_GETITEM") == "1":
  845. # This change is needed to trace models like PositionalEmbedding from BERT:
  846. # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py
  847. # but causes issues in quantization documented here:
  848. # https://github.com/pytorch/pytorch/issues/50710
  849. # once that is fixed we can make this the default behavior.
  850. _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
  851. def _find_proxy(*objects_to_search):
  852. """
  853. Recursively search a data structure for a Proxy() and return it,
  854. return None if not found.
  855. """
  856. proxy = None
  857. def find_proxy(x):
  858. nonlocal proxy
  859. if isinstance(x, Proxy):
  860. proxy = x
  861. map_aggregate(objects_to_search, find_proxy)
  862. return proxy
  863. def _create_wrapped_func(orig_fn):
  864. @functools.wraps(orig_fn)
  865. def wrapped(*args, **kwargs):
  866. """
  867. Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
  868. a Proxy object. If there is one, emit a ``call_function`` node to preserve the
  869. call to this leaf function directly. Otherwise, just return the results of
  870. this function call, as this function is not being traced.
  871. """
  872. proxy = _find_proxy(args, kwargs)
  873. if proxy is not None:
  874. return_proxy = proxy.tracer.create_proxy(
  875. "call_function", orig_fn, args, kwargs
  876. )
  877. return_proxy.node.meta["is_wrapped"] = True
  878. return return_proxy
  879. return orig_fn(*args, **kwargs)
  880. return wrapped
  881. def _create_wrapped_method(cls, name):
  882. orig_fn = getattr(cls, name)
  883. @functools.wraps(orig_fn)
  884. def wrapped(*args, **kwargs):
  885. """
  886. Search the args and kwargs for a Proxy object. If there is one,
  887. emit a ``call_method`` node to preserve the call to this method
  888. directly. Otherwise, just return the results of this function
  889. call, as this function is not being traced.
  890. """
  891. proxy = _find_proxy(args, kwargs)
  892. if proxy is not None:
  893. return proxy.tracer.create_proxy("call_method", name, args, kwargs)
  894. return orig_fn(*args, **kwargs)
  895. return wrapped
  896. class _PatchedFn(NamedTuple):
  897. frame_dict: Any
  898. fn_name: str
  899. orig_fn: Any
  900. new_fn: Any
  901. def revert(self):
  902. raise NotImplementedError
  903. def patch(self):
  904. raise NotImplementedError
  905. class _PatchedFnSetItem(_PatchedFn):
  906. def revert(self):
  907. self.frame_dict[self.fn_name] = self.orig_fn
  908. def patch(self):
  909. self.frame_dict[self.fn_name] = self.new_fn
  910. class _PatchedFnDel(_PatchedFn):
  911. def revert(self):
  912. del self.frame_dict[self.fn_name]
  913. def patch(self):
  914. self.frame_dict[self.fn_name] = self.new_fn
  915. class _PatchedFnSetAttr(_PatchedFn):
  916. def revert(self):
  917. setattr(self.frame_dict, self.fn_name, self.orig_fn)
  918. def patch(self):
  919. setattr(self.frame_dict, self.fn_name, self.new_fn)
  920. class _Patcher:
  921. def __init__(self) -> None:
  922. super().__init__()
  923. self.patches_made: list[_PatchedFn] = []
  924. self.visited: set[int] = set()
  925. def patch(
  926. self,
  927. frame_dict: dict[str, Any],
  928. name: str,
  929. new_fn: Callable,
  930. deduplicate: bool = True,
  931. ):
  932. """
  933. Replace frame_dict[name] with new_fn until we exit the context manager.
  934. """
  935. new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
  936. if name not in frame_dict and hasattr(builtins, name):
  937. self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn))
  938. self.patches_made[-1].patch()
  939. elif getattr(frame_dict[name], "__fx_already_patched", False):
  940. return # already patched, no need to do it again
  941. else:
  942. self.patches_made.append(
  943. _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn)
  944. )
  945. self.patches_made[-1].patch()
  946. def patch_method(
  947. self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True
  948. ):
  949. """
  950. Replace object_or_dict.name with new_fn until we exit the context manager.
  951. """
  952. new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
  953. orig_fn = getattr(cls, name)
  954. if getattr(orig_fn, "__fx_already_patched", False):
  955. return # already patched, no need to do it again
  956. self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn))
  957. self.patches_made[-1].patch()
  958. def visit_once(self, thing: Any):
  959. """Return True on the first call to with thing, otherwise false"""
  960. idx = id(thing)
  961. if idx in self.visited:
  962. return False
  963. self.visited.add(idx)
  964. return True
  965. def revert_all_patches(self):
  966. """
  967. Remove all the stored patcheds. It doesn't modify patches_made.
  968. """
  969. for patch in self.patches_made:
  970. patch.revert()
  971. return self.patches_made
  972. def reapply_all_patches(self):
  973. """
  974. Patch all the stored patcheds. It doesn't modify patches_made.
  975. """
  976. for patch in self.patches_made:
  977. patch.patch()
  978. return self.patches_made
  979. def __enter__(self):
  980. return self
  981. def __exit__(self, exc_type, exc_val, exc_tb):
  982. """
  983. Undo all the changes made via self.patch() and self.patch_method()
  984. """
  985. while self.patches_made:
  986. # unpatch in reverse order to handle duplicates correctly
  987. self.patches_made.pop().revert()
  988. self.visited.clear()
  989. CURRENT_PATCHER: Optional[_Patcher] = None
  990. @contextlib.contextmanager
  991. def _new_patcher():
  992. global CURRENT_PATCHER
  993. prior_patcher = CURRENT_PATCHER
  994. try:
  995. CURRENT_PATCHER = _Patcher()
  996. yield CURRENT_PATCHER
  997. finally:
  998. # Clear all the patches made by when using current patcher.
  999. if CURRENT_PATCHER is None:
  1000. raise AssertionError("CURRENT_PATCHER is None in finally block")
  1001. CURRENT_PATCHER.revert_all_patches()
  1002. CURRENT_PATCHER = prior_patcher
  1003. @contextlib.contextmanager
  1004. def _maybe_revert_all_patches():
  1005. current_patcher = CURRENT_PATCHER
  1006. patches_made = None
  1007. patches_removed = None
  1008. try:
  1009. if current_patcher is not None:
  1010. patches_removed = current_patcher.revert_all_patches()
  1011. yield
  1012. finally:
  1013. if current_patcher is not None:
  1014. patches_made = current_patcher.reapply_all_patches()
  1015. if patches_made != patches_removed:
  1016. raise AssertionError(
  1017. "CURRENT_PATCHER was changed during a revert_all_patches"
  1018. )
  1019. def _patch_wrapped_functions(patcher: _Patcher):
  1020. """
  1021. Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
  1022. the listed global functions in the `_create_wrapped_func` wrapper.
  1023. """
  1024. for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items():
  1025. if name not in frame_dict and hasattr(builtins, name):
  1026. orig_fn = getattr(builtins, name)
  1027. else:
  1028. orig_fn = frame_dict[name]
  1029. patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
  1030. for cls, name in _wrapped_methods_to_patch:
  1031. patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
  1032. def _autowrap_check(
  1033. patcher: _Patcher, frame_dict: dict[str, Any], function_ids: set[int]
  1034. ):
  1035. """
  1036. Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
  1037. This method searches a scope for them and patches them if found.
  1038. """
  1039. if patcher.visit_once(frame_dict):
  1040. for name, value in frame_dict.items():
  1041. if (
  1042. not name.startswith("_")
  1043. and callable(value)
  1044. and id(value) in function_ids
  1045. ):
  1046. patcher.patch(frame_dict, name, _create_wrapped_func(value))
  1047. @compatibility(is_backward_compatible=True)
  1048. def wrap(fn_or_name: Union[str, Callable]):
  1049. """
  1050. This function can be called at module-level scope to register fn_or_name as a "leaf function".
  1051. A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being
  1052. traced through::
  1053. # foo/bar/baz.py
  1054. def my_custom_function(x, y):
  1055. return x * x + y * y
  1056. torch.fx.wrap("my_custom_function")
  1057. def fn_to_be_traced(x, y):
  1058. # When symbolic tracing, the below call to my_custom_function will be inserted into
  1059. # the graph rather than tracing it.
  1060. return my_custom_function(x, y)
  1061. This function can also equivalently be used as a decorator::
  1062. # foo/bar/baz.py
  1063. @torch.fx.wrap
  1064. def my_custom_function(x, y):
  1065. return x * x + y * y
  1066. A wrapped function can be thought of a "leaf function", analogous to the concept of
  1067. "leaf modules", that is, they are functions that are left as calls in the FX trace
  1068. rather than traced through.
  1069. Args:
  1070. fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
  1071. graph when it's called
  1072. """
  1073. if not callable(fn_or_name) and not isinstance(fn_or_name, str):
  1074. raise RuntimeError(
  1075. "Unsupported type for global function! Must be either a callable or "
  1076. "string name"
  1077. )
  1078. if callable(fn_or_name):
  1079. if isinstance(fn_or_name, str): # to make mypy happy
  1080. raise AssertionError("Unexpected: fn_or_name is both callable and str")
  1081. fn_name = fn_or_name.__name__
  1082. else:
  1083. if not isinstance(fn_or_name, str):
  1084. raise AssertionError(
  1085. f"fn_or_name must be a global function or string name, got "
  1086. f"{type(fn_or_name)}"
  1087. )
  1088. fn_name = fn_or_name
  1089. currentframe = inspect.currentframe()
  1090. if currentframe is None:
  1091. raise AssertionError("inspect.currentframe() returned None")
  1092. f = currentframe.f_back
  1093. if f is None:
  1094. raise AssertionError("currentframe.f_back is None")
  1095. if f.f_code.co_name != "<module>":
  1096. raise NotImplementedError("wrap must be called at the top level of a module")
  1097. # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
  1098. # semantics would be slightly different, but would add support `from x import wrapped_function`
  1099. _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals
  1100. return fn_or_name
  1101. @compatibility(is_backward_compatible=True)
  1102. def symbolic_trace(
  1103. root: Union[torch.nn.Module, Callable[..., Any]],
  1104. concrete_args: Optional[dict[str, Any]] = None,
  1105. ) -> GraphModule:
  1106. """
  1107. Symbolic tracing API
  1108. Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
  1109. constructed by recording operations seen while tracing through ``root``.
  1110. ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures.
  1111. For example::
  1112. def f(a, b):
  1113. if b == True:
  1114. return a
  1115. else:
  1116. return a * 2
  1117. FX can typically not trace through this due to the presence of control
  1118. flow. However, we can use `concrete_args` to specialize on the value of
  1119. `b` to trace through this::
  1120. f = fx.symbolic_trace(f, concrete_args={"b": False})
  1121. assert f(3, False) == 6
  1122. Note that although you can still pass in different values of `b`, they will be ignored.
  1123. We can also use `concrete_args` to eliminate data-structure handling from
  1124. our function. This will use pytrees to flatten your input. To avoid
  1125. overspecializing, pass in `fx.PH` for values that shouldn't be
  1126. specialized. For example::
  1127. def f(x):
  1128. out = 0
  1129. for v in x.values():
  1130. out += v
  1131. return out
  1132. f = fx.symbolic_trace(
  1133. f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}
  1134. )
  1135. assert f({"a": 1, "b": 2, "c": 4}) == 7
  1136. Args:
  1137. root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
  1138. into a Graph representation.
  1139. concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized
  1140. Returns:
  1141. GraphModule: a Module created from the recorded operations from ``root``.
  1142. """
  1143. tracer = Tracer()
  1144. graph = tracer.trace(root, concrete_args)
  1145. name = (
  1146. root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
  1147. )
  1148. return _make_graph_module(tracer.root, graph, name)
  1149. @wrap
  1150. def _assert_is_none(value, msg):
  1151. if value is not None:
  1152. raise AssertionError(msg)