object_cache.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from collections import Counter, defaultdict
  2. from typing import Dict, Generator, List, Optional, TypeVar
  3. # Grouping key - must be hashable
  4. T = TypeVar("T")
  5. # Objects to cache
  6. U = TypeVar("U")
  7. class _ObjectCache:
  8. """Cache up to some maximum count given a grouping key.
  9. This object cache can e.g. be used to cache Ray Tune trainable actors
  10. given their resource requirements (reuse_actors=True).
  11. If the max number of cached objects for a grouping key is reached,
  12. no more objects for this group will be cached.
  13. However, if `may_keep_one=True`, one object (globally across all grouping
  14. keys) may be cached, even if the max number of objects is 0. This is to
  15. allow to cache an object if the max number of objects of this key
  16. will increase shortly after (as is the case e.g. in the Ray Tune control
  17. loop).
  18. Args:
  19. may_keep_one: If True, one object (globally) may be cached if no desired
  20. maximum objects are defined.
  21. """
  22. def __init__(self, may_keep_one: bool = True):
  23. self._num_cached_objects: int = 0
  24. self._cached_objects: Dict[T, List[U]] = defaultdict(list)
  25. self._max_num_objects: Counter[T] = Counter()
  26. self._may_keep_one = may_keep_one
  27. @property
  28. def num_cached_objects(self):
  29. return self._num_cached_objects
  30. @property
  31. def total_max_objects(self):
  32. # Counter.total() is only available for python 3.10+
  33. return sum(self._max_num_objects.values())
  34. def increase_max(self, key: T, by: int = 1) -> None:
  35. """Increase number of max objects for this key.
  36. Args:
  37. key: Group key.
  38. by: Decrease by this amount.
  39. """
  40. self._max_num_objects[key] += by
  41. def decrease_max(self, key: T, by: int = 1) -> None:
  42. """Decrease number of max objects for this key.
  43. Args:
  44. key: Group key.
  45. by: Decrease by this amount.
  46. """
  47. self._max_num_objects[key] -= by
  48. def has_cached_object(self, key: T) -> bool:
  49. """Return True if at least one cached object exists for this key.
  50. Args:
  51. key: Group key.
  52. Returns:
  53. True if at least one cached object exists for this key.
  54. """
  55. return bool(self._cached_objects[key])
  56. def cache_object(self, key: T, obj: U) -> bool:
  57. """Cache object for a given key.
  58. This will put the object into a cache, assuming the number
  59. of cached objects for this key is less than the number of
  60. max objects for this key.
  61. An exception is made if `max_keep_one=True` and no other
  62. objects are cached globally. In that case, the object can
  63. still be cached.
  64. Args:
  65. key: Group key.
  66. obj: Object to cache.
  67. Returns:
  68. True if the object has been cached. False otherwise.
  69. """
  70. # If we have more objects cached already than we desire
  71. if len(self._cached_objects[key]) >= self._max_num_objects[key]:
  72. # If may_keep_one is False, never cache
  73. if not self._may_keep_one:
  74. return False
  75. # If we have more than one other cached object, don't cache
  76. if self._num_cached_objects > 0:
  77. return False
  78. # If any other objects are expected to be cached, don't cache
  79. if any(v for v in self._max_num_objects.values()):
  80. return False
  81. # Otherwise, cache (for now).
  82. self._cached_objects[key].append(obj)
  83. self._num_cached_objects += 1
  84. return True
  85. def pop_cached_object(self, key: T) -> Optional[U]:
  86. """Get one cached object for a key.
  87. This will remove the object from the cache.
  88. Args:
  89. key: Group key.
  90. Returns:
  91. Cached object.
  92. """
  93. if not self.has_cached_object(key):
  94. return None
  95. self._num_cached_objects -= 1
  96. return self._cached_objects[key].pop(0)
  97. def flush_cached_objects(self, force_all: bool = False) -> Generator[U, None, None]:
  98. """Return a generator over cached objects evicted from the cache.
  99. This method yields all cached objects that should be evicted from the
  100. cache for cleanup by the caller.
  101. If the number of max objects is lower than the number of
  102. cached objects for a given key, objects are evicted until
  103. the numbers are equal.
  104. If `max_keep_one=True` (and ``force_all=False``), one cached object
  105. may be retained.
  106. Objects are evicted FIFO.
  107. If ``force_all=True``, all objects are evicted.
  108. Args:
  109. force_all: If True, all objects are flushed. This takes precedence
  110. over ``keep_one``.
  111. Yields:
  112. Evicted objects to be cleaned up by caller.
  113. """
  114. # If force_all=True, don't keep one.
  115. keep_one = self._may_keep_one and not force_all
  116. for key, objs in self._cached_objects.items():
  117. max_cached = self._max_num_objects[key] if not force_all else 0
  118. if (
  119. self._num_cached_objects == 1
  120. and keep_one
  121. # Only keep this object if we don't expect a different one
  122. and not any(v for v in self._max_num_objects.values())
  123. ):
  124. break
  125. while len(objs) > max_cached:
  126. self._num_cached_objects -= 1
  127. yield objs.pop(0)