precompile_context.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import copy
  2. import json
  3. import logging
  4. from abc import abstractmethod
  5. from collections import defaultdict
  6. from collections.abc import Callable
  7. from dataclasses import dataclass
  8. from typing import Any, Generic, Optional, TypeVar
  9. import torch
  10. from torch._dynamo.package import (
  11. _BackendId,
  12. _DynamoCacheEntry,
  13. DynamoCache,
  14. PrecompileCacheEntry,
  15. )
  16. """
  17. Classes and implementations related to precompile
  18. """
  19. T = TypeVar("T")
  20. logger = logging.getLogger(__name__)
  21. @dataclass
  22. class BackendCacheArtifact(Generic[T]):
  23. """
  24. Represents a single serializable backend artifact from a dynamo backend.
  25. Each BackendCacheArtifact has a key associated with it along with some
  26. serializable content.
  27. Example implementation:
  28. class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
  29. my_field: int
  30. def after_deserialization(self) -> MySerializableType:
  31. result = pickle.loads(self.content)
  32. # Do some extra work post deserialization
  33. result.my_post_deserialization_function(self.my_field)
  34. return result
  35. """
  36. key: str
  37. content: Any
  38. @abstractmethod
  39. def after_deserialization(self) -> T:
  40. """
  41. Code to be run after reading raw byte contents from disk.
  42. Generally converts self.content from raw bytes back into its original form.
  43. """
  44. ...
  45. def edit_contents(self, edit_fn: Callable[..., Any]) -> None:
  46. """
  47. Edit the contents of the artifact.
  48. """
  49. self.content = edit_fn(self.content)
  50. class EagerCacheArtifact(BackendCacheArtifact[Any]):
  51. def after_deserialization(self) -> Any:
  52. return self.content
  53. class BypassDynamoCacheEntry(Exception):
  54. pass
  55. class PrecompileContext:
  56. """
  57. PrecompileContext is a special CacheArtifactManager for handling precompilation
  58. It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
  59. of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
  60. together and place it into a global Precompile Cache.
  61. PrecompileContext has two main portions: dynamo_cache_entries and backend_cache_artifacts.
  62. When saving, PrecompileContext.serialize() will serialize all dynamo cache entries along with any PrecompileCacheArtifacts that
  63. are needed to save those dynamo cache entries.
  64. The following artifact types are supported by PrecompileContext:
  65. - BundledAOTAutogradCacheArtifact
  66. """
  67. # Protected by the compile_lock
  68. # _backend_artifacts_by_key organizes results by the key of each artifact.
  69. # Each object here must be serializable
  70. _backend_artifacts_by_key: dict[_BackendId, BackendCacheArtifact[Any]] = {}
  71. # On call to `serialize()`, all cache artifacts in _dynamo_cache_entries are converted
  72. # into DynamoCacheArtifacts and added to _new_cache_artifacts for serialization
  73. _dynamo_cache_entries: dict[str, _DynamoCacheEntry] = {}
  74. @classmethod
  75. def clear(cls) -> None:
  76. cls._backend_artifacts_by_key.clear()
  77. cls._dynamo_cache_entries.clear()
  78. @classmethod
  79. def record_artifact(
  80. cls,
  81. artifact: BackendCacheArtifact[Any],
  82. ) -> None:
  83. """
  84. Records a backend artifact to be used with dynamo cache entries
  85. """
  86. # Temporarily disable all dispatch modes (including FakeTensorMode) during
  87. # deepcopy to avoid issues with cloning fake tensors (e.g., device mesh
  88. # with meta tensors that fail when cloning due to device mismatches)
  89. from torch.utils._mode_utils import no_dispatch
  90. with no_dispatch():
  91. cls._backend_artifacts_by_key[_BackendId(artifact.key)] = copy.deepcopy(
  92. artifact
  93. )
  94. @classmethod
  95. def record_dynamo_cache_entry(
  96. cls, cache_entry: _DynamoCacheEntry, key: str
  97. ) -> None:
  98. cls._dynamo_cache_entries[key] = cache_entry
  99. @classmethod
  100. def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None:
  101. """
  102. Edit the content of an existing artifact
  103. """
  104. assert key in cls._backend_artifacts_by_key, f"Key {key} not found in artifacts"
  105. artifact = cls._backend_artifacts_by_key[_BackendId(key)]
  106. artifact.edit_contents(edit_fn)
  107. @classmethod
  108. def serialize_artifact_by_key(cls, key: str) -> Optional[BackendCacheArtifact[Any]]:
  109. """
  110. Return the backend cache artifact with the associated key
  111. """
  112. return cls._backend_artifacts_by_key.get(_BackendId(key), None)
  113. @staticmethod
  114. def dump_debug_info(
  115. dynamo_entries: dict[str, _DynamoCacheEntry],
  116. backend_artifacts: dict[_BackendId, BackendCacheArtifact[Any]],
  117. ) -> dict[str, Any]:
  118. """
  119. Return a JSON serializable debug dump of all entries in the precompile context
  120. Called in serialize before serialization, and in populate_caches after deserialization
  121. """
  122. # Print debug information
  123. debug_info: defaultdict[str, list[Any]] = defaultdict(list)
  124. for key, cache_entry in dynamo_entries.items():
  125. info = cache_entry.debug_info()
  126. info["key"] = key
  127. debug_info["dynamo"].append(info)
  128. for artifact in backend_artifacts.values():
  129. debug_info["backends"].append(artifact.key)
  130. return debug_info
  131. @classmethod
  132. def save_to_dynamo_cache(cls) -> dict[str, Any]:
  133. precompile_cache_entries, debug_info = cls.create_cache_entries()
  134. for key, entry in precompile_cache_entries.items():
  135. DynamoCache.write(entry, key)
  136. return debug_info
  137. @classmethod
  138. def create_cache_entries(
  139. cls,
  140. ) -> tuple[dict[str, PrecompileCacheEntry], dict[str, Any]]:
  141. """
  142. Grabs all the cache entries in the precompile context and
  143. stitches them together into full PrecompileCacheEntries.
  144. """
  145. dynamo_entries = cls._dynamo_cache_entries
  146. backend_artifacts = cls._backend_artifacts_by_key
  147. num_artifacts = len(dynamo_entries)
  148. debug_info = PrecompileContext.dump_debug_info(
  149. dynamo_entries, backend_artifacts
  150. )
  151. debug_str = json.dumps(
  152. {
  153. "num_entries": num_artifacts,
  154. "artifacts": debug_info,
  155. },
  156. )
  157. torch._logging.trace_structured(
  158. "artifact",
  159. metadata_fn=lambda: {
  160. "name": "dynamo_cache_entries",
  161. "encoding": "json",
  162. },
  163. payload_fn=lambda: debug_str,
  164. expect_trace_id=False,
  165. )
  166. precompile_cache_entries = {}
  167. for key, cache_entry in dynamo_entries.items():
  168. try:
  169. result = PrecompileCacheEntry.from_cache_entry(
  170. cache_entry, backend_artifacts
  171. )
  172. if result is not None:
  173. precompile_cache_entries[key] = result
  174. except Exception as e:
  175. logger.warning("Failed to create cache entry %s", key, exc_info=True)
  176. error = e
  177. data = json.dumps(
  178. {
  179. "key": key,
  180. "error": str(error),
  181. }
  182. )
  183. torch._logging.trace_structured(
  184. "artifact",
  185. metadata_fn=lambda: {
  186. "name": "dynamo_cache_exception",
  187. "encoding": "json",
  188. },
  189. payload_fn=lambda: data,
  190. )
  191. continue
  192. return precompile_cache_entries, debug_info