| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891 |
- import dataclasses
- import importlib
- import io
- import itertools
- import pickle
- from abc import abstractmethod
- from collections.abc import Callable
- from typing import Any, NewType, Optional, TypeVar, Union
- from typing_extensions import override, Self
- from torch.utils._import_utils import import_dill
- dill = import_dill()
- if dill is not None:
- pickle = dill # noqa: F811
- import torch
- import torch.utils._pytree as pytree
- from torch._guards import TracingContext
- from torch._inductor.standalone_compile import AOTCompiledArtifact
- from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor
- from torch._subclasses.meta_utils import (
- MetaConverter,
- MetaTensorDesc,
- MetaTensorDescriber,
- )
- from torch.fx.experimental.sym_node import SymNode
- from torch.fx.experimental.symbolic_shapes import ShapeEnv
- from torch.utils._mode_utils import no_dispatch
- _SymNodeT = TypeVar("_SymNodeT", torch.SymInt, torch.SymFloat)
- def _ops_filter_safe(name: str) -> bool:
- """
- An ops filter which allows pickle-safe ops. Pickle-safe ops are built-in
- ones where it will be possible to unpickle on any machine which has PyTorch.
- """
- # TODO: This list is pretty pessimistic right now. What's the full list?
- return name.startswith(
- (
- "torch.ops.aten",
- "torch.ops.fbgemm",
- )
- )
- def _node_metadata_key_filter_safe(key: str) -> bool:
- """
- A metadata filter which allows pickle-safe node metadata. These often times contain
- stacks with pointers to unserializable objects, so we clear them out.
- """
- return key not in ["source_fn_stack", "nn_module_stack", "fwd_source_fn_stack"]
- @dataclasses.dataclass
- class Options:
- # A filter for which ops will cause the pickler to raise a
- # BypassFxGraphCache exception. If None then all ops are allowed.
- ops_filter: Optional[Callable[[str], bool]] = _ops_filter_safe
- node_metadata_key_filter: Optional[Callable[[str], bool]] = (
- _node_metadata_key_filter_safe
- )
- # pyrefly: ignore [invalid-inheritance]
- class GraphPickler(pickle.Pickler):
- """
- GraphPickler is a Pickler which helps pickling fx graph - in particular
- GraphModule.
- """
- def __init__(self, file: io.BytesIO, options: Optional[Options] = None) -> None:
- if dill is not None:
- super().__init__(file, byref=True)
- else:
- super().__init__(file)
- self.options = options or Options()
- # This abomination is so we can pass external decoding state to the
- # unpickler functions. We serialize _unpickle_state as a persistent
- # external item and when we deserialize it we return the common state
- # object.
- self._unpickle_state = _UnpickleStateToken(object())
- # This is used to describe tensors. It needs to be common across the
- # pickle so that duplicates and views are properly handled.
- self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
- @override
- # pyrefly: ignore [bad-override]
- def reducer_override(
- self, obj: object
- ) -> tuple[Callable[..., Any], tuple[Any, ...]]:
- # This function is supposed to return either NotImplemented (meaning to
- # do the default pickle behavior) or a pair of (unpickle callable, data
- # to pass to unpickle).
- # We could instead teach individual classes how to pickle themselves but
- # that has a few problems:
- #
- # 1. If we have some special needs (maybe for this use-case we don't
- # want to fully serialize every field) then we're adding private
- # details to a public interface.
- #
- # 2. If we need to have some common shared data (such as a
- # FakeTensorMode) which is passed to each value it's harder to
- # support.
- # These are the types that need special handling. See the individual
- # *PickleData classes for details on pickling that particular type.
- if isinstance(obj, FakeTensor):
- return _TensorPickleData.reduce_helper(self, obj)
- elif isinstance(obj, torch.fx.GraphModule):
- return _GraphModulePickleData.reduce_helper(self, obj)
- elif isinstance(obj, (torch._ops.OperatorBase, torch._ops.OpOverloadPacket)):
- return _OpPickleData.reduce_helper(self, obj)
- elif isinstance(obj, ShapeEnv):
- return _ShapeEnvPickleData.reduce_helper(self, obj)
- elif isinstance(obj, torch.SymInt):
- return _SymNodePickleData.reduce_helper(self, obj)
- elif isinstance(obj, torch._guards.TracingContext):
- return _TracingContextPickleData.reduce_helper(self, obj)
- else:
- # We should never get a raw Node!
- if isinstance(obj, torch.fx.Node):
- raise AssertionError("Unexpected raw Node during pickling")
- if reduce := _TorchNumpyPickleData.reduce_helper(self, obj):
- return reduce
- # returning `NotImplemented` causes pickle to revert to the default
- # behavior for this object.
- return NotImplemented
- @override
- # pyrefly: ignore [bad-override]
- def persistent_id(self, obj: object) -> Optional[str]:
- if obj is self._unpickle_state:
- return "unpickle_state"
- else:
- return None
- @classmethod
- def dumps(cls, obj: object, options: Optional[Options] = None) -> bytes:
- """
- Pickle an object.
- """
- with io.BytesIO() as stream:
- pickler = cls(stream, options)
- pickler.dump(obj)
- return stream.getvalue()
- @staticmethod
- def loads(data: bytes, fake_mode: FakeTensorMode) -> object:
- """
- Unpickle an object.
- """
- state = _UnpickleState(fake_mode)
- with io.BytesIO(data) as stream:
- unpickler = _GraphUnpickler(stream, state)
- return unpickler.load()
- @classmethod
- def debug_dumps(
- cls,
- obj: object,
- options: "Options | None" = None,
- *,
- max_depth: int = 80,
- max_iter_items: int = 50,
- verbose: bool = True,
- ) -> Optional[str]:
- """
- Find the first leaf that GraphPickler.dumps cannot serialize and return its path.
- This is GraphPickler-aware and avoids infinite loops by:
- - Traversing builtin containers directly (dict/list/tuple/set) instead of
- exploring their __reduce_ex__ tuples.
- - Only using __reduce_ex__ / __reduce__ for "opaque" objects.
- - Bounding recursion depth and iterator expansion.
- Args:
- obj: The object to attempt to pickle and debug.
- options: Optional Options instance for the GraphPickler.
- max_depth: Maximum recursion depth before stopping traversal.
- max_iter_items: Maximum number of items to materialize from iterators.
- verbose: If True, prints detailed traversal information.
- Returns:
- A string representing the path to the first unpicklable leaf,
- or None if the object is fully picklable.
- """
- options = options or Options()
- pickler = cls(io.BytesIO(), options)
- visited: set[int] = set()
- def log(msg: str) -> None:
- if verbose:
- print(msg)
- def fail_exc(o: Any) -> Optional[BaseException]:
- try:
- cls.dumps(o, options)
- return None
- except Exception as e:
- return e
- def walk(o: Any, path: str, depth: int) -> Optional[str]:
- if depth > max_depth:
- log(f"{' ' * depth}Depth limit at {path} ({type(o)})")
- return path + " (depth_limit)"
- key = id(o)
- if key in visited:
- return None
- visited.add(key)
- indent = " " * depth
- log(f"{indent}Walking: {path} ({type(o)})")
- e = fail_exc(o)
- if e is None:
- log(f"{indent}✓ Pickles fine alone")
- return None
- log(f"{indent}[FAIL pickle] {type(o)} -> {e}")
- # 1) Builtin containers: walk contents directly (do NOT call __reduce_ex__)
- if isinstance(o, dict):
- for k, v in o.items():
- bad = walk(v, f"{path}[{k!r}]", depth + 1)
- if bad:
- return bad
- return path
- if isinstance(o, (list, tuple)):
- for i, v in enumerate(o):
- bad = walk(v, f"{path}[{i}]", depth + 1)
- if bad:
- return bad
- return path
- if isinstance(o, (set, frozenset)):
- for i, v in enumerate(o):
- bad = walk(v, f"{path}[{i}]", depth + 1)
- if bad:
- return bad
- return path
- # 2) Iterator types: materialize a bounded prefix
- if hasattr(o, "__iter__") and type(o).__name__.endswith("iterator"):
- try:
- prefix = list(itertools.islice(iter(o), max_iter_items + 1))
- except Exception:
- prefix = None
- if prefix is not None:
- if len(prefix) > max_iter_items:
- log(
- f"{indent}⚠ Iterator has more than {max_iter_items} items, "
- f"only checking first {max_iter_items}"
- )
- prefix = prefix[:max_iter_items]
- for i, v in enumerate(prefix):
- bad = walk(v, f"{path}[{i}]", depth + 1)
- if bad:
- return bad
- return path
- # 3) GraphPickler reducer_override
- try:
- red = pickler.reducer_override(o)
- log(f"{indent}reducer_override -> {type(red)}")
- except Exception as e2:
- log(f"{indent}💥 reducer_override crashed: {e2}")
- return path
- if red is not NotImplemented:
- _, args = red
- log(f"{indent}Using custom reduce, args={len(args)}")
- for i, a in enumerate(args):
- bad = walk(a, f"{path}.reduce_args[{i}]", depth + 1)
- if bad:
- return bad
- # 4) Dataclasses
- if dataclasses.is_dataclass(o):
- for f in dataclasses.fields(o):
- try:
- v = getattr(o, f.name)
- except Exception:
- return f"{path}.{f.name}"
- bad = walk(v, f"{path}.{f.name}", depth + 1)
- if bad:
- return bad
- return path
- # 5) __getstate__ and __dict__/__slots__
- getstate = getattr(o, "__getstate__", None)
- if callable(getstate):
- try:
- state = getstate()
- log(f"{indent}__getstate__ -> {type(state)}")
- except Exception as e3:
- log(f"{indent}💥 __getstate__ failed: {e3}")
- return path + ".__getstate__()"
- bad = walk(state, path + ".__getstate__()", depth + 1)
- if bad:
- return bad
- if hasattr(o, "__dict__"):
- for name, v in vars(o).items():
- bad = walk(v, f"{path}.{name}", depth + 1)
- if bad:
- return bad
- return path
- if hasattr(o, "__slots__"):
- for slot in o.__slots__:
- if hasattr(o, slot):
- bad = walk(getattr(o, slot), f"{path}.{slot}", depth + 1)
- if bad:
- return bad
- return path
- # 6) Last resort: reduce protocol for non-container / opaque objects
- reduce_tuple = None
- try:
- if hasattr(o, "__reduce_ex__"):
- reduce_tuple = o.__reduce_ex__(pickle.HIGHEST_PROTOCOL)
- log(f"{indent}__reduce_ex__ -> {type(reduce_tuple)}")
- elif hasattr(o, "__reduce__"):
- reduce_tuple = o.__reduce__()
- log(f"{indent}__reduce__ -> {type(reduce_tuple)}")
- except Exception as e4:
- log(f"{indent}💥 reduce protocol failed: {e4}")
- return path
- if isinstance(reduce_tuple, tuple):
- for i, part in enumerate(reduce_tuple):
- if part is None:
- continue
- bad = walk(part, f"{path}.__reduce__[{i}]", depth + 1)
- if bad:
- return bad
- return path
- bad = walk(obj, "root", 0)
- return bad
- class _UnpickleState:
- def __init__(self, fake_mode: FakeTensorMode) -> None:
- self.fake_mode = fake_mode
- self.meta_converter: MetaConverter[FakeTensor] = MetaConverter()
- # This token is passed when pickling to indicate that we want to use the
- # unpickler's _UnpickleState as a parameter in that position.
- _UnpickleStateToken = NewType("_UnpickleStateToken", object)
- # pyrefly: ignore [invalid-inheritance]
- class _GraphUnpickler(pickle.Unpickler):
- def __init__(self, stream: io.BytesIO, unpickle_state: _UnpickleState) -> None:
- super().__init__(stream)
- self._unpickle_state = unpickle_state
- @override
- # pyrefly: ignore [bad-override]
- def persistent_load(self, pid: object) -> object:
- if pid == "unpickle_state":
- return self._unpickle_state
- else:
- raise pickle.UnpicklingError("Invalid persistent ID")
- class _ShapeEnvPickleData:
- data: dict[str, object]
- @classmethod
- def reduce_helper(
- cls, pickler: GraphPickler, obj: ShapeEnv
- ) -> tuple[
- Callable[[Self, _UnpickleState], ShapeEnv], tuple[Self, _UnpickleStateToken]
- ]:
- return cls.unpickle, (cls(obj), pickler._unpickle_state)
- def __init__(self, env: ShapeEnv) -> None:
- # In theory pickle should recognize that a given ShapeEnv was already
- # pickled and reuse the resulting _ShapeEnvPickleData (so two objects
- # pointing at the same ShapeEnv get the same ShapeEnv out).
- if env._translation_validation_enabled:
- raise AssertionError("Translation validation must be disabled for pickling")
- self.data = env.__dict__.copy()
- del self.data["tracked_fakes"]
- del self.data["fake_tensor_cache"]
- def unpickle(self, unpickle_state: _UnpickleState) -> ShapeEnv:
- # Fill in the existing ShapeEnv rather than creating a new one
- if not unpickle_state.fake_mode:
- raise AssertionError("unpickle_state.fake_mode is not set")
- if not unpickle_state.fake_mode.shape_env:
- raise AssertionError("unpickle_state.fake_mode.shape_env is not set")
- for k, v in self.data.items():
- setattr(unpickle_state.fake_mode.shape_env, k, v)
- return unpickle_state.fake_mode.shape_env
- class _SymNodePickleData:
- @classmethod
- def reduce_helper(
- cls,
- pickler: GraphPickler,
- obj: _SymNodeT,
- ) -> tuple[
- Callable[[Self, _UnpickleState], _SymNodeT], tuple[Self, _UnpickleStateToken]
- ]:
- args = (cls(obj.node), pickler._unpickle_state)
- if isinstance(obj, torch.SymInt):
- # pyrefly: ignore [bad-return]
- return _SymNodePickleData.unpickle_sym_int, args
- else:
- raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
- def __init__(self, node: SymNode) -> None:
- self.expr = node._expr
- self.shape_env = node.shape_env
- self.pytype = node.pytype
- self.hint = node._hint
- def _to_sym_node(self) -> SymNode:
- if self.shape_env is None:
- raise AssertionError("shape_env is None")
- return SymNode(self.expr, self.shape_env, self.pytype, self.hint)
- def unpickle_sym_int(self, unpickle_state: _UnpickleState) -> torch.SymInt:
- return torch.SymInt(self._to_sym_node())
- class _TensorPickleData:
- metadata: MetaTensorDesc[FakeTensor]
- @classmethod
- def reduce_helper(
- cls, pickler: GraphPickler, obj: FakeTensor
- ) -> tuple[
- Callable[[Self, _UnpickleState], FakeTensor], tuple[Self, _UnpickleStateToken]
- ]:
- return cls.unpickle, (
- cls(pickler._meta_tensor_describer, obj),
- pickler._unpickle_state,
- )
- def __init__(self, describer: MetaTensorDescriber, t: Tensor) -> None:
- # THINGS TO WORRY ABOUT:
- # 1. Need to make sure that two tensors with the same id end up with the
- # same id on the other side of the wire.
- metadata = describer.describe_tensor(t)
- # view_func is fine if it's either None or a _FakeTensorViewFunc. A
- # custom one (which is basically a lambda) can't be serialized.
- if metadata.view_func and not isinstance(
- metadata.view_func, torch._subclasses.meta_utils._FakeTensorViewFunc
- ):
- raise AssertionError(
- f"view_func must be None or _FakeTensorViewFunc, got "
- f"{type(metadata.view_func)}"
- )
- self.metadata = dataclasses.replace(metadata, fake_mode=None)
- # Some debugging/verification
- for k in MetaTensorDesc._UNSERIALIZABLE:
- if k in ("fake_mode", "view_func"):
- continue
- if getattr(self.metadata, k) is not None:
- raise AssertionError(f"not None: {k}: {getattr(self.metadata, k)}")
- def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor:
- # TODO: make common w/ _output_from_cache_entry() in fake_tensor.py?
- metadata = dataclasses.replace(
- self.metadata,
- fake_mode=unpickle_state.fake_mode,
- )
- # also need to set the fake_mode on the base of a tensor if it's a view
- if metadata.is_view and metadata.base is not None:
- new_base = dataclasses.replace(
- metadata.base,
- fake_mode=unpickle_state.fake_mode,
- )
- metadata = dataclasses.replace(metadata, base=new_base)
- def with_fake(
- make_meta_t: Callable[[], torch.Tensor], device: Union[torch.device, str]
- ) -> FakeTensor:
- with no_dispatch():
- return FakeTensor(
- unpickle_state.fake_mode,
- make_meta_t(),
- # pyrefly: ignore [bad-argument-type]
- device,
- )
- return unpickle_state.meta_converter.meta_tensor(
- metadata,
- unpickle_state.fake_mode.shape_env,
- with_fake,
- None,
- None,
- )
- class _TorchNumpyPickleData:
- @classmethod
- def reduce_helper(
- cls, pickler: GraphPickler, obj: object
- ) -> Optional[
- tuple[
- Callable[[Self, _UnpickleState], object], tuple[Self, _UnpickleStateToken]
- ]
- ]:
- if data := cls.from_object(obj):
- return (cls.unpickle, (data, pickler._unpickle_state))
- else:
- return None
- def __init__(self, mod: str, name: str) -> None:
- self.mod = mod
- self.name = name
- def unpickle(self, unpickle_state: _UnpickleState) -> Callable[..., object]:
- np = getattr(importlib.import_module(self.mod), self.name)
- return torch._dynamo.variables.misc.get_np_to_tnp_map()[np]
- @classmethod
- def from_object(cls, tnp: object) -> Optional[Self]:
- if not callable(tnp):
- return None
- tnp_to_np = torch._dynamo.variables.misc.get_tnp_to_np_map()
- try:
- if not (np := tnp_to_np.get(tnp)):
- return None
- except TypeError:
- return None
- if not (mod := getattr(np, "__module__", None)):
- mod = "numpy"
- if not (name := getattr(np, "__name__", None)):
- return None
- # pyrefly: ignore [unbound-name]
- if np != getattr(importlib.import_module(mod), name):
- raise AssertionError(
- f"Numpy object mismatch for {mod}.{name}" # pyrefly: ignore [unbound-name]
- )
- # pyrefly: ignore [unbound-name]
- return cls(mod, name)
- class _GraphModulePickleData:
- @classmethod
- def reduce_helper(
- cls, pickler: GraphPickler, obj: torch.fx.GraphModule
- ) -> tuple[
- Callable[[Self, _UnpickleState], torch.fx.GraphModule],
- tuple[Self, _UnpickleStateToken],
- ]:
- return cls.unpickle, (
- cls(obj, pickler.options),
- pickler._unpickle_state,
- )
- def __init__(self, gm: torch.fx.GraphModule, options: Options) -> None:
- # Need to do this to ensure the code is created for later pickling.
- if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule):
- _python_code = gm._real_recompile()
- else:
- _python_code = gm.recompile()
- if hasattr(gm, "__getstate__"):
- self.gm_dict = gm.__getstate__()
- else:
- self.gm_dict = gm.__dict__.copy()
- del self.gm_dict["_graph"]
- self.graph = _GraphPickleData(gm._graph, options)
- def unpickle(self, unpickle_state: _UnpickleState) -> torch.fx.GraphModule:
- gm = torch.fx.GraphModule.__new__(torch.fx.GraphModule)
- gm.__dict__ = self.gm_dict
- gm._graph = self.graph.unpickle(gm, unpickle_state)
- return gm
- class _NodePickleData:
- def __init__(
- self,
- node: torch.fx.Node,
- mapping: dict[torch.fx.Node, "_NodePickleData"],
- options: Options,
- ) -> None:
- self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args)
- self.kwargs = pytree.tree_map_only(
- torch.fx.Node, lambda n: mapping[n], node.kwargs
- )
- # -- self.graph = node.graph
- self.name = node.name
- self.op = node.op
- self.target = _OpPickleData.pickle(node.target, options)
- # self.input_nodes = node._input_nodes
- # self.users = node.users
- self.type = node.type
- # self.sort_key = node._sort_key
- # self.repr_fn = node._repr_fn
- # self.meta = node.meta
- self.meta = {
- k: v
- for k, v in node.meta.items()
- if (
- not options.node_metadata_key_filter
- or options.node_metadata_key_filter(k)
- )
- }
- def unpickle(
- self,
- graph: torch.fx.Graph,
- mapping: dict["_NodePickleData", torch.fx.Node],
- unpickle_state: _UnpickleState,
- ) -> torch.fx.Node:
- args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args)
- kwargs = pytree.tree_map_only(
- _NodePickleData, lambda n: mapping[n], self.kwargs
- )
- target = self.target.unpickle(unpickle_state)
- if not (callable(target) or isinstance(target, str)):
- raise AssertionError(f"target must be callable or str, got {type(target)}")
- node = graph.create_node(self.op, target, args, kwargs, self.name, self.type)
- node.meta = self.meta
- return node
- class _OpPickleData:
- @classmethod
- def reduce_helper(
- cls, pickler: GraphPickler, op: object
- ) -> tuple[Callable[[_UnpickleState], object], tuple[_UnpickleStateToken]]:
- result = cls.pickle(op, pickler.options)
- return (result.unpickle, (pickler._unpickle_state,))
- @classmethod
- def pickle(cls, op: object, options: Options) -> "_OpPickleData":
- if isinstance(op, str):
- return _OpStrPickleData(op)
- if isinstance(getattr(op, "__wrapped__", None), AOTCompiledArtifact):
- if not hasattr(op, "__wrapped__"):
- raise AssertionError("op missing __wrapped__ attribute")
- artifact = op.__wrapped__
- if not isinstance(artifact, AOTCompiledArtifact):
- raise AssertionError(
- f"Expected AOTCompiledArtifact, got {type(artifact)}"
- )
- return _OpPrecompiledPickleData(artifact)
- name = torch.fx.Node._pretty_print_target(op)
- if isinstance(op, torch._ops.OpOverload):
- return cls._pickle_op(name, _OpOverloadPickleData, options)
- elif isinstance(op, torch._ops.OpOverloadPacket):
- return cls._pickle_op(name, _OpOverloadPacketPickleData, options)
- elif name.startswith(_OpFunctionPickleData.SUPPORTED_ROOTS):
- root, detail = name.split(".", 1)
- return _OpFunctionPickleData(root, detail)
- else:
- # TODO: raise a BypassFxGraphCache so we will just bypass this one...
- raise NotImplementedError(f"TARGET: {type(op)} {op} {name}")
- @staticmethod
- def _pickle_op(
- name: str,
- datacls: Union[
- type["_OpOverloadPickleData"], type["_OpOverloadPacketPickleData"]
- ],
- options: Options,
- ) -> "_OpPickleData":
- if (ops_filter := options.ops_filter) and not ops_filter(name):
- from torch._inductor.codecache import BypassFxGraphCache
- raise BypassFxGraphCache(f"Unable to pickle non-standard op: {name}")
- return datacls(name)
- @abstractmethod
- def unpickle(self, unpickle_state: _UnpickleState) -> object:
- pass
- @classmethod
- def _lookup_global_by_name(cls, name: str) -> object:
- """
- Like `globals()[name]` but supports dotted names.
- """
- if "." in name:
- mod, rest = name.split(".", 1)
- root = globals()[mod]
- return cls._getattr_by_name(root, rest)
- else:
- return globals()[name]
- @staticmethod
- def _getattr_by_name(root: object, name: str) -> object:
- """
- Like `getattr(root, name)` but supports dotted names.
- """
- while "." in name:
- mod, name = name.split(".", 1)
- root = getattr(root, mod)
- return getattr(root, name)
- class _OpStrPickleData(_OpPickleData):
- def __init__(self, name: str) -> None:
- self.name = name
- def unpickle(self, unpickle_state: _UnpickleState) -> str:
- return self.name
- class _OpOverloadPickleData(_OpPickleData):
- def __init__(self, name: str) -> None:
- self.name = name
- def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverload:
- obj = self._lookup_global_by_name(self.name)
- if not isinstance(obj, torch._ops.OpOverload):
- raise AssertionError(f"Expected OpOverload, got {type(obj)}")
- return obj
- class _OpOverloadPacketPickleData(_OpPickleData):
- def __init__(self, name: str) -> None:
- self.name = name
- def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverloadPacket:
- obj = self._lookup_global_by_name(self.name)
- if not isinstance(obj, torch._ops.OpOverloadPacket):
- raise AssertionError(f"Expected OpOverloadPacket, got {type(obj)}")
- return obj
- class _OpPrecompiledPickleData(_OpPickleData):
- def __init__(self, artifact: AOTCompiledArtifact) -> None:
- self.contents = artifact.serialize()
- def unpickle(self, unpickle_state: _UnpickleState) -> object:
- precompiled_artifact = AOTCompiledArtifact.deserialize(self.contents)
- import functools
- @functools.wraps(precompiled_artifact)
- def wrapped(*args: Any) -> Any:
- return precompiled_artifact(*args)
- return wrapped
- class _OpFunctionPickleData(_OpPickleData):
- """
- Supports pickling a set of standard/common functions
- These must be prefixed with the full namespace in order to properly
- be pickled (i.e `einops.rearrange` and not `from einops import rearrange`)
- """
- # Static variable listing supported root names
- SUPPORTED_ROOTS = ("builtins.", "math.", "torch.", "operator.", "einops.")
- def __init__(self, root: str, name: str) -> None:
- self.root = root
- self.name = name
- def unpickle(self, unpickle_state: _UnpickleState) -> object:
- if self.root == "builtins":
- return __builtins__.get(self.name) # type: ignore[attr-defined]
- elif self.root == "math":
- import math
- return self._getattr_by_name(math, self.name)
- elif self.root == "torch":
- return self._getattr_by_name(torch, self.name)
- elif self.root == "operator":
- import operator
- return self._getattr_by_name(operator, self.name)
- elif self.root == "einops":
- import einops
- return self._getattr_by_name(einops, self.name)
- else:
- raise NotImplementedError
- class _GraphPickleData:
- def __init__(self, graph: torch.fx.Graph, options: Options) -> None:
- self.tracer_cls = graph._tracer_cls
- self.tracer_extras = graph._tracer_extras
- nodes: dict[torch.fx.Node, _NodePickleData] = {}
- for node in graph.nodes:
- nodes[node] = _NodePickleData(node, nodes, options)
- self.nodes = tuple(nodes.values())
- self._codegen = graph._codegen
- # Unpickled variables:
- # self._used_names = graph._used_names
- # -- self._insert = self._root.prepend
- # self._len = graph._len
- # self._graph_namespace = graph._graph_namespace
- # self._owning_module = graph._owning_module
- # self._co_fields: Dict[str, Any] = graph._co_fields
- # -- self._find_nodes_lookup_table = _FindNodesLookupTable()
- def unpickle(
- self, gm: torch.fx.GraphModule, unpickle_state: _UnpickleState
- ) -> torch.fx.Graph:
- graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras)
- nodes: dict[_NodePickleData, torch.fx.Node] = {}
- for nd in self.nodes:
- nodes[nd] = nd.unpickle(graph, nodes, unpickle_state)
- if hasattr(self, "_codegen"):
- graph._codegen = self._codegen
- return graph
- class _TracingContextPickleData:
- @classmethod
- def reduce_helper(
- cls, pickler: GraphPickler, obj: torch._guards.TracingContext
- ) -> tuple[
- Callable[[Self, _UnpickleState], torch._guards.TracingContext],
- tuple[Self, _UnpickleStateToken],
- ]:
- return (
- cls.unpickle,
- (
- cls(obj),
- pickler._unpickle_state,
- ),
- )
- def __init__(self, context: TracingContext) -> None:
- # TODO: Do we really need all of this?
- self.module_context = context.module_context
- self.frame_summary_stack = context.frame_summary_stack
- self.loc_in_frame = context.loc_in_frame
- self.aot_graph_name = context.aot_graph_name
- self.params_flat = context.params_flat
- self.params_flat_unwrap_subclasses = context.params_flat_unwrap_subclasses
- self.params_unwrapped_to_flat_index = context.params_unwrapped_to_flat_index
- self.output_strides = context.output_strides
- self.force_unspec_int_unbacked_size_like = (
- context.force_unspec_int_unbacked_size_like
- )
- # Not saved (because it's difficult and maybe not needed?):
- # self.fw_metadata = context.fw_metadata
- # self.guards_context = None
- # self.global_context = None
- # self.fake_mode = None
- # self.fakify_first_call = None
- # self.hop_dispatch_set_cache = None
- # self.tensor_to_context = context.tensor_to_context
- def unpickle(self, unpickle_state: _UnpickleState) -> TracingContext:
- context = TracingContext(unpickle_state.fake_mode)
- context.module_context = self.module_context
- context.frame_summary_stack = self.frame_summary_stack
- context.loc_in_frame = self.loc_in_frame
- context.aot_graph_name = self.aot_graph_name
- context.params_flat = self.params_flat
- context.params_flat_unwrap_subclasses = self.params_flat_unwrap_subclasses
- context.params_unwrapped_to_flat_index = self.params_unwrapped_to_flat_index
- context.output_strides = self.output_strides
- context.force_unspec_int_unbacked_size_like = (
- self.force_unspec_int_unbacked_size_like
- )
- return context
|