policy_map.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import logging
  2. import threading
  3. from collections import deque
  4. from typing import Dict, Set
  5. import ray
  6. from ray._common.deprecation import deprecation_warning
  7. from ray.rllib.policy.policy import Policy
  8. from ray.rllib.utils.annotations import OldAPIStack, override
  9. from ray.rllib.utils.framework import try_import_tf
  10. from ray.rllib.utils.threading import with_lock
  11. from ray.rllib.utils.typing import PolicyID
  12. tf1, tf, tfv = try_import_tf()
  13. logger = logging.getLogger(__name__)
  14. @OldAPIStack
  15. class PolicyMap(dict):
  16. """Maps policy IDs to Policy objects.
  17. Thereby, keeps n policies in memory and - when capacity is reached -
  18. writes the least recently used to disk. This allows adding 100s of
  19. policies to a Algorithm for league-based setups w/o running out of memory.
  20. """
  21. def __init__(
  22. self,
  23. *,
  24. capacity: int = 100,
  25. policy_states_are_swappable: bool = False,
  26. # Deprecated args.
  27. worker_index=None,
  28. num_workers=None,
  29. policy_config=None,
  30. session_creator=None,
  31. seed=None,
  32. ):
  33. """Initializes a PolicyMap instance.
  34. Args:
  35. capacity: The size of the Policy object cache. This is the maximum number
  36. of policies that are held in RAM memory. When reaching this capacity,
  37. the least recently used Policy's state will be stored in the Ray object
  38. store and recovered from there when being accessed again.
  39. policy_states_are_swappable: Whether all Policy objects in this map can be
  40. "swapped out" via a simple `state = A.get_state(); B.set_state(state)`,
  41. where `A` and `B` are policy instances in this map. You should set
  42. this to True for significantly speeding up the PolicyMap's cache lookup
  43. times, iff your policies all share the same neural network
  44. architecture and optimizer types. If True, the PolicyMap will not
  45. have to garbage collect old, least recently used policies, but instead
  46. keep them in memory and simply override their state with the state of
  47. the most recently accessed one.
  48. For example, in a league-based training setup, you might have 100s of
  49. the same policies in your map (playing against each other in various
  50. combinations), but all of them share the same state structure
  51. (are "swappable").
  52. """
  53. if policy_config is not None:
  54. deprecation_warning(
  55. old="PolicyMap(policy_config=..)",
  56. error=True,
  57. )
  58. super().__init__()
  59. self.capacity = capacity
  60. if any(
  61. i is not None
  62. for i in [policy_config, worker_index, num_workers, session_creator, seed]
  63. ):
  64. deprecation_warning(
  65. old="PolicyMap([deprecated args]...)",
  66. new="PolicyMap(capacity=..., policy_states_are_swappable=...)",
  67. error=False,
  68. )
  69. self.policy_states_are_swappable = policy_states_are_swappable
  70. # The actual cache with the in-memory policy objects.
  71. self.cache: Dict[str, Policy] = {}
  72. # Set of keys that may be looked up (cached or not).
  73. self._valid_keys: Set[str] = set()
  74. # The doubly-linked list holding the currently in-memory objects.
  75. self._deque = deque()
  76. # Ray object store references to the stashed Policy states.
  77. self._policy_state_refs = {}
  78. # Lock used for locking some methods on the object-level.
  79. # This prevents possible race conditions when accessing the map
  80. # and the underlying structures, like self._deque and others.
  81. self._lock = threading.RLock()
  82. @with_lock
  83. @override(dict)
  84. def __getitem__(self, item: PolicyID):
  85. # Never seen this key -> Error.
  86. if item not in self._valid_keys:
  87. raise KeyError(
  88. f"PolicyID '{item}' not found in this PolicyMap! "
  89. f"IDs stored in this map: {self._valid_keys}."
  90. )
  91. # Item already in cache -> Rearrange deque (promote `item` to
  92. # "most recently used") and return it.
  93. if item in self.cache:
  94. self._deque.remove(item)
  95. self._deque.append(item)
  96. return self.cache[item]
  97. # Item not currently in cache -> Get from stash and - if at capacity -
  98. # remove leftmost one.
  99. if item not in self._policy_state_refs:
  100. raise AssertionError(
  101. f"PolicyID {item} not found in internal Ray object store cache!"
  102. )
  103. policy_state = ray.get(self._policy_state_refs[item])
  104. policy = None
  105. # We are at capacity: Remove the oldest policy from deque as well as the
  106. # cache and return it.
  107. if len(self._deque) == self.capacity:
  108. policy = self._stash_least_used_policy()
  109. # All our policies have same NN-architecture (are "swappable").
  110. # -> Load new policy's state into the one that just got removed from the cache.
  111. # This way, we save the costly re-creation step.
  112. if policy is not None and self.policy_states_are_swappable:
  113. logger.debug(f"restoring policy: {item}")
  114. policy.set_state(policy_state)
  115. else:
  116. logger.debug(f"creating new policy: {item}")
  117. policy = Policy.from_state(policy_state)
  118. self.cache[item] = policy
  119. # Promote the item to most recently one.
  120. self._deque.append(item)
  121. return policy
  122. @with_lock
  123. @override(dict)
  124. def __setitem__(self, key: PolicyID, value: Policy):
  125. # Item already in cache -> Rearrange deque.
  126. if key in self.cache:
  127. self._deque.remove(key)
  128. # Item not currently in cache -> store new value and - if at capacity -
  129. # remove leftmost one.
  130. else:
  131. # Cache at capacity -> Drop leftmost item.
  132. if len(self._deque) == self.capacity:
  133. self._stash_least_used_policy()
  134. # Promote `key` to "most recently used".
  135. self._deque.append(key)
  136. # Update our cache.
  137. self.cache[key] = value
  138. self._valid_keys.add(key)
  139. @with_lock
  140. @override(dict)
  141. def __delitem__(self, key: PolicyID):
  142. # Make key invalid.
  143. self._valid_keys.remove(key)
  144. # Remove policy from deque if contained
  145. if key in self._deque:
  146. self._deque.remove(key)
  147. # Remove policy from memory if currently cached.
  148. if key in self.cache:
  149. policy = self.cache[key]
  150. self._close_session(policy)
  151. del self.cache[key]
  152. # Remove Ray object store reference (if this ID has already been stored
  153. # there), so the item gets garbage collected.
  154. if key in self._policy_state_refs:
  155. del self._policy_state_refs[key]
  156. @override(dict)
  157. def __iter__(self):
  158. return iter(self.keys())
  159. @override(dict)
  160. def items(self):
  161. """Iterates over all policies, even the stashed ones."""
  162. def gen():
  163. for key in self._valid_keys:
  164. yield (key, self[key])
  165. return gen()
  166. @override(dict)
  167. def keys(self):
  168. """Returns all valid keys, even the stashed ones."""
  169. self._lock.acquire()
  170. ks = list(self._valid_keys)
  171. self._lock.release()
  172. def gen():
  173. for key in ks:
  174. yield key
  175. return gen()
  176. @override(dict)
  177. def values(self):
  178. """Returns all valid values, even the stashed ones."""
  179. self._lock.acquire()
  180. vs = [self[k] for k in self._valid_keys]
  181. self._lock.release()
  182. def gen():
  183. for value in vs:
  184. yield value
  185. return gen()
  186. @with_lock
  187. @override(dict)
  188. def update(self, __m, **kwargs):
  189. """Updates the map with the given dict and/or kwargs."""
  190. for k, v in __m.items():
  191. self[k] = v
  192. for k, v in kwargs.items():
  193. self[k] = v
  194. @with_lock
  195. @override(dict)
  196. def get(self, key: PolicyID):
  197. """Returns the value for the given key or None if not found."""
  198. if key not in self._valid_keys:
  199. return None
  200. return self[key]
  201. @with_lock
  202. @override(dict)
  203. def __len__(self) -> int:
  204. """Returns number of all policies, including the stashed-to-disk ones."""
  205. return len(self._valid_keys)
  206. @with_lock
  207. @override(dict)
  208. def __contains__(self, item: PolicyID):
  209. return item in self._valid_keys
  210. @override(dict)
  211. def __str__(self) -> str:
  212. # Only print out our keys (policy IDs), not values as this could trigger
  213. # the LRU caching.
  214. return (
  215. f"<PolicyMap lru-caching-capacity={self.capacity} policy-IDs="
  216. f"{list(self.keys())}>"
  217. )
  218. def _stash_least_used_policy(self) -> Policy:
  219. """Writes the least-recently used policy's state to the Ray object store.
  220. Also closes the session - if applicable - of the stashed policy.
  221. Returns:
  222. The least-recently used policy, that just got removed from the cache.
  223. """
  224. # Get policy's state for writing to object store.
  225. dropped_policy_id = self._deque.popleft()
  226. assert dropped_policy_id in self.cache
  227. policy = self.cache[dropped_policy_id]
  228. policy_state = policy.get_state()
  229. # If we don't simply swap out vs an existing policy:
  230. # Close the tf session, if any.
  231. if not self.policy_states_are_swappable:
  232. self._close_session(policy)
  233. # Remove from memory. This will clear the tf Graph as well.
  234. del self.cache[dropped_policy_id]
  235. # Store state in Ray object store.
  236. self._policy_state_refs[dropped_policy_id] = ray.put(policy_state)
  237. # Return the just removed policy, in case it's needed by the caller.
  238. return policy
  239. @staticmethod
  240. def _close_session(policy: Policy):
  241. sess = policy.get_session()
  242. # Closes the tf session, if any.
  243. if sess is not None:
  244. sess.close()