_cache.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import copy
  2. import dataclasses
  3. import logging
  4. from abc import ABC, abstractmethod
  5. from collections import defaultdict
  6. from collections.abc import Generator
  7. from contextlib import contextmanager
  8. from itertools import chain
  9. from typing import Any, Optional
  10. from torch.utils._appending_byte_serializer import (
  11. AppendingByteSerializer,
  12. BytesReader,
  13. BytesWriter,
  14. )
  15. from torch.utils._ordered_set import OrderedSet
  16. log = logging.getLogger(__name__)
  17. @dataclasses.dataclass(frozen=True)
  18. class CacheArtifact(ABC):
  19. """
  20. Data for each cache artifact that will be serialized and deserialized
  21. """
  22. key: str
  23. content: bytes = dataclasses.field(repr=False) # Do not display potential binary
  24. @staticmethod
  25. def serialize(writer: BytesWriter, cls: "CacheArtifact") -> None:
  26. writer.write_str(cls.key)
  27. writer.write_bytes(cls.content)
  28. @staticmethod
  29. def deserialize(artifact_type: str, reader: BytesReader) -> "CacheArtifact":
  30. key = reader.read_str()
  31. content = reader.read_bytes()
  32. return CacheArtifactFactory.create(artifact_type, key, content)
  33. @staticmethod
  34. def encode(content: Any) -> bytes:
  35. if not isinstance(content, bytes):
  36. raise AssertionError(f"Expected bytes, got {type(content)}")
  37. return content
  38. @abstractmethod
  39. def populate_cache(self) -> None:
  40. pass
  41. @staticmethod
  42. def type() -> str:
  43. """
  44. Returns the type of the artifact. Must be unique across all CacheArtifact classes.
  45. CacheArtifactFactory.register will add property method to CacheInfo based on this (def {type}_artifacts)
  46. that returns all artifacts for specific cache.
  47. """
  48. raise RuntimeError("CacheArtifact is an abstract class, please use a subclass")
  49. class CacheArtifactFactory:
  50. """
  51. Factory for creating CacheArtifact objects based on their type
  52. """
  53. _artifact_types: dict[str, type[CacheArtifact]] = {}
  54. @classmethod
  55. def register(cls, artifact_cls: type[CacheArtifact]) -> type[CacheArtifact]:
  56. artifact_type_key = artifact_cls.type()
  57. if artifact_cls.type() in cls._artifact_types:
  58. raise AssertionError(
  59. f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory"
  60. )
  61. cls._artifact_types[artifact_type_key] = artifact_cls
  62. setattr(
  63. CacheInfo,
  64. f"{artifact_type_key}_artifacts",
  65. property(lambda self: self.artifacts[artifact_type_key]),
  66. )
  67. return artifact_cls
  68. @classmethod
  69. def _get_artifact_type(cls, artifact_type_key: str) -> type[CacheArtifact]:
  70. if artifact_type_key not in cls._artifact_types:
  71. raise AssertionError(
  72. f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory"
  73. )
  74. return cls._artifact_types[artifact_type_key]
  75. @classmethod
  76. def create(cls, artifact_type_key: str, key: str, content: bytes) -> CacheArtifact:
  77. artifact_cls = cls._get_artifact_type(artifact_type_key)
  78. # pyrefly: ignore [bad-instantiation]
  79. return artifact_cls(key, content)
  80. @classmethod
  81. def encode_create(
  82. cls, artifact_type_key: str, key: str, content: Any
  83. ) -> CacheArtifact:
  84. artifact_cls = cls._get_artifact_type(artifact_type_key)
  85. # pyrefly: ignore [bad-instantiation]
  86. return artifact_cls(key, artifact_cls.encode(content))
  87. @dataclasses.dataclass
  88. class CacheInfo:
  89. """
  90. Return value of serialization and deserialization for the purpose of
  91. instrumentation
  92. """
  93. artifacts: defaultdict[str, list[str]] = dataclasses.field(
  94. default_factory=lambda: defaultdict(list)
  95. )
  96. # Methods set by CacheArtifactFactory.register based on CacheArtifact.type()
  97. @property
  98. def inductor_artifacts(self) -> list[str]: # type: ignore[empty-body]
  99. ...
  100. @property
  101. def autotune_artifacts(self) -> list[str]: # type: ignore[empty-body]
  102. ...
  103. @property
  104. def aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body]
  105. ...
  106. @property
  107. def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
  108. ...
  109. @property
  110. def precompile_artifacts(self) -> list[str]: # type: ignore[empty-body]
  111. ...
  112. def add(self, artifact: CacheArtifact) -> None:
  113. self.artifacts[artifact.type()].append(artifact.key)
  114. def clear(self) -> None:
  115. self.artifacts.clear()
  116. def empty(self) -> bool:
  117. return not self.artifacts
  118. def _serialize_single_cache(
  119. writer: BytesWriter, cls: "tuple[str, list[CacheArtifact]]"
  120. ) -> None:
  121. writer.write_str(cls[0])
  122. writer.write_uint64(len(cls[1]))
  123. for artifact in cls[1]:
  124. CacheArtifact.serialize(writer, artifact)
  125. def _deserialize_single_cache(
  126. reader: BytesReader,
  127. ) -> "tuple[str, list[CacheArtifact]]":
  128. artifacts = []
  129. artifact_type_key = reader.read_str()
  130. num_artifacts = reader.read_uint64()
  131. for _ in range(num_artifacts):
  132. artifacts.append(CacheArtifact.deserialize(artifact_type_key, reader))
  133. return artifact_type_key, artifacts
  134. CacheArtifactsResult = dict[str, list[CacheArtifact]]
  135. class CacheArtifactManager:
  136. """
  137. Lightweight manager class for collecting and processing cache artifacts for
  138. hot loading
  139. Intended Lifecycle:
  140. - Execute code via torch.compile, this will call
  141. CacheArtifactManager.record_artifact on each cache artifact
  142. - Call CacheArtifactManager.serialize to convert all the cache artifacts
  143. to portable format
  144. - Call CacheArtifactManager.deserialize to hot load the cache artifacts on
  145. a potentially different process
  146. NOTE: There's no FB/FC guarantees, results of cache artifacts will not be
  147. used unless code version matches.
  148. """
  149. # Protected by the compile_lock
  150. _new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
  151. # Keep a separate seen artifacts list to make avoid unnecessary duplicates
  152. # This list will not be cleared between serialize() calls
  153. _seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
  154. # When serialize() is called, artifacts are transferred from _cache_artifacts to
  155. # internal data structure of the _serializer
  156. # This allows us to only pay the cost of serialization if serialize() is called
  157. _serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
  158. AppendingByteSerializer(serialize_fn=_serialize_single_cache)
  159. )
  160. _cache_info: CacheInfo = CacheInfo()
  161. @classmethod
  162. def clear(cls) -> None:
  163. cls._new_cache_artifacts.clear()
  164. cls._seen_artifacts.clear()
  165. cls._serializer.clear()
  166. cls._cache_info.clear()
  167. @classmethod
  168. @contextmanager
  169. def with_fresh_cache(cls) -> Generator[None, None, None]:
  170. original_new_cache_artifacts = cls._new_cache_artifacts
  171. original_seen_artifacts = cls._seen_artifacts
  172. original_serializer = cls._serializer
  173. original_cache_info = cls._cache_info
  174. cls._new_cache_artifacts = defaultdict(list)
  175. cls._seen_artifacts = OrderedSet()
  176. cls._serializer = AppendingByteSerializer(serialize_fn=_serialize_single_cache)
  177. cls._cache_info = cls._cache_info.__class__()
  178. try:
  179. yield
  180. finally:
  181. cls._new_cache_artifacts = original_new_cache_artifacts
  182. cls._seen_artifacts = original_seen_artifacts
  183. cls._serializer = original_serializer
  184. cls._cache_info = original_cache_info
  185. @classmethod
  186. def record_artifact(
  187. cls,
  188. artifact_type: str,
  189. key: str,
  190. content: Any,
  191. ) -> None:
  192. """
  193. Called from each caching operation to record the artifact in this
  194. "mega" list
  195. """
  196. artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
  197. if artifact in cls._seen_artifacts:
  198. return
  199. log.debug("Recording %s", str(artifact))
  200. cls._new_cache_artifacts[artifact_type].append(artifact)
  201. cls._seen_artifacts.add(artifact)
  202. @classmethod
  203. def need_serialize(cls) -> bool:
  204. """
  205. Have we seen new artifacts since last serialize call?
  206. """
  207. return len(cls._new_cache_artifacts) != 0
  208. @classmethod
  209. def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
  210. """
  211. Converts the "mega" list into portable format
  212. """
  213. for artifact in chain(*cls._new_cache_artifacts.values()):
  214. log.debug("saving: %s", artifact)
  215. cls._cache_info.add(artifact)
  216. if cls._cache_info.empty():
  217. # If there are not artifacts, dont just return bytes with
  218. # version.
  219. return None
  220. try:
  221. # We deep copy cls._cache_info since later compilations
  222. # can keep adding to cache_info
  223. info = copy.deepcopy(cls._cache_info)
  224. cls._serializer.extend(cls._new_cache_artifacts.items())
  225. artifact_bytes = cls._serializer.to_bytes()
  226. cls._new_cache_artifacts.clear()
  227. return artifact_bytes, info
  228. except Exception:
  229. log.warning("Failed to pickle cache artifacts", exc_info=True)
  230. return None
  231. @staticmethod
  232. def deserialize(serialized_artifacts: bytes) -> Optional[CacheArtifactsResult]:
  233. """
  234. Converts the portable format back into CacheArtifacts
  235. """
  236. try:
  237. CacheArtifactManager._ensure_cache_artifacts_registered()
  238. artifacts = dict(
  239. AppendingByteSerializer.to_list(
  240. serialized_artifacts,
  241. deserialize_fn=_deserialize_single_cache,
  242. )
  243. )
  244. except Exception:
  245. log.warning("Failed to un-pickle cache artifacts", exc_info=True)
  246. return None
  247. return artifacts
  248. @staticmethod
  249. def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
  250. info = CacheInfo()
  251. for artifact in chain(*artifacts.values()):
  252. log.debug("writing: %s", artifact)
  253. info.add(artifact)
  254. artifact.populate_cache()
  255. return info
  256. @classmethod
  257. def _ensure_cache_artifacts_registered(cls) -> None:
  258. """When deserializing caches in fresh process, we need to ensure that all
  259. cache artifacts are registered in the cache registry. This is done by
  260. simply importing all the cache artifacts already wrapped with register call.
  261. """
  262. from torch._dynamo.package import PrecompileCacheArtifact # noqa: F401
  263. from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401
  264. from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
  265. AOTAutogradCacheArtifact,
  266. )
  267. from torch._inductor.codecache import InductorCacheArtifact # noqa: F401
  268. from torch._inductor.runtime.autotune_cache import ( # noqa: F401
  269. AutotuneCacheArtifact,
  270. )