| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321 |
- import json
- import os
- import uuid
- from abc import ABC, abstractmethod
- from typing import Dict, List, Optional
- import base64
- import hashlib
- import functools
- import sysconfig
- from triton import __version__, knobs
- from triton.windows_utils import normalize_path
- class CacheManager(ABC):
- def __init__(self, key, override=False, dump=False):
- pass
- @abstractmethod
- def get_file(self, filename) -> Optional[str]:
- pass
- @abstractmethod
- def put(self, data, filename, binary=True) -> str:
- pass
- @abstractmethod
- def get_group(self, filename: str) -> Optional[Dict[str, str]]:
- pass
- @abstractmethod
- def put_group(self, filename: str, group: Dict[str, str]):
- pass
- class FileCacheManager(CacheManager):
- def __init__(self, key, override=False, dump=False):
- self.key = key
- self.lock_path = None
- if dump:
- self.cache_dir = knobs.cache.dump_dir
- self.cache_dir = os.path.join(self.cache_dir, self.key)
- self.cache_dir = normalize_path(self.cache_dir)
- self.lock_path = os.path.join(self.cache_dir, "lock")
- os.makedirs(self.cache_dir, exist_ok=True)
- elif override:
- self.cache_dir = knobs.cache.override_dir
- self.cache_dir = os.path.join(self.cache_dir, self.key)
- self.cache_dir = normalize_path(self.cache_dir)
- else:
- # create cache directory if it doesn't exist
- self.cache_dir = knobs.cache.dir
- if self.cache_dir:
- self.cache_dir = os.path.join(self.cache_dir, self.key)
- self.cache_dir = normalize_path(self.cache_dir)
- self.lock_path = os.path.join(self.cache_dir, "lock")
- os.makedirs(self.cache_dir, exist_ok=True)
- else:
- raise RuntimeError("Could not create or locate cache dir")
- def _make_path(self, filename) -> str:
- return os.path.join(self.cache_dir, filename)
- def has_file(self, filename) -> bool:
- if not self.cache_dir:
- raise RuntimeError("Could not create or locate cache dir")
- return os.path.exists(self._make_path(filename))
- def get_file(self, filename) -> Optional[str]:
- if self.has_file(filename):
- return self._make_path(filename)
- else:
- return None
- def get_group(self, filename: str) -> Optional[Dict[str, str]]:
- grp_filename = f"__grp__{filename}"
- if not self.has_file(grp_filename):
- return None
- grp_filepath = self._make_path(grp_filename)
- with open(grp_filepath) as f:
- grp_data = json.load(f)
- child_paths = grp_data.get("child_paths", None)
- # Invalid group data.
- if child_paths is None:
- return None
- result = {}
- for c, p in child_paths.items():
- if os.path.exists(p):
- result[c] = p
- return result
- # Note a group of pushed files as being part of a group
- def put_group(self, filename: str, group: Dict[str, str]) -> str:
- if not self.cache_dir:
- raise RuntimeError("Could not create or locate cache dir")
- grp_contents = json.dumps({"child_paths": group})
- grp_filename = f"__grp__{filename}"
- return self.put(grp_contents, grp_filename, binary=False)
- def put(self, data, filename, binary=True) -> str:
- if not self.cache_dir:
- raise RuntimeError("Could not create or locate cache dir")
- binary = isinstance(data, bytes)
- if not binary:
- data = str(data)
- assert self.lock_path is not None
- filepath = self._make_path(filename)
- # Random ID to avoid any collisions
- rnd_id = str(uuid.uuid4())
- # we use the PID in case a bunch of these around so we can see what PID made it
- pid = os.getpid()
- # use temp dir to be robust against program interruptions
- temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
- os.makedirs(temp_dir, exist_ok=True)
- temp_path = os.path.join(temp_dir, filename)
- mode = "wb" if binary else "w"
- with open(temp_path, mode) as f:
- f.write(data)
- # Replace is guaranteed to be atomic on POSIX systems if it succeeds
- # so filepath cannot see a partial write
- try:
- os.replace(temp_path, filepath)
- except PermissionError:
- # Ignore PermissionError on Windows because it happens when another process already
- # put a file into the cache and locked it by opening it.
- if os.name == "nt":
- os.remove(temp_path)
- else:
- raise
- os.removedirs(temp_dir)
- return filepath
- class RemoteCacheBackend:
- """
- A backend implementation for accessing a remote/distributed cache.
- """
- def __init__(self, key: str):
- pass
- @abstractmethod
- def get(self, filenames: List[str]) -> Dict[str, bytes]:
- pass
- @abstractmethod
- def put(self, filename: str, data: bytes):
- pass
- class RedisRemoteCacheBackend(RemoteCacheBackend):
- def __init__(self, key):
- import redis
- self._key = key
- self._key_fmt = knobs.cache.redis.key_format
- self._redis = redis.Redis(
- host=knobs.cache.redis.host,
- port=knobs.cache.redis.port,
- )
- def _get_key(self, filename: str) -> str:
- return self._key_fmt.format(key=self._key, filename=filename)
- def get(self, filenames: List[str]) -> Dict[str, str]:
- results = self._redis.mget([self._get_key(f) for f in filenames])
- return {filename: result for filename, result in zip(filenames, results) if result is not None}
- def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
- self._redis.set(self._get_key(filename), data)
- class RemoteCacheManager(CacheManager):
- def __init__(self, key, override=False, dump=False):
- # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
- remote_cache_cls = knobs.cache.remote_manager_class
- if not remote_cache_cls:
- raise RuntimeError(
- "Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class")
- self._backend = remote_cache_cls(key)
- self._override = override
- self._dump = dump
- # Use a `FileCacheManager` to materialize remote cache paths locally.
- self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
- def _materialize(self, filename: str, data: bytes):
- # We use a backing `FileCacheManager` to provide the materialized data.
- return self._file_cache_manager.put(data, filename, binary=True)
- def get_file(self, filename: str) -> Optional[str]:
- # We don't handle the dump/override cases.
- if self._dump or self._override:
- return self._file_cache_manager.get_file(filename)
- # We always check the remote cache backend -- even if our internal file-
- # based cache has the item -- to make sure LRU accounting works as
- # expected.
- results = self._backend.get([filename])
- if len(results) == 0:
- return None
- (_, data), = results.items()
- return self._materialize(filename, data)
- def put(self, data, filename: str, binary=True) -> str:
- # We don't handle the dump/override cases.
- if self._dump or self._override:
- return self._file_cache_manager.put(data, filename, binary=binary)
- if not isinstance(data, bytes):
- data = str(data).encode("utf-8")
- self._backend.put(filename, data)
- return self._materialize(filename, data)
- def get_group(self, filename: str) -> Optional[Dict[str, str]]:
- # We don't handle the dump/override cases.
- if self._dump or self._override:
- return self._file_cache_manager.get_group(filename)
- grp_filename = f"__grp__{filename}"
- grp_filepath = self.get_file(grp_filename)
- if grp_filepath is None:
- return None
- with open(grp_filepath) as f:
- grp_data = json.load(f)
- child_paths = grp_data.get("child_paths", None)
- result = None
- # Found group data.
- if child_paths is not None:
- result = {}
- for child_path, data in self._backend.get(child_paths).items():
- result[child_path] = self._materialize(child_path, data)
- return result
- def put_group(self, filename: str, group: Dict[str, str]):
- # We don't handle the dump/override cases.
- if self._dump or self._override:
- return self._file_cache_manager.put_group(filename, group)
- grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
- grp_filename = f"__grp__{filename}"
- return self.put(grp_contents, grp_filename)
- def _base32(key):
- # Assume key is a hex string.
- return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
- def get_cache_manager(key) -> CacheManager:
- cls = knobs.cache.manager_class or FileCacheManager
- return cls(_base32(key))
- def get_override_manager(key) -> CacheManager:
- cls = knobs.cache.manager_class or FileCacheManager
- return cls(_base32(key), override=True)
- def get_dump_manager(key) -> CacheManager:
- cls = knobs.cache.manager_class or FileCacheManager
- return cls(_base32(key), dump=True)
- def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
- # Get unique key for the compiled code
- signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
- key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
- for kw in kwargs:
- key = f"{key}-{kwargs.get(kw)}"
- key = hashlib.sha256(key.encode("utf-8")).hexdigest()
- return _base32(key)
- @functools.lru_cache()
- def triton_key():
- import pkgutil
- TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- contents = []
- # frontend
- with open(__file__, "rb") as f:
- contents += [hashlib.sha256(f.read()).hexdigest()]
- # compiler
- path_prefixes = [
- (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
- (os.path.join(TRITON_PATH, "backends"), "triton.backends."),
- ]
- for path, prefix in path_prefixes:
- for lib in pkgutil.walk_packages([path], prefix=prefix):
- with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
- contents += [hashlib.sha256(f.read()).hexdigest()]
- # backend
- libtriton_hash = hashlib.sha256()
- ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
- with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
- while True:
- chunk = f.read(1024**2)
- if not chunk:
- break
- libtriton_hash.update(chunk)
- contents.append(libtriton_hash.hexdigest())
- # language
- language_path = os.path.join(TRITON_PATH, 'language')
- for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
- with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
- contents += [hashlib.sha256(f.read()).hexdigest()]
- return f'{__version__}' + '-'.join(contents)
- def get_cache_key(src, backend, backend_options, env_vars):
- key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
- return key
|