cache.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import json
  2. import os
  3. import uuid
  4. from abc import ABC, abstractmethod
  5. from typing import Dict, List, Optional
  6. import base64
  7. import hashlib
  8. import functools
  9. import sysconfig
  10. from triton import __version__, knobs
  11. from triton.windows_utils import normalize_path
  12. class CacheManager(ABC):
  13. def __init__(self, key, override=False, dump=False):
  14. pass
  15. @abstractmethod
  16. def get_file(self, filename) -> Optional[str]:
  17. pass
  18. @abstractmethod
  19. def put(self, data, filename, binary=True) -> str:
  20. pass
  21. @abstractmethod
  22. def get_group(self, filename: str) -> Optional[Dict[str, str]]:
  23. pass
  24. @abstractmethod
  25. def put_group(self, filename: str, group: Dict[str, str]):
  26. pass
  27. class FileCacheManager(CacheManager):
  28. def __init__(self, key, override=False, dump=False):
  29. self.key = key
  30. self.lock_path = None
  31. if dump:
  32. self.cache_dir = knobs.cache.dump_dir
  33. self.cache_dir = os.path.join(self.cache_dir, self.key)
  34. self.cache_dir = normalize_path(self.cache_dir)
  35. self.lock_path = os.path.join(self.cache_dir, "lock")
  36. os.makedirs(self.cache_dir, exist_ok=True)
  37. elif override:
  38. self.cache_dir = knobs.cache.override_dir
  39. self.cache_dir = os.path.join(self.cache_dir, self.key)
  40. self.cache_dir = normalize_path(self.cache_dir)
  41. else:
  42. # create cache directory if it doesn't exist
  43. self.cache_dir = knobs.cache.dir
  44. if self.cache_dir:
  45. self.cache_dir = os.path.join(self.cache_dir, self.key)
  46. self.cache_dir = normalize_path(self.cache_dir)
  47. self.lock_path = os.path.join(self.cache_dir, "lock")
  48. os.makedirs(self.cache_dir, exist_ok=True)
  49. else:
  50. raise RuntimeError("Could not create or locate cache dir")
  51. def _make_path(self, filename) -> str:
  52. return os.path.join(self.cache_dir, filename)
  53. def has_file(self, filename) -> bool:
  54. if not self.cache_dir:
  55. raise RuntimeError("Could not create or locate cache dir")
  56. return os.path.exists(self._make_path(filename))
  57. def get_file(self, filename) -> Optional[str]:
  58. if self.has_file(filename):
  59. return self._make_path(filename)
  60. else:
  61. return None
  62. def get_group(self, filename: str) -> Optional[Dict[str, str]]:
  63. grp_filename = f"__grp__{filename}"
  64. if not self.has_file(grp_filename):
  65. return None
  66. grp_filepath = self._make_path(grp_filename)
  67. with open(grp_filepath) as f:
  68. grp_data = json.load(f)
  69. child_paths = grp_data.get("child_paths", None)
  70. # Invalid group data.
  71. if child_paths is None:
  72. return None
  73. result = {}
  74. for c, p in child_paths.items():
  75. if os.path.exists(p):
  76. result[c] = p
  77. return result
  78. # Note a group of pushed files as being part of a group
  79. def put_group(self, filename: str, group: Dict[str, str]) -> str:
  80. if not self.cache_dir:
  81. raise RuntimeError("Could not create or locate cache dir")
  82. grp_contents = json.dumps({"child_paths": group})
  83. grp_filename = f"__grp__{filename}"
  84. return self.put(grp_contents, grp_filename, binary=False)
  85. def put(self, data, filename, binary=True) -> str:
  86. if not self.cache_dir:
  87. raise RuntimeError("Could not create or locate cache dir")
  88. binary = isinstance(data, bytes)
  89. if not binary:
  90. data = str(data)
  91. assert self.lock_path is not None
  92. filepath = self._make_path(filename)
  93. # Random ID to avoid any collisions
  94. rnd_id = str(uuid.uuid4())
  95. # we use the PID in case a bunch of these around so we can see what PID made it
  96. pid = os.getpid()
  97. # use temp dir to be robust against program interruptions
  98. temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
  99. os.makedirs(temp_dir, exist_ok=True)
  100. temp_path = os.path.join(temp_dir, filename)
  101. mode = "wb" if binary else "w"
  102. with open(temp_path, mode) as f:
  103. f.write(data)
  104. # Replace is guaranteed to be atomic on POSIX systems if it succeeds
  105. # so filepath cannot see a partial write
  106. try:
  107. os.replace(temp_path, filepath)
  108. except PermissionError:
  109. # Ignore PermissionError on Windows because it happens when another process already
  110. # put a file into the cache and locked it by opening it.
  111. if os.name == "nt":
  112. os.remove(temp_path)
  113. else:
  114. raise
  115. os.removedirs(temp_dir)
  116. return filepath
  117. class RemoteCacheBackend:
  118. """
  119. A backend implementation for accessing a remote/distributed cache.
  120. """
  121. def __init__(self, key: str):
  122. pass
  123. @abstractmethod
  124. def get(self, filenames: List[str]) -> Dict[str, bytes]:
  125. pass
  126. @abstractmethod
  127. def put(self, filename: str, data: bytes):
  128. pass
  129. class RedisRemoteCacheBackend(RemoteCacheBackend):
  130. def __init__(self, key):
  131. import redis
  132. self._key = key
  133. self._key_fmt = knobs.cache.redis.key_format
  134. self._redis = redis.Redis(
  135. host=knobs.cache.redis.host,
  136. port=knobs.cache.redis.port,
  137. )
  138. def _get_key(self, filename: str) -> str:
  139. return self._key_fmt.format(key=self._key, filename=filename)
  140. def get(self, filenames: List[str]) -> Dict[str, str]:
  141. results = self._redis.mget([self._get_key(f) for f in filenames])
  142. return {filename: result for filename, result in zip(filenames, results) if result is not None}
  143. def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
  144. self._redis.set(self._get_key(filename), data)
  145. class RemoteCacheManager(CacheManager):
  146. def __init__(self, key, override=False, dump=False):
  147. # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
  148. remote_cache_cls = knobs.cache.remote_manager_class
  149. if not remote_cache_cls:
  150. raise RuntimeError(
  151. "Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class")
  152. self._backend = remote_cache_cls(key)
  153. self._override = override
  154. self._dump = dump
  155. # Use a `FileCacheManager` to materialize remote cache paths locally.
  156. self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
  157. def _materialize(self, filename: str, data: bytes):
  158. # We use a backing `FileCacheManager` to provide the materialized data.
  159. return self._file_cache_manager.put(data, filename, binary=True)
  160. def get_file(self, filename: str) -> Optional[str]:
  161. # We don't handle the dump/override cases.
  162. if self._dump or self._override:
  163. return self._file_cache_manager.get_file(filename)
  164. # We always check the remote cache backend -- even if our internal file-
  165. # based cache has the item -- to make sure LRU accounting works as
  166. # expected.
  167. results = self._backend.get([filename])
  168. if len(results) == 0:
  169. return None
  170. (_, data), = results.items()
  171. return self._materialize(filename, data)
  172. def put(self, data, filename: str, binary=True) -> str:
  173. # We don't handle the dump/override cases.
  174. if self._dump or self._override:
  175. return self._file_cache_manager.put(data, filename, binary=binary)
  176. if not isinstance(data, bytes):
  177. data = str(data).encode("utf-8")
  178. self._backend.put(filename, data)
  179. return self._materialize(filename, data)
  180. def get_group(self, filename: str) -> Optional[Dict[str, str]]:
  181. # We don't handle the dump/override cases.
  182. if self._dump or self._override:
  183. return self._file_cache_manager.get_group(filename)
  184. grp_filename = f"__grp__{filename}"
  185. grp_filepath = self.get_file(grp_filename)
  186. if grp_filepath is None:
  187. return None
  188. with open(grp_filepath) as f:
  189. grp_data = json.load(f)
  190. child_paths = grp_data.get("child_paths", None)
  191. result = None
  192. # Found group data.
  193. if child_paths is not None:
  194. result = {}
  195. for child_path, data in self._backend.get(child_paths).items():
  196. result[child_path] = self._materialize(child_path, data)
  197. return result
  198. def put_group(self, filename: str, group: Dict[str, str]):
  199. # We don't handle the dump/override cases.
  200. if self._dump or self._override:
  201. return self._file_cache_manager.put_group(filename, group)
  202. grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
  203. grp_filename = f"__grp__{filename}"
  204. return self.put(grp_contents, grp_filename)
  205. def _base32(key):
  206. # Assume key is a hex string.
  207. return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
  208. def get_cache_manager(key) -> CacheManager:
  209. cls = knobs.cache.manager_class or FileCacheManager
  210. return cls(_base32(key))
  211. def get_override_manager(key) -> CacheManager:
  212. cls = knobs.cache.manager_class or FileCacheManager
  213. return cls(_base32(key), override=True)
  214. def get_dump_manager(key) -> CacheManager:
  215. cls = knobs.cache.manager_class or FileCacheManager
  216. return cls(_base32(key), dump=True)
  217. def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
  218. # Get unique key for the compiled code
  219. signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
  220. key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
  221. for kw in kwargs:
  222. key = f"{key}-{kwargs.get(kw)}"
  223. key = hashlib.sha256(key.encode("utf-8")).hexdigest()
  224. return _base32(key)
  225. @functools.lru_cache()
  226. def triton_key():
  227. import pkgutil
  228. TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  229. contents = []
  230. # frontend
  231. with open(__file__, "rb") as f:
  232. contents += [hashlib.sha256(f.read()).hexdigest()]
  233. # compiler
  234. path_prefixes = [
  235. (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
  236. (os.path.join(TRITON_PATH, "backends"), "triton.backends."),
  237. ]
  238. for path, prefix in path_prefixes:
  239. for lib in pkgutil.walk_packages([path], prefix=prefix):
  240. with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
  241. contents += [hashlib.sha256(f.read()).hexdigest()]
  242. # backend
  243. libtriton_hash = hashlib.sha256()
  244. ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
  245. with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
  246. while True:
  247. chunk = f.read(1024**2)
  248. if not chunk:
  249. break
  250. libtriton_hash.update(chunk)
  251. contents.append(libtriton_hash.hexdigest())
  252. # language
  253. language_path = os.path.join(TRITON_PATH, 'language')
  254. for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
  255. with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
  256. contents += [hashlib.sha256(f.read()).hexdigest()]
  257. return f'{__version__}' + '-'.join(contents)
  258. def get_cache_key(src, backend, backend_options, env_vars):
  259. key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
  260. return key