| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231 |
- import copy
- import json
- import logging
- from abc import abstractmethod
- from collections import defaultdict
- from collections.abc import Callable
- from dataclasses import dataclass
- from typing import Any, Generic, Optional, TypeVar
- import torch
- from torch._dynamo.package import (
- _BackendId,
- _DynamoCacheEntry,
- DynamoCache,
- PrecompileCacheEntry,
- )
- """
- Classes and implementations related to precompile
- """
- T = TypeVar("T")
- logger = logging.getLogger(__name__)
- @dataclass
- class BackendCacheArtifact(Generic[T]):
- """
- Represents a single serializable backend artifact from a dynamo backend.
- Each BackendCacheArtifact has a key associated with it along with some
- serializable content.
- Example implementation:
- class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
- my_field: int
- def after_deserialization(self) -> MySerializableType:
- result = pickle.loads(self.content)
- # Do some extra work post deserialization
- result.my_post_deserialization_function(self.my_field)
- return result
- """
- key: str
- content: Any
- @abstractmethod
- def after_deserialization(self) -> T:
- """
- Code to be run after reading raw byte contents from disk.
- Generally converts self.content from raw bytes back into its original form.
- """
- ...
- def edit_contents(self, edit_fn: Callable[..., Any]) -> None:
- """
- Edit the contents of the artifact.
- """
- self.content = edit_fn(self.content)
- class EagerCacheArtifact(BackendCacheArtifact[Any]):
- def after_deserialization(self) -> Any:
- return self.content
- class BypassDynamoCacheEntry(Exception):
- pass
- class PrecompileContext:
- """
- PrecompileContext is a special CacheArtifactManager for handling precompilation
- It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
- of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
- together and place it into a global Precompile Cache.
- PrecompileContext has two main portions: dynamo_cache_entries and backend_cache_artifacts.
- When saving, PrecompileContext.serialize() will serialize all dynamo cache entries along with any PrecompileCacheArtifacts that
- are needed to save those dynamo cache entries.
- The following artifact types are supported by PrecompileContext:
- - BundledAOTAutogradCacheArtifact
- """
- # Protected by the compile_lock
- # _backend_artifacts_by_key organizes results by the key of each artifact.
- # Each object here must be serializable
- _backend_artifacts_by_key: dict[_BackendId, BackendCacheArtifact[Any]] = {}
- # On call to `serialize()`, all cache artifacts in _dynamo_cache_entries are converted
- # into DynamoCacheArtifacts and added to _new_cache_artifacts for serialization
- _dynamo_cache_entries: dict[str, _DynamoCacheEntry] = {}
- @classmethod
- def clear(cls) -> None:
- cls._backend_artifacts_by_key.clear()
- cls._dynamo_cache_entries.clear()
- @classmethod
- def record_artifact(
- cls,
- artifact: BackendCacheArtifact[Any],
- ) -> None:
- """
- Records a backend artifact to be used with dynamo cache entries
- """
- # Temporarily disable all dispatch modes (including FakeTensorMode) during
- # deepcopy to avoid issues with cloning fake tensors (e.g., device mesh
- # with meta tensors that fail when cloning due to device mismatches)
- from torch.utils._mode_utils import no_dispatch
- with no_dispatch():
- cls._backend_artifacts_by_key[_BackendId(artifact.key)] = copy.deepcopy(
- artifact
- )
- @classmethod
- def record_dynamo_cache_entry(
- cls, cache_entry: _DynamoCacheEntry, key: str
- ) -> None:
- cls._dynamo_cache_entries[key] = cache_entry
- @classmethod
- def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None:
- """
- Edit the content of an existing artifact
- """
- assert key in cls._backend_artifacts_by_key, f"Key {key} not found in artifacts"
- artifact = cls._backend_artifacts_by_key[_BackendId(key)]
- artifact.edit_contents(edit_fn)
- @classmethod
- def serialize_artifact_by_key(cls, key: str) -> Optional[BackendCacheArtifact[Any]]:
- """
- Return the backend cache artifact with the associated key
- """
- return cls._backend_artifacts_by_key.get(_BackendId(key), None)
- @staticmethod
- def dump_debug_info(
- dynamo_entries: dict[str, _DynamoCacheEntry],
- backend_artifacts: dict[_BackendId, BackendCacheArtifact[Any]],
- ) -> dict[str, Any]:
- """
- Return a JSON serializable debug dump of all entries in the precompile context
- Called in serialize before serialization, and in populate_caches after deserialization
- """
- # Print debug information
- debug_info: defaultdict[str, list[Any]] = defaultdict(list)
- for key, cache_entry in dynamo_entries.items():
- info = cache_entry.debug_info()
- info["key"] = key
- debug_info["dynamo"].append(info)
- for artifact in backend_artifacts.values():
- debug_info["backends"].append(artifact.key)
- return debug_info
- @classmethod
- def save_to_dynamo_cache(cls) -> dict[str, Any]:
- precompile_cache_entries, debug_info = cls.create_cache_entries()
- for key, entry in precompile_cache_entries.items():
- DynamoCache.write(entry, key)
- return debug_info
- @classmethod
- def create_cache_entries(
- cls,
- ) -> tuple[dict[str, PrecompileCacheEntry], dict[str, Any]]:
- """
- Grabs all the cache entries in the precompile context and
- stitches them together into full PrecompileCacheEntries.
- """
- dynamo_entries = cls._dynamo_cache_entries
- backend_artifacts = cls._backend_artifacts_by_key
- num_artifacts = len(dynamo_entries)
- debug_info = PrecompileContext.dump_debug_info(
- dynamo_entries, backend_artifacts
- )
- debug_str = json.dumps(
- {
- "num_entries": num_artifacts,
- "artifacts": debug_info,
- },
- )
- torch._logging.trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": "dynamo_cache_entries",
- "encoding": "json",
- },
- payload_fn=lambda: debug_str,
- expect_trace_id=False,
- )
- precompile_cache_entries = {}
- for key, cache_entry in dynamo_entries.items():
- try:
- result = PrecompileCacheEntry.from_cache_entry(
- cache_entry, backend_artifacts
- )
- if result is not None:
- precompile_cache_entries[key] = result
- except Exception as e:
- logger.warning("Failed to create cache entry %s", key, exc_info=True)
- error = e
- data = json.dumps(
- {
- "key": key,
- "error": str(error),
- }
- )
- torch._logging.trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": "dynamo_cache_exception",
- "encoding": "json",
- },
- payload_fn=lambda: data,
- )
- continue
- return precompile_cache_entries, debug_info
|