_unlift.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import inspect
  4. import math
  5. import warnings
  6. from collections.abc import Sequence
  7. from itertools import chain
  8. from typing import Any
  9. import sympy
  10. import torch
  11. import torch.utils._pytree as pytree
  12. from torch._export.non_strict_utils import (
  13. _enter_enable_graph_inputs_of_type_nn_module,
  14. _exit_enable_graph_inputs_of_type_nn_module,
  15. _get_graph_inputs_of_type_nn_module,
  16. )
  17. from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
  18. _convert_range_to_int,
  19. )
  20. from torch._export.utils import _check_input_constraints_for_graph
  21. from torch.export.unflatten import _assign_attr, _AttrKind
  22. from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
  23. from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
  24. from torch.fx.traceback import NodeSource, NodeSourceAction
  25. from torch.utils._sympy.solve import try_solve
  26. from torch.utils._sympy.value_ranges import ValueRanges
  27. from ._remove_effect_tokens_pass import _remove_effect_tokens
  28. from ._tree_utils import reorder_kwargs
  29. from .exported_program import (
  30. ExportedProgram,
  31. ExportGraphSignature,
  32. InputKind,
  33. OutputKind,
  34. )
  35. def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool:
  36. """
  37. Refinement of TreeSpec.__eq__ where, e.g., torch.Size(...) matches tuple(...).
  38. See _pytree_subclasses_that_lose_info in proxy_tensor.py for more details.
  39. """
  40. def _normalize_type(t):
  41. return str(_pytree_subclasses_that_lose_info.get(t, t))
  42. def _match_normalized_structure(a, b):
  43. if a is b:
  44. return True
  45. if _normalize_type(a.type) != _normalize_type(b.type):
  46. return False
  47. if a.type is dict and b.type is dict:
  48. # in the case of dict, the context is list of keys and we allow the keys to be in any order
  49. if set(a.context) != set(b.context):
  50. return False
  51. elif a.context != b.context:
  52. return False
  53. if a.num_children != b.num_children:
  54. return False
  55. return all(
  56. _match_normalized_structure(a, b)
  57. for a, b in zip(a.children(), b.children())
  58. )
  59. return _match_normalized_structure(self, other)
  60. def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list:
  61. reordered_kwargs = reorder_kwargs(kwargs, in_spec)
  62. flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
  63. (args, reordered_kwargs)
  64. )
  65. if not eq_spec(received_spec, in_spec):
  66. raise ValueError( # noqa: B904
  67. "Trying to flatten user inputs with exported input tree spec: \n"
  68. f"{in_spec}\n"
  69. "but actually got inputs with tree spec of: \n"
  70. f"{received_spec}.\n"
  71. "Please check that the inputs have the same number and type of "
  72. "args and kwargs as the ones you used when tracing."
  73. )
  74. return flat_args_with_path
  75. def _force_ep_signature_match(ep_guards_code: list[str], input_paths):
  76. # TODO (tmanlaibaatar)
  77. # This is band-aid solution to export new tracer replacing
  78. # shape env sources to flat_args. The real fix should be replacing
  79. # shape env sources to original user sources but this is quite
  80. # involved because you need to carefully construct new sources using
  81. # dynamo and replace all instances of it inside shape env. But it is
  82. # lot easier to manipulate after we turn them into strings and only
  83. # time we use these guards is during retracing or running exported program,
  84. # so it is probably ok to have "not useful" guards on ep for now.
  85. name_mapping = {}
  86. for idx, path in enumerate(input_paths):
  87. name_mapping[f"L['flat_args'][{idx}]"] = f"L{pytree.keystr(path)}"
  88. new_guards_code = []
  89. for guard in ep_guards_code:
  90. for old_name, new_name in name_mapping.items():
  91. guard = guard.replace(old_name, new_name)
  92. new_guards_code.append(guard)
  93. return new_guards_code
  94. def _force_gm_signature_match(ep_guards_code: list[str], signature):
  95. """
  96. The signature of the originally exported module may not match
  97. the signature of the unlifted graph module extracted from the
  98. exported program. The guards code extracted from the exported
  99. program is based on the former, but the generated guards fn is
  100. based on the latter; thus we need to reconcile any such diff.
  101. """
  102. import re
  103. # Handle case where signatures may differ in var args.
  104. orig_arg_names = set()
  105. for g in ep_guards_code:
  106. # match substrings of the form L['<name>'][<number>]
  107. orig_arg_names.update(re.findall(r"L\[\'([^\']+)\'\]\[([0-9]+)\]", g))
  108. sig_arg_names = set()
  109. for n in signature.parameters:
  110. # match substrings of the form <name>_<number>
  111. sig_arg_names.update(re.findall(r"(.+)_([0-9]+)", n))
  112. # replace L['<name>'][<number>] with L['<name>_<number>']
  113. new_guards_code = ep_guards_code
  114. for match in orig_arg_names:
  115. if match in sig_arg_names:
  116. base, idx = match
  117. new_guards_code = [
  118. g.replace(f"L['{base}'][{idx}]", f"L['{base}_{idx}']")
  119. for g in new_guards_code
  120. ]
  121. return new_guards_code
  122. def _convert_guards_code_to_fn(
  123. guards_code: list[str],
  124. paths_of_placeholders: list[pytree.KeyPath],
  125. ):
  126. """
  127. Generates Python code given guards code and paths of placeholders.
  128. We assume that, based on source information,
  129. - the tracer generates the guards code
  130. - the input spec generates the paths of placeholders.
  131. Example:
  132. Suppose we are given the guards code "L['z']['k'].size()[1] == 3"
  133. and we are given that ['z']['k'] is the path of placeholder #2.
  134. Then we will generate:
  135. ```
  136. torch._assert(
  137. args[2].size()[0] == 3,
  138. "Guard failed: z['k'].size()[0] == 3",
  139. )
  140. ```
  141. FAQ: Why do we generate code based on (flattened) args instead of
  142. the original (unflattened) inputs? Because this would require
  143. inserting an additional pytree.unflatten call in our graph.
  144. FAQ: Why do we not emit RuntimeError on guard failure as we used to?
  145. Because it is inconvenient :/, get used to AssertionError instead.
  146. """
  147. import ast
  148. from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP
  149. actual_guards_code = []
  150. shadow_guards_code = []
  151. for c in guards_code:
  152. a, s = c, c
  153. for idx, path in enumerate(paths_of_placeholders):
  154. # e.g., replace L['z']['k'] with args[2] for Python code (actual)
  155. a = a.replace("L" + pytree.keystr(path), f"args[{idx}]")
  156. # e.g., replace L['z']['k'] with z['k'] for error message (shadow)
  157. s = s.replace(
  158. "L" + pytree.keystr(path),
  159. path[0].key + pytree.keystr(path[1:]), # type: ignore[attr-defined]
  160. )
  161. actual_guards_code.append(a)
  162. shadow_guards_code.append(s.replace("\n", ""))
  163. # generate function code as str
  164. code_str = "\ndef _(*args):\n"
  165. for actual, shadow in zip(actual_guards_code, shadow_guards_code):
  166. # printing guards code may potentially introduce redundant parens;
  167. # we can normalize them out for readability by parsing/unparsing
  168. # NOTE: this is not necessary for correctness, just deemed desirable
  169. _shadow = ast.unparse(ast.parse(shadow, mode="eval"))
  170. # actual code and shadow error message
  171. code_str += f' torch._assert({actual}, "Guard failed: {_shadow}")\n'
  172. code_str += " return\n"
  173. # populate namespace with sympy globals, materialize function (named `_`)
  174. namespace = {**SYMPY_INTERP}
  175. exec(code_str, namespace)
  176. # create and return a module whose forward is the materialized function
  177. # NOTE: we want Dynamo to trace through this module, to repopulate guards:
  178. # otherwise we would lose them when retracing
  179. # NOTE: calling this module will be a side effect (no users): so it must
  180. # be marked impure to avoid being not cleaned up by DCE
  181. guards_fn = GuardsFn()
  182. guards_fn.forward = torch._dynamo.dont_skip_tracing(namespace["_"]) # type: ignore[call-overload, method-assign]
  183. guards_fn._is_impure = True # type: ignore[assignment]
  184. return guards_fn
  185. @torch._dynamo.disable
  186. def _check_input_constraints_for_module(self, args, kwargs):
  187. flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
  188. _check_input_constraints_for_graph(
  189. self.graph.find_nodes(op="placeholder"),
  190. flat_args_with_path,
  191. self.range_constraints,
  192. )
  193. def _check_input_constraints_pre_hook(self, args, kwargs):
  194. # preserve current behavior for clients that do not want any validation
  195. if not self.validate_inputs:
  196. return
  197. # when a guards function exists, assume that the graph does calls it!
  198. # so we do not need to check input constraints...but we still want
  199. # to check inputs match, otherwise we'd get obscure pytree errors
  200. if hasattr(self, "_guards_fn"):
  201. _check_inputs_match(args, kwargs, self._in_spec)
  202. return
  203. # NOTE: for some reason, Dynamo is tracing into this, we should see why and
  204. # put compile at the right place. Until then, we can skip the input
  205. # constraint checks.
  206. if not torch.compiler.is_dynamo_compiling():
  207. _check_input_constraints_for_module(self, args, kwargs)
  208. def _unlift_inputs_as_getattr(
  209. gm: torch.fx.GraphModule,
  210. lifted_inputs: Sequence[str | None],
  211. ) -> tuple[dict[str, torch.fx.Node], dict[str, torch.fx.Node]]:
  212. """
  213. Unlift inputs referring to params/buffers/constants as getattr nodes in the
  214. graph
  215. """
  216. unlifted_name_to_node = {}
  217. input_name_to_node = {}
  218. placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
  219. if len(lifted_inputs) != len(placeholder_nodes):
  220. raise AssertionError(
  221. f"Number of lifted inputs ({len(lifted_inputs)}) does not match "
  222. f"placeholder nodes ({len(placeholder_nodes)})"
  223. )
  224. for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
  225. if lifted_node is None:
  226. input_name_to_node[input_node.name] = input_node
  227. else:
  228. with gm.graph.inserting_after(input_node):
  229. # It is fine to ignore this warning because
  230. # it is guaranteed that we will populate this
  231. # attr later.
  232. with warnings.catch_warnings():
  233. warnings.simplefilter("ignore")
  234. getattr_node = gm.graph.get_attr(lifted_node)
  235. input_node.replace_all_uses_with(getattr_node)
  236. metadata = input_node.meta
  237. gm.graph.erase_node(input_node)
  238. getattr_node.meta = metadata
  239. getattr_node.meta["from_node"] = [
  240. NodeSource(
  241. input_node,
  242. "ExportedProgram.module().unlift()",
  243. [NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
  244. )
  245. ]
  246. unlifted_name_to_node[lifted_node] = getattr_node
  247. return unlifted_name_to_node, input_name_to_node
  248. def _insert_copy_for_mutations(
  249. gm: torch.fx.GraphModule,
  250. mutated_outputs: Sequence[str | None],
  251. unlifted_name_to_node: dict[str, torch.fx.Node],
  252. input_name_to_node: dict[str, torch.fx.Node],
  253. ) -> None:
  254. """
  255. Find the all the buffers and inputs that were mutated and insert copy_
  256. operators to reflect mutations.
  257. """
  258. output_node = gm.graph.output_node()
  259. outputs = pytree.tree_flatten(output_node.args)[0]
  260. if len(outputs) != len(mutated_outputs):
  261. raise AssertionError(
  262. f"Number of outputs ({len(outputs)}) does not match "
  263. f"mutated outputs ({len(mutated_outputs)})"
  264. )
  265. user_output_nodes = []
  266. return_nodes_to_copy = {}
  267. for return_node, mutated_node_name in zip(outputs, mutated_outputs):
  268. if mutated_node_name is None:
  269. user_output_nodes.append(return_node)
  270. continue
  271. if mutated_node_name in unlifted_name_to_node:
  272. mutated_node = unlifted_name_to_node[mutated_node_name]
  273. elif mutated_node_name in input_name_to_node:
  274. mutated_node = input_name_to_node[mutated_node_name]
  275. else:
  276. raise RuntimeError(
  277. f"Could not find {mutated_node_name} in either buffer or input nodes"
  278. )
  279. with gm.graph.inserting_before(output_node):
  280. copy_node = gm.graph.call_function(
  281. torch.ops.aten.copy_.default, (mutated_node, return_node)
  282. )
  283. return_nodes_to_copy[return_node] = copy_node
  284. output_args = tuple(
  285. return_nodes_to_copy.get(node, node) for node in user_output_nodes
  286. )
  287. with gm.graph.inserting_before(output_node):
  288. # Only return user outputs
  289. new_output = gm.graph.output(output_args)
  290. output_node.replace_all_uses_with(new_output)
  291. gm.graph.erase_node(output_node)
  292. new_output.name = output_node.name
  293. new_output.meta.update(output_node.meta)
  294. new_output.meta["from_node"] = [
  295. NodeSource(
  296. output_node,
  297. "ExportedProgram.module().unlift()",
  298. [NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
  299. )
  300. ]
  301. def _get_codegen(
  302. in_spec: pytree.TreeSpec,
  303. out_spec: pytree.TreeSpec | None,
  304. forward_arg_names: list[str] | None = None,
  305. ) -> _PyTreeCodeGen:
  306. """
  307. Create the codegen for the graph module based on the in/out specs
  308. """
  309. if forward_arg_names:
  310. names = forward_arg_names
  311. elif (
  312. in_spec.type is tuple
  313. and in_spec.num_children == 2
  314. and in_spec.child(0).type is tuple
  315. and in_spec.child(1).type is dict
  316. ):
  317. # if in_spec contains the args (tuple) and kwargs (dict)
  318. names = [f"arg_{i}" for i in range(in_spec.child(0).num_children)]
  319. # add kwarg names
  320. names.extend(in_spec.child(1).context)
  321. else:
  322. names = [f"arg_{i}" for i in range(in_spec.num_children)]
  323. return _PyTreeCodeGen(
  324. _PyTreeInfo(
  325. names,
  326. in_spec,
  327. out_spec,
  328. )
  329. )
  330. def _unlift(
  331. gm: torch.fx.GraphModule,
  332. lifted_inputs: Sequence[str | None],
  333. mutated_outputs: Sequence[str | None],
  334. in_spec: pytree.TreeSpec,
  335. out_spec: pytree.TreeSpec | None,
  336. forward_arg_names: list[str] | None = None,
  337. ):
  338. """
  339. Args:
  340. lifted_inputs: A list matching the graph module's input nodes. For
  341. an input node that is referring to a lifted parameter/buffer, this
  342. list will contain the fqn the corresponding attribute. Otherwise, this
  343. list will contain None. This is used to unlift the lifted parameters as
  344. get_attr nodes.
  345. mutated_outputs: A list matching the graph module's output nodes. For
  346. an output node that is referring to a mutated buffer or user input, this
  347. list will contain the name of the corresponding buffer or user input
  348. that needs to be mutated. Otherwise, this list will contain None. This
  349. is used to re-insert an inplace copy_ operator to copy the mutated
  350. values back to the original node.
  351. """
  352. unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr(
  353. gm, lifted_inputs
  354. )
  355. _insert_copy_for_mutations(
  356. gm, mutated_outputs, unlifted_name_to_node, input_name_to_node
  357. )
  358. gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names)
  359. gm.graph.lint()
  360. gm.recompile()
  361. return gm
  362. def _register_attrs_to_new_gm(
  363. new_gm: torch.fx.GraphModule,
  364. graph_signature: ExportGraphSignature,
  365. state_dict: dict[str, Any],
  366. constants: dict[str, Any],
  367. ) -> None:
  368. non_persistent_buffers = set(graph_signature.non_persistent_buffers)
  369. for name in graph_signature.buffers:
  370. if name in non_persistent_buffers:
  371. persistent = False
  372. value = constants[name]
  373. else:
  374. persistent = True
  375. value = state_dict[name]
  376. _assign_attr(
  377. value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent
  378. )
  379. for name in graph_signature.parameters:
  380. value = state_dict[name]
  381. _assign_attr(
  382. value,
  383. new_gm,
  384. name,
  385. attr_kind=_AttrKind.PARAMETER,
  386. )
  387. # Technically this doesn't account for the aliased multiple constants but
  388. # it is ok because we have a separate pass later in the stack that populates
  389. # the final gm.
  390. for name in chain(
  391. graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants
  392. ):
  393. value = constants[name]
  394. _assign_attr(
  395. value,
  396. new_gm,
  397. name,
  398. attr_kind=_AttrKind.CONSTANT,
  399. )
  400. class _StatefulGraphModuleFactory(type):
  401. """
  402. Metaclass that ensures a private constructor for _StatefulGraphModule
  403. """
  404. def __call__(cls, *args, **kwargs):
  405. raise TypeError(
  406. f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
  407. )
  408. def _create(cls, root, graph, range_constraints=None):
  409. return super().__call__(
  410. root,
  411. graph,
  412. range_constraints=range_constraints,
  413. )
  414. class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory):
  415. def __init__(self, root, graph, range_constraints=None):
  416. super().__init__(root, graph)
  417. # Need to fix up non-persistent buffers.
  418. self.range_constraints = range_constraints or []
  419. self.validate_inputs = True
  420. def _create_stateful_graph_module(
  421. plain_graph_module: torch.fx.GraphModule,
  422. range_constraints,
  423. ep: ExportedProgram,
  424. ) -> _StatefulGraphModule:
  425. stateful_gm = _StatefulGraphModule._create(
  426. plain_graph_module,
  427. plain_graph_module.graph,
  428. range_constraints=range_constraints,
  429. )
  430. module_types = _get_graph_inputs_of_type_nn_module(ep.example_inputs)
  431. stateful_gm.register_forward_pre_hook(
  432. lambda *args, **kwargs: _enter_enable_graph_inputs_of_type_nn_module(
  433. module_types
  434. )
  435. )
  436. stateful_gm.register_forward_pre_hook(
  437. _check_input_constraints_pre_hook, with_kwargs=True
  438. )
  439. stateful_gm.register_forward_hook(
  440. lambda *args, **kwargs: _exit_enable_graph_inputs_of_type_nn_module(
  441. module_types
  442. ),
  443. always_call=True,
  444. )
  445. # When we have a constant that has requires_grad=True, we need to detach it
  446. # when we unlift as the tensors that require gradients should be registered
  447. # via parameters. But this is problematic when we have aliasing two constants
  448. # because when we call detach, they will become different tensors. This dict
  449. # keeps track of this logic.
  450. original_tensor_to_detached_tensor = {}
  451. # Fix up lifted tensor constants.
  452. # fx.GraphModule() constructor silently turns a constant attribute of plain_graph_module
  453. # into a buffer in stateful_gm and creates an inconsistency with graph_signature.
  454. # We fix this by de-registering these buffers in lifted_tensor_constants
  455. # and call _assign_attr(attr_kind=CONSTANT) to register them as constants.
  456. for constant_fqn in ep.graph_signature.lifted_tensor_constants:
  457. # Sometimes, the constant can require gradient, this is probably a bug in user code,
  458. # e.g. `self.const = torch.randn(2, 2, requires_grad=True)`.
  459. # We call detach on the constant_val since they're tensor constants and we don't need to
  460. # compute their gradients anyway.
  461. # Users should properly register it as parameter if they want it to require gradient.
  462. buffer = stateful_gm.get_buffer(constant_fqn)
  463. if buffer.requires_grad:
  464. warnings.warn(
  465. f"A model attribute `{constant_fqn}` requires gradient. "
  466. f"but it's not properly registered as a parameter. "
  467. f"torch.export will detach it and treat it as a constant tensor "
  468. f"but please register it as parameter instead.",
  469. stacklevel=2,
  470. )
  471. detached_buffer = buffer.detach()
  472. original_tensor_to_detached_tensor[buffer] = detached_buffer
  473. buffer = detached_buffer
  474. *prefix, field = constant_fqn.rsplit(".")
  475. submod = torch.fx.graph_module._get_attr_via_attr_list(stateful_gm, prefix)
  476. delattr(submod, field)
  477. _assign_attr(buffer, stateful_gm, constant_fqn, attr_kind=_AttrKind.CONSTANT)
  478. # Constants are not preserved well when we create a new GraphModule unlike param/buffers
  479. for const_name, value in ep.constants.items():
  480. if not torch.fx.graph_module._has_attr(stateful_gm, const_name):
  481. if isinstance(value, torch.Tensor):
  482. if value.requires_grad:
  483. warnings.warn(
  484. f"A model attribute `{const_name}` requires gradient "
  485. f"but it's not properly registered as a parameter. "
  486. f"torch.export will detach it and treat it as a constant tensor "
  487. f"but please register it as parameter instead.",
  488. stacklevel=2,
  489. )
  490. if value in original_tensor_to_detached_tensor:
  491. value = original_tensor_to_detached_tensor[value]
  492. else:
  493. detached_value = value.detach()
  494. original_tensor_to_detached_tensor[value] = detached_value
  495. value = detached_value
  496. _assign_attr(
  497. value,
  498. stateful_gm,
  499. const_name,
  500. attr_kind=_AttrKind.CONSTANT,
  501. )
  502. # Fix up non-persistent buffers. torch.fx does not distinguish between
  503. # persistent and non-persistent buffers, so we must restore that distinction
  504. # here.
  505. for buffer in ep.graph_signature.non_persistent_buffers:
  506. _assign_attr(
  507. plain_graph_module.get_buffer(buffer),
  508. stateful_gm,
  509. buffer,
  510. attr_kind=_AttrKind.BUFFER,
  511. persistent=False,
  512. )
  513. return stateful_gm
  514. def _get_input_paths(example_inputs, signature):
  515. """
  516. Generate paths of placeholders, needed for generating the guards function.
  517. NOTE: Here we make use of the example inputs used for export as well as
  518. the signature of the unlifted graph module (not preserved by export).
  519. """
  520. args, kwargs = example_inputs
  521. binded = signature.bind(*args, **kwargs)
  522. binded.apply_defaults()
  523. ctx = binded.arguments
  524. flat_example_inputs_with_paths = pytree.tree_leaves_with_path(ctx)
  525. return [path for path, _ in flat_example_inputs_with_paths]
  526. def _replace_sources(result_str: str, flat_input_paths: list[Any]):
  527. """
  528. Given user specified input paths, maybe fix up the guard string
  529. to reflect user path instead of tracer path.
  530. """
  531. name_mapping = {}
  532. for idx, path in enumerate(flat_input_paths):
  533. name_mapping[f"L['flat_args'][{idx}]"] = f"L{pytree.keystr(path)}"
  534. replace = result_str
  535. for key, val in name_mapping.items():
  536. replace = replace.replace(key, val)
  537. return replace
  538. def _get_input_guards_for_graph(
  539. placeholders: list[torch.fx.Node],
  540. range_constraints: dict[sympy.Symbol, ValueRanges],
  541. paths_for_placeholders: list[pytree.KeyPath],
  542. ):
  543. """
  544. Guards generated by the tracer include conditions observed in code, but
  545. but do not include some additional checks we typically do in export.
  546. For example, when dynamic shapes get specialized, are specified to be
  547. within a range, or are specified to be in some equational relation,
  548. corresponding input invalidation is done within a pre_hook, specifically,
  549. `_check_input_constraints_for_graph`.
  550. Here we generate guards corresponding to the checks that happen in
  551. `_check_input_constraints_for_graph`, and add them to the guards already
  552. generated by the tracer. In the future, it may be worthwhile to separate
  553. them so that we can allow clients to turn off one but not the other.
  554. (Looking at you, AOTI.)
  555. NOTE: We should eventually reconcile this logic with `build_guards` that
  556. is used by AOT Precompile.
  557. """
  558. deferred_expressions = []
  559. new_guards_code = []
  560. sources: dict[sympy.Expr, str] = {}
  561. def handle_symint(expr, src):
  562. if len(expr.free_symbols) == 1:
  563. # complex equations (e.g., involving derived dims) need to
  564. # handled later, since we may not have enough information
  565. # just as we are passing through the placeholders in order
  566. deferred_expressions.append((src, expr))
  567. if expr in sources:
  568. # expressions that appear in multiple sources should force
  569. # inputs corresponding to those sources to be equal
  570. # e.g., x.shape[0] == y.shape[1]
  571. orig_src = sources[expr]
  572. new_guards_code.append(f"{src} == {orig_src}")
  573. else:
  574. sources[expr] = src
  575. # process value ranges as elsewhere in export
  576. min_val, max_val = _convert_range_to_int(range_constraints[expr])
  577. if min_val > 2:
  578. new_guards_code.append(f"{src} >= {min_val}")
  579. if max_val < math.inf:
  580. new_guards_code.append(f"{src} <= {max_val}")
  581. for placeholder, path in zip(placeholders, paths_for_placeholders):
  582. src = "L" + pytree.keystr(path)
  583. meta = placeholder.meta["val"]
  584. # specializations
  585. if isinstance(meta, int):
  586. new_guards_code.append(f"{src} == {meta}")
  587. if isinstance(meta, float):
  588. if meta == math.inf:
  589. new_guards_code.append(f"{src} == math.inf")
  590. elif meta == -math.inf:
  591. new_guards_code.append(f"{src} == -math.inf")
  592. else:
  593. new_guards_code.append(f"{src} == {meta}")
  594. elif isinstance(meta, str):
  595. new_guards_code.append(f"{src} == '{meta}'")
  596. # range constraints and equalities
  597. elif isinstance(meta, torch.SymInt) and meta.node.expr in range_constraints:
  598. handle_symint(meta.node.expr, src)
  599. elif isinstance(meta, torch.Tensor):
  600. for i, dim in enumerate(meta.shape):
  601. src = "L" + pytree.keystr(path) + f".size()[{i}]"
  602. if isinstance(dim, int):
  603. # specializations
  604. new_guards_code.append(f"{src} == {dim}")
  605. elif (
  606. isinstance(dim, torch.SymInt) and dim.node.expr in range_constraints
  607. ):
  608. # range constraints and equalities
  609. handle_symint(dim.node.expr, src)
  610. unification_map: dict[sympy.Symbol, sympy.Expr] = {}
  611. py_printer = torch.utils._sympy.printers.PythonPrinter()
  612. # process complex equations (e.g., involving derived dims)
  613. for src, expr in deferred_expressions:
  614. # we know this is the only symbol in expr (see check above)
  615. symbol = next(iter(expr.free_symbols))
  616. if symbol in sources:
  617. # if s0 is already known to be directly sourced from inputs,
  618. # e.g., z.shape[2], we do not need to do anything further
  619. # (assume we have already processed constraints on s0 above)
  620. continue
  621. # otherwise s0 has some "hidden" source like 'dim'
  622. # example: src = y.shape[1], expr = s0 + 1
  623. if symbol in unification_map:
  624. # suppose that we already know that s0 = x.shape[0] * 2
  625. # so we can emit the guard: x.shape[0] * 2 + 1 = y.shape[1]
  626. substitution = expr.subs(unification_map)
  627. new_guards_code.append(
  628. py_printer.doprint(sympy.Eq(substitution, sympy.Symbol(src)))
  629. )
  630. else:
  631. # we do not yet know what s0 is, but given s0 + 1 = y.shape[1],
  632. # we can solve for s0...now knowing that s0 = y.shape[1] - 1
  633. solution = try_solve(sympy.Eq(expr, sympy.Symbol(src)), symbol)
  634. if solution is not None:
  635. definition = solution[1]
  636. unification_map[symbol] = definition
  637. return new_guards_code
  638. def _ok_to_generate_guards_fn():
  639. patterns = [
  640. "executorch",
  641. "modai",
  642. "on_device_ai",
  643. "torchao",
  644. ]
  645. # force check_guards=False for files matching `patterns`
  646. # because they have too many calls to .module() and
  647. # do not like any call modules in the graph
  648. # TODO: fix these files to handle guard fns
  649. frame = inspect.currentframe()
  650. while frame is not None:
  651. if any(path in frame.f_code.co_filename for path in patterns):
  652. return False
  653. frame = frame.f_back
  654. return True
  655. def _unlift_exported_program_lifted_states(
  656. ep: ExportedProgram, check_guards=True
  657. ) -> torch.fx.GraphModule:
  658. check_guards = check_guards and _ok_to_generate_guards_fn()
  659. source_node_dict = {
  660. node.name: node for node in ep.graph.nodes if node.op != "placeholder"
  661. }
  662. # placeholder node name might change after deepcopy
  663. placeholder_source_node_dict = {
  664. node.target: node for node in ep.graph.nodes if node.op == "placeholder"
  665. }
  666. new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
  667. new_gm.meta.update(ep.graph_module.meta)
  668. ep = copy.copy(ep)
  669. ep._graph_signature = ExportGraphSignature(
  670. ep._graph_signature.input_specs, ep._graph_signature.output_specs
  671. )
  672. ep._graph_module = new_gm
  673. # TODO T206340015
  674. if ep.verifiers[0].dialect != "TRAINING":
  675. ep = _remove_effect_tokens(ep)
  676. _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
  677. forward_arg_names = (
  678. sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None
  679. )
  680. lifted_inputs: list[str | None] = [
  681. (
  682. in_spec.target
  683. if in_spec.kind
  684. in (
  685. InputKind.BUFFER,
  686. InputKind.CONSTANT_TENSOR,
  687. InputKind.PARAMETER,
  688. InputKind.CUSTOM_OBJ,
  689. )
  690. else None
  691. )
  692. for in_spec in ep.graph_signature.input_specs
  693. ]
  694. mutated_outputs: list[str | None] = [
  695. (
  696. out_spec.target
  697. if out_spec.kind
  698. in (
  699. OutputKind.BUFFER_MUTATION,
  700. OutputKind.USER_INPUT_MUTATION,
  701. OutputKind.PARAMETER_MUTATION,
  702. )
  703. else None
  704. )
  705. for out_spec in ep.graph_signature.output_specs
  706. ]
  707. for node in new_gm.graph.nodes:
  708. source_node = None
  709. if node.op == "placeholder":
  710. source_node = placeholder_source_node_dict.get(node.target)
  711. else:
  712. if node.name in source_node_dict:
  713. source_node = source_node_dict.get(node.name)
  714. node.meta["from_node"] = [
  715. NodeSource(
  716. source_node,
  717. "ExportedProgram.module()",
  718. NodeSourceAction.CREATE,
  719. )
  720. ]
  721. if ep.call_spec.in_spec is None:
  722. raise AssertionError("ep.call_spec.in_spec cannot be None")
  723. new_gm = _unlift(
  724. new_gm,
  725. lifted_inputs,
  726. mutated_outputs,
  727. ep.call_spec.in_spec,
  728. ep.call_spec.out_spec,
  729. forward_arg_names=forward_arg_names,
  730. )
  731. unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep)
  732. unlift_gm.meta.update(ep.graph_module.meta)
  733. # create a _guards_fn submodule and insert a call to it after placeholders
  734. graph = unlift_gm.graph
  735. placeholders = graph.find_nodes(op="placeholder")
  736. if check_guards and placeholders and ep.example_inputs:
  737. sig = inspect.signature(unlift_gm.forward)
  738. input_paths = _get_input_paths(
  739. ep.example_inputs,
  740. sig,
  741. )
  742. # TODO (tmanlaibaatar)
  743. # This is band-aid solution to export new tracer replacing
  744. # shape env sources to flat_args. The real fix should be replacing
  745. # shape env sources to original user sources but this is quite
  746. # involved because you need to carefully construct new sources using
  747. # dynamo and replace all instances of it inside shape env. But it is
  748. # lot easier to manipulate after we turn them into strings and only
  749. # time we use these guards is during retracing or running exported program,
  750. # so it is probably ok to have "not useful" guards on ep for now.
  751. ep_guards = []
  752. for guard in ep._guards_code:
  753. ep_guards.append(_replace_sources(guard, input_paths))
  754. guards_code = _get_input_guards_for_graph(
  755. placeholders, ep.range_constraints, input_paths
  756. )
  757. ep_guards_code = _force_ep_signature_match(ep._guards_code, input_paths)
  758. ep_guards_code = _force_gm_signature_match(ep_guards_code, sig)
  759. guards_code.extend(ep_guards_code)
  760. unlift_gm._guards_fn = _convert_guards_code_to_fn(guards_code, input_paths)
  761. root_nn_module_stack = torch.fx._utils.first_call_function_nn_module_stack(
  762. graph
  763. )
  764. with graph.inserting_after(placeholders[-1]):
  765. node = graph.call_module("_guards_fn", tuple(placeholders))
  766. node.meta["nn_module_stack"] = root_nn_module_stack
  767. unlift_gm.recompile()
  768. return unlift_gm
  769. class GuardsFn(torch.nn.Module):
  770. """
  771. Module class for guard functions.
  772. """
  773. def forward(self, *args):
  774. pass