_graph_pickler.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891
  1. import dataclasses
  2. import importlib
  3. import io
  4. import itertools
  5. import pickle
  6. from abc import abstractmethod
  7. from collections.abc import Callable
  8. from typing import Any, NewType, Optional, TypeVar, Union
  9. from typing_extensions import override, Self
  10. from torch.utils._import_utils import import_dill
  11. dill = import_dill()
  12. if dill is not None:
  13. pickle = dill # noqa: F811
  14. import torch
  15. import torch.utils._pytree as pytree
  16. from torch._guards import TracingContext
  17. from torch._inductor.standalone_compile import AOTCompiledArtifact
  18. from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor
  19. from torch._subclasses.meta_utils import (
  20. MetaConverter,
  21. MetaTensorDesc,
  22. MetaTensorDescriber,
  23. )
  24. from torch.fx.experimental.sym_node import SymNode
  25. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  26. from torch.utils._mode_utils import no_dispatch
  27. _SymNodeT = TypeVar("_SymNodeT", torch.SymInt, torch.SymFloat)
  28. def _ops_filter_safe(name: str) -> bool:
  29. """
  30. An ops filter which allows pickle-safe ops. Pickle-safe ops are built-in
  31. ones where it will be possible to unpickle on any machine which has PyTorch.
  32. """
  33. # TODO: This list is pretty pessimistic right now. What's the full list?
  34. return name.startswith(
  35. (
  36. "torch.ops.aten",
  37. "torch.ops.fbgemm",
  38. )
  39. )
  40. def _node_metadata_key_filter_safe(key: str) -> bool:
  41. """
  42. A metadata filter which allows pickle-safe node metadata. These often times contain
  43. stacks with pointers to unserializable objects, so we clear them out.
  44. """
  45. return key not in ["source_fn_stack", "nn_module_stack", "fwd_source_fn_stack"]
  46. @dataclasses.dataclass
  47. class Options:
  48. # A filter for which ops will cause the pickler to raise a
  49. # BypassFxGraphCache exception. If None then all ops are allowed.
  50. ops_filter: Optional[Callable[[str], bool]] = _ops_filter_safe
  51. node_metadata_key_filter: Optional[Callable[[str], bool]] = (
  52. _node_metadata_key_filter_safe
  53. )
  54. # pyrefly: ignore [invalid-inheritance]
  55. class GraphPickler(pickle.Pickler):
  56. """
  57. GraphPickler is a Pickler which helps pickling fx graph - in particular
  58. GraphModule.
  59. """
  60. def __init__(self, file: io.BytesIO, options: Optional[Options] = None) -> None:
  61. if dill is not None:
  62. super().__init__(file, byref=True)
  63. else:
  64. super().__init__(file)
  65. self.options = options or Options()
  66. # This abomination is so we can pass external decoding state to the
  67. # unpickler functions. We serialize _unpickle_state as a persistent
  68. # external item and when we deserialize it we return the common state
  69. # object.
  70. self._unpickle_state = _UnpickleStateToken(object())
  71. # This is used to describe tensors. It needs to be common across the
  72. # pickle so that duplicates and views are properly handled.
  73. self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
  74. @override
  75. # pyrefly: ignore [bad-override]
  76. def reducer_override(
  77. self, obj: object
  78. ) -> tuple[Callable[..., Any], tuple[Any, ...]]:
  79. # This function is supposed to return either NotImplemented (meaning to
  80. # do the default pickle behavior) or a pair of (unpickle callable, data
  81. # to pass to unpickle).
  82. # We could instead teach individual classes how to pickle themselves but
  83. # that has a few problems:
  84. #
  85. # 1. If we have some special needs (maybe for this use-case we don't
  86. # want to fully serialize every field) then we're adding private
  87. # details to a public interface.
  88. #
  89. # 2. If we need to have some common shared data (such as a
  90. # FakeTensorMode) which is passed to each value it's harder to
  91. # support.
  92. # These are the types that need special handling. See the individual
  93. # *PickleData classes for details on pickling that particular type.
  94. if isinstance(obj, FakeTensor):
  95. return _TensorPickleData.reduce_helper(self, obj)
  96. elif isinstance(obj, torch.fx.GraphModule):
  97. return _GraphModulePickleData.reduce_helper(self, obj)
  98. elif isinstance(obj, (torch._ops.OperatorBase, torch._ops.OpOverloadPacket)):
  99. return _OpPickleData.reduce_helper(self, obj)
  100. elif isinstance(obj, ShapeEnv):
  101. return _ShapeEnvPickleData.reduce_helper(self, obj)
  102. elif isinstance(obj, torch.SymInt):
  103. return _SymNodePickleData.reduce_helper(self, obj)
  104. elif isinstance(obj, torch._guards.TracingContext):
  105. return _TracingContextPickleData.reduce_helper(self, obj)
  106. else:
  107. # We should never get a raw Node!
  108. if isinstance(obj, torch.fx.Node):
  109. raise AssertionError("Unexpected raw Node during pickling")
  110. if reduce := _TorchNumpyPickleData.reduce_helper(self, obj):
  111. return reduce
  112. # returning `NotImplemented` causes pickle to revert to the default
  113. # behavior for this object.
  114. return NotImplemented
  115. @override
  116. # pyrefly: ignore [bad-override]
  117. def persistent_id(self, obj: object) -> Optional[str]:
  118. if obj is self._unpickle_state:
  119. return "unpickle_state"
  120. else:
  121. return None
  122. @classmethod
  123. def dumps(cls, obj: object, options: Optional[Options] = None) -> bytes:
  124. """
  125. Pickle an object.
  126. """
  127. with io.BytesIO() as stream:
  128. pickler = cls(stream, options)
  129. pickler.dump(obj)
  130. return stream.getvalue()
  131. @staticmethod
  132. def loads(data: bytes, fake_mode: FakeTensorMode) -> object:
  133. """
  134. Unpickle an object.
  135. """
  136. state = _UnpickleState(fake_mode)
  137. with io.BytesIO(data) as stream:
  138. unpickler = _GraphUnpickler(stream, state)
  139. return unpickler.load()
  140. @classmethod
  141. def debug_dumps(
  142. cls,
  143. obj: object,
  144. options: "Options | None" = None,
  145. *,
  146. max_depth: int = 80,
  147. max_iter_items: int = 50,
  148. verbose: bool = True,
  149. ) -> Optional[str]:
  150. """
  151. Find the first leaf that GraphPickler.dumps cannot serialize and return its path.
  152. This is GraphPickler-aware and avoids infinite loops by:
  153. - Traversing builtin containers directly (dict/list/tuple/set) instead of
  154. exploring their __reduce_ex__ tuples.
  155. - Only using __reduce_ex__ / __reduce__ for "opaque" objects.
  156. - Bounding recursion depth and iterator expansion.
  157. Args:
  158. obj: The object to attempt to pickle and debug.
  159. options: Optional Options instance for the GraphPickler.
  160. max_depth: Maximum recursion depth before stopping traversal.
  161. max_iter_items: Maximum number of items to materialize from iterators.
  162. verbose: If True, prints detailed traversal information.
  163. Returns:
  164. A string representing the path to the first unpicklable leaf,
  165. or None if the object is fully picklable.
  166. """
  167. options = options or Options()
  168. pickler = cls(io.BytesIO(), options)
  169. visited: set[int] = set()
  170. def log(msg: str) -> None:
  171. if verbose:
  172. print(msg)
  173. def fail_exc(o: Any) -> Optional[BaseException]:
  174. try:
  175. cls.dumps(o, options)
  176. return None
  177. except Exception as e:
  178. return e
  179. def walk(o: Any, path: str, depth: int) -> Optional[str]:
  180. if depth > max_depth:
  181. log(f"{' ' * depth}Depth limit at {path} ({type(o)})")
  182. return path + " (depth_limit)"
  183. key = id(o)
  184. if key in visited:
  185. return None
  186. visited.add(key)
  187. indent = " " * depth
  188. log(f"{indent}Walking: {path} ({type(o)})")
  189. e = fail_exc(o)
  190. if e is None:
  191. log(f"{indent}✓ Pickles fine alone")
  192. return None
  193. log(f"{indent}[FAIL pickle] {type(o)} -> {e}")
  194. # 1) Builtin containers: walk contents directly (do NOT call __reduce_ex__)
  195. if isinstance(o, dict):
  196. for k, v in o.items():
  197. bad = walk(v, f"{path}[{k!r}]", depth + 1)
  198. if bad:
  199. return bad
  200. return path
  201. if isinstance(o, (list, tuple)):
  202. for i, v in enumerate(o):
  203. bad = walk(v, f"{path}[{i}]", depth + 1)
  204. if bad:
  205. return bad
  206. return path
  207. if isinstance(o, (set, frozenset)):
  208. for i, v in enumerate(o):
  209. bad = walk(v, f"{path}[{i}]", depth + 1)
  210. if bad:
  211. return bad
  212. return path
  213. # 2) Iterator types: materialize a bounded prefix
  214. if hasattr(o, "__iter__") and type(o).__name__.endswith("iterator"):
  215. try:
  216. prefix = list(itertools.islice(iter(o), max_iter_items + 1))
  217. except Exception:
  218. prefix = None
  219. if prefix is not None:
  220. if len(prefix) > max_iter_items:
  221. log(
  222. f"{indent}⚠ Iterator has more than {max_iter_items} items, "
  223. f"only checking first {max_iter_items}"
  224. )
  225. prefix = prefix[:max_iter_items]
  226. for i, v in enumerate(prefix):
  227. bad = walk(v, f"{path}[{i}]", depth + 1)
  228. if bad:
  229. return bad
  230. return path
  231. # 3) GraphPickler reducer_override
  232. try:
  233. red = pickler.reducer_override(o)
  234. log(f"{indent}reducer_override -> {type(red)}")
  235. except Exception as e2:
  236. log(f"{indent}💥 reducer_override crashed: {e2}")
  237. return path
  238. if red is not NotImplemented:
  239. _, args = red
  240. log(f"{indent}Using custom reduce, args={len(args)}")
  241. for i, a in enumerate(args):
  242. bad = walk(a, f"{path}.reduce_args[{i}]", depth + 1)
  243. if bad:
  244. return bad
  245. # 4) Dataclasses
  246. if dataclasses.is_dataclass(o):
  247. for f in dataclasses.fields(o):
  248. try:
  249. v = getattr(o, f.name)
  250. except Exception:
  251. return f"{path}.{f.name}"
  252. bad = walk(v, f"{path}.{f.name}", depth + 1)
  253. if bad:
  254. return bad
  255. return path
  256. # 5) __getstate__ and __dict__/__slots__
  257. getstate = getattr(o, "__getstate__", None)
  258. if callable(getstate):
  259. try:
  260. state = getstate()
  261. log(f"{indent}__getstate__ -> {type(state)}")
  262. except Exception as e3:
  263. log(f"{indent}💥 __getstate__ failed: {e3}")
  264. return path + ".__getstate__()"
  265. bad = walk(state, path + ".__getstate__()", depth + 1)
  266. if bad:
  267. return bad
  268. if hasattr(o, "__dict__"):
  269. for name, v in vars(o).items():
  270. bad = walk(v, f"{path}.{name}", depth + 1)
  271. if bad:
  272. return bad
  273. return path
  274. if hasattr(o, "__slots__"):
  275. for slot in o.__slots__:
  276. if hasattr(o, slot):
  277. bad = walk(getattr(o, slot), f"{path}.{slot}", depth + 1)
  278. if bad:
  279. return bad
  280. return path
  281. # 6) Last resort: reduce protocol for non-container / opaque objects
  282. reduce_tuple = None
  283. try:
  284. if hasattr(o, "__reduce_ex__"):
  285. reduce_tuple = o.__reduce_ex__(pickle.HIGHEST_PROTOCOL)
  286. log(f"{indent}__reduce_ex__ -> {type(reduce_tuple)}")
  287. elif hasattr(o, "__reduce__"):
  288. reduce_tuple = o.__reduce__()
  289. log(f"{indent}__reduce__ -> {type(reduce_tuple)}")
  290. except Exception as e4:
  291. log(f"{indent}💥 reduce protocol failed: {e4}")
  292. return path
  293. if isinstance(reduce_tuple, tuple):
  294. for i, part in enumerate(reduce_tuple):
  295. if part is None:
  296. continue
  297. bad = walk(part, f"{path}.__reduce__[{i}]", depth + 1)
  298. if bad:
  299. return bad
  300. return path
  301. bad = walk(obj, "root", 0)
  302. return bad
  303. class _UnpickleState:
  304. def __init__(self, fake_mode: FakeTensorMode) -> None:
  305. self.fake_mode = fake_mode
  306. self.meta_converter: MetaConverter[FakeTensor] = MetaConverter()
  307. # This token is passed when pickling to indicate that we want to use the
  308. # unpickler's _UnpickleState as a parameter in that position.
  309. _UnpickleStateToken = NewType("_UnpickleStateToken", object)
  310. # pyrefly: ignore [invalid-inheritance]
  311. class _GraphUnpickler(pickle.Unpickler):
  312. def __init__(self, stream: io.BytesIO, unpickle_state: _UnpickleState) -> None:
  313. super().__init__(stream)
  314. self._unpickle_state = unpickle_state
  315. @override
  316. # pyrefly: ignore [bad-override]
  317. def persistent_load(self, pid: object) -> object:
  318. if pid == "unpickle_state":
  319. return self._unpickle_state
  320. else:
  321. raise pickle.UnpicklingError("Invalid persistent ID")
  322. class _ShapeEnvPickleData:
  323. data: dict[str, object]
  324. @classmethod
  325. def reduce_helper(
  326. cls, pickler: GraphPickler, obj: ShapeEnv
  327. ) -> tuple[
  328. Callable[[Self, _UnpickleState], ShapeEnv], tuple[Self, _UnpickleStateToken]
  329. ]:
  330. return cls.unpickle, (cls(obj), pickler._unpickle_state)
  331. def __init__(self, env: ShapeEnv) -> None:
  332. # In theory pickle should recognize that a given ShapeEnv was already
  333. # pickled and reuse the resulting _ShapeEnvPickleData (so two objects
  334. # pointing at the same ShapeEnv get the same ShapeEnv out).
  335. if env._translation_validation_enabled:
  336. raise AssertionError("Translation validation must be disabled for pickling")
  337. self.data = env.__dict__.copy()
  338. del self.data["tracked_fakes"]
  339. del self.data["fake_tensor_cache"]
  340. def unpickle(self, unpickle_state: _UnpickleState) -> ShapeEnv:
  341. # Fill in the existing ShapeEnv rather than creating a new one
  342. if not unpickle_state.fake_mode:
  343. raise AssertionError("unpickle_state.fake_mode is not set")
  344. if not unpickle_state.fake_mode.shape_env:
  345. raise AssertionError("unpickle_state.fake_mode.shape_env is not set")
  346. for k, v in self.data.items():
  347. setattr(unpickle_state.fake_mode.shape_env, k, v)
  348. return unpickle_state.fake_mode.shape_env
  349. class _SymNodePickleData:
  350. @classmethod
  351. def reduce_helper(
  352. cls,
  353. pickler: GraphPickler,
  354. obj: _SymNodeT,
  355. ) -> tuple[
  356. Callable[[Self, _UnpickleState], _SymNodeT], tuple[Self, _UnpickleStateToken]
  357. ]:
  358. args = (cls(obj.node), pickler._unpickle_state)
  359. if isinstance(obj, torch.SymInt):
  360. # pyrefly: ignore [bad-return]
  361. return _SymNodePickleData.unpickle_sym_int, args
  362. else:
  363. raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
  364. def __init__(self, node: SymNode) -> None:
  365. self.expr = node._expr
  366. self.shape_env = node.shape_env
  367. self.pytype = node.pytype
  368. self.hint = node._hint
  369. def _to_sym_node(self) -> SymNode:
  370. if self.shape_env is None:
  371. raise AssertionError("shape_env is None")
  372. return SymNode(self.expr, self.shape_env, self.pytype, self.hint)
  373. def unpickle_sym_int(self, unpickle_state: _UnpickleState) -> torch.SymInt:
  374. return torch.SymInt(self._to_sym_node())
  375. class _TensorPickleData:
  376. metadata: MetaTensorDesc[FakeTensor]
  377. @classmethod
  378. def reduce_helper(
  379. cls, pickler: GraphPickler, obj: FakeTensor
  380. ) -> tuple[
  381. Callable[[Self, _UnpickleState], FakeTensor], tuple[Self, _UnpickleStateToken]
  382. ]:
  383. return cls.unpickle, (
  384. cls(pickler._meta_tensor_describer, obj),
  385. pickler._unpickle_state,
  386. )
  387. def __init__(self, describer: MetaTensorDescriber, t: Tensor) -> None:
  388. # THINGS TO WORRY ABOUT:
  389. # 1. Need to make sure that two tensors with the same id end up with the
  390. # same id on the other side of the wire.
  391. metadata = describer.describe_tensor(t)
  392. # view_func is fine if it's either None or a _FakeTensorViewFunc. A
  393. # custom one (which is basically a lambda) can't be serialized.
  394. if metadata.view_func and not isinstance(
  395. metadata.view_func, torch._subclasses.meta_utils._FakeTensorViewFunc
  396. ):
  397. raise AssertionError(
  398. f"view_func must be None or _FakeTensorViewFunc, got "
  399. f"{type(metadata.view_func)}"
  400. )
  401. self.metadata = dataclasses.replace(metadata, fake_mode=None)
  402. # Some debugging/verification
  403. for k in MetaTensorDesc._UNSERIALIZABLE:
  404. if k in ("fake_mode", "view_func"):
  405. continue
  406. if getattr(self.metadata, k) is not None:
  407. raise AssertionError(f"not None: {k}: {getattr(self.metadata, k)}")
  408. def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor:
  409. # TODO: make common w/ _output_from_cache_entry() in fake_tensor.py?
  410. metadata = dataclasses.replace(
  411. self.metadata,
  412. fake_mode=unpickle_state.fake_mode,
  413. )
  414. # also need to set the fake_mode on the base of a tensor if it's a view
  415. if metadata.is_view and metadata.base is not None:
  416. new_base = dataclasses.replace(
  417. metadata.base,
  418. fake_mode=unpickle_state.fake_mode,
  419. )
  420. metadata = dataclasses.replace(metadata, base=new_base)
  421. def with_fake(
  422. make_meta_t: Callable[[], torch.Tensor], device: Union[torch.device, str]
  423. ) -> FakeTensor:
  424. with no_dispatch():
  425. return FakeTensor(
  426. unpickle_state.fake_mode,
  427. make_meta_t(),
  428. # pyrefly: ignore [bad-argument-type]
  429. device,
  430. )
  431. return unpickle_state.meta_converter.meta_tensor(
  432. metadata,
  433. unpickle_state.fake_mode.shape_env,
  434. with_fake,
  435. None,
  436. None,
  437. )
  438. class _TorchNumpyPickleData:
  439. @classmethod
  440. def reduce_helper(
  441. cls, pickler: GraphPickler, obj: object
  442. ) -> Optional[
  443. tuple[
  444. Callable[[Self, _UnpickleState], object], tuple[Self, _UnpickleStateToken]
  445. ]
  446. ]:
  447. if data := cls.from_object(obj):
  448. return (cls.unpickle, (data, pickler._unpickle_state))
  449. else:
  450. return None
  451. def __init__(self, mod: str, name: str) -> None:
  452. self.mod = mod
  453. self.name = name
  454. def unpickle(self, unpickle_state: _UnpickleState) -> Callable[..., object]:
  455. np = getattr(importlib.import_module(self.mod), self.name)
  456. return torch._dynamo.variables.misc.get_np_to_tnp_map()[np]
  457. @classmethod
  458. def from_object(cls, tnp: object) -> Optional[Self]:
  459. if not callable(tnp):
  460. return None
  461. tnp_to_np = torch._dynamo.variables.misc.get_tnp_to_np_map()
  462. try:
  463. if not (np := tnp_to_np.get(tnp)):
  464. return None
  465. except TypeError:
  466. return None
  467. if not (mod := getattr(np, "__module__", None)):
  468. mod = "numpy"
  469. if not (name := getattr(np, "__name__", None)):
  470. return None
  471. # pyrefly: ignore [unbound-name]
  472. if np != getattr(importlib.import_module(mod), name):
  473. raise AssertionError(
  474. f"Numpy object mismatch for {mod}.{name}" # pyrefly: ignore [unbound-name]
  475. )
  476. # pyrefly: ignore [unbound-name]
  477. return cls(mod, name)
  478. class _GraphModulePickleData:
  479. @classmethod
  480. def reduce_helper(
  481. cls, pickler: GraphPickler, obj: torch.fx.GraphModule
  482. ) -> tuple[
  483. Callable[[Self, _UnpickleState], torch.fx.GraphModule],
  484. tuple[Self, _UnpickleStateToken],
  485. ]:
  486. return cls.unpickle, (
  487. cls(obj, pickler.options),
  488. pickler._unpickle_state,
  489. )
  490. def __init__(self, gm: torch.fx.GraphModule, options: Options) -> None:
  491. # Need to do this to ensure the code is created for later pickling.
  492. if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule):
  493. _python_code = gm._real_recompile()
  494. else:
  495. _python_code = gm.recompile()
  496. if hasattr(gm, "__getstate__"):
  497. self.gm_dict = gm.__getstate__()
  498. else:
  499. self.gm_dict = gm.__dict__.copy()
  500. del self.gm_dict["_graph"]
  501. self.graph = _GraphPickleData(gm._graph, options)
  502. def unpickle(self, unpickle_state: _UnpickleState) -> torch.fx.GraphModule:
  503. gm = torch.fx.GraphModule.__new__(torch.fx.GraphModule)
  504. gm.__dict__ = self.gm_dict
  505. gm._graph = self.graph.unpickle(gm, unpickle_state)
  506. return gm
  507. class _NodePickleData:
  508. def __init__(
  509. self,
  510. node: torch.fx.Node,
  511. mapping: dict[torch.fx.Node, "_NodePickleData"],
  512. options: Options,
  513. ) -> None:
  514. self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args)
  515. self.kwargs = pytree.tree_map_only(
  516. torch.fx.Node, lambda n: mapping[n], node.kwargs
  517. )
  518. # -- self.graph = node.graph
  519. self.name = node.name
  520. self.op = node.op
  521. self.target = _OpPickleData.pickle(node.target, options)
  522. # self.input_nodes = node._input_nodes
  523. # self.users = node.users
  524. self.type = node.type
  525. # self.sort_key = node._sort_key
  526. # self.repr_fn = node._repr_fn
  527. # self.meta = node.meta
  528. self.meta = {
  529. k: v
  530. for k, v in node.meta.items()
  531. if (
  532. not options.node_metadata_key_filter
  533. or options.node_metadata_key_filter(k)
  534. )
  535. }
  536. def unpickle(
  537. self,
  538. graph: torch.fx.Graph,
  539. mapping: dict["_NodePickleData", torch.fx.Node],
  540. unpickle_state: _UnpickleState,
  541. ) -> torch.fx.Node:
  542. args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args)
  543. kwargs = pytree.tree_map_only(
  544. _NodePickleData, lambda n: mapping[n], self.kwargs
  545. )
  546. target = self.target.unpickle(unpickle_state)
  547. if not (callable(target) or isinstance(target, str)):
  548. raise AssertionError(f"target must be callable or str, got {type(target)}")
  549. node = graph.create_node(self.op, target, args, kwargs, self.name, self.type)
  550. node.meta = self.meta
  551. return node
  552. class _OpPickleData:
  553. @classmethod
  554. def reduce_helper(
  555. cls, pickler: GraphPickler, op: object
  556. ) -> tuple[Callable[[_UnpickleState], object], tuple[_UnpickleStateToken]]:
  557. result = cls.pickle(op, pickler.options)
  558. return (result.unpickle, (pickler._unpickle_state,))
  559. @classmethod
  560. def pickle(cls, op: object, options: Options) -> "_OpPickleData":
  561. if isinstance(op, str):
  562. return _OpStrPickleData(op)
  563. if isinstance(getattr(op, "__wrapped__", None), AOTCompiledArtifact):
  564. if not hasattr(op, "__wrapped__"):
  565. raise AssertionError("op missing __wrapped__ attribute")
  566. artifact = op.__wrapped__
  567. if not isinstance(artifact, AOTCompiledArtifact):
  568. raise AssertionError(
  569. f"Expected AOTCompiledArtifact, got {type(artifact)}"
  570. )
  571. return _OpPrecompiledPickleData(artifact)
  572. name = torch.fx.Node._pretty_print_target(op)
  573. if isinstance(op, torch._ops.OpOverload):
  574. return cls._pickle_op(name, _OpOverloadPickleData, options)
  575. elif isinstance(op, torch._ops.OpOverloadPacket):
  576. return cls._pickle_op(name, _OpOverloadPacketPickleData, options)
  577. elif name.startswith(_OpFunctionPickleData.SUPPORTED_ROOTS):
  578. root, detail = name.split(".", 1)
  579. return _OpFunctionPickleData(root, detail)
  580. else:
  581. # TODO: raise a BypassFxGraphCache so we will just bypass this one...
  582. raise NotImplementedError(f"TARGET: {type(op)} {op} {name}")
  583. @staticmethod
  584. def _pickle_op(
  585. name: str,
  586. datacls: Union[
  587. type["_OpOverloadPickleData"], type["_OpOverloadPacketPickleData"]
  588. ],
  589. options: Options,
  590. ) -> "_OpPickleData":
  591. if (ops_filter := options.ops_filter) and not ops_filter(name):
  592. from torch._inductor.codecache import BypassFxGraphCache
  593. raise BypassFxGraphCache(f"Unable to pickle non-standard op: {name}")
  594. return datacls(name)
  595. @abstractmethod
  596. def unpickle(self, unpickle_state: _UnpickleState) -> object:
  597. pass
  598. @classmethod
  599. def _lookup_global_by_name(cls, name: str) -> object:
  600. """
  601. Like `globals()[name]` but supports dotted names.
  602. """
  603. if "." in name:
  604. mod, rest = name.split(".", 1)
  605. root = globals()[mod]
  606. return cls._getattr_by_name(root, rest)
  607. else:
  608. return globals()[name]
  609. @staticmethod
  610. def _getattr_by_name(root: object, name: str) -> object:
  611. """
  612. Like `getattr(root, name)` but supports dotted names.
  613. """
  614. while "." in name:
  615. mod, name = name.split(".", 1)
  616. root = getattr(root, mod)
  617. return getattr(root, name)
  618. class _OpStrPickleData(_OpPickleData):
  619. def __init__(self, name: str) -> None:
  620. self.name = name
  621. def unpickle(self, unpickle_state: _UnpickleState) -> str:
  622. return self.name
  623. class _OpOverloadPickleData(_OpPickleData):
  624. def __init__(self, name: str) -> None:
  625. self.name = name
  626. def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverload:
  627. obj = self._lookup_global_by_name(self.name)
  628. if not isinstance(obj, torch._ops.OpOverload):
  629. raise AssertionError(f"Expected OpOverload, got {type(obj)}")
  630. return obj
  631. class _OpOverloadPacketPickleData(_OpPickleData):
  632. def __init__(self, name: str) -> None:
  633. self.name = name
  634. def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverloadPacket:
  635. obj = self._lookup_global_by_name(self.name)
  636. if not isinstance(obj, torch._ops.OpOverloadPacket):
  637. raise AssertionError(f"Expected OpOverloadPacket, got {type(obj)}")
  638. return obj
  639. class _OpPrecompiledPickleData(_OpPickleData):
  640. def __init__(self, artifact: AOTCompiledArtifact) -> None:
  641. self.contents = artifact.serialize()
  642. def unpickle(self, unpickle_state: _UnpickleState) -> object:
  643. precompiled_artifact = AOTCompiledArtifact.deserialize(self.contents)
  644. import functools
  645. @functools.wraps(precompiled_artifact)
  646. def wrapped(*args: Any) -> Any:
  647. return precompiled_artifact(*args)
  648. return wrapped
  649. class _OpFunctionPickleData(_OpPickleData):
  650. """
  651. Supports pickling a set of standard/common functions
  652. These must be prefixed with the full namespace in order to properly
  653. be pickled (i.e `einops.rearrange` and not `from einops import rearrange`)
  654. """
  655. # Static variable listing supported root names
  656. SUPPORTED_ROOTS = ("builtins.", "math.", "torch.", "operator.", "einops.")
  657. def __init__(self, root: str, name: str) -> None:
  658. self.root = root
  659. self.name = name
  660. def unpickle(self, unpickle_state: _UnpickleState) -> object:
  661. if self.root == "builtins":
  662. return __builtins__.get(self.name) # type: ignore[attr-defined]
  663. elif self.root == "math":
  664. import math
  665. return self._getattr_by_name(math, self.name)
  666. elif self.root == "torch":
  667. return self._getattr_by_name(torch, self.name)
  668. elif self.root == "operator":
  669. import operator
  670. return self._getattr_by_name(operator, self.name)
  671. elif self.root == "einops":
  672. import einops
  673. return self._getattr_by_name(einops, self.name)
  674. else:
  675. raise NotImplementedError
  676. class _GraphPickleData:
  677. def __init__(self, graph: torch.fx.Graph, options: Options) -> None:
  678. self.tracer_cls = graph._tracer_cls
  679. self.tracer_extras = graph._tracer_extras
  680. nodes: dict[torch.fx.Node, _NodePickleData] = {}
  681. for node in graph.nodes:
  682. nodes[node] = _NodePickleData(node, nodes, options)
  683. self.nodes = tuple(nodes.values())
  684. self._codegen = graph._codegen
  685. # Unpickled variables:
  686. # self._used_names = graph._used_names
  687. # -- self._insert = self._root.prepend
  688. # self._len = graph._len
  689. # self._graph_namespace = graph._graph_namespace
  690. # self._owning_module = graph._owning_module
  691. # self._co_fields: Dict[str, Any] = graph._co_fields
  692. # -- self._find_nodes_lookup_table = _FindNodesLookupTable()
  693. def unpickle(
  694. self, gm: torch.fx.GraphModule, unpickle_state: _UnpickleState
  695. ) -> torch.fx.Graph:
  696. graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras)
  697. nodes: dict[_NodePickleData, torch.fx.Node] = {}
  698. for nd in self.nodes:
  699. nodes[nd] = nd.unpickle(graph, nodes, unpickle_state)
  700. if hasattr(self, "_codegen"):
  701. graph._codegen = self._codegen
  702. return graph
  703. class _TracingContextPickleData:
  704. @classmethod
  705. def reduce_helper(
  706. cls, pickler: GraphPickler, obj: torch._guards.TracingContext
  707. ) -> tuple[
  708. Callable[[Self, _UnpickleState], torch._guards.TracingContext],
  709. tuple[Self, _UnpickleStateToken],
  710. ]:
  711. return (
  712. cls.unpickle,
  713. (
  714. cls(obj),
  715. pickler._unpickle_state,
  716. ),
  717. )
  718. def __init__(self, context: TracingContext) -> None:
  719. # TODO: Do we really need all of this?
  720. self.module_context = context.module_context
  721. self.frame_summary_stack = context.frame_summary_stack
  722. self.loc_in_frame = context.loc_in_frame
  723. self.aot_graph_name = context.aot_graph_name
  724. self.params_flat = context.params_flat
  725. self.params_flat_unwrap_subclasses = context.params_flat_unwrap_subclasses
  726. self.params_unwrapped_to_flat_index = context.params_unwrapped_to_flat_index
  727. self.output_strides = context.output_strides
  728. self.force_unspec_int_unbacked_size_like = (
  729. context.force_unspec_int_unbacked_size_like
  730. )
  731. # Not saved (because it's difficult and maybe not needed?):
  732. # self.fw_metadata = context.fw_metadata
  733. # self.guards_context = None
  734. # self.global_context = None
  735. # self.fake_mode = None
  736. # self.fakify_first_call = None
  737. # self.hop_dispatch_set_cache = None
  738. # self.tensor_to_context = context.tensor_to_context
  739. def unpickle(self, unpickle_state: _UnpickleState) -> TracingContext:
  740. context = TracingContext(unpickle_state.fake_mode)
  741. context.module_context = self.module_context
  742. context.frame_summary_stack = self.frame_summary_stack
  743. context.loc_in_frame = self.loc_in_frame
  744. context.aot_graph_name = self.aot_graph_name
  745. context.params_flat = self.params_flat
  746. context.params_flat_unwrap_subclasses = self.params_flat_unwrap_subclasses
  747. context.params_unwrapped_to_flat_index = self.params_unwrapped_to_flat_index
  748. context.output_strides = self.output_strides
  749. context.force_unspec_int_unbacked_size_like = (
  750. self.force_unspec_int_unbacked_size_like
  751. )
  752. return context