storage.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. import abc
  2. import os
  3. from dataclasses import dataclass
  4. from typing import Any
  5. from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
  6. from torch.distributed.checkpoint.planner import (
  7. LoadPlan,
  8. LoadPlanner,
  9. SavePlan,
  10. SavePlanner,
  11. )
  12. from torch.futures import Future
  13. __all__ = ["WriteResult", "StorageWriter", "StorageReader"]
  14. @dataclass(frozen=True)
  15. class WriteResult:
  16. index: MetadataIndex
  17. size_in_bytes: int
  18. storage_data: Any
  19. class StorageWriter(abc.ABC):
  20. """
  21. Interface used by ``save_state_dict`` to write to storage.
  22. One StorageWriter instance acts as both the coordinator and the follower
  23. in a distributed checkpoint. As part of initialization, each instance
  24. is told its role.
  25. A subclass should expect the following sequence of calls.
  26. 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
  27. 1) (all ranks) set_up_storage_writer()
  28. 2) (all ranks) prepare_local_plan()
  29. 3) (coordinator) prepare_global_plan()
  30. 4) (all ranks) write_data()
  31. 5) (coordinator) finish()
  32. """
  33. @abc.abstractmethod
  34. def reset(self, checkpoint_id: str | os.PathLike | None = None) -> None:
  35. """
  36. Calls to indicates a brand new checkpoint write is going to happen.
  37. A checkpoint_id may be present if users set the checkpoint_id for
  38. this checkpoint write. The meaning of the checkpiont_id is
  39. storage-dependent. It can be a path to a folder/file or a key for
  40. a key-value storage.
  41. Args:
  42. checkpoint_id (Union[str, os.PathLike, None]):
  43. The ID of this checkpoint instance. The meaning of the checkpoint_id
  44. depends on the storage. It can be a path to a folder or to a file.
  45. It can also be a key if the storage is a key-value store.
  46. (Default: ``None``)
  47. """
  48. ...
  49. @abc.abstractmethod
  50. def set_up_storage_writer(
  51. self, is_coordinator: bool, *args: Any, **kwargs: Any
  52. ) -> None:
  53. """
  54. Initialize this instance.
  55. Args:
  56. is_coordinator (bool): Whether this instance is responsible for coordinating
  57. the checkpoint.
  58. """
  59. @abc.abstractmethod
  60. def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
  61. """
  62. Perform storage-specific local planning.
  63. While this method can produce a completely different plan, the recommended
  64. way is to store storage specific data in SavePlan::storage_data.
  65. Args:
  66. plan (SavePlan): The local plan from the ``SavePlanner`` in use.
  67. Returns:
  68. A transformed ``SavePlan`` after storage local planning
  69. """
  70. @abc.abstractmethod
  71. def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
  72. """
  73. Perform centralized planning of storage.
  74. This method is only called on the coordinator instance.
  75. While this method can produce a completely different plan, the preferred
  76. way is to store storage specific data in SavePlan::storage_data.
  77. Args:
  78. plans: A list of ``SavePlan`` instances, one for each rank.
  79. Returns:
  80. A list of transformed ``SavePlan`` after storage global planning
  81. """
  82. @abc.abstractmethod
  83. def write_data(
  84. self, plan: SavePlan, planner: SavePlanner
  85. ) -> Future[list[WriteResult]]:
  86. """
  87. Write all items from ``plan`` using ``planner`` to resolve the data.
  88. A subclass should call ``SavePlanner::resolve_data`` on each item
  89. from the plan to get access to the underlying object to write.
  90. Subclasses should lazily call `resolve_data` as it can allocate memory.
  91. In case of tensors, make following assumptions:
  92. - They might be on any device, including not matching the one on ``WriteItem::tensor_data``
  93. - They might be views or not contiguous. Only the projection needs to be saved.
  94. Args:
  95. plan (SavePlan): The save plan to execute.
  96. planner (SavePlanner): Planner object to be used to resolve items to data.
  97. Returns:
  98. A future that completes to a list of WriteResult
  99. """
  100. @abc.abstractmethod
  101. def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
  102. """
  103. Write the metadata and marks the current checkpoint as successful.
  104. The actual format/schema used for serializing `metadata` is an
  105. implementation detail. The only requirement is that it's recoverable
  106. in to the same object graph.
  107. Args:
  108. metadata (Metadata): metadata for the new checkpoint
  109. results: A list of WriteResults from all ranks.
  110. Returns:
  111. None
  112. """
  113. @classmethod
  114. @abc.abstractmethod
  115. def validate_checkpoint_id(cls, checkpoint_id: str | os.PathLike) -> bool:
  116. """
  117. Check if the given checkpoint_id is supported by the storage. This allow
  118. us to enable automatic storage selection.
  119. """
  120. ...
  121. def storage_meta(self) -> StorageMeta | None:
  122. """
  123. Return the storage-specific metadata. This is used to store additional information
  124. in a checkpoint that can be useful for providing request-level observability. StorageMeta
  125. is passed to the ``SavePlanner`` during save calls. Returns None by default.
  126. TODO: provide an example
  127. """
  128. return None
  129. class StorageReader(abc.ABC):
  130. """
  131. Interface used by ``load_state_dict`` to read from storage.
  132. One StorageReader instance acts as both the coordinator and the follower
  133. in a distributed checkpoint. As part of initialization, each instance
  134. is told its role.
  135. A subclass should expected the following sequence of calls by ``load_state_dict``:
  136. 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
  137. 1) (all ranks) read_metadata()
  138. 2) (all ranks) set_up_storage_reader()
  139. 3) (all ranks) prepare_local_plan()
  140. 4) (coordinator) prepare_global_plan()
  141. 5) (all ranks) read_data()
  142. """
  143. @abc.abstractmethod
  144. def reset(self, checkpoint_id: str | os.PathLike | None = None) -> None:
  145. """
  146. Calls to indicates a brand new checkpoint read is going to happen.
  147. A checkpoint_id may be present if users set the checkpoint_id for
  148. this checkpoint read. The meaning of the checkpiont_id is
  149. storage-dependent. It can be a path to a folder/file or a key for
  150. a key-value storage.
  151. Args:
  152. checkpoint_id (Union[str, os.PathLike, None]):
  153. The ID of this checkpoint instance. The meaning of the checkpoint_id
  154. depends on the storage. It can be a path to a folder or to a file.
  155. It can also be a key if the storage is more like a key-value store.
  156. (Default: ``None``)
  157. """
  158. ...
  159. @abc.abstractmethod
  160. def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata:
  161. """
  162. Read the checkpoint metadata.
  163. Returns:
  164. The metadata object associated with the checkpoint being loaded.
  165. """
  166. @abc.abstractmethod
  167. def set_up_storage_reader(
  168. self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any
  169. ) -> None:
  170. """
  171. Initialize this instance.
  172. Args:
  173. metadata (Metadata): The metadata schema to use.
  174. is_coordinator (bool): Whether this instance is responsible for coordinating
  175. the checkpoint.
  176. """
  177. @abc.abstractmethod
  178. def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
  179. """
  180. Perform storage-specific local planning.
  181. While this method can produce a completely different plan, the recommended
  182. way is to store storage specific data in LoadPlan::storage_data.
  183. Args:
  184. plan (LoadPlan): The local plan from the ``LoadPlan`` in use.
  185. Returns:
  186. A transformed ``LoadPlan`` after storage local planning
  187. """
  188. @abc.abstractmethod
  189. def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
  190. """
  191. Perform centralized planning of storage loading.
  192. This method is only called on the coordinator instance.
  193. While this method can produce a completely different plan, the preferred
  194. way is to store storage specific data in LoadPlan::storage_data.
  195. Args:
  196. plans: A list of ``LoadPlan`` instances, one for each rank.
  197. Returns:
  198. A list of transformed ``LoadPlan`` after storage global planning
  199. """
  200. @abc.abstractmethod
  201. def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
  202. """
  203. Read all items from ``plan`` using ``planner`` to resolve the data.
  204. A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO
  205. object into the right place.
  206. A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the
  207. tensors that in should load data into.
  208. It's the StorageLayer responsibility to properly schedule any cross device copies
  209. required.
  210. Args:
  211. plan (LoadPlan): The local plan to execute on
  212. planner (LoadPlanner): The planner object to use to resolve items.
  213. Returns:
  214. A future that completes once all reads are finished.
  215. """
  216. @classmethod
  217. @abc.abstractmethod
  218. def validate_checkpoint_id(cls, checkpoint_id: str | os.PathLike) -> bool:
  219. """
  220. Check if the given checkpoint_id is supported by the storage. This allow
  221. us to enable automatic storage selection.
  222. """
  223. ...