_traverse.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # ruff: noqa: F821
  3. # flake8: noqa: F821
  4. from collections.abc import Callable, Collection, Mapping, MutableMapping
  5. from typing import cast, TypeVar, Union
  6. import torch
  7. from torch.distributed._shard.sharded_tensor.api import ShardedTensor
  8. from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
  9. from torch.distributed.tensor import DTensor
  10. PATH_ITEM = Union[str, int]
  11. OBJ_PATH = tuple[PATH_ITEM, ...]
  12. T = TypeVar("T")
  13. STATE_DICT_ITEM = object
  14. CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
  15. __all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
  16. def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
  17. return isinstance(value, torch.Tensor)
  18. # TODO: update docstring for traverse.py
  19. def traverse_state_dict(
  20. state_dict: STATE_DICT_TYPE,
  21. visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
  22. keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
  23. ) -> None:
  24. """
  25. Invoke ``visitor`` for each value recursively in ``state_dict``.
  26. Mapping will be traversed and ``visitor`` will be applied to the leaf elements.
  27. ``visitor`` will only be applied to elements in a list or a tuple, if the
  28. container contains tensors or mappings.
  29. """
  30. def _is_terminal(value: STATE_DICT_ITEM) -> bool:
  31. values: Collection[STATE_DICT_ITEM]
  32. if isinstance(value, Mapping):
  33. return False
  34. elif isinstance(value, list):
  35. values = value
  36. else:
  37. return True
  38. for entry in values:
  39. if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
  40. return False
  41. if keep_traversing is not None and keep_traversing(entry):
  42. return False
  43. return True
  44. def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
  45. if isinstance(value, Mapping):
  46. for k, v in value.items():
  47. _traverse_obj(path + (str(k),), v)
  48. elif _is_terminal(value):
  49. visitor(path, value)
  50. elif isinstance(value, (list, tuple)):
  51. for i, v in enumerate(value):
  52. _traverse_obj(path + (i,), v)
  53. for key, value in state_dict.items():
  54. _traverse_obj((str(key),), value)
  55. # release reference cycle to prevent memory leaks in async_save
  56. del _traverse_obj, _is_terminal
  57. def traverse_state_dict_v_2_3(
  58. state_dict: STATE_DICT_TYPE,
  59. visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
  60. keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
  61. ) -> None:
  62. """
  63. Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
  64. to false for all elements.
  65. By default, all collections with at least one ``torch.Tensor`` element are traversed.
  66. Visitor takes a path argument that is a tuple of the keys used to reach it.
  67. """
  68. # a value is terminal if it has no other containers values inside it
  69. def _is_terminal(value: STATE_DICT_ITEM) -> bool:
  70. values: Collection[STATE_DICT_ITEM]
  71. if isinstance(value, Mapping):
  72. values = value.values()
  73. elif isinstance(value, list):
  74. values = value
  75. else:
  76. return True
  77. for entry in values:
  78. if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
  79. return False
  80. if keep_traversing is not None and keep_traversing(entry):
  81. return False
  82. return True
  83. def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
  84. if _is_terminal(value):
  85. visitor(path, value)
  86. elif isinstance(value, Mapping):
  87. for k, v in value.items():
  88. _traverse_obj(path + (str(k),), v)
  89. elif isinstance(value, list):
  90. for i, v in enumerate(value):
  91. _traverse_obj(path + (i,), v)
  92. for key, value in state_dict.items():
  93. _traverse_obj((str(key),), value)
  94. # release reference cycle to prevent memory leaks in async_save
  95. del _traverse_obj, _is_terminal
  96. def set_element(
  97. root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
  98. ) -> None:
  99. """Set ``value`` in ``root_dict`` along the ``path`` object path."""
  100. cur_container = cast(CONTAINER_TYPE, root_dict)
  101. def extend_list(lst: list[STATE_DICT_ITEM], idx: int) -> None:
  102. while len(lst) <= idx:
  103. lst.append(None)
  104. for i in range(1, len(path)):
  105. prev_key = path[i - 1]
  106. key = path[i]
  107. def_val = cast(STATE_DICT_ITEM, {} if type(key) is str else [])
  108. if isinstance(cur_container, Mapping):
  109. cur_container = cast(
  110. CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
  111. )
  112. else:
  113. # pyrefly: ignore [bad-argument-type]
  114. extend_list(cur_container, prev_key)
  115. if cur_container[prev_key] is None:
  116. cur_container[prev_key] = def_val
  117. cur_container = cur_container[prev_key]
  118. key = path[-1]
  119. if type(key) is int:
  120. extend_list(cast(list[STATE_DICT_ITEM], cur_container), key)
  121. cur_container[key] = value
  122. def get_element(
  123. root_dict: STATE_DICT_TYPE,
  124. path: OBJ_PATH,
  125. default_value: T | None = None,
  126. ) -> T | None:
  127. """Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found."""
  128. cur_value = cast(CONTAINER_TYPE, root_dict)
  129. for part in path:
  130. if type(part) is int:
  131. if not isinstance(cur_value, list) or len(cur_value) < part:
  132. return default_value
  133. elif not isinstance(cur_value, Mapping) or part not in cur_value:
  134. return default_value
  135. cur_value = cast(CONTAINER_TYPE, cur_value[part])
  136. return cast(T | None, cur_value)
  137. def _print_nested(
  138. value: STATE_DICT_ITEM,
  139. prefix: str = "",
  140. print_fun: Callable[[str], None] = print,
  141. ) -> None:
  142. if type(value) is ShardedTensor:
  143. print_fun(f"{prefix} ShardedTensor size: {value.size()}")
  144. for shard in value.local_shards():
  145. _print_nested(
  146. shard.tensor,
  147. f"{shard.metadata.shard_offsets} ",
  148. print_fun=print_fun,
  149. )
  150. elif type(value) is (DTensor):
  151. print_fun(f"{prefix} DistributedTensor size: {value.size()}")
  152. # TODO: add local offset for _local_tensor in print_nested.
  153. _print_nested(
  154. value._local_tensor,
  155. print_fun=print_fun,
  156. )
  157. elif isinstance(value, torch.Tensor):
  158. print_fun(f"{prefix} Tensor size: {value.size()}")
  159. else:
  160. print_fun(f"{prefix} Type: {type(value)}")
  161. def print_tensor(
  162. path: OBJ_PATH,
  163. value: STATE_DICT_ITEM,
  164. print_fun: Callable[[str], None] = print,
  165. ) -> None:
  166. """
  167. Use this callback with traverse_state_dict to print its content.
  168. By default the content is printed using the builtin ``print`` but this can
  169. be change by passing a different ``print_fun` callable.
  170. """
  171. _print_nested(value, prefix=str(path), print_fun=print_fun)