proxy.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955
  1. # mypy: ignore-errors
  2. import collections
  3. import copy
  4. import dis
  5. import enum
  6. import inspect
  7. import logging
  8. import operator
  9. import sys
  10. import traceback
  11. from collections import OrderedDict
  12. from collections.abc import Callable, Iterator
  13. from dataclasses import fields, is_dataclass
  14. from typing import Any, Optional, TypeVar
  15. import torch
  16. import torch.fx.traceback as fx_traceback
  17. from torch._C import _fx_map_arg as map_arg
  18. from torch._library.opaque_object import is_opaque_value_type
  19. from torch._logging import getArtifactLogger
  20. from torch.utils._pytree import tree_map_
  21. from torch.utils._traceback import CapturedTraceback
  22. from ._compatibility import compatibility
  23. from .graph import Graph, magic_methods, reflectable_magic_methods
  24. from .immutable_collections import immutable_dict, immutable_list
  25. from .node import Argument, base_types, Node, Target
  26. from .operator_schemas import check_for_mutable_operation
  27. __all__ = [
  28. "TracerBase",
  29. "GraphAppendingTracer",
  30. "TraceError",
  31. "Proxy",
  32. "MetaProxy",
  33. "Attribute",
  34. "ParameterProxy",
  35. "Scope",
  36. "ScopeContextManager",
  37. ]
  38. log = logging.getLogger(__name__)
  39. annotation_log = getArtifactLogger(__name__, "annotation")
  40. def _is_arbitrary_callable(obj: object) -> bool:
  41. """
  42. Returns True if obj is an arbitrary callable (function, lambda, method, etc.)
  43. that requires special tracing to handle. These cannot be symbolically traced
  44. using the standard Proxy mechanism.
  45. """
  46. import functools
  47. import types
  48. return isinstance(
  49. obj,
  50. (
  51. types.FunctionType,
  52. types.MethodType,
  53. types.BuiltinFunctionType,
  54. types.BuiltinMethodType,
  55. functools.partial,
  56. ),
  57. )
  58. def _find_arbitrary_callable(
  59. args: tuple[object, ...], kwargs: dict[str, object]
  60. ) -> object:
  61. """
  62. Recursively searches args and kwargs for any arbitrary callable.
  63. Returns the first arbitrary callable found, or None if none exist.
  64. """
  65. found = None
  66. _T = TypeVar("_T")
  67. def check(obj: _T) -> _T:
  68. nonlocal found
  69. if found is not None:
  70. return obj
  71. if _is_arbitrary_callable(obj):
  72. found = obj
  73. return obj
  74. tree_map_(check, args)
  75. tree_map_(check, kwargs)
  76. return found
  77. @compatibility(is_backward_compatible=False)
  78. class Scope:
  79. """Scope object that records the module path and the module type
  80. of a module. Scope is used to track the information of the module
  81. that contains a Node in a Graph of GraphModule. For example::
  82. class Sub(torch.nn.Module):
  83. def forward(self, x):
  84. # This will be a call_method Node in GraphModule,
  85. # scope for this would be (module_path="sub", module_type=Sub)
  86. return x.transpose(1, 2)
  87. class M(torch.nn.Module):
  88. def __init__(self) -> None:
  89. self.sub = Sub()
  90. def forward(self, x):
  91. # This will be a call_method Node as well,
  92. # scope for this would be (module_path="", None)
  93. x = x.transpose(1, 2)
  94. x = self.sub(x)
  95. return x
  96. """
  97. def __init__(self, module_path: str, module_type: Any):
  98. super().__init__()
  99. self.module_path = module_path
  100. self.module_type = module_type
  101. @compatibility(is_backward_compatible=False)
  102. class ScopeContextManager:
  103. """A context manager to track the Scope of Node during symbolic tracing.
  104. When entering a forward function of a Module, we'll update the scope information of
  105. the current module, and when we exit, we'll restore the previous scope information.
  106. """
  107. def __init__(
  108. self,
  109. scope: Scope,
  110. current_scope: Scope,
  111. ):
  112. super().__init__()
  113. # Keep a copy of prev scope to restore on exit
  114. self._prev_scope = copy.copy(scope)
  115. # Update scope to current scope
  116. scope.module_path = current_scope.module_path
  117. scope.module_type = current_scope.module_type
  118. # Save a reference so we can restore it
  119. self._scope = scope
  120. def __enter__(self):
  121. return self._scope
  122. def __exit__(self, *args):
  123. self._scope.module_path = self._prev_scope.module_path
  124. self._scope.module_type = self._prev_scope.module_type
  125. return
  126. _COPY_META_FIELDS = [
  127. "nn_module_stack",
  128. "torch_fn",
  129. "source_fn_stack",
  130. "original_aten",
  131. "recompute",
  132. "ac_graph_id",
  133. "has_backward_hook",
  134. "from_node",
  135. "quantization_tag", # TODO deprecated
  136. "_numeric_debug_handle", # TODO deprecated
  137. "custom",
  138. "partitioner_tag",
  139. ]
  140. @compatibility(is_backward_compatible=True)
  141. class TracerBase:
  142. graph: Graph
  143. record_stack_traces: bool = False
  144. # When record_stack_traces is True, only reocrd stack traces
  145. # with forward function names.
  146. # This helps when we want stack trace back to model code
  147. _record_forward_stack_traces_only: bool = False
  148. # Feature flag for mutable schema checking
  149. # Enableby default in 1.12
  150. check_mutable_operations: bool = False
  151. # Feature flag for assert tracing
  152. trace_asserts: bool = False
  153. # Feature flag for proxying accesses to buffer values
  154. proxy_buffer_attributes: bool = False
  155. # Name of the function to be traced. It will only be used when
  156. # ``root`` is an instance of ``nn.Module``
  157. traced_func_name: str = "forward"
  158. # Maps the containing module's name to the operator name
  159. scope: Scope
  160. # Records the module call stack
  161. module_stack: OrderedDict[str, tuple[str, Any]]
  162. # Mapping of node name to module scope
  163. node_name_to_scope: dict[str, tuple[str, type]]
  164. @compatibility(is_backward_compatible=True)
  165. def create_node(
  166. self,
  167. kind: str,
  168. target: Target,
  169. args: tuple[Argument, ...],
  170. kwargs: dict[str, Argument],
  171. name: Optional[str] = None,
  172. type_expr: Optional[Any] = None,
  173. ) -> Node:
  174. """
  175. Inserts a graph node given target, args, kwargs, and name.
  176. This method can be overridden to do extra checking, validation, or
  177. modification of values used in node creation. For example, one might
  178. want to disallow in-place operations from being recorded.
  179. """
  180. if kind == "call_function" and self.check_mutable_operations:
  181. check_for_mutable_operation(target, args, kwargs)
  182. node = self.graph.create_node(kind, target, args, kwargs, name, type_expr)
  183. # TODO node_name_to_scope will be depreciated in favor of
  184. # node.meta['nn_module_stack']
  185. self.node_name_to_scope[node.name] = (
  186. self.scope.module_path,
  187. self.scope.module_type,
  188. )
  189. # Optionally set stack trace on the created Node for debugging purposes
  190. if fx_traceback.has_preserved_node_meta():
  191. current_meta: dict[str, Any] = fx_traceback.get_current_meta()
  192. stack_trace = current_meta.get("stack_trace")
  193. if stack_trace:
  194. node.stack_trace = stack_trace
  195. if fx_traceback.GRADIENT_ACC_SPECIAL_STACK in stack_trace:
  196. node.meta["is_gradient_acc"] = True
  197. # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta
  198. # If other meta fields are needed, they can be added here
  199. for field in _COPY_META_FIELDS:
  200. if field in current_meta:
  201. node.meta[field] = copy.copy(current_meta[field])
  202. new_seq_nr = _get_seq_nr(node.name)
  203. if new_seq_nr is not None:
  204. annotation_log.debug(
  205. "Assigning new_seq_nr %s to %s", new_seq_nr, node.name
  206. )
  207. node.meta["seq_nr"] = new_seq_nr
  208. # See Note [Functionalization View Replay Annotation]
  209. # Overriding some node meta with the original node meta of the
  210. # regenerated node.
  211. replay_node: Node = fx_traceback.get_current_replay_node()
  212. if replay_node is not None:
  213. node.meta["is_functional_regenerated"] = True
  214. if "custom" in replay_node.meta:
  215. node.meta["custom"] = replay_node.meta.get("custom")
  216. if "stack_trace" in replay_node.meta:
  217. node.stack_trace = replay_node.meta.get("stack_trace")
  218. elif self.module_stack:
  219. node.meta["nn_module_stack"] = copy.copy(self.module_stack)
  220. if self.record_stack_traces and not node.stack_trace:
  221. user_stack_summary = CapturedTraceback.extract().summary()
  222. if user_stack_summary:
  223. user_stack_summary = self._filter_traceback_frames(user_stack_summary)
  224. if user_stack_summary:
  225. node.stack_trace = "".join(user_stack_summary.format()).strip()
  226. log.debug("create_node %s", node)
  227. return node
  228. def _filter_traceback_frames(
  229. self, user_stack_summary: traceback.StackSummary
  230. ) -> traceback.StackSummary:
  231. # This method can be overridden to customize the frame filtering logic
  232. # for the recorded stack trace
  233. user_frames = []
  234. if self._record_forward_stack_traces_only:
  235. user_frames = [
  236. frame
  237. for frame in user_stack_summary
  238. if (
  239. frame.name == "forward"
  240. or frame.filename.endswith("torch/__init__.py")
  241. )
  242. ]
  243. else:
  244. first_forward = -1
  245. for i, frame in enumerate(user_stack_summary):
  246. if frame.name == "forward":
  247. user_frames = user_stack_summary[i:]
  248. first_forward = i
  249. break
  250. # Not having a "forward" call in the stacktrace implies the
  251. # stacktrace will probably be irrelevant
  252. if first_forward == -1:
  253. user_frames = []
  254. from torch.fx.experimental.symbolic_shapes import uninteresting_files
  255. user_frames = [
  256. frame
  257. for frame in user_frames
  258. if frame.filename not in uninteresting_files()
  259. ]
  260. return traceback.StackSummary.from_list(user_frames)
  261. @compatibility(is_backward_compatible=True)
  262. def proxy(self, node: Node) -> "Proxy":
  263. return Proxy(node, self)
  264. @compatibility(is_backward_compatible=True)
  265. def create_proxy(
  266. self,
  267. kind: str,
  268. target: Target,
  269. args: tuple[Any, ...],
  270. kwargs: dict[str, Any],
  271. name: Optional[str] = None,
  272. type_expr: Optional[Any] = None,
  273. # fix noqa when updating bc tests
  274. proxy_factory_fn: Callable[[Node], "Proxy"] = None, # noqa: RUF013
  275. ):
  276. """
  277. Create a Node from the given arguments, then return the Node
  278. wrapped in a Proxy object.
  279. If kind = 'placeholder', then we're creating a Node that
  280. represents the parameter of a function. If we need to encode
  281. a default parameter, we use the ``args`` tuple. ``args`` is
  282. otherwise empty for ``placeholder`` Nodes.
  283. """
  284. args_ = self.create_arg(args)
  285. kwargs_ = self.create_arg(kwargs)
  286. if not isinstance(args_, tuple):
  287. raise AssertionError(f"Expected args_ to be tuple, got {type(args_)}")
  288. if not isinstance(kwargs_, dict):
  289. raise AssertionError(f"Expected kwargs_ to be dict, got {type(kwargs_)}")
  290. node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
  291. if not proxy_factory_fn:
  292. proxy = self.proxy(node)
  293. else:
  294. proxy = proxy_factory_fn(node)
  295. return proxy
  296. def _find_user_frame(self):
  297. """
  298. Find the Python stack frame executing the user code during
  299. symbolic tracing.
  300. """
  301. # We have to do a little dance here. Basically, walk up the callstack and
  302. # record the first frame not in the pytorch source. This is the frame executing
  303. # the user code during tracing.
  304. frame = inspect.currentframe()
  305. pt_files = [
  306. "torch/fx/proxy.py",
  307. "torch/fx/_symbolic_trace.py",
  308. "torch/fx/experimental/proxy_tensor.py",
  309. "torch/_ops.py",
  310. "torch/_tensor.py",
  311. "torch/utils/_python_dispatch.py",
  312. "torch/_prims_common/wrappers.py",
  313. "torch/_refs/__init__.py",
  314. "torch/_refs/nn/functional/__init__.py",
  315. "torch/utils/_stats.py",
  316. ]
  317. while frame:
  318. frame = frame.f_back
  319. if frame and all(
  320. not frame.f_code.co_filename.endswith(file) for file in pt_files
  321. ):
  322. break
  323. if not frame:
  324. return None
  325. return frame
  326. @compatibility(is_backward_compatible=True)
  327. def create_arg(self, a: Any) -> Argument:
  328. """
  329. A method that lowers the objects seen as arguments during symbolic evaluation
  330. into Argument types that can be stored in IR.
  331. Can be override to support more trace-specific types.
  332. """
  333. # IMPORTANT: Are you here because you are trying to proxy a new type into
  334. # the graph? Please Please Please contact someone on the PyTorch Compiler team;
  335. # the considerations are subtle.
  336. #
  337. # 1) When you add a new type, all of the downstream consumers and pass writers
  338. # need to handle the new type. torch.fx is intended to be easy to write
  339. # passes for, so we will push back against new types.
  340. # 2) In torch.compile's IR, there are only specific operations that go
  341. # into the graph. In particular, Tensor operations should go into the graph,
  342. # but non-Tensor operations shouldn't. What that means is that constructors
  343. # for new types *SHOULD NOT* become nodes in the FX graph.
  344. handler = _create_arg_bypass.get(type(a))
  345. if handler is not None:
  346. # this is just a performance optimization and can be removed if needed
  347. # for common types, we have a fast path to avoid isinstance() overhead
  348. # this doesn't remove the checks below since we need to handle subclasses
  349. return handler(self, a)
  350. if isinstance(a, Proxy):
  351. return a.node # most common arg type goes first
  352. elif hasattr(a, "__fx_create_arg__"):
  353. return a.__fx_create_arg__(self)
  354. # aggregates
  355. elif isinstance(a, tuple):
  356. if hasattr(a, "_fields"):
  357. # NamedTuple constructors don't seem to like getting a generator
  358. # expression as an argument to their constructor, so build this
  359. # intermediate tuple and unpack it into the NamedTuple constructor
  360. args = [self.create_arg(elem) for elem in a]
  361. return type(a)(*args) # type: ignore[arg-type]
  362. return type(a)([self.create_arg(elem) for elem in a])
  363. elif isinstance(a, list):
  364. return [self.create_arg(elem) for elem in a]
  365. elif isinstance(a, dict):
  366. return _create_arg_dict(self, a)
  367. elif isinstance(a, slice):
  368. return slice(
  369. self.create_arg(a.start),
  370. self.create_arg(a.stop),
  371. self.create_arg(a.step),
  372. )
  373. elif isinstance(a, range):
  374. return range(
  375. self.create_arg(a.start),
  376. self.create_arg(a.stop),
  377. self.create_arg(a.step),
  378. )
  379. elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
  380. return a
  381. elif is_opaque_value_type(type(a)):
  382. return a
  383. elif is_dataclass(a):
  384. kwargs = {
  385. field.name: self.create_arg(getattr(a, field.name))
  386. for field in fields(a)
  387. }
  388. return self.create_node("call_function", a.__class__, (), kwargs)
  389. elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...:
  390. return a
  391. raise NotImplementedError(f"argument of type: {type(a)}")
  392. @compatibility(is_backward_compatible=True)
  393. def to_bool(self, obj: "Proxy") -> bool:
  394. """Called when a proxy object is being converted to a boolean, such as
  395. when used in control flow. Normally we don't know what to do because
  396. we don't know the value of the proxy, but a custom tracer can attach more
  397. information to the graph node using create_node and can choose to return a value.
  398. """
  399. raise TraceError(
  400. "symbolically traced variables cannot be used as inputs to control flow"
  401. )
  402. @compatibility(is_backward_compatible=True)
  403. def iter(self, obj: "Proxy") -> Iterator:
  404. """Called when a proxy object is being iterated over, such as
  405. when used in control flow. Normally we don't know what to do because
  406. we don't know the value of the proxy, but a custom tracer can attach more
  407. information to the graph node using create_node and can choose to return an iterator.
  408. """
  409. raise TraceError(
  410. "Proxy object cannot be iterated. This can be "
  411. "attempted when the Proxy is used in a loop or"
  412. " as a *args or **kwargs function argument. "
  413. "See the torch.fx docs on pytorch.org for a "
  414. "more detailed explanation of what types of "
  415. "control flow can be traced, and check out the"
  416. " Proxy docstring for help troubleshooting "
  417. "Proxy iteration errors"
  418. )
  419. @compatibility(is_backward_compatible=True)
  420. def keys(self, obj: "Proxy") -> Any:
  421. """Called when a proxy object is has the keys() method called.
  422. This is what happens when ** is called on a proxy. This should return an
  423. iterator it ** is suppose to work in your custom tracer.
  424. """
  425. return Attribute(obj, "keys")()
  426. def _get_seq_nr(node_name: str = ""):
  427. """
  428. Returns the seq_nr node meta for the current proxy node that we're creating.
  429. The seq_nr number in node meta is related to but not the same as the "sequence number"
  430. in autograd.
  431. We use the seq_nr to correlate forward and backward nodes in the traced FX graphs
  432. (e.g. in copy_fwd_metadata_to_bw_nodes).
  433. The corresponding forward and backward FX graph nodes should have the same seq_nr.
  434. The ordering of the seq_nr in the graph does not indicate when the node is executed.
  435. For example, the nodes in the invoke_subgraph HOP's subgraph may be called multiple
  436. times by different HOP nodes, but these nodes only have a single seq_nr, which may be
  437. smaller, the same, or larger than the calling HOP.
  438. `node_name` is the name of the node that we're creating. It is used for logging only.
  439. """
  440. current_meta: dict[str, Any] = fx_traceback.get_current_meta()
  441. new_seq_nr = None
  442. # The sequence_nr increments every time a new autograd Node
  443. # is created. During the FWD pass we store the sequence_nr
  444. # corresponding to the last autograd Node created on this fx
  445. # node's meta. A single aten op can create multiple autograd
  446. # nodes as is the case with in-place foreach ops. During the
  447. # BWD pass we retrieve the sequence_nr stored on the current
  448. # executing autograd Node. See NOTE [ Sequence Number ].
  449. if current_meta.get("in_grad_fn", 0) > 0:
  450. # This branch is used to get seq_nr for backward nodes
  451. annotation_log.debug("%s: seq_nr from current_meta grad_fn_seq_nr", node_name)
  452. new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
  453. elif torch.fx.traceback._is_preserving_node_seq_nr():
  454. # Special case where we preserve seq_nr from currently tracing node
  455. # Used to preserve seq_nr when re-tracing subgraphs in HOP
  456. annotation_log.debug("%s: seq_nr from current_meta seq_nr", node_name)
  457. new_seq_nr = current_meta.get("seq_nr")
  458. else:
  459. # Here we decrement to account for the sequence_nr having
  460. # just been incremented while tracing this lowered aten op.
  461. # This branch is used to get seq_nr for forward nodes
  462. new_seq_nr = torch.autograd._get_sequence_nr() - 1
  463. if not torch.fx.traceback._is_preserving_node_seq_nr():
  464. # See Note [Functionalization View Replay Annotation]
  465. # Overriding some node meta with the original node meta of the
  466. # regenerated node.
  467. replay_node: Node = fx_traceback.get_current_replay_node()
  468. if replay_node is not None:
  469. if "seq_nr" in replay_node.meta:
  470. annotation_log.debug("%s: seq_nr from replay_node", node_name)
  471. new_seq_nr = replay_node.meta["seq_nr"]
  472. return new_seq_nr
  473. # used in Proxy object when just appending to the graph while not tracing.
  474. @compatibility(is_backward_compatible=True)
  475. class GraphAppendingTracer(TracerBase):
  476. def __init__(self, graph: Graph):
  477. super().__init__()
  478. self.graph = graph
  479. self.scope = Scope("", None)
  480. self.module_stack = collections.OrderedDict()
  481. self.node_name_to_scope = {}
  482. @compatibility(is_backward_compatible=False)
  483. def assert_fn(x):
  484. if not x:
  485. raise AssertionError("Assertion failed")
  486. @compatibility(is_backward_compatible=True)
  487. class TraceError(ValueError):
  488. pass
  489. @compatibility(is_backward_compatible=True)
  490. class Proxy:
  491. """
  492. ``Proxy`` objects are ``Node`` wrappers that flow through the
  493. program during symbolic tracing and record all the operations
  494. (``torch`` function calls, method calls, operators) that they touch
  495. into the growing FX Graph.
  496. If you're doing graph transforms, you can wrap your own ``Proxy``
  497. method around a raw ``Node`` so that you can use the overloaded
  498. operators to add additional things to a ``Graph``.
  499. ``Proxy`` objects cannot be iterated. In other words, the symbolic
  500. tracer will throw an error if a ``Proxy`` is used in a loop or as
  501. an ``*args``/``**kwargs`` function argument.
  502. There are two main ways around this:
  503. 1. Factor out the untraceable logic into a top-level function and
  504. use ``fx.wrap`` on it.
  505. 2. If the control flow is static (i.e. the loop trip count is
  506. based on some hyperparameter), the code can be kept in its original
  507. position and refactored into something like::
  508. for i in range(self.some_hyperparameter):
  509. indexed_item = proxied_value[i]
  510. For a more detailed description into the Proxy internals, check out
  511. the "Proxy" section in `torch/fx/README.md`
  512. """
  513. @compatibility(is_backward_compatible=True)
  514. def __init__(self, node: Node, tracer: "Optional[TracerBase]" = None):
  515. if tracer is None:
  516. # This allows you to create a Proxy object around a raw Node
  517. tracer = GraphAppendingTracer(node.graph)
  518. self.tracer = tracer
  519. self.node = node
  520. def __repr__(self) -> str:
  521. return f"Proxy({self.node.name})"
  522. def __getattr__(self, k) -> "Attribute":
  523. # note: not added to the graph yet, if this is a method call
  524. # we peephole optimize to the method invocation
  525. return Attribute(self, k)
  526. def __getstate__(self) -> dict:
  527. return self.__dict__
  528. def __deepcopy__(self, memo) -> dict:
  529. # We have to explicitly override this method, because otherwise deepcopy
  530. # will go to __getattr__(self, "__deepcopy__") and return a
  531. # Attribute(__deepcopy__), and may go into an infinite loop in some cases.
  532. import copy
  533. new_dict = {}
  534. for k, v in self.__dict__.items():
  535. try:
  536. new_obj = copy.deepcopy(v, memo)
  537. except Exception:
  538. log.warning(
  539. "Shallow copy %s of Proxy because it cannot be deepcopied. "
  540. "Proxy is created for node %s",
  541. k,
  542. self.node.name,
  543. )
  544. new_obj = copy.copy(v)
  545. new_dict[k] = new_obj
  546. if "node" not in new_dict:
  547. raise AssertionError("'node' not in new_dict during proxy unpickling")
  548. if "tracer" not in new_dict:
  549. raise AssertionError("'tracer' not in new_dict during proxy unpickling")
  550. new_proxy = Proxy(new_dict["node"], new_dict["tracer"])
  551. for k, v in new_dict.items():
  552. new_proxy.__dict__[k] = v
  553. return new_proxy
  554. def __setstate__(self, d):
  555. # This is called when being unpickled/loaded.
  556. self.__dict__ = d
  557. def __call__(self, *args, **kwargs) -> "Proxy":
  558. return self.tracer.create_proxy(
  559. "call_method", "__call__", (self,) + args, kwargs
  560. )
  561. def __iter__(self) -> Iterator["Proxy"]:
  562. frame = inspect.currentframe()
  563. if frame is None:
  564. raise AssertionError("inspect.currentframe() returned None")
  565. calling_frame = frame.f_back
  566. if calling_frame is None:
  567. raise AssertionError("frame.f_back is None")
  568. inst_list = list(dis.get_instructions(calling_frame.f_code))
  569. if sys.version_info >= (3, 11):
  570. from bisect import bisect_left
  571. inst_idx = bisect_left(
  572. inst_list, calling_frame.f_lasti, key=lambda x: x.offset
  573. )
  574. else:
  575. inst_idx = calling_frame.f_lasti // 2
  576. inst = inst_list[inst_idx]
  577. if inst.opname == "UNPACK_SEQUENCE":
  578. return (self[i] for i in range(inst.argval)) # type: ignore[index]
  579. return self.tracer.iter(self)
  580. def __abs__(self):
  581. return self.tracer.create_proxy("call_function", operator.abs, (self,), {})
  582. def __bool__(self) -> bool:
  583. if self.tracer.trace_asserts:
  584. # check if this boolean is used in an assertion, bytecode pattern for assertions
  585. # is pretty stable for Python 3.7--3.9
  586. frame = inspect.currentframe()
  587. if frame is None:
  588. raise AssertionError("inspect.currentframe() returned None")
  589. calling_frame = frame.f_back
  590. if calling_frame is None:
  591. raise AssertionError("frame.f_back is None")
  592. insts = list(dis.get_instructions(calling_frame.f_code))
  593. if sys.version_info >= (3, 11):
  594. from bisect import bisect_left
  595. cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset)
  596. else:
  597. cur = calling_frame.f_lasti // 2
  598. inst = insts[cur]
  599. if inst.opname == "POP_JUMP_IF_TRUE":
  600. first = insts[cur + 1]
  601. if inst.arg is None:
  602. raise AssertionError("inst.arg is None for POP_JUMP_IF_TRUE")
  603. last = insts[inst.arg // 2 - 1]
  604. starts_with_assert = (
  605. first.opname == "LOAD_GLOBAL"
  606. and first.argval == "AssertionError"
  607. or first.opname == "LOAD_ASSERTION_ERROR"
  608. )
  609. if starts_with_assert and last.opname == "RAISE_VARARGS":
  610. self.tracer.create_proxy("call_function", assert_fn, (self,), {})
  611. return True
  612. return self.tracer.to_bool(self)
  613. @compatibility(is_backward_compatible=True)
  614. def keys(self):
  615. return self.tracer.keys(self)
  616. def __len__(self):
  617. raise RuntimeError(
  618. "'len' is not supported in symbolic tracing by default. If you want "
  619. "this call to be recorded, please call torch.fx.wrap('len') at "
  620. "module scope"
  621. )
  622. @classmethod
  623. def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
  624. args = args if args else ()
  625. kwargs = kwargs if kwargs else {}
  626. tracers: dict[Any, None] = {}
  627. def find_tracer(a):
  628. if isinstance(a, cls):
  629. tracers[a.tracer] = None
  630. tree_map_(find_tracer, args)
  631. tree_map_(find_tracer, kwargs)
  632. if len(tracers) > 1:
  633. raise RuntimeError(
  634. f"Found multiple different tracers {list(tracers.keys())} while "
  635. f"trying to trace operations {orig_method}"
  636. )
  637. tracer = next(iter(tracers.keys()))
  638. if isinstance(orig_method, torch._C.ScriptMethod):
  639. args = (orig_method.owner,) + args
  640. return tracer.create_proxy("call_method", orig_method.name, args, kwargs)
  641. if torch.overrides.is_tensor_method_or_property(orig_method):
  642. return tracer.create_proxy(
  643. "call_method", orig_method.__name__, args, kwargs
  644. )
  645. else:
  646. if isinstance(orig_method, torch._ops.HigherOrderOperator):
  647. bad_callable = _find_arbitrary_callable(args, kwargs)
  648. if bad_callable is not None:
  649. raise RuntimeError(
  650. f"Unable to symbolically trace the HigherOrderOperator "
  651. f"{orig_method._name} because it received an arbitrary "
  652. f"callable argument {bad_callable}. Use make_fx or dynamo "
  653. f"tracing instead."
  654. )
  655. return tracer.create_proxy(
  656. "call_function",
  657. orig_method,
  658. args,
  659. kwargs,
  660. name=tracer.graph._target_to_str(orig_method.__name__),
  661. )
  662. @compatibility(is_backward_compatible=False)
  663. class MetaProxy(Proxy):
  664. """
  665. A Proxy subclass that propagates metadata (meta['val']) during graph tracing.
  666. """
  667. def __init__(
  668. self, node: Node, tracer: "Optional[TracerBase]" = None, fake_mode=None
  669. ):
  670. super().__init__(node, tracer)
  671. self.fake_mode = fake_mode
  672. def __repr__(self) -> str:
  673. return f"MetaProxy({self.node.name})"
  674. @classmethod
  675. def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
  676. args = args if args else ()
  677. kwargs = kwargs if kwargs else {}
  678. meta_proxy = None
  679. for arg in args:
  680. if isinstance(arg, MetaProxy):
  681. meta_proxy = arg
  682. break
  683. if meta_proxy is None:
  684. raise AssertionError(
  685. "No MetaProxy found in arguments, but one is expected."
  686. )
  687. proxy = super().__torch_function__(orig_method, types, args, kwargs)
  688. with meta_proxy.fake_mode:
  689. proxy.node.meta["val"] = orig_method(
  690. *[a.node.meta["val"] if isinstance(a, Proxy) else a for a in args],
  691. **kwargs,
  692. )
  693. return MetaProxy(proxy.node, proxy.tracer, meta_proxy.fake_mode)
  694. @compatibility(is_backward_compatible=True)
  695. class Attribute(Proxy):
  696. @compatibility(is_backward_compatible=True)
  697. def __init__(self, root: Proxy, attr: str):
  698. self.root = root
  699. self.attr = attr
  700. self.tracer = root.tracer
  701. self._node: Optional[Node] = None
  702. @property
  703. def node(self):
  704. # the node for attributes is added lazily, since most will just be method calls
  705. # which do not rely on the getitem call
  706. if self._node is None:
  707. self._node = self.tracer.create_proxy(
  708. "call_function", getattr, (self.root, self.attr), {}
  709. ).node
  710. return self._node
  711. def __call__(self, *args, **kwargs):
  712. return self.tracer.create_proxy(
  713. "call_method", self.attr, (self.root,) + args, kwargs
  714. )
  715. @compatibility(is_backward_compatible=False)
  716. class ParameterProxy(Proxy):
  717. """
  718. A special proxy which lets "shape", "size", "dim", and a few other
  719. attribute accesses pass through to the underlying module parameter object,
  720. so that conditional tests on these attributes will not throw exception during tracing
  721. """
  722. def __init__(self, tracer: TracerBase, node: Node, name, param):
  723. super().__init__(node, tracer)
  724. if not isinstance(param, torch.nn.Parameter):
  725. raise AssertionError(f"Expected Parameter, got {type(param)}")
  726. self.param = param
  727. self.name = name
  728. def __repr__(self) -> str:
  729. return f"ParameterProxy({self.name})"
  730. @property
  731. def shape(self):
  732. return self.param.shape
  733. def size(self):
  734. return self.param.size()
  735. def dim(self):
  736. return self.param.dim()
  737. @property
  738. def ndim(self):
  739. return self.param.ndim
  740. def numel(self):
  741. return self.param.numel()
  742. def nelement(self):
  743. return self.param.nelement()
  744. for method in magic_methods:
  745. def _scope(method):
  746. def impl(*args, **kwargs):
  747. tracer = args[0].tracer
  748. target = getattr(operator, method)
  749. return tracer.create_proxy("call_function", target, args, kwargs)
  750. impl.__name__ = method
  751. as_magic = f"__{method.strip('_')}__"
  752. setattr(Proxy, as_magic, impl)
  753. _scope(method)
  754. def _define_reflectable(orig_method_name):
  755. method_name = f"__r{orig_method_name.strip('_')}__"
  756. def impl(self, rhs):
  757. target = getattr(operator, orig_method_name)
  758. return self.tracer.create_proxy("call_function", target, (rhs, self), {})
  759. impl.__name__ = method_name
  760. impl.__qualname__ = method_name
  761. setattr(Proxy, method_name, impl)
  762. for orig_method_name in reflectable_magic_methods:
  763. _define_reflectable(orig_method_name)
  764. def _no_nodes_error(arg):
  765. raise RuntimeError(
  766. "Keys for dictionaries used as an argument cannot contain a "
  767. f"Node. Got key: {arg}"
  768. )
  769. def _create_arg_dict(self, a):
  770. r = {}
  771. for k, v in a.items():
  772. if not isinstance(k, str):
  773. # Check for invalid dict keys. We do not want a Proxy to appear
  774. # anywhere within the key. Since keys can be collection types,
  775. # we iterate through the key with map_arg
  776. k = self.create_arg(k)
  777. map_arg(k, _no_nodes_error)
  778. r[k] = self.create_arg(v)
  779. return r
  780. _create_arg_bypass = {
  781. t: lambda self, a: a
  782. for t in [
  783. *base_types,
  784. type(None),
  785. type(...),
  786. torch._ops.OpOverload,
  787. torch._ops.HigherOrderOperator,
  788. ]
  789. }
  790. _create_arg_bypass[Proxy] = lambda self, a: a.node
  791. _create_arg_bypass[tuple] = lambda self, a: tuple(self.create_arg(elem) for elem in a)
  792. _create_arg_bypass[list] = lambda self, a: [self.create_arg(elem) for elem in a]
  793. _create_arg_bypass[dict] = _create_arg_dict
  794. _create_arg_bypass[immutable_list] = _create_arg_bypass[list]
  795. _create_arg_bypass[immutable_dict] = _create_arg_bypass[dict]