_pytree.py 75 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219
  1. """
  2. Contains utility functions for working with nested python data structures.
  3. A *pytree* is Python nested data structure. It is a tree in the sense that
  4. nodes are Python collections (e.g., list, tuple, dict) and the leaves are
  5. Python values. Furthermore, a pytree should not contain reference cycles.
  6. pytrees are useful for working with nested collections of Tensors. For example,
  7. one can use `tree_map` to map a function over all Tensors inside some nested
  8. collection of Tensors and `tree_leaves` to get a flat list of all Tensors
  9. inside some nested collection. pytrees are helpful for implementing nested
  10. collection support for PyTorch APIs.
  11. This pytree implementation is not very performant due to Python overhead
  12. To improve the performance we can move parts of the implementation to C++.
  13. """
  14. import dataclasses
  15. import functools
  16. import importlib
  17. import importlib.metadata
  18. import json
  19. import sys
  20. import threading
  21. import types
  22. import warnings
  23. from collections import defaultdict, deque, namedtuple, OrderedDict
  24. from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
  25. from enum import Enum
  26. from typing import (
  27. Any,
  28. cast,
  29. ClassVar,
  30. Final,
  31. Generic,
  32. NoReturn,
  33. overload,
  34. Protocol,
  35. TYPE_CHECKING,
  36. TypeAlias,
  37. TypeVar,
  38. Union,
  39. )
  40. from typing_extensions import deprecated, NamedTuple, Self, TypeIs
  41. from torch.torch_version import TorchVersion as _TorchVersion
  42. if TYPE_CHECKING:
  43. import torch.utils._cxx_pytree as cxx_pytree
  44. __all__ = [
  45. "PyTree",
  46. "Context",
  47. "FlattenFunc",
  48. "UnflattenFunc",
  49. "DumpableContext",
  50. "ToDumpableContextFn",
  51. "FromDumpableContextFn",
  52. "PyTreeSpec",
  53. "TreeSpec",
  54. "LeafSpec",
  55. "keystr",
  56. "key_get",
  57. "register_pytree_node",
  58. "tree_is_leaf",
  59. "tree_flatten",
  60. "tree_flatten_with_path",
  61. "tree_unflatten",
  62. "tree_iter",
  63. "tree_leaves",
  64. "tree_leaves_with_path",
  65. "tree_structure",
  66. "tree_map",
  67. "tree_map_with_path",
  68. "tree_map_",
  69. "tree_map_only",
  70. "tree_map_only_",
  71. "tree_all",
  72. "tree_any",
  73. "tree_all_only",
  74. "tree_any_only",
  75. "treespec_dumps",
  76. "treespec_loads",
  77. "treespec_pprint",
  78. "is_namedtuple",
  79. "is_namedtuple_class",
  80. "is_namedtuple_instance",
  81. "is_structseq",
  82. "is_structseq_class",
  83. "is_structseq_instance",
  84. ]
  85. T = TypeVar("T")
  86. S = TypeVar("S")
  87. U = TypeVar("U")
  88. R = TypeVar("R")
  89. DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1
  90. NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
  91. class KeyEntry(Protocol):
  92. def __hash__(self) -> int: ...
  93. def __eq__(self, other: object) -> bool: ...
  94. def __str__(self) -> str: ...
  95. def get(self, parent: Any) -> Any: ...
  96. class EnumEncoder(json.JSONEncoder):
  97. def default(self, obj: object) -> str | dict[str, Any]:
  98. if isinstance(obj, Enum):
  99. return {
  100. "__enum__": True,
  101. "fqn": f"{obj.__class__.__module__}:{obj.__class__.__qualname__}",
  102. "name": obj.name,
  103. }
  104. return cast(str, super().default(obj))
  105. Context = Any
  106. PyTree = Any
  107. FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
  108. UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
  109. DumpableContext = Any # Any json dumpable text
  110. ToDumpableContextFn = Callable[[Context], DumpableContext]
  111. FromDumpableContextFn = Callable[[DumpableContext], Context]
  112. ToStrFunc = Callable[["TreeSpec", list[str]], str]
  113. MaybeFromStrFunc = Callable[[str], tuple[Any, Context, str] | None]
  114. KeyPath = tuple[KeyEntry, ...]
  115. FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
  116. # A NodeDef holds two callables:
  117. # - flatten_fn should take the collection and return a flat list of values.
  118. # It can also return some context that is used in reconstructing the
  119. # collection.
  120. # - unflatten_fn should take a flat list of values and some context
  121. # (returned by flatten_fn). It returns the collection by reconstructing
  122. # it from the list and the context.
  123. # - flatten_with_keys_fn, which is a callable that takes a
  124. # pytree and returns a list of (keypath, value) pairs and a context.
  125. class NodeDef(NamedTuple):
  126. type: type[Any]
  127. flatten_fn: FlattenFunc
  128. unflatten_fn: UnflattenFunc
  129. flatten_with_keys_fn: FlattenWithKeysFunc | None
  130. _NODE_REGISTRY_LOCK = threading.RLock()
  131. SUPPORTED_NODES: dict[type[Any], NodeDef] = {}
  132. # _SerializeNodeDef holds the following:
  133. # - typ: the type of the node (e.g., "Dict", "List", etc)
  134. # - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict"
  135. # - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the
  136. # context, and the version number
  137. # - from_dumpable_context takes in a string representation of the context, and the
  138. # version, and returns the deserialized context
  139. class _SerializeNodeDef(NamedTuple):
  140. typ: type[Any]
  141. serialized_type_name: str
  142. to_dumpable_context: ToDumpableContextFn | None
  143. from_dumpable_context: FromDumpableContextFn | None
  144. SUPPORTED_SERIALIZED_TYPES: dict[type[Any], _SerializeNodeDef] = {}
  145. SERIALIZED_TYPE_TO_PYTHON_TYPE: dict[str, type[Any]] = {}
  146. # NB: we try really hard to not import _cxx_pytree (which depends on optree)
  147. # as much as possible. This is for isolation: a user who is not using C++ pytree
  148. # shouldn't pay for it, and it helps makes things like cpython upgrades easier.
  149. _optree_minimum_version = _TorchVersion("0.13.0")
  150. try:
  151. _optree_version = importlib.metadata.version("optree")
  152. except importlib.metadata.PackageNotFoundError:
  153. # No optree package found
  154. _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
  155. _optree_version = _TorchVersion("0.0.0a0")
  156. else:
  157. _optree_version = _TorchVersion(_optree_version)
  158. if _optree_version < _optree_minimum_version:
  159. # optree package less than our required minimum version.
  160. # Pretend the optree package doesn't exist.
  161. # NB: We will raise ImportError if the user directly tries to
  162. # `import torch.utils._cxx_pytree` (look in that file for the check).
  163. _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
  164. else:
  165. _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = True
  166. _cxx_pytree_imported = False
  167. _cxx_pytree_pending_imports: list[Any] = []
  168. def register_pytree_node(
  169. cls: type[Any],
  170. flatten_fn: FlattenFunc,
  171. unflatten_fn: UnflattenFunc,
  172. *,
  173. serialized_type_name: str | None = None,
  174. to_dumpable_context: ToDumpableContextFn | None = None,
  175. from_dumpable_context: FromDumpableContextFn | None = None,
  176. flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
  177. ) -> None:
  178. """Register a container-like type as pytree node.
  179. Note:
  180. :func:`register_dataclass` is a simpler way of registering a container-like
  181. type as a pytree node.
  182. Args:
  183. cls: the type to register
  184. flatten_fn: A callable that takes a pytree and returns a flattened
  185. representation of the pytree and additional context to represent the
  186. flattened pytree.
  187. unflatten_fn: A callable that takes a flattened version of the pytree,
  188. additional context, and returns an unflattened pytree.
  189. serialized_type_name: A keyword argument used to specify the fully qualified
  190. name used when serializing the tree spec.
  191. to_dumpable_context: An optional keyword argument to custom specify how
  192. to convert the context of the pytree to a custom json dumpable
  193. representation. This is used for json serialization, which is being
  194. used in torch.export right now.
  195. from_dumpable_context: An optional keyword argument to custom specify how
  196. to convert the custom json dumpable representation of the context
  197. back to the original context. This is used for json deserialization,
  198. which is being used in torch.export right now.
  199. flatten_with_keys_fn: An optional keyword argument to specify how to
  200. access each pytree leaf's keypath when flattening and tree-mapping.
  201. Like ``flatten_fn``, but in place of a List[leaf], it should return
  202. a List[(keypath, leaf)].
  203. """
  204. with _NODE_REGISTRY_LOCK:
  205. if cls in SUPPORTED_NODES:
  206. raise ValueError(f"{cls} is already registered as pytree node.")
  207. _private_register_pytree_node(
  208. cls,
  209. flatten_fn,
  210. unflatten_fn,
  211. serialized_type_name=serialized_type_name,
  212. to_dumpable_context=to_dumpable_context,
  213. from_dumpable_context=from_dumpable_context,
  214. flatten_with_keys_fn=flatten_with_keys_fn,
  215. )
  216. if not _cxx_pytree_exists:
  217. return
  218. if _cxx_pytree_imported:
  219. import torch.utils._cxx_pytree as cxx_pytree
  220. cxx_pytree._private_register_pytree_node(
  221. cls,
  222. flatten_fn,
  223. unflatten_fn,
  224. serialized_type_name=serialized_type_name,
  225. to_dumpable_context=to_dumpable_context,
  226. from_dumpable_context=from_dumpable_context,
  227. )
  228. else:
  229. args = (cls, flatten_fn, unflatten_fn)
  230. kwargs = {
  231. "serialized_type_name": serialized_type_name,
  232. "to_dumpable_context": to_dumpable_context,
  233. "from_dumpable_context": from_dumpable_context,
  234. }
  235. _cxx_pytree_pending_imports.append((args, kwargs))
  236. def register_dataclass(
  237. cls: type[Any],
  238. *,
  239. field_names: list[str] | None = None,
  240. drop_field_names: list[str] | None = None,
  241. serialized_type_name: str | None = None,
  242. ) -> None:
  243. """
  244. Registers a type that has the semantics of a ``dataclasses.dataclass`` type
  245. as a pytree node.
  246. This is a simpler API than :func:`register_pytree_node` for registering
  247. a dataclass or a custom class with the semantics of a dataclass.
  248. Args:
  249. cls: The python type to register. The class must have the semantics of a
  250. dataclass; in particular, it must be constructed by passing the fields
  251. in.
  252. field_names (Optional[List[str]]): A list of field names that correspond
  253. to the **non-constant data** in this class. This list must contain
  254. all the fields that are used to initialize the class. This argument
  255. is optional if ``cls`` is a dataclass, in which case the fields will
  256. be taken from ``dataclasses.fields()``.
  257. drop_field_names (Optional[List[str]]): A list of field names that
  258. should not be included in the pytree.
  259. serialized_type_name: A keyword argument used to specify the fully
  260. qualified name used when serializing the tree spec. This is only
  261. needed for serializing the treespec in torch.export.
  262. Example:
  263. >>> from torch import Tensor
  264. >>> from dataclasses import dataclass
  265. >>> import torch.utils._pytree as pytree
  266. >>>
  267. >>> @dataclass
  268. >>> class Point:
  269. >>> x: Tensor
  270. >>> y: Tensor
  271. >>>
  272. >>> pytree.register_dataclass(Point)
  273. >>>
  274. >>> point = Point(torch.tensor(0), torch.tensor(1))
  275. >>> point = pytree.tree_map(lambda x: x + 1, point)
  276. >>> assert torch.allclose(point.x, torch.tensor(1))
  277. >>> assert torch.allclose(point.y, torch.tensor(2))
  278. """
  279. drop_field_names = drop_field_names or []
  280. if not dataclasses.is_dataclass(cls):
  281. if field_names is None:
  282. raise ValueError(
  283. "field_names must be specified with a list of all fields used to "
  284. f"initialize {cls}, as it is not a dataclass."
  285. )
  286. elif field_names is None:
  287. field_names = [f.name for f in dataclasses.fields(cls) if f.init]
  288. else:
  289. dataclass_init_fields = {f.name for f in dataclasses.fields(cls) if f.init}
  290. dataclass_init_fields.difference_update(drop_field_names)
  291. if dataclass_init_fields != set(field_names):
  292. error_msg = "field_names does not include all dataclass fields.\n"
  293. if missing := dataclass_init_fields - set(field_names):
  294. error_msg += (
  295. f"Missing fields in `field_names`: {missing}. If you want "
  296. "to include these fields in the pytree, please add them "
  297. "to `field_names`, otherwise please add them to "
  298. "`drop_field_names`.\n"
  299. )
  300. if unexpected := set(field_names) - dataclass_init_fields:
  301. error_msg += (
  302. f"Unexpected fields in `field_names`: {unexpected}. "
  303. "Please remove these fields, or add them to `drop_field_names`.\n"
  304. )
  305. raise ValueError(error_msg)
  306. def _flatten_fn(obj: Any) -> tuple[list[Any], Context]:
  307. flattened = []
  308. flat_names = []
  309. none_names = []
  310. for name in field_names:
  311. val = getattr(obj, name)
  312. if val is not None:
  313. flattened.append(val)
  314. flat_names.append(name)
  315. else:
  316. none_names.append(name)
  317. return flattened, [flat_names, none_names]
  318. def _unflatten_fn(values: Iterable[Any], context: Context) -> Any:
  319. flat_names, none_names = context
  320. return cls(
  321. **dict(zip(flat_names, values, strict=True)), **dict.fromkeys(none_names)
  322. )
  323. def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
  324. flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc]
  325. return [
  326. (GetAttrKey(k), v) for k, v in zip(flat_names, flattened, strict=True)
  327. ], flat_names
  328. _private_register_pytree_node(
  329. cls,
  330. _flatten_fn,
  331. _unflatten_fn,
  332. serialized_type_name=serialized_type_name,
  333. flatten_with_keys_fn=_flatten_fn_with_keys,
  334. )
  335. CONSTANT_NODES: set[type] = set()
  336. def register_constant(cls: type[Any]) -> None:
  337. """Registers a type as a pytree node with no leaves.
  338. In a :func:`torch.compile` region, if instances of these types get passed to
  339. :func:`torch._dynamo.nonstrict_trace`-ed function, they treated as a
  340. constant (sometimes referred to as "static"):
  341. 1. if the instance object existed before the :func:`torch.compile` region,
  342. we _assume_ no mutation will happen to it inside the :func:`torch.compile`
  343. region, require that it has non-default `__eq__` and `__hash__` methods, and
  344. we guard on the instance based on its `__eq__` method, i.e., if a new
  345. instance fails to match any instances from the previous compilations,
  346. :func:`torch.compile` will recompile the function using the new instance.
  347. 2. else if the instance object is created inside the :func:`torch.compile`
  348. region, we currently don't support using it in a
  349. :func:`torch._dynamo.nonstrict_trace`-ed function.
  350. In general, if your class holds Tensors or dynamic int/float/bool (values that
  351. may change from run-to-run of a function being compiled), then you probably
  352. do not want to register it as a constant.
  353. Otherwise if you want to pass instance of a class to a
  354. :func:`torch._dynamo.nonstrict_trace`-ed function, but you either can't use
  355. :func:`register_pytree_node` on the class, or the class is "constant" enough
  356. that you don't want to bother using :func:`register_pytree_node`, you should
  357. consider using this function.
  358. Args:
  359. cls: the type to register as a constant. This type must be hashable.
  360. Example:
  361. >>> from dataclasses import dataclass
  362. >>> import torch.utils._pytree as pytree
  363. >>>
  364. >>> @dataclass(frozen=True)
  365. >>> class Config:
  366. >>> norm: str
  367. >>>
  368. >>> pytree.register_constant(Config)
  369. >>>
  370. >>> config = Config("l2")
  371. >>> values, spec = pytree.tree_flatten(config)
  372. >>> assert len(values) == 0
  373. """
  374. if cls.__eq__ is object.__eq__: # type: ignore[comparison-overlap]
  375. raise TypeError(
  376. "register_constant(cls) expects `cls` to have a non-default `__eq__` implementation."
  377. )
  378. # Class with a custom `__eq__` without `__hash__` won't inherit the default
  379. # `__hash__` from object; see https://stackoverflow.com/a/1608907.
  380. if cls.__hash__ is None: # type: ignore[comparison-overlap]
  381. raise TypeError(
  382. "register_constant(cls) expects `cls` to have a non-default `__hash__` implementation."
  383. )
  384. def _flatten(x): # type: ignore[no-untyped-def]
  385. return [], ConstantNode(x)
  386. def _unflatten(_, context): # type: ignore[no-untyped-def]
  387. return context.value
  388. def _flatten_with_keys(x): # type: ignore[no-untyped-def]
  389. return [], ConstantNode(x)
  390. with _NODE_REGISTRY_LOCK:
  391. _private_register_pytree_node(
  392. cls,
  393. _flatten,
  394. _unflatten,
  395. flatten_with_keys_fn=_flatten_with_keys,
  396. )
  397. CONSTANT_NODES.add(cls)
  398. def is_constant_class(cls: type[Any]) -> bool:
  399. return isinstance(cls, type) and cls in CONSTANT_NODES
  400. @dataclasses.dataclass(frozen=True, slots=True)
  401. class ConstantNode(Generic[T]):
  402. value: T
  403. def _is_constant_holder(spec: "TreeSpec") -> bool:
  404. """Checks if the spec is from a pytree registered with register_constant"""
  405. return isinstance(spec._context, ConstantNode)
  406. def _retrieve_constant(spec: "TreeSpec") -> Any:
  407. """Given a spec from a pytree registered with register_constant, retrieves the constant"""
  408. if not _is_constant_holder(spec):
  409. raise AssertionError("spec does not correspond to a registered constant pytree")
  410. return tree_unflatten([], spec)
  411. def _register_namedtuple(
  412. cls: type[Any],
  413. *,
  414. serialized_type_name: str,
  415. ) -> None:
  416. """
  417. Registers a namedtuple as a valid pytree node. By default namedtuples are
  418. valid pytree nodes, but they are not serializable. This API provides the
  419. argument `serialized_type_name` which allows these namedtuples to be
  420. serialized.
  421. Args:
  422. cls: the dataclass type to register
  423. serialized_type_name: The serialized name for the dataclass. This is
  424. required if you want to serialize the pytree TreeSpec containing this
  425. namedtuple.
  426. """
  427. _private_register_pytree_node(
  428. cls,
  429. _namedtuple_flatten,
  430. _namedtuple_unflatten,
  431. serialized_type_name=serialized_type_name,
  432. to_dumpable_context=_namedtuple_serialize,
  433. from_dumpable_context=_namedtuple_deserialize,
  434. flatten_with_keys_fn=_namedtuple_flatten_with_keys,
  435. )
  436. @deprecated(
  437. "`torch.utils._pytree._register_pytree_node` is deprecated. "
  438. "Please use `torch.utils._pytree.register_pytree_node` instead.",
  439. category=FutureWarning,
  440. )
  441. def _register_pytree_node(
  442. cls: type[Any],
  443. flatten_fn: FlattenFunc,
  444. unflatten_fn: UnflattenFunc,
  445. to_str_fn: ToStrFunc | None = None, # deprecated
  446. maybe_from_str_fn: MaybeFromStrFunc | None = None, # deprecated
  447. *,
  448. serialized_type_name: str | None = None,
  449. to_dumpable_context: ToDumpableContextFn | None = None,
  450. from_dumpable_context: FromDumpableContextFn | None = None,
  451. flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
  452. ) -> None:
  453. """Register a container-like type as pytree node for the Python pytree only.
  454. Args:
  455. cls: the type to register
  456. flatten_fn: A callable that takes a pytree and returns a flattened
  457. representation of the pytree and additional context to represent the
  458. flattened pytree.
  459. unflatten_fn: A callable that takes a flattened version of the pytree,
  460. additional context, and returns an unflattened pytree.
  461. serialized_type_name: A keyword argument used to specify the fully qualified
  462. name used when serializing the tree spec.
  463. to_dumpable_context: An optional keyword argument to custom specify how
  464. to convert the context of the pytree to a custom json dumpable
  465. representation. This is used for json serialization, which is being
  466. used in torch.export right now.
  467. from_dumpable_context: An optional keyword argument to custom specify how
  468. to convert the custom json dumpable representation of the context
  469. back to the original context. This is used for json deserialization,
  470. which is being used in torch.export right now.
  471. flatten_with_keys_fn: An optional keyword argument to specify how to
  472. access each pytree leaf's keypath when flattening and tree-mapping.
  473. Like ``flatten_fn``, but in place of a List[leaf], it should return
  474. a List[(keypath, leaf)].
  475. """
  476. if to_str_fn is not None or maybe_from_str_fn is not None:
  477. warnings.warn(
  478. "`to_str_fn` and `maybe_from_str_fn` is deprecated. "
  479. "Please use `to_dumpable_context` and `from_dumpable_context` instead.",
  480. FutureWarning,
  481. stacklevel=2,
  482. )
  483. _private_register_pytree_node(
  484. cls,
  485. flatten_fn,
  486. unflatten_fn,
  487. serialized_type_name=serialized_type_name,
  488. to_dumpable_context=to_dumpable_context,
  489. from_dumpable_context=from_dumpable_context,
  490. flatten_with_keys_fn=flatten_with_keys_fn,
  491. )
  492. def _deregister_pytree_node(
  493. cls: type[Any],
  494. ) -> None:
  495. """This is an internal function that is used to deregister a pytree node type
  496. for the Python pytree only. This should be only used inside PyTorch.
  497. """
  498. with _NODE_REGISTRY_LOCK:
  499. del SUPPORTED_NODES[cls]
  500. node_def = SUPPORTED_SERIALIZED_TYPES[cls]
  501. del SERIALIZED_TYPE_TO_PYTHON_TYPE[node_def.serialized_type_name]
  502. del SUPPORTED_SERIALIZED_TYPES[cls]
  503. CONSTANT_NODES.discard(cls)
  504. def _private_register_pytree_node(
  505. cls: type[Any],
  506. flatten_fn: FlattenFunc,
  507. unflatten_fn: UnflattenFunc,
  508. *,
  509. serialized_type_name: str | None = None,
  510. to_dumpable_context: ToDumpableContextFn | None = None,
  511. from_dumpable_context: FromDumpableContextFn | None = None,
  512. flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
  513. ) -> None:
  514. """This is an internal function that is used to register a pytree node type
  515. for the Python pytree only. End-users should use :func:`register_pytree_node`
  516. instead.
  517. """
  518. from torch._library.opaque_object import is_opaque_type
  519. if is_opaque_type(cls):
  520. raise ValueError(
  521. f"{cls} cannot be registered as a pytree as it has been "
  522. "registered as an opaque object. Opaque objects must be pytree leaves."
  523. )
  524. with _NODE_REGISTRY_LOCK:
  525. if cls in SUPPORTED_NODES:
  526. # TODO: change this warning to an error after OSS/internal stabilize
  527. warnings.warn(
  528. f"{cls} is already registered as pytree node. "
  529. "Overwriting the previous registration.",
  530. stacklevel=2,
  531. )
  532. node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn)
  533. SUPPORTED_NODES[cls] = node_def
  534. if (to_dumpable_context is None) ^ (from_dumpable_context is None):
  535. raise ValueError(
  536. f"Both to_dumpable_context and from_dumpable_context for {cls} must "
  537. "be None or registered."
  538. )
  539. if serialized_type_name is None:
  540. serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND
  541. serialize_node_def = _SerializeNodeDef(
  542. cls,
  543. serialized_type_name,
  544. to_dumpable_context,
  545. from_dumpable_context,
  546. )
  547. SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
  548. SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
  549. @dataclasses.dataclass(frozen=True, slots=True)
  550. class SequenceKey(Generic[T]):
  551. idx: int
  552. def __str__(self) -> str:
  553. return f"[{self.idx!r}]"
  554. def get(self, sequence: Sequence[T]) -> T:
  555. return sequence[self.idx]
  556. K = TypeVar("K", bound=Hashable)
  557. @dataclasses.dataclass(frozen=True, slots=True)
  558. class MappingKey(Generic[K, T]):
  559. key: K
  560. def __str__(self) -> str:
  561. return f"[{self.key!r}]"
  562. def get(self, mapping: Mapping[K, T]) -> T:
  563. return mapping[self.key]
  564. @dataclasses.dataclass(frozen=True, slots=True)
  565. class GetAttrKey:
  566. name: str
  567. def __str__(self) -> str:
  568. return f".{self.name}"
  569. def get(self, obj: Any) -> Any:
  570. return getattr(obj, self.name)
  571. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  572. def is_namedtuple(obj: object | type) -> bool:
  573. """Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
  574. cls = obj if isinstance(obj, type) else type(obj)
  575. return is_namedtuple_class(cls)
  576. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  577. def is_namedtuple_class(cls: type) -> bool:
  578. """Return whether the class is a subclass of namedtuple."""
  579. return (
  580. isinstance(cls, type)
  581. and issubclass(cls, tuple)
  582. and isinstance(getattr(cls, "_fields", None), tuple)
  583. and all(type(field) is str for field in cls._fields) # type: ignore[attr-defined]
  584. and callable(getattr(cls, "_make", None))
  585. and callable(getattr(cls, "_asdict", None))
  586. )
  587. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  588. def is_namedtuple_instance(obj: object) -> bool:
  589. """Return whether the object is an instance of namedtuple."""
  590. return is_namedtuple_class(type(obj))
  591. _T_co = TypeVar("_T_co", covariant=True)
  592. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  593. class structseq(tuple[_T_co, ...]):
  594. """A generic type stub for CPython's ``PyStructSequence`` type."""
  595. __slots__: ClassVar[tuple[()]] = ()
  596. n_fields: Final[int] # type: ignore[misc]
  597. n_sequence_fields: Final[int] # type: ignore[misc]
  598. n_unnamed_fields: Final[int] # type: ignore[misc]
  599. def __init_subclass__(cls) -> NoReturn:
  600. """Prohibit subclassing."""
  601. raise TypeError("type 'structseq' is not an acceptable base type")
  602. def __new__(
  603. cls: type[Self],
  604. sequence: Iterable[_T_co],
  605. # pyrefly: ignore [bad-function-definition]
  606. dict: dict[str, Any] = ...,
  607. ) -> Self:
  608. raise NotImplementedError
  609. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  610. def is_structseq(obj: object | type) -> bool:
  611. """Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
  612. cls = obj if isinstance(obj, type) else type(obj)
  613. return is_structseq_class(cls)
  614. # Set if the type allows subclassing (see CPython's Include/object.h)
  615. Py_TPFLAGS_BASETYPE: int = 1 << 10
  616. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  617. def is_structseq_class(cls: type) -> bool:
  618. """Return whether the class is a class of PyStructSequence."""
  619. return (
  620. isinstance(cls, type)
  621. # Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
  622. and cls.__bases__ == (tuple,)
  623. # Check PyStructSequence members
  624. and isinstance(getattr(cls, "n_fields", None), int)
  625. and isinstance(getattr(cls, "n_sequence_fields", None), int)
  626. and isinstance(getattr(cls, "n_unnamed_fields", None), int)
  627. # Check the type does not allow subclassing
  628. and not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) # only works for CPython
  629. )
  630. # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
  631. def is_structseq_instance(obj: object) -> bool:
  632. """Return whether the object is an instance of PyStructSequence."""
  633. return is_structseq_class(type(obj))
  634. def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]:
  635. return list(d), None
  636. def _tuple_flatten_with_keys(
  637. d: tuple[T, ...],
  638. ) -> tuple[list[tuple[KeyEntry, T]], Context]:
  639. values, context = _tuple_flatten(d)
  640. # pyrefly: ignore [bad-return]
  641. return [(SequenceKey(i), v) for i, v in enumerate(values)], context
  642. def _tuple_unflatten(values: Iterable[T], context: Context) -> tuple[T, ...]:
  643. return tuple(values)
  644. def _list_flatten(d: list[T]) -> tuple[list[T], Context]:
  645. return d, None
  646. def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]:
  647. values, context = _list_flatten(d)
  648. # pyrefly: ignore [bad-return]
  649. return [(SequenceKey(i), v) for i, v in enumerate(values)], context
  650. def _list_unflatten(values: Iterable[T], context: Context) -> list[T]:
  651. return list(values)
  652. def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]:
  653. return list(d.values()), list(d.keys())
  654. def _dict_flatten_with_keys(
  655. d: dict[Any, T],
  656. ) -> tuple[list[tuple[KeyEntry, T]], Context]:
  657. values, context = _dict_flatten(d)
  658. # pyrefly: ignore [bad-return]
  659. return [(MappingKey(k), v) for k, v in zip(context, values, strict=True)], context
  660. def _dict_unflatten(values: Iterable[T], context: Context) -> dict[Any, T]:
  661. return dict(zip(context, values, strict=True))
  662. def _namedtuple_flatten(d: NamedTuple) -> tuple[list[Any], Context]:
  663. return list(d), type(d)
  664. def _namedtuple_flatten_with_keys(
  665. d: NamedTuple,
  666. ) -> tuple[list[tuple[KeyEntry, Any]], Context]:
  667. values, context = _namedtuple_flatten(d)
  668. # pyrefly: ignore [bad-return]
  669. return (
  670. [
  671. (GetAttrKey(field), v)
  672. for field, v in zip(context._fields, values, strict=True)
  673. ],
  674. context,
  675. )
  676. def _namedtuple_unflatten(values: Iterable[T], context: Context) -> NamedTuple:
  677. return cast(NamedTuple, context(*values))
  678. def _namedtuple_serialize(context: Context) -> DumpableContext:
  679. if context not in SUPPORTED_SERIALIZED_TYPES:
  680. raise NotImplementedError(
  681. f"Can't serialize TreeSpec of namedtuple class {context} because we "
  682. "didn't register a serializated_type_name. Please register using "
  683. "`_register_namedtuple`."
  684. )
  685. serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context]
  686. serialized_type_name = serialize_node_def.serialized_type_name
  687. if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
  688. raise NotImplementedError(
  689. f"Can't serialize TreeSpec of namedtuple class {context} because we "
  690. "couldn't find a serializated_type_name. Please register using "
  691. "`_register_namedtuple`."
  692. )
  693. return serialized_type_name
  694. def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context:
  695. if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
  696. raise NotImplementedError(
  697. f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} "
  698. "because we couldn't find a serializated name."
  699. )
  700. typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context]
  701. return typ
  702. def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]:
  703. return list(d.values()), list(d.keys())
  704. def _ordereddict_flatten_with_keys(
  705. d: OrderedDict[Any, T],
  706. ) -> tuple[list[tuple[KeyEntry, T]], Context]:
  707. values, context = _ordereddict_flatten(d)
  708. # pyrefly: ignore [bad-return]
  709. return [(MappingKey(k), v) for k, v in zip(context, values, strict=True)], context
  710. def _ordereddict_unflatten(
  711. values: Iterable[T],
  712. context: Context,
  713. ) -> OrderedDict[Any, T]:
  714. return OrderedDict((key, value) for key, value in zip(context, values, strict=True))
  715. _odict_flatten = _ordereddict_flatten
  716. _odict_unflatten = _ordereddict_unflatten
  717. def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]:
  718. values, dict_context = _dict_flatten(d)
  719. return values, [d.default_factory, dict_context]
  720. def _defaultdict_flatten_with_keys(
  721. d: defaultdict[Any, T],
  722. ) -> tuple[list[tuple[KeyEntry, T]], Context]:
  723. values, context = _defaultdict_flatten(d)
  724. _, dict_context = context
  725. # pyrefly: ignore [bad-return]
  726. return [
  727. (MappingKey(k), v) for k, v in zip(dict_context, values, strict=True)
  728. ], context
  729. def _defaultdict_unflatten(
  730. values: Iterable[T],
  731. context: Context,
  732. ) -> defaultdict[Any, T]:
  733. default_factory, dict_context = context
  734. return defaultdict(default_factory, _dict_unflatten(values, dict_context))
  735. def _defaultdict_serialize(context: Context) -> DumpableContext:
  736. default_factory, dict_context = context
  737. json_defaultdict = {
  738. "default_factory_module": default_factory.__module__,
  739. "default_factory_name": default_factory.__qualname__,
  740. "dict_context": dict_context,
  741. }
  742. return json_defaultdict
  743. def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
  744. if not isinstance(dumpable_context, dict):
  745. raise AssertionError("dumpable_context must be a dict")
  746. expected_keys = {
  747. "default_factory_module",
  748. "default_factory_name",
  749. "dict_context",
  750. }
  751. if set(dumpable_context) != expected_keys:
  752. raise AssertionError(
  753. f"dumpable_context keys must be {expected_keys}, got {set(dumpable_context)}"
  754. )
  755. default_factory_module = dumpable_context["default_factory_module"]
  756. default_factory_name = dumpable_context["default_factory_name"]
  757. if not isinstance(default_factory_module, str):
  758. raise AssertionError("default_factory_module must be a string")
  759. if not isinstance(default_factory_name, str):
  760. raise AssertionError("default_factory_name must be a string")
  761. module = importlib.import_module(default_factory_module)
  762. default_factory = getattr(module, default_factory_name)
  763. dict_context = dumpable_context["dict_context"]
  764. return [default_factory, dict_context]
  765. def _deque_flatten(d: deque[T]) -> tuple[list[T], Context]:
  766. return list(d), d.maxlen
  767. def _deque_flatten_with_keys(
  768. d: deque[T],
  769. ) -> tuple[list[tuple[KeyEntry, T]], Context]:
  770. values, context = _deque_flatten(d)
  771. # pyrefly: ignore [bad-return]
  772. return [(SequenceKey(i), v) for i, v in enumerate(values)], context
  773. def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]:
  774. return deque(values, maxlen=context)
  775. _private_register_pytree_node(
  776. tuple,
  777. _tuple_flatten,
  778. _tuple_unflatten,
  779. serialized_type_name="builtins.tuple",
  780. flatten_with_keys_fn=_tuple_flatten_with_keys,
  781. )
  782. _private_register_pytree_node(
  783. list,
  784. _list_flatten,
  785. _list_unflatten,
  786. serialized_type_name="builtins.list",
  787. flatten_with_keys_fn=_list_flatten_with_keys,
  788. )
  789. _private_register_pytree_node(
  790. dict,
  791. _dict_flatten,
  792. _dict_unflatten,
  793. serialized_type_name="builtins.dict",
  794. flatten_with_keys_fn=_dict_flatten_with_keys,
  795. )
  796. _private_register_pytree_node(
  797. namedtuple, # type: ignore[arg-type]
  798. _namedtuple_flatten,
  799. _namedtuple_unflatten,
  800. serialized_type_name="collections.namedtuple",
  801. to_dumpable_context=_namedtuple_serialize,
  802. from_dumpable_context=_namedtuple_deserialize,
  803. flatten_with_keys_fn=_namedtuple_flatten_with_keys,
  804. )
  805. _private_register_pytree_node(
  806. OrderedDict,
  807. _ordereddict_flatten,
  808. _ordereddict_unflatten,
  809. serialized_type_name="collections.OrderedDict",
  810. flatten_with_keys_fn=_ordereddict_flatten_with_keys,
  811. )
  812. _private_register_pytree_node(
  813. defaultdict,
  814. _defaultdict_flatten,
  815. _defaultdict_unflatten,
  816. serialized_type_name="collections.defaultdict",
  817. to_dumpable_context=_defaultdict_serialize,
  818. from_dumpable_context=_defaultdict_deserialize,
  819. flatten_with_keys_fn=_defaultdict_flatten_with_keys,
  820. )
  821. _private_register_pytree_node(
  822. deque,
  823. _deque_flatten,
  824. _deque_unflatten,
  825. serialized_type_name="collections.deque",
  826. flatten_with_keys_fn=_deque_flatten_with_keys,
  827. )
  828. STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict})
  829. # pyrefly: ignore [no-matching-overload]
  830. BUILTIN_TYPES: frozenset[type] = frozenset(
  831. {
  832. tuple,
  833. list,
  834. dict,
  835. namedtuple, # type: ignore[arg-type]
  836. OrderedDict,
  837. defaultdict,
  838. deque,
  839. },
  840. )
  841. @deprecated(
  842. "torch.utils._pytree._is_namedtuple_instance is private and will be removed in a future release. "
  843. "Please use torch.utils._pytree.is_namedtuple_instance instead.",
  844. category=FutureWarning,
  845. )
  846. def _is_namedtuple_instance(tree: Any) -> bool:
  847. return is_namedtuple_instance(tree)
  848. def _get_node_type(tree: Any) -> Any:
  849. node_type = type(tree)
  850. # All namedtuple types are implicitly registered as pytree nodes.
  851. # XXX: Other parts of the codebase expect namedtuple types always return
  852. # `namedtuple` instead of the actual namedtuple type. Even if the type
  853. # is explicitly registered.
  854. if is_namedtuple_class(node_type):
  855. return namedtuple
  856. return node_type
  857. # A leaf is defined as anything that is not a Node.
  858. def tree_is_leaf(
  859. tree: PyTree,
  860. is_leaf: Callable[[PyTree], bool] | None = None,
  861. ) -> bool:
  862. """Check if a pytree is a leaf.
  863. >>> tree_is_leaf(1)
  864. True
  865. >>> tree_is_leaf(None)
  866. True
  867. >>> tree_is_leaf([1, 2, 3])
  868. False
  869. >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
  870. True
  871. >>> tree_is_leaf({"a": 1, "b": 2, "c": 3})
  872. False
  873. >>> tree_is_leaf({"a": 1, "b": 2, "c": None})
  874. False
  875. """
  876. if is_leaf is not None and is_leaf(tree):
  877. return True
  878. return _get_node_type(tree) not in SUPPORTED_NODES
  879. @deprecated(
  880. "torch.utils._pytree._is_leaf is private and will be removed in a future release. "
  881. "Please use torch.utils._pytree.tree_is_leaf instead.",
  882. category=FutureWarning,
  883. )
  884. def _is_leaf(tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None) -> bool:
  885. return tree_is_leaf(tree, is_leaf=is_leaf)
  886. # A TreeSpec represents the structure of a pytree. It holds:
  887. # "type": the type of root Node of the pytree
  888. # context: some context that is useful in unflattening the pytree
  889. # children(): specs for each child of the root Node
  890. # num_nodes: the total number of nodes
  891. # num_leaves: the number of leaves
  892. # num_children: the number of children of the root Node (i.e., len(children()))
  893. # is_leaf(): whether the root Node is a leaf
  894. @dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False, slots=True)
  895. class TreeSpec:
  896. type: Any
  897. _context: Context
  898. _children: list[Self]
  899. num_nodes: int = dataclasses.field(init=False)
  900. num_leaves: int = dataclasses.field(init=False)
  901. num_children: int = dataclasses.field(init=False)
  902. def __init__(
  903. self,
  904. type: Any,
  905. context: Context, # keep for backward compatibility
  906. children_specs: list[Self], # keep for backward compatibility
  907. ) -> None:
  908. object.__setattr__(self, "type", type)
  909. object.__setattr__(self, "_context", context)
  910. object.__setattr__(self, "_children", children_specs)
  911. self.__post_init__()
  912. def __post_init__(self) -> None:
  913. if self.type is None:
  914. if self._context is not None:
  915. raise AssertionError("leaf node should not have a context")
  916. if len(self._children) != 0:
  917. raise AssertionError("leaf node should not have children")
  918. num_nodes = 1
  919. num_leaves = 1
  920. num_children = 0
  921. else:
  922. num_nodes = 1
  923. num_leaves = 0
  924. for child in self._children:
  925. num_nodes += child.num_nodes
  926. num_leaves += child.num_leaves
  927. num_children = len(self._children)
  928. object.__setattr__(self, "num_nodes", num_nodes)
  929. object.__setattr__(self, "num_leaves", num_leaves)
  930. object.__setattr__(self, "num_children", num_children)
  931. def __repr__(self, indent: int = 0) -> str:
  932. repr_prefix: str = f"TreeSpec({self.type.__name__}, {self._context}, ["
  933. children_specs_str: str = ""
  934. if self.num_children > 0:
  935. indent += 2
  936. children_specs_str += self._children[0].__repr__(indent)
  937. children_specs_str += "," if self.num_children > 1 else ""
  938. children_specs_str += ",".join(
  939. [
  940. "\n" + " " * indent + child.__repr__(indent)
  941. for child in self._children[1:]
  942. ]
  943. )
  944. repr_suffix: str = f"{children_specs_str}])"
  945. return repr_prefix + repr_suffix
  946. def __eq__(self, other: PyTree) -> bool:
  947. if self is other:
  948. return True
  949. elif other.__class__ is self.__class__:
  950. if str(self.type) != str(other.type):
  951. return False
  952. if self._context != other._context:
  953. return False
  954. elif self._children != other._children:
  955. return False
  956. return True
  957. return NotImplemented
  958. @property
  959. def context(self) -> Context:
  960. return self._context
  961. @property
  962. @deprecated(
  963. "`treespec.children_specs` is deprecated. "
  964. "Use `treespec.child(index)` to access a single child, "
  965. "or `treespec.children()` to get all children.",
  966. category=FutureWarning,
  967. )
  968. def children_specs(self) -> list[Self]:
  969. return self._children
  970. def is_leaf(self) -> bool:
  971. return self.num_nodes == 1 and self.num_leaves == 1
  972. def children(self) -> list[Self]:
  973. return self._children.copy()
  974. def child(self, index: int) -> Self:
  975. return self._children[index]
  976. def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
  977. def helper(treespec: TreeSpec, node: PyTree, subtrees: list[PyTree]) -> None:
  978. if treespec.is_leaf():
  979. subtrees.append(node)
  980. return
  981. node_type = _get_node_type(node)
  982. if treespec.type not in BUILTIN_TYPES:
  983. # Always require custom node types to match exactly
  984. if node_type != treespec.type:
  985. raise ValueError(
  986. f"Type mismatch; "
  987. f"expected {treespec.type!r}, but got {node_type!r}.",
  988. )
  989. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  990. children, context = flatten_fn(node)
  991. if len(children) != treespec.num_children:
  992. raise ValueError(
  993. f"Node arity mismatch; "
  994. f"expected {treespec.num_children}, but got {len(children)}.",
  995. )
  996. if context != treespec._context:
  997. raise ValueError(
  998. f"Node context mismatch for custom node type {treespec.type!r}.",
  999. )
  1000. else:
  1001. # For builtin dictionary types, we allow some flexibility
  1002. # Otherwise, we require exact matches
  1003. both_standard_dict = (
  1004. treespec.type in STANDARD_DICT_TYPES
  1005. and node_type in STANDARD_DICT_TYPES
  1006. )
  1007. if not both_standard_dict and node_type != treespec.type:
  1008. raise ValueError(
  1009. f"Node type mismatch; "
  1010. f"expected {treespec.type!r}, but got {node_type!r}.",
  1011. )
  1012. if len(node) != treespec.num_children:
  1013. raise ValueError(
  1014. f"Node arity mismatch; "
  1015. f"expected {treespec.num_children}, but got {len(node)}.",
  1016. )
  1017. if both_standard_dict:
  1018. # dictionary types are compatible with each other
  1019. dict_context = (
  1020. treespec._context
  1021. if treespec.type is not defaultdict
  1022. # ignore mismatch of `default_factory` for defaultdict
  1023. else treespec._context[1]
  1024. )
  1025. expected_keys = dict_context
  1026. got_key_set = set(node)
  1027. expected_key_set = set(expected_keys)
  1028. if got_key_set != expected_key_set:
  1029. missing_keys = expected_key_set.difference(got_key_set)
  1030. extra_keys = got_key_set.difference(expected_key_set)
  1031. message = ""
  1032. if missing_keys:
  1033. message += f"; missing key(s): {missing_keys}"
  1034. if extra_keys:
  1035. message += f"; extra key(s): {extra_keys}"
  1036. raise ValueError(f"Node keys mismatch{message}.")
  1037. children = [node[key] for key in expected_keys]
  1038. else:
  1039. # node_type is treespec.type
  1040. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  1041. children, context = flatten_fn(node)
  1042. if (
  1043. node_type is not deque # ignore mismatch of `maxlen` for deque
  1044. ) and context != treespec._context:
  1045. raise ValueError(
  1046. f"Node context mismatch for node type {treespec.type!r}; "
  1047. f"expected {treespec._context!r}, but got {context!r}.", # namedtuple type mismatch
  1048. )
  1049. for subtree, subspec in zip(children, treespec._children, strict=True):
  1050. helper(subspec, subtree, subtrees)
  1051. subtrees: list[PyTree] = []
  1052. helper(self, tree, subtrees)
  1053. return subtrees
  1054. def unflatten(self, leaves: Iterable[Any]) -> PyTree:
  1055. if not isinstance(leaves, (list, tuple)):
  1056. leaves = list(leaves)
  1057. if len(leaves) != self.num_leaves:
  1058. raise ValueError(
  1059. f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
  1060. f"but the spec refers to a pytree that holds {self.num_leaves} "
  1061. f"items ({self}).",
  1062. )
  1063. if self.is_leaf():
  1064. return leaves[0]
  1065. unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn
  1066. # Recursively unflatten the children
  1067. start = 0
  1068. end = 0
  1069. child_pytrees = []
  1070. for child_spec in self._children:
  1071. end += child_spec.num_leaves
  1072. child_pytrees.append(child_spec.unflatten(leaves[start:end]))
  1073. start = end
  1074. return unflatten_fn(child_pytrees, self._context)
  1075. def __hash__(self) -> int:
  1076. node_type = self.type
  1077. if node_type is defaultdict:
  1078. default_factory, dict_context = self._context
  1079. hashable_context = (default_factory, tuple(dict_context))
  1080. elif node_type in (dict, OrderedDict):
  1081. hashable_context = tuple(self._context)
  1082. elif node_type is None or node_type in BUILTIN_TYPES:
  1083. hashable_context = self._context
  1084. elif isinstance(self._context, ConstantNode):
  1085. hashable_context = self._context.value
  1086. else:
  1087. # The context for user-defined node types might not be hashable.
  1088. # Ignore it for hashing.
  1089. # This does not break the correctness that equal objects imply the
  1090. # same hash. This might increase the hash collision rate, but we
  1091. # don't care about that.
  1092. hashable_context = None
  1093. return hash((node_type, hashable_context, tuple(self._children)))
  1094. PyTreeSpec: TypeAlias = TreeSpec
  1095. # NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
  1096. # this class with `dataclasses.fields`, etc., while having a simplified
  1097. # constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
  1098. # again, with fields that have `init=False`.
  1099. @deprecated(
  1100. "`isinstance(treespec, LeafSpec)` is deprecated, "
  1101. "use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.",
  1102. category=FutureWarning,
  1103. )
  1104. @dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False, slots=True)
  1105. class LeafSpec(TreeSpec):
  1106. type: Any = dataclasses.field(default=None, init=False)
  1107. _context: Context = dataclasses.field(default=None, init=False)
  1108. _children: list[Self] = dataclasses.field(default_factory=list, init=False)
  1109. def __post_init__(self) -> None:
  1110. # Override `__post_init__` for `num_leaves` derivation.
  1111. object.__setattr__(self, "num_nodes", 1)
  1112. object.__setattr__(self, "num_leaves", 1)
  1113. object.__setattr__(self, "num_children", 0)
  1114. def __repr__(self, indent: int = 0) -> str:
  1115. return "*"
  1116. # All leaves are equivalent, so represent with a single object to save on
  1117. # object construction time
  1118. with warnings.catch_warnings():
  1119. warnings.filterwarnings(
  1120. "ignore", category=FutureWarning, module=__name__, append=False
  1121. )
  1122. _LEAF_SPEC = LeafSpec()
  1123. def treespec_leaf() -> LeafSpec:
  1124. """Make a treespec representing a leaf node."""
  1125. return _LEAF_SPEC
  1126. def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
  1127. """Make a tuple treespec from an iterable of child treespecs."""
  1128. children = list(iterable)
  1129. if any(not isinstance(child, TreeSpec) for child in children):
  1130. raise ValueError(f"Expected a tuple of TreeSpec values, got: {children!r}.")
  1131. return TreeSpec(tuple, None, children)
  1132. def treespec_dict(
  1133. mapping: Mapping[Any, TreeSpec] | Iterable[tuple[Any, TreeSpec]] = (),
  1134. /,
  1135. **kwargs: TreeSpec,
  1136. ) -> TreeSpec:
  1137. """Make a dict treespec from a dict of child treespecs."""
  1138. dct = dict(mapping, **kwargs)
  1139. if any(not isinstance(child, TreeSpec) for child in dct.values()):
  1140. raise ValueError(f"Expected a dictionary of TreeSpec values, got: {dct!r}.")
  1141. return TreeSpec(dict, list(dct.keys()), list(dct.values()))
  1142. def _is_pytreespec_instance(
  1143. obj: Any,
  1144. ) -> TypeIs[Union[TreeSpec, "cxx_pytree.PyTreeSpec"]]:
  1145. if isinstance(obj, TreeSpec):
  1146. return True
  1147. if "torch.utils._cxx_pytree" in sys.modules:
  1148. # The C++ pytree module is not always available, so we check if it is loaded.
  1149. # If the C++ pytree module is loaded, we can check if the treespec
  1150. # is an instance of the C++ TreeSpec class.
  1151. import torch.utils._cxx_pytree as cxx_pytree
  1152. if isinstance(obj, cxx_pytree.PyTreeSpec):
  1153. return True
  1154. if "torch._dynamo.polyfills.pytree" in sys.modules:
  1155. # The PyTorch Dynamo pytree module is not always available, so we check if it is loaded.
  1156. # If the PyTorch Dynamo pytree module is loaded, we can check if the treespec
  1157. # is an instance of the PyTorch Dynamo TreeSpec class.
  1158. import torch._dynamo.polyfills.pytree as dynamo_pytree
  1159. return isinstance(obj, dynamo_pytree.PyTreeSpec)
  1160. return False
  1161. def _ensure_python_treespec_instance(
  1162. treespec: Union[TreeSpec, "cxx_pytree.PyTreeSpec"],
  1163. ) -> TreeSpec:
  1164. if isinstance(treespec, TreeSpec):
  1165. return treespec
  1166. if not _is_pytreespec_instance(treespec):
  1167. raise TypeError(
  1168. f"Expected `treespec` to be an instance of "
  1169. f"PyTreeSpec but got item of type {type(treespec)}."
  1170. )
  1171. dummy_tree = treespec.unflatten([0] * treespec.num_leaves)
  1172. return tree_structure(dummy_tree)
  1173. def tree_flatten(
  1174. tree: PyTree,
  1175. is_leaf: Callable[[PyTree], bool] | None = None,
  1176. ) -> tuple[list[Any], TreeSpec]:
  1177. """Flattens a pytree into a list of values and a TreeSpec that can be used
  1178. to reconstruct the pytree.
  1179. """
  1180. def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:
  1181. if tree_is_leaf(node, is_leaf=is_leaf):
  1182. leaves.append(node)
  1183. return _LEAF_SPEC
  1184. node_type = _get_node_type(node)
  1185. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  1186. children, context = flatten_fn(node)
  1187. # Recursively flatten the children
  1188. subspecs = [helper(child, leaves) for child in children]
  1189. return TreeSpec(node_type, context, subspecs)
  1190. leaves: list[Any] = []
  1191. treespec = helper(tree, leaves)
  1192. return leaves, treespec
  1193. def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
  1194. """Given a list of values and a TreeSpec, builds a pytree.
  1195. This is the inverse operation of `tree_flatten`.
  1196. """
  1197. if not _is_pytreespec_instance(treespec):
  1198. if not _is_pytreespec_instance(leaves):
  1199. raise TypeError(
  1200. f"Expected `treespec` to be an instance of "
  1201. f"PyTreeSpec but got item of type {type(treespec)}."
  1202. )
  1203. # Allow passing the PyTreeSpec instance as the first argument
  1204. leaves, treespec = treespec, leaves
  1205. return treespec.unflatten(leaves)
  1206. def tree_iter(
  1207. tree: PyTree,
  1208. is_leaf: Callable[[PyTree], bool] | None = None,
  1209. ) -> Iterable[Any]:
  1210. """Get an iterator over the leaves of a pytree."""
  1211. if tree_is_leaf(tree, is_leaf=is_leaf):
  1212. yield tree
  1213. else:
  1214. node_type = _get_node_type(tree)
  1215. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  1216. child_pytrees, _ = flatten_fn(tree)
  1217. # Recursively flatten the children
  1218. for child in child_pytrees:
  1219. yield from tree_iter(child, is_leaf=is_leaf)
  1220. def tree_leaves(
  1221. tree: PyTree,
  1222. is_leaf: Callable[[PyTree], bool] | None = None,
  1223. ) -> list[Any]:
  1224. """Get a list of leaves of a pytree."""
  1225. return list(tree_iter(tree, is_leaf=is_leaf))
  1226. def tree_structure(
  1227. tree: PyTree,
  1228. is_leaf: Callable[[PyTree], bool] | None = None,
  1229. ) -> TreeSpec:
  1230. """Get the TreeSpec for a pytree."""
  1231. return tree_flatten(tree, is_leaf=is_leaf)[1]
  1232. def tree_map(
  1233. func: Callable[..., Any],
  1234. tree: PyTree,
  1235. *rests: PyTree,
  1236. is_leaf: Callable[[PyTree], bool] | None = None,
  1237. ) -> PyTree:
  1238. """Map a multi-input function over pytree args to produce a new pytree.
  1239. See also :func:`tree_map_`.
  1240. >>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)})
  1241. {'x': 8, 'y': (43, 65)}
  1242. >>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None})
  1243. {'x': False, 'y': (False, False), 'z': True}
  1244. If multiple inputs are given, the structure of the tree is taken from the first input;
  1245. subsequent inputs need only have ``tree`` as a prefix:
  1246. >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
  1247. [[5, 7, 9], [6, 1, 2]]
  1248. Args:
  1249. func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
  1250. corresponding leaves of the pytrees.
  1251. tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
  1252. argument to function ``func``.
  1253. rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
  1254. ``tree`` or has ``tree`` as a prefix.
  1255. is_leaf (callable, optional): An extra leaf predicate function that will be called at each
  1256. flattening step. The function should have a single argument with signature
  1257. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  1258. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  1259. leaf or not. If the function is not specified, the default pytree registry will be used.
  1260. Returns:
  1261. A new pytree with the same structure as ``tree`` but with the value at each leaf given by
  1262. ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
  1263. is the tuple of values at corresponding nodes in ``rests``.
  1264. """
  1265. leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
  1266. flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
  1267. return treespec.unflatten(map(func, *flat_args))
  1268. def tree_map_(
  1269. func: Callable[..., Any],
  1270. tree: PyTree,
  1271. *rests: PyTree,
  1272. is_leaf: Callable[[PyTree], bool] | None = None,
  1273. ) -> PyTree:
  1274. """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
  1275. See also :func:`tree_map`.
  1276. Args:
  1277. func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
  1278. corresponding leaves of the pytrees.
  1279. tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
  1280. argument to function ``func``.
  1281. rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
  1282. ``tree`` or has ``tree`` as a prefix.
  1283. is_leaf (callable, optional): An extra leaf predicate function that will be called at each
  1284. flattening step. The function should have a single argument with signature
  1285. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  1286. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  1287. leaf or not. If the function is not specified, the default pytree registry will be used.
  1288. Returns:
  1289. The original ``tree`` with the value at each leaf is given by the side-effect of function
  1290. ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
  1291. in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
  1292. """
  1293. leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
  1294. flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
  1295. deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
  1296. return tree
  1297. Type2 = tuple[type[T], type[S]]
  1298. Type3 = tuple[type[T], type[S], type[U]]
  1299. TypeAny = type[Any] | tuple[type[Any], ...] | types.UnionType
  1300. Fn2 = Callable[[T | S], R]
  1301. Fn3 = Callable[[T | S | U], R]
  1302. Fn = Callable[[T], R]
  1303. FnAny = Callable[[Any], R]
  1304. MapOnlyFn = Callable[[T], Callable[[Any], Any]]
  1305. # These specializations help with type inference on the lambda passed to this
  1306. # function
  1307. @overload
  1308. def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ...
  1309. @overload
  1310. def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ...
  1311. @overload
  1312. def map_only(
  1313. type_or_types_or_pred: Type3[T, S, U], /
  1314. ) -> MapOnlyFn[Fn3[T, S, U, Any]]: ...
  1315. # This specialization is needed for the implementations below that call
  1316. @overload
  1317. def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ...
  1318. @overload
  1319. def map_only(
  1320. type_or_types_or_pred: Callable[[Any], bool], /
  1321. ) -> MapOnlyFn[FnAny[Any]]: ...
  1322. def map_only(
  1323. type_or_types_or_pred: TypeAny | Callable[[Any], bool], /
  1324. ) -> MapOnlyFn[FnAny[Any]]:
  1325. """
  1326. Suppose you are writing a tree_map over tensors, leaving everything
  1327. else unchanged. Ordinarily you would have to write:
  1328. def go(t):
  1329. if isinstance(t, Tensor):
  1330. return ...
  1331. else:
  1332. return t
  1333. With this function, you only need to write:
  1334. @map_only(Tensor)
  1335. def go(t):
  1336. return ...
  1337. You can also directly use 'tree_map_only'
  1338. """
  1339. if isinstance(type_or_types_or_pred, (type, tuple, types.UnionType)):
  1340. def pred(x: Any) -> bool:
  1341. return isinstance(x, type_or_types_or_pred) # type: ignore[arg-type]
  1342. elif callable(type_or_types_or_pred):
  1343. pred = type_or_types_or_pred # type: ignore[assignment]
  1344. else:
  1345. raise TypeError("Argument must be a type, a tuple of types, or a callable.")
  1346. def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
  1347. @functools.wraps(func)
  1348. def wrapped(x: T) -> Any:
  1349. if pred(x):
  1350. return func(x)
  1351. return x
  1352. return wrapped
  1353. return wrapper
  1354. @overload
  1355. def tree_map_only(
  1356. type_or_types_or_pred: type[T],
  1357. /,
  1358. func: Fn[T, Any],
  1359. tree: PyTree,
  1360. is_leaf: Callable[[PyTree], bool] | None = None,
  1361. ) -> PyTree: ...
  1362. @overload
  1363. def tree_map_only(
  1364. type_or_types_or_pred: Type2[T, S],
  1365. /,
  1366. func: Fn2[T, S, Any],
  1367. tree: PyTree,
  1368. is_leaf: Callable[[PyTree], bool] | None = None,
  1369. ) -> PyTree: ...
  1370. @overload
  1371. def tree_map_only(
  1372. type_or_types_or_pred: Type3[T, S, U],
  1373. /,
  1374. func: Fn3[T, S, U, Any],
  1375. tree: PyTree,
  1376. is_leaf: Callable[[PyTree], bool] | None = None,
  1377. ) -> PyTree: ...
  1378. @overload
  1379. def tree_map_only(
  1380. type_or_types_or_pred: TypeAny,
  1381. /,
  1382. func: FnAny[Any],
  1383. tree: PyTree,
  1384. is_leaf: Callable[[PyTree], bool] | None = None,
  1385. ) -> PyTree: ...
  1386. @overload
  1387. def tree_map_only(
  1388. type_or_types_or_pred: Callable[[Any], bool],
  1389. /,
  1390. func: FnAny[Any],
  1391. tree: PyTree,
  1392. is_leaf: Callable[[PyTree], bool] | None = None,
  1393. ) -> PyTree: ...
  1394. def tree_map_only(
  1395. type_or_types_or_pred: TypeAny | Callable[[Any], bool],
  1396. /,
  1397. func: FnAny[Any],
  1398. tree: PyTree,
  1399. is_leaf: Callable[[PyTree], bool] | None = None,
  1400. ) -> PyTree:
  1401. return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
  1402. @overload
  1403. def tree_map_only_(
  1404. type_or_types_or_pred: type[T],
  1405. /,
  1406. func: Fn[T, Any],
  1407. tree: PyTree,
  1408. is_leaf: Callable[[PyTree], bool] | None = None,
  1409. ) -> PyTree: ...
  1410. @overload
  1411. def tree_map_only_(
  1412. type_or_types_or_pred: Type2[T, S],
  1413. /,
  1414. func: Fn2[T, S, Any],
  1415. tree: PyTree,
  1416. is_leaf: Callable[[PyTree], bool] | None = None,
  1417. ) -> PyTree: ...
  1418. @overload
  1419. def tree_map_only_(
  1420. type_or_types_or_pred: Type3[T, S, U],
  1421. /,
  1422. func: Fn3[T, S, U, Any],
  1423. tree: PyTree,
  1424. is_leaf: Callable[[PyTree], bool] | None = None,
  1425. ) -> PyTree: ...
  1426. @overload
  1427. def tree_map_only_(
  1428. type_or_types_or_pred: TypeAny,
  1429. /,
  1430. func: FnAny[Any],
  1431. tree: PyTree,
  1432. is_leaf: Callable[[PyTree], bool] | None = None,
  1433. ) -> PyTree: ...
  1434. @overload
  1435. def tree_map_only_(
  1436. type_or_types_or_pred: Callable[[Any], bool],
  1437. /,
  1438. func: FnAny[Any],
  1439. tree: PyTree,
  1440. is_leaf: Callable[[PyTree], bool] | None = None,
  1441. ) -> PyTree: ...
  1442. def tree_map_only_(
  1443. type_or_types_or_pred: TypeAny | Callable[[Any], bool],
  1444. /,
  1445. func: FnAny[Any],
  1446. tree: PyTree,
  1447. is_leaf: Callable[[PyTree], bool] | None = None,
  1448. ) -> PyTree:
  1449. return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
  1450. def tree_all(
  1451. pred: Callable[[Any], bool],
  1452. tree: PyTree,
  1453. is_leaf: Callable[[PyTree], bool] | None = None,
  1454. ) -> bool:
  1455. flat_args = tree_iter(tree, is_leaf=is_leaf)
  1456. return all(map(pred, flat_args))
  1457. def tree_any(
  1458. pred: Callable[[Any], bool],
  1459. tree: PyTree,
  1460. is_leaf: Callable[[PyTree], bool] | None = None,
  1461. ) -> bool:
  1462. flat_args = tree_iter(tree, is_leaf=is_leaf)
  1463. return any(map(pred, flat_args))
  1464. @overload
  1465. def tree_all_only(
  1466. type_or_types: type[T],
  1467. /,
  1468. pred: Fn[T, bool],
  1469. tree: PyTree,
  1470. is_leaf: Callable[[PyTree], bool] | None = None,
  1471. ) -> bool: ...
  1472. @overload
  1473. def tree_all_only(
  1474. type_or_types: Type2[T, S],
  1475. /,
  1476. pred: Fn2[T, S, bool],
  1477. tree: PyTree,
  1478. is_leaf: Callable[[PyTree], bool] | None = None,
  1479. ) -> bool: ...
  1480. @overload
  1481. def tree_all_only(
  1482. type_or_types: Type3[T, S, U],
  1483. /,
  1484. pred: Fn3[T, S, U, bool],
  1485. tree: PyTree,
  1486. is_leaf: Callable[[PyTree], bool] | None = None,
  1487. ) -> bool: ...
  1488. def tree_all_only(
  1489. type_or_types: TypeAny,
  1490. /,
  1491. pred: FnAny[bool],
  1492. tree: PyTree,
  1493. is_leaf: Callable[[PyTree], bool] | None = None,
  1494. ) -> bool:
  1495. flat_args = tree_iter(tree, is_leaf=is_leaf)
  1496. return all(pred(x) for x in flat_args if isinstance(x, type_or_types))
  1497. @overload
  1498. def tree_any_only(
  1499. type_or_types: type[T],
  1500. /,
  1501. pred: Fn[T, bool],
  1502. tree: PyTree,
  1503. is_leaf: Callable[[PyTree], bool] | None = None,
  1504. ) -> bool: ...
  1505. @overload
  1506. def tree_any_only(
  1507. type_or_types: Type2[T, S],
  1508. /,
  1509. pred: Fn2[T, S, bool],
  1510. tree: PyTree,
  1511. is_leaf: Callable[[PyTree], bool] | None = None,
  1512. ) -> bool: ...
  1513. @overload
  1514. def tree_any_only(
  1515. type_or_types: Type3[T, S, U],
  1516. /,
  1517. pred: Fn3[T, S, U, bool],
  1518. tree: PyTree,
  1519. is_leaf: Callable[[PyTree], bool] | None = None,
  1520. ) -> bool: ...
  1521. def tree_any_only(
  1522. type_or_types: TypeAny,
  1523. /,
  1524. pred: FnAny[bool],
  1525. tree: PyTree,
  1526. is_leaf: Callable[[PyTree], bool] | None = None,
  1527. ) -> bool:
  1528. flat_args = tree_iter(tree, is_leaf=is_leaf)
  1529. return any(pred(x) for x in flat_args if isinstance(x, type_or_types))
  1530. # Broadcasts a pytree to the provided TreeSpec and returns the flattened
  1531. # values. If this is not possible, then this function returns None.
  1532. #
  1533. # For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
  1534. # would return [0, 0]. This is useful for part of the vmap implementation:
  1535. # a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
  1536. # broadcastable to the tree structure of `inputs` and we use
  1537. # _broadcast_to_and_flatten to check this.
  1538. def _broadcast_to_and_flatten(
  1539. tree: PyTree,
  1540. treespec: TreeSpec,
  1541. is_leaf: Callable[[PyTree], bool] | None = None,
  1542. ) -> list[Any] | None:
  1543. def broadcast_prefix(
  1544. prefix_tree: PyTree,
  1545. full_tree: PyTree,
  1546. is_leaf: Callable[[PyTree], bool] | None = None,
  1547. ) -> list[Any]:
  1548. result: list[Any] = []
  1549. def add_leaves(x: Any, subtree: PyTree) -> None:
  1550. subtreespec = tree_structure(subtree, is_leaf=is_leaf)
  1551. result.extend([x] * subtreespec.num_leaves)
  1552. tree_map_(
  1553. add_leaves,
  1554. prefix_tree,
  1555. full_tree,
  1556. is_leaf=is_leaf,
  1557. )
  1558. return result
  1559. full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
  1560. try:
  1561. return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
  1562. except ValueError:
  1563. return None
  1564. @dataclasses.dataclass
  1565. class _TreeSpecSchema:
  1566. """
  1567. _TreeSpecSchema is the schema used to serialize the TreeSpec
  1568. It contains the following fields:
  1569. - type: A string name of the type. null for the case of a LeafSpec.
  1570. - context: Any format which is json dumpable
  1571. - children_spec: A list of children serialized specs.
  1572. """
  1573. type: str | None
  1574. context: DumpableContext
  1575. children_spec: list["_TreeSpecSchema"]
  1576. class _ProtocolFn(NamedTuple):
  1577. treespec_to_json: Callable[[TreeSpec], DumpableContext]
  1578. json_to_treespec: Callable[[DumpableContext], TreeSpec]
  1579. _SUPPORTED_PROTOCOLS: dict[int, _ProtocolFn] = {}
  1580. def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
  1581. if treespec.is_leaf():
  1582. return _TreeSpecSchema(None, None, [])
  1583. if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
  1584. raise NotImplementedError(
  1585. f"Serializing {treespec.type} in pytree is not registered.",
  1586. )
  1587. serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]
  1588. serialized_type_name = serialize_node_def.serialized_type_name
  1589. if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
  1590. raise NotImplementedError(
  1591. f"No registered serialization name for {treespec.type} found. "
  1592. "Please update your _register_pytree_node call with a `serialized_type_name` kwarg."
  1593. )
  1594. if serialize_node_def.to_dumpable_context is None:
  1595. try:
  1596. serialized_context = json.dumps(treespec._context, cls=EnumEncoder)
  1597. except TypeError as e:
  1598. raise TypeError(
  1599. "Unable to serialize context. "
  1600. "Please make the context json dump-able, or register a "
  1601. "custom serializer using _register_pytree_node."
  1602. ) from e
  1603. else:
  1604. serialized_context = serialize_node_def.to_dumpable_context(treespec._context)
  1605. child_schemas = [_treespec_to_json(child) for child in treespec._children]
  1606. return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
  1607. def enum_object_hook(obj: dict[str, Any]) -> Enum | dict[str, Any]:
  1608. if "__enum__" in obj:
  1609. modname, _, classname = obj["fqn"].partition(":")
  1610. mod = importlib.import_module(modname)
  1611. enum_cls = mod
  1612. for attr in classname.split("."):
  1613. enum_cls = getattr(enum_cls, attr)
  1614. enum_cls = cast(type[Enum], enum_cls)
  1615. return enum_cls[obj["name"]]
  1616. return obj
  1617. def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
  1618. if (
  1619. json_schema["type"] is None
  1620. and json_schema["context"] is None
  1621. and len(json_schema["children_spec"]) == 0
  1622. ):
  1623. return _LEAF_SPEC
  1624. if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
  1625. raise NotImplementedError(
  1626. f"Deserializing {json_schema['type']} in pytree is not registered.",
  1627. )
  1628. typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]]
  1629. serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ]
  1630. if serialize_node_def.from_dumpable_context is None:
  1631. try:
  1632. context = json.loads(json_schema["context"], object_hook=enum_object_hook)
  1633. except TypeError as ex:
  1634. raise TypeError(
  1635. "Unable to deserialize context. "
  1636. "Please make the context json load-able, or register a "
  1637. "custom serializer using _register_pytree_node.",
  1638. ) from ex
  1639. else:
  1640. context = serialize_node_def.from_dumpable_context(json_schema["context"])
  1641. children_specs = [
  1642. _json_to_treespec(child_string) for child_string in json_schema["children_spec"]
  1643. ]
  1644. return TreeSpec(typ, context, children_specs)
  1645. _SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
  1646. def treespec_dumps(treespec: TreeSpec, protocol: int | None = None) -> str:
  1647. treespec = _ensure_python_treespec_instance(treespec)
  1648. if protocol is None:
  1649. protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL
  1650. if protocol in _SUPPORTED_PROTOCOLS:
  1651. json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec)
  1652. else:
  1653. raise ValueError(
  1654. f"Unknown protocol {protocol}. "
  1655. f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}",
  1656. )
  1657. str_spec = json.dumps((protocol, dataclasses.asdict(json_spec)), cls=EnumEncoder)
  1658. return str_spec
  1659. @functools.lru_cache
  1660. def treespec_loads(serialized: str) -> TreeSpec:
  1661. protocol, json_schema = json.loads(serialized)
  1662. if protocol in _SUPPORTED_PROTOCOLS:
  1663. return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema)
  1664. raise ValueError(
  1665. f"Unknown protocol {protocol}. "
  1666. f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}",
  1667. )
  1668. class _DummyLeaf:
  1669. def __repr__(self) -> str:
  1670. return "*"
  1671. def treespec_pprint(treespec: TreeSpec) -> str:
  1672. dummy_tree = tree_unflatten(
  1673. [_DummyLeaf() for _ in range(treespec.num_leaves)],
  1674. treespec,
  1675. )
  1676. return repr(dummy_tree)
  1677. # TODO(angelayi): remove this function after OSS/internal stabilize
  1678. @deprecated(
  1679. "`pytree_to_str` is deprecated. Please use `treespec_dumps` instead.",
  1680. category=FutureWarning,
  1681. )
  1682. def pytree_to_str(treespec: TreeSpec) -> str:
  1683. return treespec_dumps(treespec)
  1684. # TODO(angelayi): remove this function after OSS/internal stabilize
  1685. @deprecated(
  1686. "`str_to_pytree` is deprecated. Please use `treespec_loads` instead.",
  1687. category=FutureWarning,
  1688. )
  1689. def str_to_pytree(json: str) -> TreeSpec:
  1690. return treespec_loads(json)
  1691. def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> list[Any]:
  1692. """Get a flat list of arguments to this function
  1693. A slightly faster version of tree_leaves((args, kwargs))
  1694. """
  1695. leaves: list[Any] = []
  1696. for a in args:
  1697. leaves.extend(tree_iter(a))
  1698. for a in kwargs.values():
  1699. leaves.extend(tree_iter(a))
  1700. return leaves
  1701. def tree_flatten_with_path(
  1702. tree: PyTree,
  1703. is_leaf: Callable[[PyTree], bool] | None = None,
  1704. ) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]:
  1705. """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
  1706. Args:
  1707. tree: a pytree to flatten. If it contains a custom type, that type must be
  1708. registered with an appropriate `tree_flatten_with_path_fn` when registered
  1709. with :func:`register_pytree_node`.
  1710. is_leaf: An extra leaf predicate function that will be called at each
  1711. flattening step. The function should have a single argument with signature
  1712. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  1713. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  1714. leaf or not. If the function is not specified, the default pytree registry will be used.
  1715. Returns:
  1716. A tuple where the first element is a list of (key path, leaf) pairs, and the
  1717. second element is a :class:`TreeSpec` representing the structure of the flattened
  1718. tree.
  1719. """
  1720. _, treespec = tree_flatten(tree, is_leaf)
  1721. return list(_generate_key_paths((), tree, is_leaf)), treespec
  1722. def tree_leaves_with_path(
  1723. tree: PyTree,
  1724. is_leaf: Callable[[PyTree], bool] | None = None,
  1725. ) -> list[tuple[KeyPath, Any]]:
  1726. """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
  1727. Args:
  1728. tree: a pytree. If it contains a custom type, that type must be
  1729. registered with an appropriate `tree_flatten_with_path_fn` when registered
  1730. with :func:`register_pytree_node`.
  1731. is_leaf: An extra leaf predicate function that will be called at each
  1732. flattening step. The function should have a single argument with signature
  1733. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  1734. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  1735. leaf or not. If the function is not specified, the default pytree registry will be used.
  1736. Returns:
  1737. A list of (key path, leaf) pairs.
  1738. """
  1739. return list(_generate_key_paths((), tree, is_leaf))
  1740. def _generate_key_paths(
  1741. key_path: KeyPath,
  1742. tree: PyTree,
  1743. is_leaf: Callable[[PyTree], bool] | None = None,
  1744. ) -> Iterable[tuple[KeyPath, Any]]:
  1745. if is_leaf and is_leaf(tree):
  1746. yield key_path, tree
  1747. return
  1748. node_type = _get_node_type(tree)
  1749. handler = SUPPORTED_NODES.get(node_type)
  1750. if not handler:
  1751. # This is a leaf
  1752. yield key_path, tree
  1753. return
  1754. flatten_with_keys = handler.flatten_with_keys_fn
  1755. if flatten_with_keys:
  1756. key_children, _ = flatten_with_keys(tree)
  1757. for k, c in key_children:
  1758. yield from _generate_key_paths((*key_path, k), c, is_leaf)
  1759. else:
  1760. # We registered this pytree but didn't add a flatten_with_keys_fn, complain.
  1761. raise ValueError(
  1762. f"Did not find a flatten_with_keys_fn for type: {node_type}. "
  1763. "Please pass a flatten_with_keys_fn argument to register_pytree_node."
  1764. )
  1765. def tree_map_with_path(
  1766. func: Callable[..., Any],
  1767. tree: PyTree,
  1768. *rests: PyTree,
  1769. is_leaf: Callable[[PyTree], bool] | None = None,
  1770. ) -> PyTree:
  1771. """Like :func:`tree_map`, but the provided callable takes an additional key path argument.
  1772. Args:
  1773. func: A function that takes ``2 + len(rests)`` arguments, to be applied at the
  1774. corresponding leaves of the pytrees. The first positional argument
  1775. to ``func`` is the key path of the leaf in question. The second
  1776. positional argument is the value of the leaf.
  1777. tree: A pytree to be mapped over, with each leaf providing the first positional
  1778. argument to function ``func``.
  1779. rests: A tuple of pytrees, each of which has the same structure as
  1780. ``tree`` or has ``tree`` as a prefix.
  1781. is_leaf: An extra leaf predicate function that will be called at each
  1782. flattening step. The function should have a single argument with signature
  1783. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  1784. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  1785. leaf or not. If the function is not specified, the default pytree registry will be used.
  1786. Returns
  1787. A new pytree with the same structure as ``tree`` but with the value at each leaf given by
  1788. ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the
  1789. corresponding leaf in ``tree``, ``x`` is the value at that leaf, and
  1790. ``xs`` is the tuple of values at corresponding nodes in ``rests``.
  1791. """
  1792. keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf)
  1793. keypath_leaves = list(zip(*keypath_leaves, strict=True))
  1794. all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
  1795. return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves, strict=True))
  1796. def keystr(kp: KeyPath) -> str:
  1797. """Given a key path, return a pretty-printed representation."""
  1798. return "".join([str(k) for k in kp])
  1799. def key_get(obj: Any, kp: KeyPath) -> Any:
  1800. """Given an object and a key path, return the value at the key path."""
  1801. for k in kp:
  1802. obj = k.get(obj)
  1803. return obj