_extension.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import abc
  3. import io
  4. from collections.abc import Sequence
  5. from typing import cast, IO
  6. # introduced as collections.abc.Buffer in Python 3.12
  7. from typing_extensions import Buffer
  8. from torch._utils import try_import
  9. # NOTE: everything in this file is experimental, and subject to
  10. # change. Feedback and bug fixes are always welcome.
  11. pyzstd_module_name = "pyzstd"
  12. pyzstd = try_import(pyzstd_module_name)
  13. zstandard_module_name = "zstandard"
  14. zstandard = try_import(zstandard_module_name)
  15. __all__ = [
  16. "Extension",
  17. "StreamTransformExtension",
  18. "ZStandard",
  19. "ExtensionRegistry",
  20. ]
  21. class Extension(abc.ABC):
  22. """
  23. Extensions provide modular additions to functionality within distributed checkpointing,
  24. which affect the layout or format of the written artifacts. Extensions may be
  25. built into pytorch, or provided externally.
  26. When writing, the caller provides a list of extension instances of the appropriate
  27. type. Each extension can output a descriptor which is used to reconstitute the
  28. extension at read-time.
  29. """
  30. @staticmethod
  31. @abc.abstractmethod
  32. def registry_name() -> str:
  33. """
  34. See ExtensionRegistry.from_descriptor_list
  35. """
  36. @staticmethod
  37. @abc.abstractmethod
  38. def from_descriptor(version: str) -> "Extension":
  39. """
  40. See ExtensionRegistry.from_descriptor_list
  41. """
  42. @abc.abstractmethod
  43. def get_descriptor(self) -> str:
  44. """
  45. Return descriptor name to be included in metadata. The form should be
  46. "extension_name[@local-domain][/version]".
  47. """
  48. class StreamTransformExtension(Extension):
  49. """
  50. An extension which performs transformation on a byte stream, such as compression
  51. or encryption.
  52. Implementations should try to be memory friendly and performant. For example, don't
  53. read the whole input, then transform it, and write it back. If at all possible, do it in
  54. chunks. But, don't read/transform/write one byte at a time, either.
  55. """
  56. @abc.abstractmethod
  57. def transform_to(self, output: IO[bytes]) -> IO[bytes]:
  58. """
  59. Takes a writeable output stream, and generates a new stream which implements the
  60. output transform. Input data written to the returned stream will be transformed
  61. and written to the `output` argument stream.
  62. """
  63. @abc.abstractmethod
  64. def transform_from(self, input: IO[bytes]) -> IO[bytes]:
  65. """
  66. Takes a readable input stream, and generates a new stream which implements the
  67. input transform. When the returned stream is read, data will be read from the
  68. 'input' stream, transformed, and returned.
  69. """
  70. class ZStandard(StreamTransformExtension):
  71. @staticmethod
  72. def is_available() -> bool:
  73. return zstandard is not None or pyzstd is not None
  74. @staticmethod
  75. # pyrefly: ignore [bad-override]
  76. def from_descriptor(version: str) -> "ZStandard":
  77. if version.partition(".")[0] != "1":
  78. raise ValueError(f"Unknown extension {version=}")
  79. if not ZStandard.is_available():
  80. raise ValueError(
  81. f"Stream with ZStandard compression cannot be processed because "
  82. f"no module named '{zstandard_module_name}' or '{pyzstd_module_name}'"
  83. )
  84. return ZStandard()
  85. @staticmethod
  86. def registry_name() -> str:
  87. return "stream.zstd"
  88. def __init__(self) -> None:
  89. super().__init__()
  90. if not ZStandard.is_available():
  91. raise ValueError(
  92. f"ZStandard extension is unavailable because no module named '{zstandard_module_name}' or '{pyzstd_module_name}'"
  93. )
  94. def get_descriptor(self) -> str:
  95. return f"{self.registry_name()}/1"
  96. def transform_to(self, output: IO[bytes]) -> IO[bytes]:
  97. if zstandard is not None:
  98. compressor = zstandard.ZstdCompressor() # type: ignore[union-attr]
  99. return compressor.stream_writer(output)
  100. class Writer(io.RawIOBase):
  101. def __init__(self, output: IO[bytes]) -> None:
  102. self.output = output
  103. self.compressor = pyzstd.ZstdCompressor() # type: ignore[union-attr]
  104. def writeable(self) -> bool:
  105. return True
  106. def write(self, b: Buffer) -> int | None:
  107. outdata = self.compressor.compress(b)
  108. if outdata:
  109. self.output.write(outdata)
  110. return len(memoryview(b))
  111. def flush(self) -> None:
  112. outdata = self.compressor.flush()
  113. if outdata:
  114. self.output.write(outdata)
  115. self.output.flush()
  116. return cast(IO[bytes], Writer(output))
  117. def transform_from(self, input: IO[bytes]) -> IO[bytes]:
  118. if zstandard is not None:
  119. decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr]
  120. return decompressor.stream_reader(input)
  121. class Reader(io.RawIOBase):
  122. def __init__(self, input: IO[bytes]) -> None:
  123. self.input = input
  124. self.decompressor = pyzstd.EndlessZstdDecompressor() # type: ignore[union-attr]
  125. def readable(self) -> bool:
  126. return True
  127. def readinto(self, b: Buffer) -> int | None:
  128. # This needs to read enough so it can decompress
  129. # something so the output doesn't look like EOF. This
  130. # means reading at least one block. The max block
  131. # size is 128KB, so we read that plus some
  132. # overhead to be sure.
  133. if self.decompressor.needs_input:
  134. indata = self.input.read((128 + 6) * 1024)
  135. else:
  136. indata = b""
  137. bview = memoryview(b)
  138. blen = len(bview)
  139. outdata = self.decompressor.decompress(indata, blen)
  140. if outdata is None:
  141. return None
  142. count = len(outdata)
  143. bview[:count] = outdata
  144. return count
  145. def seekable(self) -> bool:
  146. return False
  147. return cast(IO[bytes], Reader(input))
  148. class ExtensionRegistry:
  149. def __init__(self) -> None:
  150. # Populate default registry contents
  151. self.extensions: dict[str, type[Extension]] = {
  152. cls.registry_name(): cls for cls in (ZStandard,)
  153. }
  154. def register(self, cls: type[Extension]) -> None:
  155. self.extensions[cls.registry_name()] = cls
  156. def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]:
  157. """
  158. Given a seuquence of descriptor strings as returned by
  159. Extension.get_descriptor at save time, creates a sequence of
  160. Extension instances. The name[@local-domain] preceding the
  161. version number is used to look up an implementation class in
  162. the registry, and the version is passed to the class's
  163. from_descriptor static method. If the registry contains no
  164. match, this will throw ValueError. If the from_descriptor
  165. method raises an exception, that will pass through to the
  166. caller.
  167. """
  168. def from_descriptor(desc: str) -> Extension:
  169. name, _, version = desc.partition("/")
  170. if version is None:
  171. version = 0
  172. ext = self.extensions.get(name)
  173. if not ext:
  174. raise ValueError(f"Unknown extension {name=}")
  175. # pyrefly: ignore [bad-argument-type]
  176. return ext.from_descriptor(version)
  177. return [from_descriptor(desc) for desc in descriptors]