pytree.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758
  1. """
  2. Python polyfills for torch.utils.pytree
  3. """
  4. from __future__ import annotations
  5. from collections import deque
  6. from dataclasses import dataclass, field
  7. from typing import Any, TYPE_CHECKING, TypeVar
  8. import optree
  9. import optree._C
  10. import optree.utils
  11. from optree import (
  12. is_namedtuple,
  13. is_namedtuple_class,
  14. is_namedtuple_instance,
  15. is_structseq,
  16. is_structseq_class,
  17. is_structseq_instance,
  18. namedtuple_fields,
  19. structseq_fields,
  20. )
  21. import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
  22. import torch.utils._pytree as python_pytree
  23. from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
  24. from ..decorators import substitute_in_graph
  25. if TYPE_CHECKING:
  26. import builtins
  27. from collections.abc import Callable, Iterable, Mapping
  28. from typing_extensions import Self, TypeIs
  29. from torch.utils._cxx_pytree import PyTree
  30. __all__ = [
  31. "is_namedtuple",
  32. "is_namedtuple_class",
  33. "is_namedtuple_instance",
  34. "is_structseq",
  35. "is_structseq_class",
  36. "is_structseq_instance",
  37. "namedtuple_fields",
  38. "structseq_fields",
  39. "treespec_leaf",
  40. "treespec_tuple",
  41. "treespec_dict",
  42. "tree_is_leaf",
  43. "tree_iter",
  44. "tree_leaves",
  45. "tree_flatten",
  46. "tree_flatten_with_path",
  47. "tree_structure",
  48. "tree_unflatten",
  49. ]
  50. _T = TypeVar("_T")
  51. _KT = TypeVar("_KT")
  52. _VT = TypeVar("_VT")
  53. @substitute_in_graph(
  54. optree._C.is_dict_insertion_ordered,
  55. can_constant_fold_through=True,
  56. )
  57. def _(*args: Any, **kwargs: Any) -> bool:
  58. # In namespace 'torch', the dictionary is always traversed in insertion order.
  59. # This function returns True.
  60. raise ValueError(
  61. "Should not be called directly "
  62. "because the original function will be called in the constant fold path."
  63. )
  64. __name = ""
  65. for __name, __func in (
  66. ("is_namedtuple", is_namedtuple),
  67. ("is_namedtuple_class", is_namedtuple_class),
  68. ("is_namedtuple_instance", is_namedtuple_instance),
  69. ("is_structseq", is_structseq),
  70. ("is_structseq_class", is_structseq_class),
  71. ("is_structseq_instance", is_structseq_instance),
  72. ("namedtuple_fields", namedtuple_fields),
  73. ("structseq_fields", structseq_fields),
  74. ):
  75. globals()[__name] = substitute_in_graph(
  76. __func, # type: ignore[arg-type]
  77. can_constant_fold_through=True,
  78. )(__func.__python_implementation__) # type: ignore[attr-defined]
  79. del __func
  80. del __name
  81. @substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) # type: ignore[arg-type]
  82. def tree_is_leaf(
  83. tree: PyTree,
  84. /,
  85. is_leaf: Callable[[PyTree], bool] | None = None,
  86. *,
  87. none_is_leaf: bool = False,
  88. namespace: str = "",
  89. ) -> bool:
  90. if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
  91. return True
  92. if optree.register_pytree_node.get(type(tree), namespace=namespace) is None:
  93. return True
  94. return False
  95. @substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) # type: ignore[arg-type]
  96. def tree_iter(
  97. tree: PyTree,
  98. /,
  99. is_leaf: Callable[[PyTree], bool] | None = None,
  100. *,
  101. none_is_leaf: bool = False,
  102. namespace: str = "",
  103. ) -> Iterable[Any]:
  104. stack = [tree]
  105. while stack:
  106. node = stack.pop()
  107. if tree_is_leaf(
  108. node,
  109. is_leaf=is_leaf,
  110. none_is_leaf=none_is_leaf,
  111. namespace=namespace,
  112. ):
  113. yield node
  114. continue
  115. children, *_ = optree.tree_flatten_one_level(
  116. node,
  117. is_leaf=is_leaf,
  118. none_is_leaf=none_is_leaf,
  119. namespace=namespace,
  120. )
  121. stack.extend(reversed(children))
  122. @substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type]
  123. def tree_leaves(
  124. tree: PyTree,
  125. /,
  126. is_leaf: Callable[[PyTree], bool] | None = None,
  127. *,
  128. none_is_leaf: bool = False,
  129. namespace: str = "",
  130. ) -> list[Any]:
  131. return list(
  132. tree_iter(
  133. tree,
  134. is_leaf=is_leaf,
  135. none_is_leaf=none_is_leaf,
  136. namespace=namespace,
  137. )
  138. )
  139. class _Asterisk(str):
  140. __slots__ = ()
  141. def __new__(cls) -> Self:
  142. return super().__new__(cls, "*")
  143. def __repr__(self) -> str:
  144. return "*" # no quotes
  145. _asterisk = _Asterisk()
  146. del _Asterisk
  147. @dataclass(frozen=True, slots=True)
  148. class PyTreeSpec:
  149. """Analog for :class:`optree.PyTreeSpec` in Python."""
  150. _children: tuple[PyTreeSpec, ...]
  151. _type: builtins.type | None
  152. _metadata: Any
  153. _entries: tuple[Any, ...]
  154. _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
  155. none_is_leaf: bool
  156. namespace: str
  157. num_nodes: int = field(init=False)
  158. num_leaves: int = field(init=False)
  159. num_children: int = field(init=False)
  160. def __post_init__(self, /) -> None:
  161. if self._type is None:
  162. assert len(self._children) == 0
  163. assert self._metadata is None
  164. assert self._entries == ()
  165. assert self._unflatten_func is None
  166. num_nodes = 1
  167. num_leaves = 1
  168. num_children = 0
  169. else:
  170. assert callable(self._unflatten_func)
  171. num_nodes = 1
  172. num_leaves = 0
  173. for child in self._children:
  174. num_nodes += child.num_nodes
  175. num_leaves += child.num_leaves
  176. num_children = len(self._children)
  177. object.__setattr__(self, "num_nodes", num_nodes)
  178. object.__setattr__(self, "num_leaves", num_leaves)
  179. object.__setattr__(self, "num_children", num_children)
  180. def __repr__(self, /) -> str:
  181. def helper(treespec: PyTreeSpec) -> str:
  182. if treespec.is_leaf():
  183. assert treespec.type is None
  184. return _asterisk
  185. assert treespec.type is not None
  186. assert callable(treespec._unflatten_func)
  187. children_representations = [
  188. helper(subspec) for subspec in treespec._children
  189. ]
  190. if (
  191. treespec.type in BUILTIN_TYPES
  192. or (treespec.type is type(None) and not self.none_is_leaf)
  193. or optree.is_namedtuple_class(treespec.type)
  194. or optree.is_structseq_class(treespec.type)
  195. ):
  196. # pyrefly: ignore [bad-return]
  197. return treespec._unflatten_func(
  198. treespec._metadata,
  199. children_representations,
  200. )
  201. return (
  202. f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], "
  203. f"[{', '.join(children_representations)}])"
  204. )
  205. inner = [
  206. str(helper(self)),
  207. *(["NoneIsLeaf"] if self.none_is_leaf else []),
  208. f"namespace={self.namespace!r}",
  209. ]
  210. return f"PyTreeSpec({', '.join(inner)})"
  211. def __len__(self, /) -> int:
  212. return self.num_leaves
  213. @property
  214. def type(self, /) -> builtins.type | None:
  215. return self._type
  216. def is_leaf(self, /) -> bool:
  217. return self.num_nodes == 1 and self.num_leaves == 1
  218. def paths(self, /) -> list[tuple[Any, ...]]:
  219. def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None:
  220. if treespec.is_leaf():
  221. paths.append(path_prefix)
  222. return
  223. for entry, subspec in zip(
  224. treespec._entries,
  225. treespec._children,
  226. strict=True,
  227. ):
  228. helper(subspec, path_prefix + [entry])
  229. paths: list[list[Any]] = []
  230. helper(self, [])
  231. return [tuple(path) for path in paths]
  232. def accessors(self, /) -> list[optree.PyTreeAccessor]:
  233. def helper(
  234. treespec: PyTreeSpec,
  235. entry_path_prefix: list[optree.PyTreeEntry],
  236. ) -> None:
  237. if treespec.is_leaf():
  238. entry_paths.append(entry_path_prefix)
  239. return
  240. node_type = treespec.type
  241. assert node_type is not None
  242. handler = optree.register_pytree_node.get(
  243. node_type, namespace=treespec.namespace
  244. )
  245. assert handler is not None
  246. kind: optree.PyTreeKind = handler.kind
  247. path_entry_type: type[optree.PyTreeEntry] = handler.path_entry_type
  248. for entry, subspec in zip(
  249. treespec._entries,
  250. treespec._children,
  251. strict=True,
  252. ):
  253. helper(
  254. subspec,
  255. entry_path_prefix + [path_entry_type(entry, node_type, kind)],
  256. )
  257. entry_paths: list[list[optree.PyTreeEntry]] = []
  258. helper(self, [])
  259. return [optree.PyTreeAccessor(path) for path in entry_paths]
  260. def children(self, /) -> list[PyTreeSpec]:
  261. return list(self._children)
  262. def child(self, index: int, /) -> PyTreeSpec:
  263. return self._children[index]
  264. def entries(self, /) -> list[Any]:
  265. return list(self._entries)
  266. def entry(self, index: int, /) -> Any:
  267. return self._entries[index]
  268. def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
  269. def helper(
  270. treespec: PyTreeSpec,
  271. node: PyTree,
  272. subtrees: list[PyTree],
  273. ) -> None:
  274. if treespec.is_leaf():
  275. subtrees.append(node)
  276. return
  277. node_type = type(node)
  278. if treespec.type not in BUILTIN_TYPES:
  279. # Always require custom node types to match exactly
  280. if node_type != treespec.type:
  281. raise ValueError(
  282. f"Type mismatch; "
  283. f"expected {treespec.type!r}, but got {node_type!r}.",
  284. )
  285. children, metadata, *_ = optree.tree_flatten_one_level(
  286. node,
  287. none_is_leaf=self.none_is_leaf,
  288. namespace=self.namespace,
  289. )
  290. if len(children) != treespec.num_children:
  291. raise ValueError(
  292. f"Node arity mismatch; "
  293. f"expected {treespec.num_children}, but got {len(children)}.",
  294. )
  295. if metadata != treespec._metadata:
  296. raise ValueError(
  297. f"Node context mismatch for custom node type {treespec.type!r}.",
  298. )
  299. else:
  300. # For builtin dictionary types, we allow some flexibility
  301. # Otherwise, we require exact matches
  302. both_standard_dict = (
  303. treespec.type in STANDARD_DICT_TYPES
  304. and node_type in STANDARD_DICT_TYPES
  305. )
  306. if not both_standard_dict and node_type != treespec.type:
  307. raise ValueError(
  308. f"Node type mismatch; "
  309. f"expected {treespec.type!r}, but got {node_type!r}.",
  310. )
  311. if len(node) != treespec.num_children:
  312. raise ValueError(
  313. f"Node arity mismatch; "
  314. f"expected {treespec.num_children}, but got {len(node)}.",
  315. )
  316. if both_standard_dict:
  317. # dictionary types are compatible with each other
  318. expected_keys = treespec.entries()
  319. got_key_set = set(node)
  320. expected_key_set = set(expected_keys)
  321. if got_key_set != expected_key_set:
  322. missing_keys = expected_key_set.difference(got_key_set)
  323. extra_keys = got_key_set.difference(expected_key_set)
  324. message = ""
  325. if missing_keys:
  326. message += f"; missing key(s): {missing_keys}"
  327. if extra_keys:
  328. message += f"; extra key(s): {extra_keys}"
  329. raise ValueError(f"Node keys mismatch{message}.")
  330. children = [node[key] for key in expected_keys]
  331. else:
  332. # node_type is treespec.type
  333. children, metadata, *_ = optree.tree_flatten_one_level(
  334. node,
  335. none_is_leaf=self.none_is_leaf,
  336. namespace=self.namespace,
  337. )
  338. if (
  339. node_type is not deque # ignore mismatch of `maxlen` for deque
  340. ) and metadata != treespec._metadata:
  341. raise ValueError(
  342. f"Node metadata mismatch for node type {treespec.type!r}; "
  343. f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
  344. )
  345. for subtree, subspec in zip(children, treespec._children, strict=True):
  346. helper(subspec, subtree, subtrees)
  347. subtrees: list[PyTree] = []
  348. helper(self, tree, subtrees)
  349. return subtrees
  350. def unflatten(self, leaves: Iterable[Any], /) -> PyTree:
  351. if not isinstance(leaves, (list, tuple)):
  352. leaves = list(leaves)
  353. if len(leaves) != self.num_leaves:
  354. raise ValueError(
  355. f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
  356. f"but the spec refers to a pytree that holds {self.num_leaves} "
  357. f"items ({self}).",
  358. )
  359. if self.is_leaf():
  360. return leaves[0]
  361. # Recursively unflatten the children
  362. start = 0
  363. end = 0
  364. subtrees = []
  365. for subspec in self._children:
  366. end += subspec.num_leaves
  367. subtrees.append(subspec.unflatten(leaves[start:end]))
  368. start = end
  369. assert callable(self._unflatten_func)
  370. return self._unflatten_func(self._metadata, subtrees)
  371. def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]:
  372. return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec))
  373. @substitute_in_graph( # type: ignore[arg-type]
  374. optree.treespec_leaf,
  375. # We need to disable constant folding here because we want the function to reference the
  376. # PyTreeSpec class defined above, not the one in the C++ module.
  377. can_constant_fold_through=False,
  378. )
  379. def treespec_leaf(
  380. *,
  381. none_is_leaf: bool = False,
  382. namespace: str = "", # unused
  383. ) -> PyTreeSpec:
  384. return PyTreeSpec(
  385. (),
  386. None,
  387. None,
  388. (),
  389. None,
  390. none_is_leaf=none_is_leaf,
  391. namespace="",
  392. )
  393. @substitute_in_graph( # type: ignore[arg-type]
  394. optree.treespec_tuple,
  395. # We need to disable constant folding here because we want the function to reference the
  396. # PyTreeSpec class defined above, not the one in the C++ module.
  397. can_constant_fold_through=False,
  398. )
  399. def treespec_tuple(
  400. iterable: Iterable[PyTreeSpec] = (),
  401. /,
  402. *,
  403. none_is_leaf: bool = False,
  404. namespace: str = "",
  405. ) -> PyTreeSpec:
  406. children = tuple(iterable)
  407. if any(not _is_pytreespec_instance(child) for child in children):
  408. raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.")
  409. if any(child.none_is_leaf != none_is_leaf for child in children):
  410. raise ValueError(
  411. "All children PyTreeSpecs must have the same `none_is_leaf` value "
  412. f"as the parent; expected {none_is_leaf}, got: {children!r}.",
  413. )
  414. if any(child.namespace not in (namespace, "") for child in children):
  415. raise ValueError(
  416. "All children PyTreeSpecs must have the same `namespace` value "
  417. f"as the parent; expected {namespace!r}, got: {children!r}.",
  418. )
  419. handler = optree.register_pytree_node.get(tuple, namespace=namespace)
  420. assert handler is not None
  421. return PyTreeSpec(
  422. tuple(children),
  423. tuple,
  424. None,
  425. tuple(range(len(children))),
  426. handler.unflatten_func,
  427. none_is_leaf=none_is_leaf,
  428. namespace=namespace,
  429. )
  430. @substitute_in_graph( # type: ignore[arg-type]
  431. optree.treespec_dict,
  432. # We need to disable constant folding here because we want the function to reference the
  433. # PyTreeSpec class defined above, not the one in the C++ module.
  434. can_constant_fold_through=False,
  435. )
  436. def treespec_dict(
  437. mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
  438. /,
  439. *,
  440. none_is_leaf: bool = False,
  441. namespace: str = "",
  442. **kwargs: PyTreeSpec,
  443. ) -> PyTreeSpec:
  444. dct = dict(mapping, **kwargs)
  445. if any(not _is_pytreespec_instance(child) for child in dct.values()):
  446. raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.")
  447. if any(child.none_is_leaf != none_is_leaf for child in dct.values()):
  448. raise ValueError(
  449. "All children PyTreeSpecs must have the same `none_is_leaf` value "
  450. f"as the parent; expected {none_is_leaf}, got: {dct!r}.",
  451. )
  452. if any(child.namespace not in (namespace, "") for child in dct.values()):
  453. raise ValueError(
  454. "All children PyTreeSpecs must have the same `namespace` value "
  455. f"as the parent; expected {namespace!r}, got: {dct!r}.",
  456. )
  457. (
  458. children,
  459. metadata,
  460. entries,
  461. unflatten_func,
  462. ) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated]
  463. dct, # type: ignore[arg-type]
  464. none_is_leaf=none_is_leaf,
  465. namespace=namespace,
  466. )
  467. return PyTreeSpec(
  468. tuple(children), # type: ignore[arg-type]
  469. dict,
  470. metadata,
  471. entries,
  472. unflatten_func, # type: ignore[arg-type]
  473. none_is_leaf=none_is_leaf,
  474. namespace=namespace,
  475. )
  476. @substitute_in_graph( # type: ignore[arg-type]
  477. optree.tree_flatten,
  478. # We need to disable constant folding here because we want the function to reference the
  479. # PyTreeSpec class defined above, not the one in the C++ module.
  480. can_constant_fold_through=False,
  481. )
  482. def tree_flatten(
  483. tree: PyTree,
  484. /,
  485. is_leaf: Callable[[PyTree], bool] | None = None,
  486. *,
  487. none_is_leaf: bool = False,
  488. namespace: str = "",
  489. ) -> tuple[list[Any], PyTreeSpec]:
  490. def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
  491. if tree_is_leaf(
  492. node,
  493. is_leaf=is_leaf,
  494. none_is_leaf=none_is_leaf,
  495. namespace=namespace,
  496. ):
  497. leaves.append(node)
  498. return PyTreeSpec(
  499. (),
  500. None,
  501. None,
  502. (),
  503. None,
  504. none_is_leaf=none_is_leaf,
  505. namespace=namespace,
  506. )
  507. (
  508. children,
  509. metadata,
  510. entries,
  511. unflatten_func,
  512. ) = optree.tree_flatten_one_level(
  513. node,
  514. is_leaf=is_leaf,
  515. none_is_leaf=none_is_leaf,
  516. namespace=namespace,
  517. )
  518. # Recursively flatten the children
  519. subspecs = tuple(helper(child, leaves) for child in children)
  520. return PyTreeSpec(
  521. subspecs,
  522. type(node),
  523. metadata,
  524. entries,
  525. unflatten_func, # type: ignore[arg-type]
  526. none_is_leaf=none_is_leaf,
  527. namespace=namespace,
  528. ) # type: ignore[arg-type]
  529. leaves: list[Any] = []
  530. treespec = helper(tree, leaves)
  531. return leaves, treespec
  532. @substitute_in_graph( # type: ignore[arg-type]
  533. optree._C.flatten,
  534. # We need to disable constant folding here because we want the function to reference the
  535. # PyTreeSpec class defined above, not the one in the C++ module.
  536. can_constant_fold_through=False,
  537. )
  538. def _C_flatten(
  539. tree: PyTree,
  540. /,
  541. leaf_predicate: Callable[[PyTree], bool] | None = None,
  542. none_is_leaf: bool = False,
  543. namespace: str = "",
  544. ) -> tuple[list[Any], PyTreeSpec]:
  545. return tree_flatten( # type: ignore[return-value]
  546. tree,
  547. is_leaf=leaf_predicate,
  548. none_is_leaf=none_is_leaf,
  549. namespace=namespace,
  550. )
  551. @substitute_in_graph( # type: ignore[arg-type]
  552. optree.tree_flatten_with_path,
  553. # We need to disable constant folding here because we want the function to reference the
  554. # PyTreeSpec class defined above, not the one in the C++ module.
  555. can_constant_fold_through=False,
  556. )
  557. def tree_flatten_with_path(
  558. tree: PyTree,
  559. /,
  560. is_leaf: Callable[[PyTree], bool] | None = None,
  561. *,
  562. none_is_leaf: bool = False,
  563. namespace: str = "",
  564. ) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
  565. leaves, treespec = tree_flatten(
  566. tree,
  567. is_leaf=is_leaf,
  568. none_is_leaf=none_is_leaf,
  569. namespace=namespace,
  570. )
  571. return treespec.paths(), leaves, treespec # type: ignore[return-value]
  572. @substitute_in_graph( # type: ignore[arg-type]
  573. optree._C.flatten_with_path,
  574. # We need to disable constant folding here because we want the function to reference the
  575. # PyTreeSpec class defined above, not the one in the C++ module.
  576. can_constant_fold_through=False,
  577. )
  578. def _C_flatten_with_path(
  579. tree: PyTree,
  580. /,
  581. leaf_predicate: Callable[[PyTree], bool] | None = None,
  582. none_is_leaf: bool = False,
  583. namespace: str = "",
  584. ) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
  585. return tree_flatten_with_path( # type: ignore[return-value]
  586. tree,
  587. is_leaf=leaf_predicate,
  588. none_is_leaf=none_is_leaf,
  589. namespace=namespace,
  590. )
  591. @substitute_in_graph( # type: ignore[arg-type]
  592. optree.tree_structure,
  593. # We need to disable constant folding here because we want the function to reference the
  594. # PyTreeSpec class defined above, not the one in the C++ module.
  595. can_constant_fold_through=False,
  596. )
  597. def tree_structure(
  598. tree: PyTree,
  599. /,
  600. is_leaf: Callable[[PyTree], bool] | None = None,
  601. *,
  602. none_is_leaf: bool = False,
  603. namespace: str = "",
  604. ) -> PyTreeSpec:
  605. return tree_flatten( # type: ignore[return-value]
  606. tree,
  607. is_leaf=is_leaf,
  608. none_is_leaf=none_is_leaf,
  609. namespace=namespace,
  610. )[1]
  611. @substitute_in_graph( # type: ignore[arg-type]
  612. optree.tree_unflatten,
  613. # We need to disable constant folding here because we want the function to reference the
  614. # PyTreeSpec class defined above, not the one in the C++ module.
  615. can_constant_fold_through=False,
  616. )
  617. def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
  618. if not _is_pytreespec_instance(treespec):
  619. raise TypeError(
  620. f"Expected `treespec` to be an instance of "
  621. f"PyTreeSpec but got item of type {type(treespec)}."
  622. )
  623. return treespec.unflatten(leaves)
  624. _none_registration = optree.register_pytree_node.get(type(None))
  625. assert _none_registration is not None
  626. @substitute_in_graph( # type: ignore[arg-type]
  627. _none_registration.unflatten_func,
  628. can_constant_fold_through=True,
  629. skip_signature_check=True,
  630. )
  631. def none_unflatten(_: None, children: Iterable[_T], /) -> None:
  632. if len(list(children)) != 0:
  633. raise ValueError("Expected no children.")
  634. return None
  635. with optree.dict_insertion_ordered(False, namespace="torch"):
  636. _dict_registration = optree.register_pytree_node.get(dict)
  637. assert _dict_registration is not None
  638. @substitute_in_graph( # type: ignore[arg-type]
  639. _dict_registration.flatten_func,
  640. can_constant_fold_through=True,
  641. skip_signature_check=True,
  642. )
  643. def dict_flatten(
  644. dct: dict[_KT, _VT], /
  645. ) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]:
  646. sorted_keys = optree.utils.total_order_sorted(dct)
  647. values = [dct[key] for key in sorted_keys]
  648. original_keys = list(dct)
  649. return values, (original_keys, sorted_keys), tuple(sorted_keys)
  650. @substitute_in_graph( # type: ignore[arg-type]
  651. _dict_registration.unflatten_func,
  652. can_constant_fold_through=True,
  653. skip_signature_check=True,
  654. )
  655. def dict_unflatten(
  656. metadata: tuple[list[_KT], list[_KT]],
  657. values: Iterable[_VT],
  658. /,
  659. ) -> dict[_KT, _VT]:
  660. original_keys, sorted_keys = metadata
  661. d = dict.fromkeys(original_keys)
  662. d.update(zip(sorted_keys, values, strict=True))
  663. return d # type: ignore[return-value]