metadata.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # mypy: allow-untyped-defs
  2. import os
  3. from collections.abc import Sequence
  4. from dataclasses import dataclass, field
  5. from enum import Enum
  6. from typing import Any, Union
  7. import torch
  8. from torch.distributed.checkpoint.stateful import StatefulT
  9. __all__ = [
  10. "ChunkStorageMetadata",
  11. "TensorStorageMetadata",
  12. "BytesStorageMetadata",
  13. "Metadata",
  14. "MetadataIndex",
  15. "TensorProperties",
  16. "StorageMeta",
  17. ]
  18. @dataclass
  19. class ChunkStorageMetadata:
  20. """
  21. Each chunk is expected to have the same properties of the TensorStorageMetadata
  22. that includes it.
  23. """
  24. offsets: torch.Size
  25. sizes: torch.Size
  26. class _MEM_FORMAT_ENCODING(Enum):
  27. """Describe the memory format of a tensor."""
  28. TORCH_CONTIGUOUS_FORMAT = 0
  29. TORCH_CHANNELS_LAST = 1
  30. TORCH_PRESERVE_FORMAT = 2
  31. @dataclass
  32. class TensorProperties:
  33. """Properties used to create :class:`Tensor`"""
  34. # Regular tensor fields
  35. dtype: torch.dtype = field(default_factory=torch.get_default_dtype)
  36. # This field is deprecated.
  37. layout: torch.layout = field(default=torch.strided)
  38. # This field is deprecated.
  39. requires_grad: bool = False
  40. # This field is deprecated.
  41. memory_format: torch.memory_format = field(default=torch.contiguous_format)
  42. # This field is deprecated.
  43. pin_memory: bool = False
  44. def __getstate__(self):
  45. # Since torch.memory_format cannot be pickled!
  46. memory_format = self.memory_format
  47. if memory_format == torch.contiguous_format:
  48. mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
  49. elif memory_format == torch.channels_last:
  50. mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST
  51. elif memory_format == torch.preserve_format:
  52. mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT
  53. else:
  54. raise RuntimeError(f"Invalid torch.memory_format: {memory_format}")
  55. return (
  56. self.dtype,
  57. self.layout,
  58. self.requires_grad,
  59. mem_format_encoding,
  60. self.pin_memory,
  61. )
  62. def __setstate__(
  63. self,
  64. state,
  65. ):
  66. (
  67. self.dtype,
  68. self.layout,
  69. self.requires_grad,
  70. mem_format_encoding,
  71. self.pin_memory,
  72. ) = state
  73. if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
  74. memory_format = torch.contiguous_format
  75. elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST:
  76. memory_format = torch.channels_last
  77. elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT:
  78. memory_format = torch.preserve_format
  79. else:
  80. raise RuntimeError(
  81. f"Invalid torch.memory_format encoding: {mem_format_encoding}"
  82. )
  83. self.memory_format = memory_format
  84. @staticmethod
  85. def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties":
  86. return TensorProperties(
  87. dtype=tensor.dtype,
  88. layout=tensor.layout,
  89. requires_grad=tensor.requires_grad,
  90. memory_format=torch.contiguous_format,
  91. pin_memory=tensor.is_pinned(),
  92. )
  93. @dataclass
  94. class TensorStorageMetadata:
  95. properties: TensorProperties
  96. size: torch.Size
  97. chunks: list[ChunkStorageMetadata]
  98. @dataclass
  99. class BytesStorageMetadata:
  100. pass
  101. STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
  102. STATE_DICT_TYPE = dict[str, StatefulT | Any]
  103. @dataclass
  104. class StorageMeta:
  105. checkpoint_id: str | os.PathLike | None = None
  106. save_id: str | None = None
  107. load_id: str | None = None
  108. modules: list[str] = field(default_factory=list)
  109. @dataclass
  110. class Metadata:
  111. """This class represents the metadata of the checkpoint."""
  112. # Keys are the same from the `state_dict` used.
  113. state_dict_metadata: dict[str, STORAGE_TYPES]
  114. # It is the responsibility of the planner and storage plugins to ensure
  115. # backward compatibility of the planner_data and storage_data. DCP will
  116. # also ensure the backward compatibility of the metadata in this file and
  117. # the metadata of the built-in planner and storage plugins.
  118. planner_data: Any = None
  119. storage_data: Any = None
  120. storage_meta: StorageMeta | None = None
  121. version: str | None = None
  122. @dataclass(frozen=True)
  123. class MetadataIndex:
  124. """This class represents a lookup key for items in a state dict or Metadata."""
  125. fqn: str
  126. """Fully Qualified Name of the object"""
  127. offset: torch.Size | None = None
  128. """If the object is a tensor, offset into the tensor we're looking for"""
  129. index: int | None = field(hash=False, compare=False, default=None)
  130. """
  131. Index hint when searching for tensor chunk to speedup lookups (optional)
  132. A common representation of a sharded tensor is as a list of chunks so to
  133. find the index in such a list you need to linear search it.
  134. When constructing an instance of MetadataIndex that points to that list,
  135. one can provide the index as a hint and it will be probed first before
  136. the linear search and thus making it significantly faster.
  137. """
  138. def __init__(
  139. self,
  140. fqn: str,
  141. offset: Sequence[int] | None = None,
  142. index: int | None = None,
  143. ):
  144. # We must use object.__setattr__ due to frozen=True
  145. object.__setattr__(self, "fqn", fqn)
  146. object.__setattr__(self, "index", index)
  147. if offset is not None:
  148. object.__setattr__(self, "offset", torch.Size(offset))