cache.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. from __future__ import annotations
  2. import pickle
  3. from abc import ABC, abstractmethod
  4. from ast import literal_eval
  5. from functools import cached_property
  6. from hashlib import sha256
  7. from os import getenv
  8. from pathlib import Path
  9. from tempfile import gettempdir
  10. from threading import Lock
  11. from typing import Any, Generic, TYPE_CHECKING, TypeVar
  12. from typing_extensions import assert_never, override, Self
  13. from torch.utils._filelock import FileLock
  14. if TYPE_CHECKING:
  15. from concurrent.futures import Future, ThreadPoolExecutor
  16. # TypeVars can't be recursive, so generic types that fall within
  17. # Key or Value can't be bound properly; for example, Key should
  18. # only take tuples of other Key types: tuple[Key, ...]. this is
  19. # a known shortcoming of torch's typing
  20. Key = TypeVar("Key", str, int, tuple[Any, ...])
  21. Value = TypeVar("Value", str, int, tuple[Any, ...], bytes, dict[Any, Any], list[Any])
  22. class CacheError(ValueError):
  23. """
  24. Exception raised for errors encountered during cache operations.
  25. """
  26. class Cache(ABC, Generic[Key, Value]):
  27. """
  28. Abstract base class for cache implementations.
  29. Provides the interface for cache operations.
  30. """
  31. @abstractmethod
  32. def get(self: Self, key: Key) -> Value | None:
  33. """
  34. Retrieve a value from the cache.
  35. Args:
  36. key (Key): The key to look up.
  37. Returns:
  38. Value | None: The cached value if present, else None.
  39. """
  40. @abstractmethod
  41. def insert(self: Self, key: Key, value: Value) -> bool:
  42. """
  43. Insert a value into the cache.
  44. Args:
  45. key (Key): The key to insert.
  46. value (Value): The value to associate with the key.
  47. Returns:
  48. bool: True if the value was inserted, False if the key already exists.
  49. """
  50. class InMemoryCache(Cache[Key, Value]):
  51. """
  52. In-memory cache implementation using a dictionary and thread lock.
  53. """
  54. def __init__(self: Self) -> None:
  55. """
  56. Initialize an empty in-memory cache.
  57. """
  58. self._cache: dict[Key, Value] = {}
  59. self._lock: Lock = Lock()
  60. def get(self: Self, key: Key) -> Value | None:
  61. """
  62. Retrieve a value from the cache.
  63. Args:
  64. key (Key): The key to look up.
  65. Returns:
  66. Value | None: The cached value if present, else None.
  67. """
  68. with self._lock:
  69. if (value := self._cache.get(key)) is not None:
  70. return value
  71. return None
  72. def insert(self: Self, key: Key, value: Value) -> bool:
  73. """
  74. Insert a value into the cache.
  75. Args:
  76. key (Key): The key to insert.
  77. value (Value): The value to associate with the key.
  78. Returns:
  79. bool: True if the value was inserted, False if the key already exists.
  80. """
  81. with self._lock:
  82. if key in self._cache:
  83. # no overwrites for insert!
  84. return False
  85. self._cache[key] = value
  86. return True
  87. @classmethod
  88. def from_env_var(cls, env_var: str) -> Self:
  89. """
  90. Create an in-memory cache from an environment variable.
  91. Args:
  92. env_var (str): Name of the environment variable containing cache data.
  93. Returns:
  94. InMemoryCache: An instance populated from the environment variable.
  95. Raises:
  96. CacheError: If the environment variable is malformed or contains invalid data.
  97. """
  98. cache = cls()
  99. if (env_val := getenv(env_var)) is None:
  100. # env_var doesn't exist = empty cache
  101. return cache
  102. for kv_pair in env_val.split(";"):
  103. # ignore whitespace prefix/suffix
  104. kv_pair = kv_pair.strip()
  105. if not kv_pair:
  106. # kv_pair could be '' if env_val is '' or has ; suffix
  107. continue
  108. try:
  109. # keys and values should be comma separated
  110. key_bytes_repr, value_bytes_repr = kv_pair.split(",", 1)
  111. except ValueError as err:
  112. raise CacheError(
  113. f"Malformed kv_pair {kv_pair!r} from env_var {env_var!r}, likely missing comma separator."
  114. ) from err
  115. # ignore whitespace prefix/suffix, again
  116. key_bytes_repr, value_bytes_repr = (
  117. key_bytes_repr.strip(),
  118. value_bytes_repr.strip(),
  119. )
  120. try:
  121. # check that key_bytes_str is an actual, legitimate encoding
  122. key_bytes = literal_eval(key_bytes_repr)
  123. except (ValueError, SyntaxError) as err:
  124. raise CacheError(
  125. f"Malformed key_bytes_repr {key_bytes_repr!r} in kv_pair {kv_pair!r}, encoding is invalid."
  126. ) from err
  127. try:
  128. # check that value_bytes_str is an actual, legitimate encoding
  129. value_bytes = literal_eval(value_bytes_repr)
  130. except (ValueError, SyntaxError) as err:
  131. raise CacheError(
  132. f"Malformed value_bytes_repr {value_bytes_repr!r} in kv_pair {kv_pair!r}, encoding is invalid."
  133. ) from err
  134. try:
  135. key = pickle.loads(key_bytes)
  136. except pickle.UnpicklingError as err:
  137. raise CacheError(
  138. f"Malformed key_bytes_repr {key_bytes_repr!r} in kv_pair {kv_pair!r}, not un-pickle-able."
  139. ) from err
  140. try:
  141. value = pickle.loads(value_bytes)
  142. except pickle.UnpicklingError as err:
  143. raise CacheError(
  144. f"Malformed value_bytes_repr {value_bytes_repr!r} in kv_pair {kv_pair!r}, not un-pickle-able."
  145. ) from err
  146. # true duplicates, i.e. multiple occurrences of the same key => value
  147. # mapping are ok and treated as a no-op; key duplicates with differing
  148. # values, i.e. key => value_1 and key => value_2 where value_1 != value_2,
  149. # are not okay since we don't allow overwriting cached values (it's bad regardless)
  150. if (not cache.insert(key, value)) and (cache.get(key) != value):
  151. raise CacheError(
  152. f"Multiple values for key {key!r} found, got {cache.get(key)!r} and {value!r}."
  153. )
  154. return cache
  155. @classmethod
  156. def from_file_path(cls, fpath: Path) -> Self:
  157. """
  158. Create an in-memory cache from a file path.
  159. Args:
  160. fpath (Path): Path to the file containing pickled cache data.
  161. Returns:
  162. InMemoryCache: An instance populated from the file.
  163. Raises:
  164. CacheError: If the file is not a valid pickled dictionary.
  165. """
  166. cache = cls()
  167. if not fpath.is_file():
  168. # fpath doesn't exit = empty cache
  169. return cache
  170. try:
  171. with open(fpath, "rb") as fp:
  172. cache._cache = pickle.load(fp)
  173. except pickle.UnpicklingError as err:
  174. raise CacheError(
  175. f"Failed to create cache from file path {fpath}, file contents are un-pickle-able."
  176. ) from err
  177. if not isinstance(cache._cache, dict):
  178. raise CacheError(
  179. f"Failed to create cache from file path {fpath}, file contents not pickled dict[Key, Value]."
  180. )
  181. return cache
  182. class AsyncCache(Cache[Key, Value]):
  183. """
  184. Asynchronous cache implementation using ThreadPoolExecutor.
  185. """
  186. def get_async(
  187. self: Self, key: Key, executor: ThreadPoolExecutor
  188. ) -> Future[Value | None]:
  189. """
  190. Retrieve a value from the cache asynchronously.
  191. Args:
  192. key (Key): The key to look up.
  193. executor (ThreadPoolExecutor): Executor for async execution.
  194. Returns:
  195. Future[Value | None]: Future for the cached value or None.
  196. """
  197. return executor.submit(self.get, key)
  198. def insert_async(
  199. self: Self, key: Key, value: Value, executor: ThreadPoolExecutor
  200. ) -> Future[bool]:
  201. """
  202. Insert a value into the cache asynchronously.
  203. Args:
  204. key (Key): The key to insert.
  205. value (Value): The value to associate with the key.
  206. executor (ThreadPoolExecutor): Executor for async execution.
  207. Returns:
  208. Future[bool]: Future for the result of insertion.
  209. """
  210. return executor.submit(self.insert, key, value)
  211. class OnDiskCache(AsyncCache[Key, Value]):
  212. """
  213. On-disk cache implementation using files and file locks.
  214. Stores cache data in files on disk, with atomic operations and versioning.
  215. Supports custom cache directory names.
  216. Attributes:
  217. version (int): The version used for cache versioning.
  218. name (str): The name of the cache directory.
  219. """
  220. version: int = 0
  221. def __init__(self: Self, name: str | None = None) -> None:
  222. """
  223. Initialize an on-disk cache instance.
  224. Args:
  225. name (str | None, optional): The name of the cache directory. If None,
  226. defaults to "on_disk_cache".
  227. """
  228. self.name = name or "on_disk_cache"
  229. @cached_property
  230. def base_dir(self: Self) -> Path:
  231. """
  232. Get the base directory for the cache.
  233. Returns:
  234. Path: The base directory path for storing cache files.
  235. """
  236. return Path(gettempdir()) / "cache" / self.name
  237. def _fpath_from_key(self: Self, key: Key) -> Path:
  238. """
  239. Get the file path for a given key.
  240. Args:
  241. key (Key): The key to convert to a file path.
  242. Returns:
  243. Path: The file path for the key.
  244. Raises:
  245. CacheError: If the key is not pickle-able.
  246. """
  247. try:
  248. return self.base_dir / sha256(pickle.dumps(key)).hexdigest()[:32]
  249. except (AttributeError, pickle.PicklingError) as err:
  250. raise CacheError(
  251. f"Failed to get fpath for key {key!r}, key is not pickle-able."
  252. ) from err
  253. # pyrefly: ignore [bad-argument-type]
  254. assert_never(key)
  255. def _flock_from_fpath(self: Self, fpath: Path) -> FileLock:
  256. """
  257. Get a file lock for a given file path.
  258. Args:
  259. fpath (Path): The file path.
  260. Returns:
  261. FileLock: The file lock for the path.
  262. """
  263. # fpath.name is a hex digest, meaning there are 16^4 potential values
  264. # for fpath.name[:4]; this is more than enough unique locks to not
  265. # cause additional overhead from shared locks and it also saves our
  266. # cache dir from becoming 50 percent locks
  267. # pyrefly: ignore [bad-return]
  268. return FileLock(str(fpath.parent / "locks" / fpath.name[:4]) + ".lock")
  269. @property
  270. def version_prefix(self: Self) -> bytes:
  271. """
  272. Get the version prefix for the cache.
  273. Returns:
  274. bytes: The version prefix as bytes, derived from the cache version string.
  275. """
  276. return sha256(str(OnDiskCache.version).encode()).digest()[:4]
  277. @override
  278. def get(self: Self, key: Key) -> Value | None:
  279. """
  280. Retrieve a value from the cache.
  281. Args:
  282. key (Key): The key to look up.
  283. Returns:
  284. Value | None: The cached value if present and version matches, else None.
  285. Raises:
  286. CacheError: If the value is corrupted or cannot be unpickled.
  287. Side Effects:
  288. Removes stale cache files if the version prefix does not match.
  289. """
  290. fpath = self._fpath_from_key(key)
  291. flock = self._flock_from_fpath(fpath)
  292. with flock:
  293. if not fpath.is_file():
  294. return None
  295. value_bytes = None
  296. prefix_length = len(self.version_prefix)
  297. with open(fpath, "rb") as fp:
  298. if fp.read(prefix_length) == self.version_prefix:
  299. value_bytes = fp.read()
  300. if value_bytes is None:
  301. # version_prefix did not match, so we can't read the stale
  302. # cached value; we should also remove the stale cached value,
  303. # so that key can be re-cached by the newer version
  304. fpath.unlink()
  305. return None
  306. try:
  307. value = pickle.loads(value_bytes)
  308. except pickle.UnpicklingError as err:
  309. raise CacheError(
  310. f"Failed to get key {key!r}, value is potentially corrupted (value is not un-pickle-able)."
  311. ) from err
  312. return value
  313. @override
  314. def insert(self: Self, key: Key, value: Value) -> bool:
  315. """
  316. Insert a value into the cache.
  317. Args:
  318. key (Key): The key to insert.
  319. value (Value): The value to associate with the key.
  320. Returns:
  321. bool: True if the value was inserted, False if the key already exists.
  322. Raises:
  323. CacheError: If the value is not pickle-able.
  324. Side Effects:
  325. Creates the cache directory if it does not exist.
  326. """
  327. fpath = self._fpath_from_key(key)
  328. flock = self._flock_from_fpath(fpath)
  329. fpath.parent.mkdir(parents=True, exist_ok=True)
  330. try:
  331. # "x" mode is exclusive creation, meaning the file will be created
  332. # iff the file does not already exist (atomic w/o overwrite); use
  333. # flock for added atomicity guarantee and to prevent partial writes
  334. with flock as _, open(fpath, "xb") as fp:
  335. fp.write(self.version_prefix)
  336. pickle.dump(value, fp)
  337. except pickle.PicklingError as err:
  338. raise CacheError(
  339. f"Failed to insert key {key!r} with value {value!r}, value is not pickle-able."
  340. ) from err
  341. except FileExistsError:
  342. return False
  343. return True
  344. class InductorOnDiskCache(OnDiskCache[Key, Value]):
  345. """
  346. Inductor-specific on-disk cache implementation.
  347. Uses a custom base directory for Inductor cache files.
  348. """
  349. def __init__(self: Self) -> None:
  350. """
  351. Initialize an inductor on-disk cache instance.
  352. Sets the cache directory name to "inductor_on_disk_cache".
  353. """
  354. super().__init__("inductor_on_disk_cache")
  355. @cached_property
  356. def base_dir(self: Self) -> Path:
  357. """
  358. Get the base directory for the Inductor cache.
  359. Returns:
  360. Path: The base directory path for Inductor cache files.
  361. """
  362. from torch._inductor.runtime.runtime_utils import default_cache_dir
  363. return Path(default_cache_dir(), "cache", self.name)