| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758 |
- """
- Python polyfills for torch.utils.pytree
- """
- from __future__ import annotations
- from collections import deque
- from dataclasses import dataclass, field
- from typing import Any, TYPE_CHECKING, TypeVar
- import optree
- import optree._C
- import optree.utils
- from optree import (
- is_namedtuple,
- is_namedtuple_class,
- is_namedtuple_instance,
- is_structseq,
- is_structseq_class,
- is_structseq_instance,
- namedtuple_fields,
- structseq_fields,
- )
- import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
- import torch.utils._pytree as python_pytree
- from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
- from ..decorators import substitute_in_graph
- if TYPE_CHECKING:
- import builtins
- from collections.abc import Callable, Iterable, Mapping
- from typing_extensions import Self, TypeIs
- from torch.utils._cxx_pytree import PyTree
- __all__ = [
- "is_namedtuple",
- "is_namedtuple_class",
- "is_namedtuple_instance",
- "is_structseq",
- "is_structseq_class",
- "is_structseq_instance",
- "namedtuple_fields",
- "structseq_fields",
- "treespec_leaf",
- "treespec_tuple",
- "treespec_dict",
- "tree_is_leaf",
- "tree_iter",
- "tree_leaves",
- "tree_flatten",
- "tree_flatten_with_path",
- "tree_structure",
- "tree_unflatten",
- ]
- _T = TypeVar("_T")
- _KT = TypeVar("_KT")
- _VT = TypeVar("_VT")
- @substitute_in_graph(
- optree._C.is_dict_insertion_ordered,
- can_constant_fold_through=True,
- )
- def _(*args: Any, **kwargs: Any) -> bool:
- # In namespace 'torch', the dictionary is always traversed in insertion order.
- # This function returns True.
- raise ValueError(
- "Should not be called directly "
- "because the original function will be called in the constant fold path."
- )
- __name = ""
- for __name, __func in (
- ("is_namedtuple", is_namedtuple),
- ("is_namedtuple_class", is_namedtuple_class),
- ("is_namedtuple_instance", is_namedtuple_instance),
- ("is_structseq", is_structseq),
- ("is_structseq_class", is_structseq_class),
- ("is_structseq_instance", is_structseq_instance),
- ("namedtuple_fields", namedtuple_fields),
- ("structseq_fields", structseq_fields),
- ):
- globals()[__name] = substitute_in_graph(
- __func, # type: ignore[arg-type]
- can_constant_fold_through=True,
- )(__func.__python_implementation__) # type: ignore[attr-defined]
- del __func
- del __name
- @substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) # type: ignore[arg-type]
- def tree_is_leaf(
- tree: PyTree,
- /,
- is_leaf: Callable[[PyTree], bool] | None = None,
- *,
- none_is_leaf: bool = False,
- namespace: str = "",
- ) -> bool:
- if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
- return True
- if optree.register_pytree_node.get(type(tree), namespace=namespace) is None:
- return True
- return False
- @substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) # type: ignore[arg-type]
- def tree_iter(
- tree: PyTree,
- /,
- is_leaf: Callable[[PyTree], bool] | None = None,
- *,
- none_is_leaf: bool = False,
- namespace: str = "",
- ) -> Iterable[Any]:
- stack = [tree]
- while stack:
- node = stack.pop()
- if tree_is_leaf(
- node,
- is_leaf=is_leaf,
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- ):
- yield node
- continue
- children, *_ = optree.tree_flatten_one_level(
- node,
- is_leaf=is_leaf,
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- )
- stack.extend(reversed(children))
- @substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type]
- def tree_leaves(
- tree: PyTree,
- /,
- is_leaf: Callable[[PyTree], bool] | None = None,
- *,
- none_is_leaf: bool = False,
- namespace: str = "",
- ) -> list[Any]:
- return list(
- tree_iter(
- tree,
- is_leaf=is_leaf,
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- )
- )
- class _Asterisk(str):
- __slots__ = ()
- def __new__(cls) -> Self:
- return super().__new__(cls, "*")
- def __repr__(self) -> str:
- return "*" # no quotes
- _asterisk = _Asterisk()
- del _Asterisk
- @dataclass(frozen=True, slots=True)
- class PyTreeSpec:
- """Analog for :class:`optree.PyTreeSpec` in Python."""
- _children: tuple[PyTreeSpec, ...]
- _type: builtins.type | None
- _metadata: Any
- _entries: tuple[Any, ...]
- _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
- none_is_leaf: bool
- namespace: str
- num_nodes: int = field(init=False)
- num_leaves: int = field(init=False)
- num_children: int = field(init=False)
- def __post_init__(self, /) -> None:
- if self._type is None:
- assert len(self._children) == 0
- assert self._metadata is None
- assert self._entries == ()
- assert self._unflatten_func is None
- num_nodes = 1
- num_leaves = 1
- num_children = 0
- else:
- assert callable(self._unflatten_func)
- num_nodes = 1
- num_leaves = 0
- for child in self._children:
- num_nodes += child.num_nodes
- num_leaves += child.num_leaves
- num_children = len(self._children)
- object.__setattr__(self, "num_nodes", num_nodes)
- object.__setattr__(self, "num_leaves", num_leaves)
- object.__setattr__(self, "num_children", num_children)
- def __repr__(self, /) -> str:
- def helper(treespec: PyTreeSpec) -> str:
- if treespec.is_leaf():
- assert treespec.type is None
- return _asterisk
- assert treespec.type is not None
- assert callable(treespec._unflatten_func)
- children_representations = [
- helper(subspec) for subspec in treespec._children
- ]
- if (
- treespec.type in BUILTIN_TYPES
- or (treespec.type is type(None) and not self.none_is_leaf)
- or optree.is_namedtuple_class(treespec.type)
- or optree.is_structseq_class(treespec.type)
- ):
- # pyrefly: ignore [bad-return]
- return treespec._unflatten_func(
- treespec._metadata,
- children_representations,
- )
- return (
- f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], "
- f"[{', '.join(children_representations)}])"
- )
- inner = [
- str(helper(self)),
- *(["NoneIsLeaf"] if self.none_is_leaf else []),
- f"namespace={self.namespace!r}",
- ]
- return f"PyTreeSpec({', '.join(inner)})"
- def __len__(self, /) -> int:
- return self.num_leaves
- @property
- def type(self, /) -> builtins.type | None:
- return self._type
- def is_leaf(self, /) -> bool:
- return self.num_nodes == 1 and self.num_leaves == 1
- def paths(self, /) -> list[tuple[Any, ...]]:
- def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None:
- if treespec.is_leaf():
- paths.append(path_prefix)
- return
- for entry, subspec in zip(
- treespec._entries,
- treespec._children,
- strict=True,
- ):
- helper(subspec, path_prefix + [entry])
- paths: list[list[Any]] = []
- helper(self, [])
- return [tuple(path) for path in paths]
- def accessors(self, /) -> list[optree.PyTreeAccessor]:
- def helper(
- treespec: PyTreeSpec,
- entry_path_prefix: list[optree.PyTreeEntry],
- ) -> None:
- if treespec.is_leaf():
- entry_paths.append(entry_path_prefix)
- return
- node_type = treespec.type
- assert node_type is not None
- handler = optree.register_pytree_node.get(
- node_type, namespace=treespec.namespace
- )
- assert handler is not None
- kind: optree.PyTreeKind = handler.kind
- path_entry_type: type[optree.PyTreeEntry] = handler.path_entry_type
- for entry, subspec in zip(
- treespec._entries,
- treespec._children,
- strict=True,
- ):
- helper(
- subspec,
- entry_path_prefix + [path_entry_type(entry, node_type, kind)],
- )
- entry_paths: list[list[optree.PyTreeEntry]] = []
- helper(self, [])
- return [optree.PyTreeAccessor(path) for path in entry_paths]
- def children(self, /) -> list[PyTreeSpec]:
- return list(self._children)
- def child(self, index: int, /) -> PyTreeSpec:
- return self._children[index]
- def entries(self, /) -> list[Any]:
- return list(self._entries)
- def entry(self, index: int, /) -> Any:
- return self._entries[index]
- def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
- def helper(
- treespec: PyTreeSpec,
- node: PyTree,
- subtrees: list[PyTree],
- ) -> None:
- if treespec.is_leaf():
- subtrees.append(node)
- return
- node_type = type(node)
- if treespec.type not in BUILTIN_TYPES:
- # Always require custom node types to match exactly
- if node_type != treespec.type:
- raise ValueError(
- f"Type mismatch; "
- f"expected {treespec.type!r}, but got {node_type!r}.",
- )
- children, metadata, *_ = optree.tree_flatten_one_level(
- node,
- none_is_leaf=self.none_is_leaf,
- namespace=self.namespace,
- )
- if len(children) != treespec.num_children:
- raise ValueError(
- f"Node arity mismatch; "
- f"expected {treespec.num_children}, but got {len(children)}.",
- )
- if metadata != treespec._metadata:
- raise ValueError(
- f"Node context mismatch for custom node type {treespec.type!r}.",
- )
- else:
- # For builtin dictionary types, we allow some flexibility
- # Otherwise, we require exact matches
- both_standard_dict = (
- treespec.type in STANDARD_DICT_TYPES
- and node_type in STANDARD_DICT_TYPES
- )
- if not both_standard_dict and node_type != treespec.type:
- raise ValueError(
- f"Node type mismatch; "
- f"expected {treespec.type!r}, but got {node_type!r}.",
- )
- if len(node) != treespec.num_children:
- raise ValueError(
- f"Node arity mismatch; "
- f"expected {treespec.num_children}, but got {len(node)}.",
- )
- if both_standard_dict:
- # dictionary types are compatible with each other
- expected_keys = treespec.entries()
- got_key_set = set(node)
- expected_key_set = set(expected_keys)
- if got_key_set != expected_key_set:
- missing_keys = expected_key_set.difference(got_key_set)
- extra_keys = got_key_set.difference(expected_key_set)
- message = ""
- if missing_keys:
- message += f"; missing key(s): {missing_keys}"
- if extra_keys:
- message += f"; extra key(s): {extra_keys}"
- raise ValueError(f"Node keys mismatch{message}.")
- children = [node[key] for key in expected_keys]
- else:
- # node_type is treespec.type
- children, metadata, *_ = optree.tree_flatten_one_level(
- node,
- none_is_leaf=self.none_is_leaf,
- namespace=self.namespace,
- )
- if (
- node_type is not deque # ignore mismatch of `maxlen` for deque
- ) and metadata != treespec._metadata:
- raise ValueError(
- f"Node metadata mismatch for node type {treespec.type!r}; "
- f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
- )
- for subtree, subspec in zip(children, treespec._children, strict=True):
- helper(subspec, subtree, subtrees)
- subtrees: list[PyTree] = []
- helper(self, tree, subtrees)
- return subtrees
- def unflatten(self, leaves: Iterable[Any], /) -> PyTree:
- if not isinstance(leaves, (list, tuple)):
- leaves = list(leaves)
- if len(leaves) != self.num_leaves:
- raise ValueError(
- f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
- f"but the spec refers to a pytree that holds {self.num_leaves} "
- f"items ({self}).",
- )
- if self.is_leaf():
- return leaves[0]
- # Recursively unflatten the children
- start = 0
- end = 0
- subtrees = []
- for subspec in self._children:
- end += subspec.num_leaves
- subtrees.append(subspec.unflatten(leaves[start:end]))
- start = end
- assert callable(self._unflatten_func)
- return self._unflatten_func(self._metadata, subtrees)
- def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]:
- return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec))
- @substitute_in_graph( # type: ignore[arg-type]
- optree.treespec_leaf,
- # We need to disable constant folding here because we want the function to reference the
- # PyTreeSpec class defined above, not the one in the C++ module.
- can_constant_fold_through=False,
- )
- def treespec_leaf(
- *,
- none_is_leaf: bool = False,
- namespace: str = "", # unused
- ) -> PyTreeSpec:
- return PyTreeSpec(
- (),
- None,
- None,
- (),
- None,
- none_is_leaf=none_is_leaf,
- namespace="",
- )
- @substitute_in_graph( # type: ignore[arg-type]
- optree.treespec_tuple,
- # We need to disable constant folding here because we want the function to reference the
- # PyTreeSpec class defined above, not the one in the C++ module.
- can_constant_fold_through=False,
- )
- def treespec_tuple(
- iterable: Iterable[PyTreeSpec] = (),
- /,
- *,
- none_is_leaf: bool = False,
- namespace: str = "",
- ) -> PyTreeSpec:
- children = tuple(iterable)
- if any(not _is_pytreespec_instance(child) for child in children):
- raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.")
- if any(child.none_is_leaf != none_is_leaf for child in children):
- raise ValueError(
- "All children PyTreeSpecs must have the same `none_is_leaf` value "
- f"as the parent; expected {none_is_leaf}, got: {children!r}.",
- )
- if any(child.namespace not in (namespace, "") for child in children):
- raise ValueError(
- "All children PyTreeSpecs must have the same `namespace` value "
- f"as the parent; expected {namespace!r}, got: {children!r}.",
- )
- handler = optree.register_pytree_node.get(tuple, namespace=namespace)
- assert handler is not None
- return PyTreeSpec(
- tuple(children),
- tuple,
- None,
- tuple(range(len(children))),
- handler.unflatten_func,
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- )
- @substitute_in_graph( # type: ignore[arg-type]
- optree.treespec_dict,
- # We need to disable constant folding here because we want the function to reference the
- # PyTreeSpec class defined above, not the one in the C++ module.
- can_constant_fold_through=False,
- )
- def treespec_dict(
- mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
- /,
- *,
- none_is_leaf: bool = False,
- namespace: str = "",
- **kwargs: PyTreeSpec,
- ) -> PyTreeSpec:
- dct = dict(mapping, **kwargs)
- if any(not _is_pytreespec_instance(child) for child in dct.values()):
- raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.")
- if any(child.none_is_leaf != none_is_leaf for child in dct.values()):
- raise ValueError(
- "All children PyTreeSpecs must have the same `none_is_leaf` value "
- f"as the parent; expected {none_is_leaf}, got: {dct!r}.",
- )
- if any(child.namespace not in (namespace, "") for child in dct.values()):
- raise ValueError(
- "All children PyTreeSpecs must have the same `namespace` value "
- f"as the parent; expected {namespace!r}, got: {dct!r}.",
- )
- (
- children,
- metadata,
- entries,
- unflatten_func,
- ) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated]
- dct, # type: ignore[arg-type]
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- )
- return PyTreeSpec(
- tuple(children), # type: ignore[arg-type]
- dict,
- metadata,
- entries,
- unflatten_func, # type: ignore[arg-type]
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- )
- @substitute_in_graph( # type: ignore[arg-type]
- optree.tree_flatten,
- # We need to disable constant folding here because we want the function to reference the
- # PyTreeSpec class defined above, not the one in the C++ module.
- can_constant_fold_through=False,
- )
- def tree_flatten(
- tree: PyTree,
- /,
- is_leaf: Callable[[PyTree], bool] | None = None,
- *,
- none_is_leaf: bool = False,
- namespace: str = "",
- ) -> tuple[list[Any], PyTreeSpec]:
- def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
- if tree_is_leaf(
- node,
- is_leaf=is_leaf,
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- ):
- leaves.append(node)
- return PyTreeSpec(
- (),
- None,
- None,
- (),
- None,
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- )
- (
- children,
- metadata,
- entries,
- unflatten_func,
- ) = optree.tree_flatten_one_level(
- node,
- is_leaf=is_leaf,
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- )
- # Recursively flatten the children
- subspecs = tuple(helper(child, leaves) for child in children)
- return PyTreeSpec(
- subspecs,
- type(node),
- metadata,
- entries,
- unflatten_func, # type: ignore[arg-type]
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- ) # type: ignore[arg-type]
- leaves: list[Any] = []
- treespec = helper(tree, leaves)
- return leaves, treespec
- @substitute_in_graph( # type: ignore[arg-type]
- optree._C.flatten,
- # We need to disable constant folding here because we want the function to reference the
- # PyTreeSpec class defined above, not the one in the C++ module.
- can_constant_fold_through=False,
- )
- def _C_flatten(
- tree: PyTree,
- /,
- leaf_predicate: Callable[[PyTree], bool] | None = None,
- none_is_leaf: bool = False,
- namespace: str = "",
- ) -> tuple[list[Any], PyTreeSpec]:
- return tree_flatten( # type: ignore[return-value]
- tree,
- is_leaf=leaf_predicate,
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- )
- @substitute_in_graph( # type: ignore[arg-type]
- optree.tree_flatten_with_path,
- # We need to disable constant folding here because we want the function to reference the
- # PyTreeSpec class defined above, not the one in the C++ module.
- can_constant_fold_through=False,
- )
- def tree_flatten_with_path(
- tree: PyTree,
- /,
- is_leaf: Callable[[PyTree], bool] | None = None,
- *,
- none_is_leaf: bool = False,
- namespace: str = "",
- ) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
- leaves, treespec = tree_flatten(
- tree,
- is_leaf=is_leaf,
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- )
- return treespec.paths(), leaves, treespec # type: ignore[return-value]
- @substitute_in_graph( # type: ignore[arg-type]
- optree._C.flatten_with_path,
- # We need to disable constant folding here because we want the function to reference the
- # PyTreeSpec class defined above, not the one in the C++ module.
- can_constant_fold_through=False,
- )
- def _C_flatten_with_path(
- tree: PyTree,
- /,
- leaf_predicate: Callable[[PyTree], bool] | None = None,
- none_is_leaf: bool = False,
- namespace: str = "",
- ) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
- return tree_flatten_with_path( # type: ignore[return-value]
- tree,
- is_leaf=leaf_predicate,
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- )
- @substitute_in_graph( # type: ignore[arg-type]
- optree.tree_structure,
- # We need to disable constant folding here because we want the function to reference the
- # PyTreeSpec class defined above, not the one in the C++ module.
- can_constant_fold_through=False,
- )
- def tree_structure(
- tree: PyTree,
- /,
- is_leaf: Callable[[PyTree], bool] | None = None,
- *,
- none_is_leaf: bool = False,
- namespace: str = "",
- ) -> PyTreeSpec:
- return tree_flatten( # type: ignore[return-value]
- tree,
- is_leaf=is_leaf,
- none_is_leaf=none_is_leaf,
- namespace=namespace,
- )[1]
- @substitute_in_graph( # type: ignore[arg-type]
- optree.tree_unflatten,
- # We need to disable constant folding here because we want the function to reference the
- # PyTreeSpec class defined above, not the one in the C++ module.
- can_constant_fold_through=False,
- )
- def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
- if not _is_pytreespec_instance(treespec):
- raise TypeError(
- f"Expected `treespec` to be an instance of "
- f"PyTreeSpec but got item of type {type(treespec)}."
- )
- return treespec.unflatten(leaves)
- _none_registration = optree.register_pytree_node.get(type(None))
- assert _none_registration is not None
- @substitute_in_graph( # type: ignore[arg-type]
- _none_registration.unflatten_func,
- can_constant_fold_through=True,
- skip_signature_check=True,
- )
- def none_unflatten(_: None, children: Iterable[_T], /) -> None:
- if len(list(children)) != 0:
- raise ValueError("Expected no children.")
- return None
- with optree.dict_insertion_ordered(False, namespace="torch"):
- _dict_registration = optree.register_pytree_node.get(dict)
- assert _dict_registration is not None
- @substitute_in_graph( # type: ignore[arg-type]
- _dict_registration.flatten_func,
- can_constant_fold_through=True,
- skip_signature_check=True,
- )
- def dict_flatten(
- dct: dict[_KT, _VT], /
- ) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]:
- sorted_keys = optree.utils.total_order_sorted(dct)
- values = [dct[key] for key in sorted_keys]
- original_keys = list(dct)
- return values, (original_keys, sorted_keys), tuple(sorted_keys)
- @substitute_in_graph( # type: ignore[arg-type]
- _dict_registration.unflatten_func,
- can_constant_fold_through=True,
- skip_signature_check=True,
- )
- def dict_unflatten(
- metadata: tuple[list[_KT], list[_KT]],
- values: Iterable[_VT],
- /,
- ) -> dict[_KT, _VT]:
- original_keys, sorted_keys = metadata
- d = dict.fromkeys(original_keys)
- d.update(zip(sorted_keys, values, strict=True))
- return d # type: ignore[return-value]
|