graph_module.py 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213
  1. # mypy: allow-untyped-defs
  2. import base64
  3. import contextlib
  4. import copy
  5. import hashlib
  6. import itertools
  7. import linecache
  8. import os
  9. import sys
  10. import traceback
  11. import warnings
  12. from collections.abc import Callable
  13. from pathlib import Path
  14. from typing import Any, Optional, Union
  15. import torch
  16. import torch.nn as nn
  17. import torch.overrides
  18. from torch.nn.modules.module import _addindent
  19. from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
  20. from ._compatibility import compatibility
  21. from .experimental import _config as fx_experimental_config
  22. from .graph import (
  23. _BoxedCodeGen,
  24. _custom_builtins,
  25. _is_from_torch,
  26. _override_sym_repr,
  27. _PyTreeCodeGen,
  28. Graph,
  29. PythonCode,
  30. )
  31. __all__ = [
  32. "reduce_graph_module",
  33. "reduce_package_graph_module",
  34. "GraphModule",
  35. ]
  36. _USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
  37. FX_GRAPH_MODULE_FILE_PREFIX = "fx_generated_"
  38. # Normal exec loses the source code, however we can work with
  39. # the linecache module to recover it.
  40. # Using _exec_with_source will add it to our local cache
  41. # and then tools like TorchScript will be able to get source info.
  42. class _EvalCacheLoader:
  43. def __init__(self):
  44. self.eval_cache = {}
  45. self.next_id = 0
  46. def cache(self, src: str, globals: dict[str, Any], co_fields=None):
  47. """Store the source in a private cache, and add a lazy entry in linecache
  48. that allows the source to be retrieved by 'filename'.
  49. Args:
  50. src (str): The module source to cache
  51. globals (dict): The module globals
  52. Returns:
  53. str: The cache key (and dummy filename) generated for src.
  54. """
  55. key = self._get_key()
  56. if co_fields:
  57. if "co_filename" in co_fields:
  58. # If only co_filename is provided, use it directly as the key
  59. if "co_firstlineno" not in co_fields or "co_name" not in co_fields:
  60. key = co_fields["co_filename"]
  61. else:
  62. # Full co_fields with all three components
  63. key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
  64. self.eval_cache[key] = src
  65. # Don't mutate globals so that this loader is only used
  66. # to populate linecache, and doesn't interact with other modules
  67. # that might check `__loader__`
  68. globals_copy = globals.copy()
  69. globals_copy["__file__"] = key
  70. globals_copy["__name__"] = key
  71. globals_copy["__loader__"] = self
  72. linecache.lazycache(key, globals_copy)
  73. return key
  74. # Part of the loader protocol (PEP 302)
  75. # linecache will use this method when trying to find source code
  76. def get_source(self, module_name) -> Optional[str]:
  77. if module_name in self.eval_cache:
  78. return self.eval_cache[module_name]
  79. return None
  80. def _get_key(self):
  81. key = f"<eval_with_key>.{self.next_id}"
  82. self.next_id += 1
  83. return key
  84. _loader = _EvalCacheLoader()
  85. def _exec_with_source(src: str, globals: dict[str, Any], co_fields=None):
  86. key = _loader.cache(src, globals, co_fields)
  87. exec(compile(src, key, "exec"), globals)
  88. def _forward_from_src(src: str, globals: dict[str, Any], co_fields=None):
  89. return _method_from_src(
  90. method_name="forward", src=src, globals=globals, co_fields=co_fields
  91. )
  92. def _method_from_src(
  93. method_name: str, src: str, globals: dict[str, Any], co_fields=None
  94. ) -> Callable:
  95. # avoid mutating the passed in dict
  96. globals_copy = globals.copy()
  97. _exec_with_source(src, globals_copy, co_fields)
  98. fn = globals_copy[method_name]
  99. del globals_copy[method_name]
  100. return fn
  101. def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
  102. if name in _custom_builtins:
  103. return _custom_builtins[name].import_str
  104. if _is_from_torch(name):
  105. return "import torch"
  106. module_name, attr_name = importer.get_name(obj)
  107. return f"from {module_name} import {attr_name} as {name}"
  108. def _format_import_block(globals: dict[str, Any], importer: Importer):
  109. import_strs: set[str] = {
  110. _format_import_statement(name, obj, importer) for name, obj in globals.items()
  111. }
  112. # Sort the imports so we have a stable import block that allows us to
  113. # hash the graph module and get a consistent key for use in a cache.
  114. return "\n".join(sorted(import_strs))
  115. @compatibility(is_backward_compatible=True)
  116. def reduce_graph_module(body: dict[Any, Any], import_block: str) -> torch.nn.Module:
  117. # BC: attribute name was changed from `code` to `_code` to facilitate
  118. # making `code` into a property and adding a docstring to it
  119. fn_src = body.get("_code") or body["code"]
  120. forward = _forward_from_src(import_block + fn_src, {})
  121. return _deserialize_graph_module(forward, body)
  122. @compatibility(is_backward_compatible=True)
  123. def reduce_package_graph_module(
  124. importer: PackageImporter, body: dict[Any, Any], generated_module_name: str
  125. ) -> torch.nn.Module:
  126. forward = importer.import_module(generated_module_name).forward
  127. return _deserialize_graph_module(forward, body)
  128. # We create a dummy class here because symbolic_trace pulls the forward()
  129. # function off of the class, rather than the instance. This class is used
  130. # in _deserialize_graph_module() below.
  131. class _CodeOnlyModule(torch.nn.Module):
  132. def __init__(self, body):
  133. super().__init__()
  134. self.__dict__ = body
  135. def _deserialize_graph_module(
  136. forward, body: dict[Any, Any], graph_module_cls=None
  137. ) -> torch.nn.Module:
  138. """
  139. Deserialize a GraphModule given the dictionary of the original module,
  140. using the code to reconstruct the graph. We delete the actual graph before
  141. saving the dictionary so that changes to the in-memory graph format do not
  142. get serialized.
  143. """
  144. # Try to retrieve the forward source in a backward-compatible way
  145. _CodeOnlyModule.forward = forward
  146. tracer_cls = body.get("_tracer_cls")
  147. if tracer_cls is None:
  148. from ._symbolic_trace import Tracer
  149. tracer_cls = Tracer
  150. graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule")
  151. # This is a workaround for a mypy linter issue related to
  152. # passing base class as an argument - https://github.com/python/mypy/issues/5865.
  153. cls_tracer: Any = tracer_cls
  154. class KeepModules(cls_tracer):
  155. # we shouldn't trace into any of the submodules,
  156. # because they were not traced in the original GraphModule
  157. def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
  158. return True
  159. com = _CodeOnlyModule(body)
  160. tracer_extras = body.get("_tracer_extras", {})
  161. graph = KeepModules().trace(com, **tracer_extras)
  162. # Recover node.meta["stack_trace"] after re-tracing
  163. node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace")
  164. if node_meta_stack_trace is not None:
  165. del body["_graphmodule_graph_node_meta_stack_trace"]
  166. for node in graph.nodes:
  167. if node_meta_stack_trace.get(node.name, None) is not None:
  168. node.meta["stack_trace"] = node_meta_stack_trace[node.name]
  169. # Manually set Tracer class on the reconstructed Graph, to avoid
  170. # referencing the private local subclass KeepModules.
  171. graph._tracer_cls = tracer_cls
  172. from ._lazy_graph_module import _make_graph_module
  173. gm = _make_graph_module(
  174. com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls
  175. )
  176. # The GraphModule constructor only retains attributes referenced by the graph.
  177. # In this case, our goal is return a GraphModule as close to identical as the one
  178. # put into the package. If any additional attributes were present in body,
  179. # we should keep them.
  180. for k, v in body.items():
  181. if not hasattr(gm, k):
  182. setattr(gm, k, v)
  183. return gm
  184. # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
  185. # This installs empty Modules where none exist yet if they are subpaths of target
  186. def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
  187. *prefix, field = target.split(".")
  188. for item in prefix:
  189. f = getattr(from_module, item)
  190. t = getattr(to_module, item, None)
  191. if f is t:
  192. # we have already installed one of its parents
  193. # (e.g. target = root.linear.weight, but we have already installed root.linear)
  194. # once we install a parent, we no longer need to copy the children
  195. # since all the needed properties will already be present
  196. return
  197. if t is None:
  198. t = torch.nn.Module()
  199. setattr(to_module, item, t)
  200. from_module, to_module = f, t
  201. orig = getattr(from_module, field)
  202. # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
  203. # So, we register it as a named buffer in the target module.
  204. if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
  205. to_module.register_buffer(field, orig)
  206. else:
  207. setattr(to_module, field, orig)
  208. # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
  209. # This installs empty Modules where none exist yet if they are subpaths of target
  210. def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
  211. *prefix, field = target.split(".")
  212. for item in prefix:
  213. t = getattr(to_module, item, None)
  214. if t is None:
  215. t = torch.nn.Module()
  216. setattr(to_module, item, t)
  217. to_module = t
  218. # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
  219. # So, we register it as a named buffer in the target module.
  220. if isinstance(from_obj, torch.Tensor) and not isinstance(
  221. from_obj, torch.nn.Parameter
  222. ):
  223. to_module.register_buffer(field, from_obj)
  224. else:
  225. setattr(to_module, field, from_obj)
  226. # Recursively look up target from a graph module.
  227. def _get_attr(model: torch.nn.Module, attr_name: str):
  228. return _get_attr_via_attr_list(model, attr_name.split("."))
  229. def _del_attr(model: torch.nn.Module, attr_name: str):
  230. attr_names = attr_name.split(".")
  231. t = _get_attr_via_attr_list(model, attr_names[:-1])
  232. return delattr(t, attr_names[-1])
  233. def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: list[str]):
  234. if len(attr_list) == 0:
  235. return model
  236. *prefix, field = attr_list
  237. t = model
  238. for item in prefix:
  239. t = getattr(t, item, None) # type: ignore[assignment]
  240. if t is None:
  241. raise AssertionError(f"Attribute '{item}' not found in model")
  242. return getattr(t, field)
  243. def _has_attr(model: torch.nn.Module, attr_name: str):
  244. *prefix, field = attr_name.split(".")
  245. t = model
  246. for item in prefix:
  247. t = hasattr(t, item) # type: ignore[assignment]
  248. if t is False:
  249. return False
  250. return hasattr(t, field)
  251. def _print_readable(
  252. module,
  253. module_name,
  254. print_output=True,
  255. include_stride=False,
  256. include_device=False,
  257. colored=False,
  258. expanded_def=False,
  259. additional_meta=None,
  260. ):
  261. graph = module.graph
  262. if graph is None or not isinstance(graph, torch.fx.Graph):
  263. raise AssertionError("print_readable must be used on a module with a graph")
  264. verbose_python_code = graph.python_code(
  265. root_module="self",
  266. verbose=True,
  267. include_stride=include_stride,
  268. include_device=include_device,
  269. colored=colored,
  270. expanded_def=expanded_def,
  271. additional_meta=additional_meta,
  272. )
  273. module_code = verbose_python_code.src
  274. module_code = module_code.lstrip("\n")
  275. module_code = f"class {module_name}(torch.nn.Module):\n" + module_code
  276. module_code = _addindent(module_code, 4)
  277. submodule_code_list = [""]
  278. for submodule_name, submodule in module.named_children():
  279. if hasattr(submodule, "graph"):
  280. submodule_code_list.append(
  281. _print_readable(
  282. submodule,
  283. submodule_name,
  284. print_output=False,
  285. include_stride=include_stride,
  286. include_device=include_device,
  287. colored=colored,
  288. additional_meta=additional_meta,
  289. )
  290. )
  291. submodule_code = "\n".join(submodule_code_list)
  292. submodule_code = _addindent(submodule_code, 4)
  293. output = module_code + submodule_code
  294. if print_output:
  295. print(module_code + submodule_code)
  296. return output
  297. def _metadata_hash(code: str, node_metadata: dict) -> str:
  298. """
  299. Create a content-addressed hash from code and metadata.
  300. Args:
  301. code: The source code string
  302. lineno_map: Mapping from line numbers to node indices
  303. node_metadata: Metadata for each node
  304. Returns:
  305. A 51-character base32-encoded hash
  306. """
  307. import json
  308. # Create a deterministic string representation of all components
  309. # We use JSON to ensure consistent serialization
  310. hash_data = {
  311. "code": code,
  312. "node_metadata": node_metadata,
  313. }
  314. hashing_str = json.dumps(hash_data).encode("utf-8")
  315. # [:51] to strip off the "Q====" suffix common to every hash value.
  316. return (
  317. base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
  318. .decode("utf-8")
  319. .lower()
  320. )
  321. class _WrappedCall:
  322. def __init__(self, cls, cls_call):
  323. self.cls = cls
  324. self.cls_call = cls_call
  325. # Previously, if an error occurred when valid
  326. # symbolically-traced code was run with an invalid input, the
  327. # user would see the source of the error as coming from
  328. # `File "<eval_with_key_N">`, where N is some number. We use
  329. # this function to generate a more informative error message. We
  330. # return the traceback itself, a message explaining that the
  331. # error occurred in a traced Module's generated forward
  332. # function, and five lines of context surrounding the faulty
  333. # line
  334. @staticmethod
  335. def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
  336. # auxiliary variables (for readability)
  337. err_lineno = frame_summary.lineno
  338. if err_lineno is None:
  339. raise AssertionError("frame_summary.lineno is None")
  340. line = frame_summary.line
  341. if line is None:
  342. raise AssertionError("frame_summary.line is None")
  343. err_line_len = len(line)
  344. all_src_lines = linecache.getlines(frame_summary.filename)
  345. # constituent substrings of the error message
  346. tb_repr = torch._dynamo.disable(
  347. traceback.format_exc,
  348. reason="do not trace into traceback.format_exc when generating error message",
  349. )()
  350. custom_msg = (
  351. "Call using an FX-traced Module, "
  352. f"line {err_lineno} of the traced Module's "
  353. "generated forward function:"
  354. )
  355. before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
  356. marker = "~" * err_line_len + "~~~ <--- HERE"
  357. err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
  358. # joined message
  359. return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
  360. def __call__(self, obj, *args, **kwargs):
  361. try:
  362. if self.cls_call is not None:
  363. return self.cls_call(obj, *args, **kwargs)
  364. else:
  365. return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
  366. except Exception as e:
  367. if not e.__traceback__:
  368. raise AssertionError("Exception has no traceback") from e
  369. topmost_framesummary: traceback.FrameSummary = (
  370. traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1]
  371. )
  372. if "eval_with_key" in topmost_framesummary.filename:
  373. print(
  374. _WrappedCall._generate_error_message(topmost_framesummary),
  375. file=sys.stderr,
  376. )
  377. raise e.with_traceback(None) # noqa: B904
  378. else:
  379. raise e
  380. @compatibility(is_backward_compatible=True)
  381. class GraphModule(torch.nn.Module):
  382. """
  383. GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
  384. ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
  385. from that ``graph``.
  386. .. warning::
  387. When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
  388. regenerated. However, if you edit the contents of the ``graph`` without reassigning
  389. the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
  390. code.
  391. """
  392. def __new__(cls: "type[GraphModule]", *args, **kwargs):
  393. # each instance of a graph module needs its own forward method
  394. # so create a new singleton class for each instance.
  395. # it is a subclass of the user-defined class, the only difference
  396. # is an extra layer to install the forward method
  397. # address issue described at https://github.com/pytorch/pytorch/issues/63883
  398. # in other words, traverse class hierarchy to fix the redundant class definition problem
  399. for t in cls.__mro__:
  400. c = t.__qualname__.split(".")[-1]
  401. if c != "GraphModuleImpl":
  402. cls = t
  403. break
  404. class GraphModuleImpl(cls): # type: ignore[misc, valid-type]
  405. pass
  406. return super().__new__(GraphModuleImpl)
  407. @compatibility(is_backward_compatible=True)
  408. def __init__(
  409. self,
  410. root: Union[torch.nn.Module, dict[str, Any]],
  411. graph: Graph,
  412. class_name: str = "GraphModule",
  413. ):
  414. """
  415. Construct a GraphModule.
  416. Args:
  417. root (Union[torch.nn.Module, Dict[str, Any]):
  418. ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
  419. In the case that ``root`` is a Module, any references to Module-based objects (via qualified
  420. name) in the Graph's Nodes' ``target`` field will be copied over from the respective place
  421. within ``root``'s Module hierarchy into the GraphModule's module hierarchy.
  422. In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be
  423. looked up directly in the dict's keys. The object mapped to by the Dict will be copied
  424. over into the appropriate place within the GraphModule's module hierarchy.
  425. graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation
  426. class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
  427. error messages will report as originating from ``GraphModule``. It may be helpful to set this
  428. to ``root``'s original name or a name that makes sense within the context of your transform.
  429. """
  430. super().__init__()
  431. self.__class__.__name__ = class_name
  432. if isinstance(root, torch.nn.Module):
  433. if hasattr(root, "training"):
  434. self.training = root.training
  435. # When we pickle/unpickle graph module, we don't want to drop any module or attributes.
  436. if isinstance(root, _CodeOnlyModule):
  437. for k, _ in root.named_children():
  438. _copy_attr(root, self, k)
  439. for k, _ in root.named_buffers():
  440. _copy_attr(root, self, k)
  441. for k, _ in root.named_parameters():
  442. _copy_attr(root, self, k)
  443. for node in graph.nodes:
  444. if node.op in ["get_attr", "call_module"]:
  445. if not isinstance(node.target, str):
  446. raise AssertionError(
  447. f"Expected node.target to be str, got {type(node.target)}"
  448. )
  449. _copy_attr(root, self, node.target)
  450. elif isinstance(root, dict):
  451. targets_to_copy = []
  452. for node in graph.nodes:
  453. if node.op in ["get_attr", "call_module"]:
  454. if not isinstance(node.target, str):
  455. raise AssertionError(
  456. f"Expected node.target to be str, got {type(node.target)}"
  457. )
  458. if node.target not in root:
  459. raise RuntimeError(
  460. "Node "
  461. + str(node)
  462. + " referenced target "
  463. + node.target
  464. + " but that target was not provided in ``root``!"
  465. )
  466. targets_to_copy.append(node.target)
  467. # Sort targets in ascending order of the # of atoms.
  468. # This will ensure that less deeply nested attributes are assigned
  469. # before more deeply nested attributes. For example, foo.bar
  470. # will be assigned before foo.bar.baz. Otherwise, we might assign
  471. # the user-provided ``foo.bar`` and wipe out the previously-assigned
  472. # ``foo.bar.baz``
  473. targets_to_copy.sort(key=lambda t: t.count("."))
  474. for target_to_copy in targets_to_copy:
  475. _assign_attr(root[target_to_copy], self, target_to_copy)
  476. else:
  477. raise RuntimeError("Unsupported type " + str(root) + " passed for root!")
  478. self.graph = graph
  479. # Store the Tracer class responsible for creating a Graph separately as part of the
  480. # GraphModule state, except when the Tracer is defined in a local namespace.
  481. # Locally defined Tracers are not pickleable. This is needed because torch.package will
  482. # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
  483. # to re-create the Graph during deserialization.
  484. self._tracer_cls = None
  485. if (
  486. self.graph._tracer_cls
  487. and "<locals>" not in self.graph._tracer_cls.__qualname__
  488. ):
  489. # pyrefly: ignore [bad-assignment]
  490. self._tracer_cls = self.graph._tracer_cls
  491. self._tracer_extras = {}
  492. if self.graph._tracer_extras:
  493. self._tracer_extras = self.graph._tracer_extras
  494. # Dictionary to store metadata
  495. self.meta: dict[str, Any] = {}
  496. self._replace_hooks: list[Callable] = []
  497. self._create_node_hooks: list[Callable] = []
  498. self._erase_node_hooks: list[Callable] = []
  499. # Used to remove hooks from deepcopied graph modules within a context manager.
  500. self._deepcopy_hooks: list[Callable] = []
  501. self.shape_env = None # optional not always set even when dynamic shapes exist.
  502. # TorchScript breaks trying to compile the graph setter because of the
  503. # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
  504. #
  505. # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
  506. __jit_unused_properties__ = ["graph", "_boxed_call"]
  507. @property
  508. def _boxed_call(self) -> bool:
  509. return isinstance(self._graph._codegen, _BoxedCodeGen)
  510. @property
  511. def graph(self) -> Graph:
  512. """
  513. Return the ``Graph`` underlying this ``GraphModule``
  514. """
  515. return self._graph
  516. @graph.setter
  517. def graph(self, g: Graph) -> None:
  518. """
  519. Set the underlying ``Graph`` for this ``GraphModule``. This will internally
  520. recompile the ``GraphModule`` so that the generated ``forward()`` function
  521. corresponds to ``g``
  522. """
  523. if not isinstance(g, Graph):
  524. raise AssertionError(f"Expected a Graph instance, but got {type(g)}")
  525. self._graph = g
  526. g.owning_module = self
  527. self.recompile()
  528. @compatibility(is_backward_compatible=False)
  529. def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
  530. """Dumps out module to ``folder`` with ``module_name`` so that it can be
  531. imported with ``from <folder> import <module_name>``
  532. Args:
  533. folder (Union[str, os.PathLike]): The folder to write the code out to
  534. module_name (str): Top-level name to use for the ``Module`` while
  535. writing out the code
  536. """
  537. folder = Path(folder)
  538. Path(folder).mkdir(exist_ok=True)
  539. torch.save(self.state_dict(), folder / "state_dict.pt")
  540. tab = " " * 4
  541. custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()])
  542. model_str = f"""
  543. import torch
  544. {custom_builtins}
  545. from torch.nn import *
  546. class {module_name}(torch.nn.Module):
  547. def __init__(self):
  548. super().__init__()
  549. """
  550. def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
  551. safe_reprs = [
  552. nn.Linear,
  553. nn.Conv1d,
  554. nn.Conv2d,
  555. nn.Conv3d,
  556. nn.BatchNorm1d,
  557. nn.BatchNorm2d,
  558. nn.BatchNorm3d,
  559. ]
  560. if type(module) in safe_reprs:
  561. return f"{module.__repr__()}"
  562. else:
  563. return None
  564. blobified_modules = []
  565. for module_name, module in self.named_children():
  566. module_str = _gen_model_repr(module_name, module)
  567. if module_str is None:
  568. module_file = folder / f"{module_name}.pt"
  569. torch.save(module, module_file)
  570. blobified_modules.append(module_name)
  571. module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
  572. # weights_only=False as this is legacy code that saves the model
  573. module_load_str = f"torch.load(r'{module_file}', weights_only=False)"
  574. model_str += f"{tab * 2}setattr(self, '{module_name}', {module_load_str}) # {module_repr}\n"
  575. else:
  576. model_str += f"{tab * 2}setattr(self, '{module_name}', {module_str})\n"
  577. for buffer_name, buffer in self._buffers.items():
  578. if buffer is None:
  579. continue
  580. model_str += f"{tab * 2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" # noqa: B950
  581. for param_name, param in self._parameters.items():
  582. if param is None:
  583. continue
  584. model_str += f"{tab * 2}setattr(self, '{param_name}', torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype})))\n" # noqa: B950
  585. model_str += (
  586. f"{tab * 2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
  587. )
  588. model_str += f"{_addindent(self.code, 4)}\n"
  589. module_file = folder / "module.py"
  590. module_file.write_text(model_str)
  591. init_file = folder / "__init__.py"
  592. init_file.write_text("from .module import *")
  593. if len(blobified_modules) > 0:
  594. warnings.warn(
  595. "Was not able to save the following children modules as reprs -"
  596. f"saved as pickled files instead: {blobified_modules}"
  597. )
  598. @compatibility(is_backward_compatible=True)
  599. def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
  600. """
  601. Adds the given submodule to ``self``.
  602. This installs empty Modules where none exist yet if they are
  603. subpaths of ``target``.
  604. Args:
  605. target: The fully-qualified string name of the new submodule
  606. (See example in ``nn.Module.get_submodule`` for how to
  607. specify a fully-qualified string.)
  608. m: The submodule itself; the actual object we want to
  609. install in the current Module
  610. Return:
  611. bool: Whether or not the submodule could be inserted. For
  612. this method to return True, each object in the chain
  613. denoted by ``target`` must either a) not exist yet,
  614. or b) reference an ``nn.Module`` (not a parameter or
  615. other attribute)
  616. """
  617. *prefix, field = target.split(".")
  618. mod: torch.nn.Module = self
  619. for item in prefix:
  620. submod = getattr(mod, item, None)
  621. if submod is None:
  622. submod = torch.nn.Module()
  623. setattr(mod, item, submod)
  624. if not isinstance(submod, torch.nn.Module):
  625. return False
  626. mod = submod
  627. mod.add_module(field, m)
  628. return True
  629. @compatibility(is_backward_compatible=True)
  630. def delete_submodule(self, target: str) -> bool:
  631. """
  632. Deletes the given submodule from ``self``.
  633. The module will not be deleted if ``target`` is not a valid
  634. target.
  635. Args:
  636. target: The fully-qualified string name of the new submodule
  637. (See example in ``nn.Module.get_submodule`` for how to
  638. specify a fully-qualified string.)
  639. Returns:
  640. bool: Whether or not the target string referenced a
  641. submodule we want to delete. A return value of ``False``
  642. means that the ``target`` was not a valid reference to
  643. a submodule.
  644. """
  645. atoms = target.split(".")
  646. path, target_submod = atoms[:-1], atoms[-1]
  647. mod: torch.nn.Module = self
  648. # Get the parent module
  649. for item in path:
  650. if not hasattr(mod, item):
  651. return False
  652. mod = getattr(mod, item)
  653. if not isinstance(mod, torch.nn.Module):
  654. return False
  655. if not hasattr(mod, target_submod):
  656. return False
  657. if not isinstance(getattr(mod, target_submod), torch.nn.Module):
  658. return False
  659. delattr(mod, target_submod)
  660. return True
  661. @compatibility(is_backward_compatible=True)
  662. def delete_all_unused_submodules(self) -> None:
  663. """
  664. Deletes all unused submodules from ``self``.
  665. A Module is considered "used" if any one of the following is
  666. true:
  667. 1. It has children that are used
  668. 2. Its forward is called directly via a ``call_module`` node
  669. 3. It has a non-Module attribute that is used from a
  670. ``get_attr`` node
  671. This method can be called to clean up an ``nn.Module`` without
  672. manually calling ``delete_submodule`` on each unused submodule.
  673. """
  674. used: list[str] = []
  675. for node in self.graph.nodes:
  676. if node.op == "call_module" or node.op == "get_attr":
  677. # A list of strings representing the different parts
  678. # of the path. For example, `foo.bar.baz` gives us
  679. # ["foo", "bar", "baz"]
  680. fullpath = node.target.split(".")
  681. # If we're looking at multiple parts of a path, join
  682. # join them with a dot. Otherwise, return that single
  683. # element without doing anything to it.
  684. def join_fn(x: str, y: str) -> str:
  685. return ".".join([x, y] if y else [x])
  686. # Progressively collect all the names of intermediate
  687. # modules. For example, if we have the target
  688. # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
  689. # `foo.bar.baz` to the list.
  690. used.extend(itertools.accumulate(fullpath, join_fn))
  691. # For a `call_module` node, also register all recursive submodules
  692. # as used
  693. if node.op == "call_module":
  694. try:
  695. submod = self.get_submodule(node.target)
  696. for submod_name, _ in submod.named_modules():
  697. if submod_name != "":
  698. used.append(".".join([node.target, submod_name]))
  699. except AttributeError:
  700. # Node referenced nonexistent submodule, don't need to
  701. # worry about GCing anything
  702. pass
  703. to_delete = [name for name, _ in self.named_modules() if name not in used]
  704. for name in to_delete:
  705. self.delete_submodule(name)
  706. @property
  707. def code(self) -> str:
  708. """
  709. Return the Python code generated from the ``Graph`` underlying this
  710. ``GraphModule``.
  711. """
  712. if not hasattr(self, "_code"):
  713. raise RuntimeError(
  714. "Code has not been generated! Please report a bug to PyTorch"
  715. )
  716. return self._code
  717. @compatibility(is_backward_compatible=True)
  718. def recompile(self) -> PythonCode:
  719. """
  720. Recompile this GraphModule from its ``graph`` attribute. This should be
  721. called after editing the contained ``graph``, otherwise the generated
  722. code of this ``GraphModule`` will be out of date.
  723. """
  724. # Do not import anything inside recompile, it might slow down the
  725. # function and cause perf regression. Import outside of the method instead.
  726. if isinstance(self._graph._codegen, _PyTreeCodeGen):
  727. self._in_spec = self._graph._codegen.pytree_info.in_spec
  728. self._out_spec = self._graph._codegen.pytree_info.out_spec
  729. python_code = self._graph.python_code(
  730. root_module="self",
  731. record_func=fx_experimental_config.enrich_profiler_metadata,
  732. )
  733. self._code = python_code.src
  734. self._lineno_map = python_code._lineno_map
  735. self._prologue_start = python_code._prologue_start
  736. cls = type(self)
  737. co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
  738. if fx_experimental_config.enrich_profiler_metadata:
  739. # Generate metadata and register for profiler augmentation
  740. node_metadata: dict[int, dict[str, Any]] = {}
  741. for i, node in enumerate(self._graph.nodes):
  742. node_metadata[i] = {
  743. "name": node.name,
  744. "op": node.op,
  745. "target": str(node.target),
  746. "stack_trace": node.meta.get("stack_trace", None),
  747. }
  748. # Generate a content-addressed filename based on hash of code and metadata
  749. # This ensures the same code+metadata always generates the same filename
  750. hash_value = _metadata_hash(self._code, node_metadata)
  751. file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
  752. filename = f"{file_stem}.py"
  753. # Only include co_filename to use it directly as the cache key
  754. co_fields = {
  755. "co_filename": filename,
  756. }
  757. # Store metadata in global in-memory registry
  758. metadata = {
  759. "lineno_map": python_code._lineno_map,
  760. "prologue_start": python_code._prologue_start,
  761. "node_metadata": node_metadata,
  762. }
  763. # Register metadata in the global registry
  764. from torch.fx.traceback import _register_fx_metadata
  765. _register_fx_metadata(filename, metadata)
  766. # Replace the placeholder in generated code with actual filename
  767. # The double hash ## convention is used by post-processing to find the fx markers
  768. self._code = self._code.replace(
  769. "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')",
  770. f"torch._C._profiler._RecordFunctionFast('## {filename} ##')",
  771. )
  772. cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
  773. # Determine whether this class explicitly defines a __call__ implementation
  774. # to wrap. If it does, save it in order to have wrapped_call invoke it.
  775. # If it does not, wrapped_call can use a dynamic call to super() instead.
  776. # In most cases, super().__call__ should be torch.nn.Module.__call__.
  777. # We do not want to hold a reference to Module.__call__ here; doing so will
  778. # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
  779. cls_call = cls.__call__ if "__call__" in vars(cls) else None
  780. if "_wrapped_call" not in vars(cls):
  781. cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
  782. self._recompile_submodules()
  783. def call_wrapped(self, *args, **kwargs):
  784. return self._wrapped_call(self, *args, **kwargs)
  785. cls.__call__ = call_wrapped # type: ignore[method-assign]
  786. return python_code
  787. def _recompile_submodules(self) -> list[tuple[str, PythonCode]]:
  788. """
  789. Recompile all submodules of this graph module, returning their respective PythonCodes
  790. in a similar format to named_children()
  791. """
  792. results: list[tuple[str, PythonCode]] = []
  793. for name, mod in self.named_children():
  794. if isinstance(mod, GraphModule):
  795. results.append((name, mod.recompile()))
  796. return results
  797. # Passing Tracer as argument allows subclasses extending fx.GraphModule
  798. # define their own Tracer (extending fx.Tracer).
  799. def __reduce_package__(self, exporter: PackageExporter):
  800. dict_without_graph = self.__dict__.copy()
  801. dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
  802. del dict_without_graph["_graph"]
  803. # Store node.meta["stack_trace"] so we can recover them after re-tracing during deserialization
  804. node_meta_stack_trace = {
  805. node.name: node.meta["stack_trace"]
  806. for node in self.graph.nodes
  807. if "stack_trace" in node.meta
  808. }
  809. dict_without_graph["_graphmodule_graph_node_meta_stack_trace"] = (
  810. node_meta_stack_trace
  811. )
  812. generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
  813. python_code = self.recompile()
  814. import_block = _format_import_block(python_code.globals, exporter.importer)
  815. module_code = import_block + self.code
  816. exporter.save_source_string(generated_module_name, module_code)
  817. return (
  818. reduce_package_graph_module,
  819. (dict_without_graph, generated_module_name),
  820. )
  821. def __reduce__(self):
  822. """
  823. Serialization of GraphModule. We serialize only the generated code, not
  824. the underlying ``Graph``. This is because ``Graph`` does not have on-disk
  825. backward-compatibility guarantees, whereas Python source code does.
  826. On the deserialization side, we symbolically trace through the generated
  827. code to regenerate the underlying ``Graph``
  828. """
  829. dict_without_graph = self.__dict__.copy()
  830. python_code = self.recompile()
  831. import_block = _format_import_block(python_code.globals, sys_importer)
  832. del dict_without_graph["_graph"]
  833. return (reduce_graph_module, (dict_without_graph, import_block))
  834. def _deepcopy_init(self):
  835. return GraphModule.__init__
  836. # because __reduce__ is defined for serialization,
  837. # we need to define deepcopy otherwise it will call __reduce__
  838. # and cause symbolic tracing to occur every time we try to copy the object
  839. def __deepcopy__(self, memo):
  840. res = type(self).__new__(type(self))
  841. memo[id(self)] = res
  842. fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
  843. self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"])
  844. # hooks are lost during `GraphModule.__init__`, so we need to copy over
  845. # them explicitly, note right now we are only copying state_dict related
  846. # hooks, to reduce bc-related issues, we can copy forward/backward related
  847. # hooks in the future as well if needed
  848. extra_preserved_attrs = [
  849. "_state_dict_hooks",
  850. "_load_state_dict_pre_hooks",
  851. "_load_state_dict_post_hooks",
  852. "_replace_hooks",
  853. "_create_node_hooks",
  854. "_erase_node_hooks",
  855. "_deepcopy_hooks",
  856. ]
  857. for attr in extra_preserved_attrs:
  858. if attr in self.__dict__:
  859. setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
  860. res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
  861. if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
  862. for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
  863. setattr(res, attr_name, attr)
  864. if hasattr(self, "_deepcopy_hooks"):
  865. for hook in self._deepcopy_hooks:
  866. hook(res)
  867. return res
  868. def __copy__(self):
  869. from ._lazy_graph_module import _make_graph_module
  870. res = _make_graph_module(self, self.graph)
  871. res.meta = getattr(self, "meta", {})
  872. return res
  873. @compatibility(is_backward_compatible=False)
  874. def print_readable(
  875. self,
  876. print_output=True,
  877. include_stride=False,
  878. include_device=False,
  879. colored=False,
  880. *,
  881. # If `fast_sympy_print` is True then we use a sympy printer which is faster
  882. # but may result in less-readable output.
  883. fast_sympy_print: bool = False,
  884. expanded_def: bool = False,
  885. additional_meta: Optional[list[str]] = None,
  886. ):
  887. """
  888. Return the Python code generated for current GraphModule and its children GraphModules.
  889. Args:
  890. additional_meta: Optional list of meta keys to include in the output.
  891. For each key in the list, if it exists in node.meta, its value
  892. will be shown in the format "key: value".
  893. Example: `print_readable(additional_meta=["seq_nr"])`.
  894. """
  895. ctx_mgr = contextlib.ExitStack()
  896. with ctx_mgr:
  897. if fast_sympy_print:
  898. from torch._inductor.utils import sympy_str
  899. def fast_repr(expr: torch.types.PySymType) -> str:
  900. return sympy_str(expr.node.expr)
  901. ctx_mgr.enter_context(_override_sym_repr(fast_repr))
  902. r = _print_readable(
  903. self,
  904. self._get_name(),
  905. print_output,
  906. include_stride,
  907. include_device,
  908. colored,
  909. expanded_def,
  910. additional_meta,
  911. )
  912. return r
  913. def __str__(self) -> str:
  914. orig_str = super().__str__()
  915. print_readable_reminder = (
  916. "# To see more debug info, please use `graph_module.print_readable()`"
  917. )
  918. return "\n".join([orig_str, self._code, print_readable_reminder])
  919. def _replicate_for_data_parallel(self):
  920. new_gm = self.__copy__()
  921. new_gm._is_replica = True
  922. return new_gm
  923. @contextlib.contextmanager
  924. def _set_replace_hook(self, f):
  925. """
  926. Takes a callable which will be called every time when we replace a node
  927. to a new node, or change the node's name. Callable takes three arguments:
  928. the old node we're changing, and NAME of the new node, followed by the
  929. user node which consumes the old node to be replaced.
  930. """
  931. if not callable(f):
  932. raise AssertionError("Replace hook must be a callable.")
  933. self._register_replace_node_hook(f)
  934. try:
  935. yield
  936. finally:
  937. self._unregister_replace_node_hook(f)
  938. def _register_replace_node_hook(self, f):
  939. """
  940. Takes a callable which will be called every time when we replace a node
  941. to a new node, or change the node's name. Callable takes three arguments:
  942. the old node we're changing, and NAME of the new node, followed by the
  943. user node which consumes the old node to be replaced.
  944. """
  945. if not callable(f):
  946. raise AssertionError("create_node hook must be a callable.")
  947. self._replace_hooks.append(f)
  948. def _unregister_replace_node_hook(self, f):
  949. """
  950. Takes a callable which was previously registered to be called every time when we replace a node.
  951. This function will unregister that callable so it is no longer invoked on node replacement.
  952. """
  953. if not callable(f):
  954. raise AssertionError("create_node hook must be a callable.")
  955. self._replace_hooks.remove(f)
  956. def _register_create_node_hook(self, f):
  957. """
  958. Takes a callable which will be called after we create a new node. The
  959. callable takes the newly created node as input and returns None.
  960. """
  961. if not callable(f):
  962. raise AssertionError("create_node hook must be a callable.")
  963. self._create_node_hooks.append(f)
  964. def _unregister_create_node_hook(self, f):
  965. """
  966. Takes a callable which was previously registered to be called after we create a node.
  967. This function will unregister that callable so it is no longer invoked on node creation.
  968. """
  969. if not callable(f):
  970. raise AssertionError("create_node hook must be a callable.")
  971. self._create_node_hooks.remove(f)
  972. def _register_erase_node_hook(self, f):
  973. """
  974. Takes a callable which will be called after we erase a node. The
  975. callable takes the node that is being erased as input and returns None.
  976. """
  977. if not callable(f):
  978. raise AssertionError("erase_node hook must be a callable.")
  979. self._erase_node_hooks.append(f)
  980. def _unregister_erase_node_hook(self, f):
  981. """
  982. Takes a callable which was previously registered to be called after we erase a node.
  983. This function will unregister that callable so it is no longer invoked on node erasure.
  984. """
  985. if not callable(f):
  986. raise AssertionError("erase_node hook must be a callable.")
  987. self._erase_node_hooks.remove(f)
  988. def _register_deepcopy_hook(self, f):
  989. """
  990. Takes a callable which will be called when we deepcopy this graph module. The
  991. callable takes the resulting deepcopied graph module.
  992. """
  993. if not callable(f):
  994. raise AssertionError("deepcopy hook must be a callable.")
  995. self._deepcopy_hooks.append(f)
  996. def _unregister_deepcopy_hook(self, f):
  997. """
  998. Takes a callable which was previously registered to be called after deepcopy.
  999. This function will unregister that callable so it is no longer invoked on deepcopy.
  1000. """
  1001. if not callable(f):
  1002. raise AssertionError("deepcopy hook must be a callable.")
  1003. self._deepcopy_hooks.remove(f)
  1004. # workarounds for issues in __torch_function__
  1005. # WAR for __torch_function__ not handling tensor lists,
  1006. # fix is in https://github.com/pytorch/pytorch/pull/34725
  1007. # orig_cat = torch.cat
  1008. # def patched_cat(*args, **kwargs):
  1009. # tensors = args[0]
  1010. # for t in tensors:
  1011. # if isinstance(t, Proxy):
  1012. # return t.__torch_function__(patched_cat, (), args, kwargs)
  1013. # return orig_cat(*args, **kwargs)
  1014. # patched_cat.__module__ = 'torch'
  1015. # patched_cat.__name__ = 'cat'
  1016. # torch.cat = patched_cat