_pytree.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from collections import namedtuple
  2. from collections.abc import Callable
  3. from typing import Any, Optional, TypeVar
  4. from typing_extensions import NamedTuple
  5. import torch.return_types
  6. from torch.utils._pytree import PyTree, tree_flatten, TreeSpec
  7. FlattenFuncSpec = Callable[[PyTree, TreeSpec], list]
  8. FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
  9. SUPPORTED_NODES: dict[type[Any], FlattenFuncSpec] = {}
  10. SUPPORTED_NODES_EXACT_MATCH: dict[type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
  11. _T = TypeVar("_T")
  12. _K = TypeVar("_K")
  13. _V = TypeVar("_V")
  14. def register_pytree_flatten_spec(
  15. cls: type[Any],
  16. flatten_fn_spec: FlattenFuncSpec,
  17. flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
  18. ) -> None:
  19. SUPPORTED_NODES[cls] = flatten_fn_spec
  20. SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec
  21. def _deregister_pytree_flatten_spec(
  22. cls: type[Any],
  23. ) -> None:
  24. del SUPPORTED_NODES[cls]
  25. del SUPPORTED_NODES_EXACT_MATCH[cls]
  26. def tree_flatten_spec(
  27. pytree: PyTree,
  28. spec: TreeSpec,
  29. ) -> list[Any]:
  30. if spec.is_leaf():
  31. return [pytree]
  32. # I guess these exist for BC, FC reasons.
  33. # In general, we should be able to directly
  34. # use pytree tree flattener to flatten them,
  35. # as export serializes the pytree separately.
  36. # Will remove it in follow up PR.
  37. if spec.type in SUPPORTED_NODES:
  38. flatten_fn_spec = SUPPORTED_NODES[spec.type]
  39. child_pytrees = flatten_fn_spec(pytree, spec)
  40. result = []
  41. for child, child_spec in zip(child_pytrees, spec.children()):
  42. flat = tree_flatten_spec(child, child_spec)
  43. result += flat
  44. return result
  45. flat_result, real_spec = tree_flatten(pytree)
  46. if spec != real_spec:
  47. raise RuntimeError(
  48. f"Real spec {real_spec} of object {pytree} is different from expected spec {spec}. "
  49. f"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml"
  50. )
  51. return flat_result
  52. def _dict_flatten_spec(d: dict[_K, _V], spec: TreeSpec) -> list[_V]:
  53. return [d[k] for k in spec.context]
  54. def _list_flatten_spec(d: list[_T], spec: TreeSpec) -> list[_T]:
  55. return [d[i] for i in range(spec.num_children)]
  56. def _tuple_flatten_spec(d: tuple[_T, ...], spec: TreeSpec) -> list[_T]:
  57. return [d[i] for i in range(spec.num_children)]
  58. def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> list[Any]:
  59. return [d[i] for i in range(spec.num_children)]
  60. def _dict_flatten_spec_exact_match(d: dict[_K, _V], spec: TreeSpec) -> bool:
  61. return len(d) == spec.num_children
  62. def _list_flatten_spec_exact_match(d: list[_T], spec: TreeSpec) -> bool:
  63. return len(d) == spec.num_children
  64. def _tuple_flatten_spec_exact_match(d: tuple[_T, ...], spec: TreeSpec) -> bool:
  65. return len(d) == spec.num_children
  66. def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool:
  67. return len(d) == spec.num_children
  68. register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match)
  69. register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match)
  70. register_pytree_flatten_spec(
  71. tuple,
  72. _tuple_flatten_spec,
  73. _tuple_flatten_spec_exact_match,
  74. )
  75. for return_type in torch.return_types.all_return_types:
  76. register_pytree_flatten_spec(
  77. return_type,
  78. _tuple_flatten_spec,
  79. _tuple_flatten_spec_exact_match,
  80. )
  81. register_pytree_flatten_spec(
  82. namedtuple, # type: ignore[arg-type]
  83. _namedtuple_flatten_spec,
  84. _namedtuple_flatten_spec_exact_match,
  85. )