| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- from collections import Counter, defaultdict
- from typing import Dict, Generator, List, Optional, TypeVar
- # Grouping key - must be hashable
- T = TypeVar("T")
- # Objects to cache
- U = TypeVar("U")
- class _ObjectCache:
- """Cache up to some maximum count given a grouping key.
- This object cache can e.g. be used to cache Ray Tune trainable actors
- given their resource requirements (reuse_actors=True).
- If the max number of cached objects for a grouping key is reached,
- no more objects for this group will be cached.
- However, if `may_keep_one=True`, one object (globally across all grouping
- keys) may be cached, even if the max number of objects is 0. This is to
- allow to cache an object if the max number of objects of this key
- will increase shortly after (as is the case e.g. in the Ray Tune control
- loop).
- Args:
- may_keep_one: If True, one object (globally) may be cached if no desired
- maximum objects are defined.
- """
- def __init__(self, may_keep_one: bool = True):
- self._num_cached_objects: int = 0
- self._cached_objects: Dict[T, List[U]] = defaultdict(list)
- self._max_num_objects: Counter[T] = Counter()
- self._may_keep_one = may_keep_one
- @property
- def num_cached_objects(self):
- return self._num_cached_objects
- @property
- def total_max_objects(self):
- # Counter.total() is only available for python 3.10+
- return sum(self._max_num_objects.values())
- def increase_max(self, key: T, by: int = 1) -> None:
- """Increase number of max objects for this key.
- Args:
- key: Group key.
- by: Decrease by this amount.
- """
- self._max_num_objects[key] += by
- def decrease_max(self, key: T, by: int = 1) -> None:
- """Decrease number of max objects for this key.
- Args:
- key: Group key.
- by: Decrease by this amount.
- """
- self._max_num_objects[key] -= by
- def has_cached_object(self, key: T) -> bool:
- """Return True if at least one cached object exists for this key.
- Args:
- key: Group key.
- Returns:
- True if at least one cached object exists for this key.
- """
- return bool(self._cached_objects[key])
- def cache_object(self, key: T, obj: U) -> bool:
- """Cache object for a given key.
- This will put the object into a cache, assuming the number
- of cached objects for this key is less than the number of
- max objects for this key.
- An exception is made if `max_keep_one=True` and no other
- objects are cached globally. In that case, the object can
- still be cached.
- Args:
- key: Group key.
- obj: Object to cache.
- Returns:
- True if the object has been cached. False otherwise.
- """
- # If we have more objects cached already than we desire
- if len(self._cached_objects[key]) >= self._max_num_objects[key]:
- # If may_keep_one is False, never cache
- if not self._may_keep_one:
- return False
- # If we have more than one other cached object, don't cache
- if self._num_cached_objects > 0:
- return False
- # If any other objects are expected to be cached, don't cache
- if any(v for v in self._max_num_objects.values()):
- return False
- # Otherwise, cache (for now).
- self._cached_objects[key].append(obj)
- self._num_cached_objects += 1
- return True
- def pop_cached_object(self, key: T) -> Optional[U]:
- """Get one cached object for a key.
- This will remove the object from the cache.
- Args:
- key: Group key.
- Returns:
- Cached object.
- """
- if not self.has_cached_object(key):
- return None
- self._num_cached_objects -= 1
- return self._cached_objects[key].pop(0)
- def flush_cached_objects(self, force_all: bool = False) -> Generator[U, None, None]:
- """Return a generator over cached objects evicted from the cache.
- This method yields all cached objects that should be evicted from the
- cache for cleanup by the caller.
- If the number of max objects is lower than the number of
- cached objects for a given key, objects are evicted until
- the numbers are equal.
- If `max_keep_one=True` (and ``force_all=False``), one cached object
- may be retained.
- Objects are evicted FIFO.
- If ``force_all=True``, all objects are evicted.
- Args:
- force_all: If True, all objects are flushed. This takes precedence
- over ``keep_one``.
- Yields:
- Evicted objects to be cleaned up by caller.
- """
- # If force_all=True, don't keep one.
- keep_one = self._may_keep_one and not force_all
- for key, objs in self._cached_objects.items():
- max_cached = self._max_num_objects[key] if not force_all else 0
- if (
- self._num_cached_objects == 1
- and keep_one
- # Only keep this object if we don't expect a different one
- and not any(v for v in self._max_num_objects.values())
- ):
- break
- while len(objs) > max_cached:
- self._num_cached_objects -= 1
- yield objs.pop(0)
|