utils.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. """
  2. [1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
  3. Luo et al. 2020
  4. https://arxiv.org/pdf/1912.00167
  5. """
  6. import threading
  7. import time
  8. from collections import deque
  9. from typing import Any, Optional
  10. import numpy as np
  11. from ray.rllib.models.catalog import ModelCatalog
  12. from ray.rllib.models.modelv2 import ModelV2
  13. from ray.rllib.utils.annotations import OldAPIStack
  14. from ray.rllib.utils.metrics.ray_metrics import (
  15. DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  16. TimerAndPrometheusLogger,
  17. )
  18. from ray.util.metrics import Counter, Histogram
  19. POLICY_SCOPE = "func"
  20. TARGET_POLICY_SCOPE = "target_func"
  21. class CircularBuffer:
  22. """A circular batch-wise buffer with Queue-like interface.
  23. The buffer holds at most N batches, which are sampled at random (uniformly).
  24. If full and a new batch is added, the oldest batch is discarded. Each batch
  25. can be sampled at most K times (after which it is also discarded).
  26. This version implements Queue-like put/get methods with blocking support.
  27. """
  28. def __init__(self, num_batches: int, iterations_per_batch: int):
  29. """
  30. Args:
  31. num_batches: N from the paper (queue buffer size).
  32. iterations_per_batch: K ("replay coefficient") from the paper. Defines
  33. how often a single batch can sampled before being discarded. If a
  34. new batch is added when the buffer is full, the oldest batch is
  35. discarded entirely (regardless of how often it has been sampled).
  36. """
  37. self.num_batches = num_batches
  38. self.iterations_per_batch = iterations_per_batch
  39. self._NxK = self.num_batches * self.iterations_per_batch
  40. self._num_added = 0
  41. self._buffer = deque([None for _ in range(self._NxK)], maxlen=self._NxK)
  42. self._indices = set()
  43. self._offset = self._NxK
  44. self._lock = threading.Lock()
  45. # Semaphore tracks the number of *available* samples.
  46. self._items_available = threading.Semaphore(0)
  47. self._rng = np.random.default_rng()
  48. # Statistics
  49. self._total_puts = 0
  50. self._total_gets = 0
  51. self._total_dropped = 0
  52. # Ray metrics
  53. self._metrics_circular_buffer_put_time = Histogram(
  54. name="rllib_utils_circular_buffer_put_time",
  55. description="Time spent in CircularBuffer.put()",
  56. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  57. tag_keys=("rllib",),
  58. )
  59. self._metrics_circular_buffer_put_time.set_default_tags(
  60. {"rllib": self.__class__.__name__}
  61. )
  62. self._metrics_circular_buffer_put_ts_dropped = Counter(
  63. name="rllib_utils_circular_buffer_put_ts_dropped_counter",
  64. description="Total number of env steps dropped by the CircularBuffer.",
  65. tag_keys=("rllib",),
  66. )
  67. self._metrics_circular_buffer_put_ts_dropped.set_default_tags(
  68. {"rllib": self.__class__.__name__}
  69. )
  70. self._metrics_circular_buffer_get_time = Histogram(
  71. name="rllib_utils_circular_buffer_get_time",
  72. description="Time spent in CircularBuffer.get()",
  73. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  74. tag_keys=("rllib",),
  75. )
  76. self._metrics_circular_buffer_get_time.set_default_tags(
  77. {"rllib": self.__class__.__name__}
  78. )
  79. def put(
  80. self, item: Any, block: bool = True, timeout: Optional[float] = None
  81. ) -> int:
  82. """Add a new batch to the buffer.
  83. The batch is added K times (iterations_per_batch) to allow for K samples.
  84. If full, the oldest batch entries are dropped.
  85. Args:
  86. item: The batch to add
  87. block: Not used (always non-blocking for puts)
  88. timeout: Not used
  89. Returns:
  90. Number of dropped entries (0 or iterations_per_batch)
  91. """
  92. with TimerAndPrometheusLogger(self._metrics_circular_buffer_put_time):
  93. with self._lock:
  94. self._total_puts += 1
  95. # Check if we'll drop old entries
  96. dropped_entry = self._buffer[0]
  97. # Add buffer K times with new indices
  98. for _ in range(self.iterations_per_batch):
  99. self._buffer.append(item)
  100. self._indices.add(self._offset)
  101. self._indices.discard(self._offset - self._NxK)
  102. self._offset += 1
  103. # Release semaphore for each available sample
  104. self._items_available.release()
  105. self._num_added += 1
  106. # A valid entry (w/ a batch whose k has not been reach K yet) was dropped.
  107. dropped_ts = 0
  108. if dropped_entry is not None:
  109. dropped_ts = (
  110. dropped_entry[0].env_steps()
  111. if isinstance(dropped_entry, tuple)
  112. else dropped_entry.env_steps()
  113. )
  114. if dropped_ts > 0:
  115. self._metrics_circular_buffer_put_ts_dropped.inc(
  116. value=dropped_ts
  117. )
  118. return dropped_ts
  119. def put_nowait(self, item: Any) -> int:
  120. """Equivalent to self.put(block=False)."""
  121. return self.put(item, block=False)
  122. def get(self, block: bool = True, timeout: Optional[float] = None) -> Any:
  123. """Sample a random batch from the buffer.
  124. The sampled entry is removed and won't be sampled again.
  125. Blocks if the buffer is empty (when block=True).
  126. Args:
  127. block: If True, block until an item is available
  128. timeout: Maximum time to wait (only used when block=True)
  129. Returns:
  130. A randomly sampled batch
  131. Raises:
  132. TimeoutError: If timeout expires while blocking
  133. IndexError: If buffer is empty and block=False
  134. """
  135. # Only initially, the buffer may be empty -> Just wait for some time.
  136. with TimerAndPrometheusLogger(self._metrics_circular_buffer_get_time):
  137. while len(self) == 0:
  138. time.sleep(0.0001)
  139. # Sample a random buffer index.
  140. with self._lock:
  141. idx = self._rng.choice(list(self._indices))
  142. actual_buffer_idx = idx - self._offset + self._NxK
  143. batch = self._buffer[actual_buffer_idx]
  144. assert batch is not None, (
  145. idx,
  146. actual_buffer_idx,
  147. self._offset,
  148. self._indices,
  149. [b is None for b in self._buffer],
  150. )
  151. self._buffer[actual_buffer_idx] = None
  152. self._indices.discard(idx)
  153. # Return the sampled batch.
  154. return batch
  155. def get_nowait(self) -> Any:
  156. """Equivalent to self.get(block=False)."""
  157. return self.get(block=False)
  158. @property
  159. def filled(self) -> bool:
  160. """Whether the buffer has been filled once with at least `self.num_batches`."""
  161. with self._lock:
  162. return self._num_added >= self.num_batches
  163. def qsize(self) -> int:
  164. """Returns the number of actually valid (non-expired) batches in the buffer."""
  165. with self._lock:
  166. return len(self._indices)
  167. def __len__(self) -> int:
  168. return self.qsize()
  169. def task_done(self):
  170. """No-op for Queue compatibility."""
  171. pass
  172. def get_stats(self) -> dict:
  173. """Get buffer statistics for monitoring."""
  174. with self._lock:
  175. return {
  176. "size": len(self._indices),
  177. "capacity": self._NxK,
  178. "num_batches": self.num_batches,
  179. "iterations_per_batch": self.iterations_per_batch,
  180. "total_puts": self._total_puts,
  181. "total_gets": self._total_gets,
  182. "total_dropped": self._total_dropped,
  183. "filled": self._num_added >= self.num_batches,
  184. }
  185. @OldAPIStack
  186. def make_appo_models(policy) -> ModelV2:
  187. """Builds model and target model for APPO.
  188. Returns:
  189. ModelV2: The Model for the Policy to use.
  190. Note: The target model will not be returned, just assigned to
  191. `policy.target_model`.
  192. """
  193. # Get the num_outputs for the following model construction calls.
  194. _, logit_dim = ModelCatalog.get_action_dist(
  195. policy.action_space, policy.config["model"]
  196. )
  197. # Construct the (main) model.
  198. policy.model = ModelCatalog.get_model_v2(
  199. policy.observation_space,
  200. policy.action_space,
  201. logit_dim,
  202. policy.config["model"],
  203. name=POLICY_SCOPE,
  204. framework=policy.framework,
  205. )
  206. policy.model_variables = policy.model.variables()
  207. # Construct the target model.
  208. policy.target_model = ModelCatalog.get_model_v2(
  209. policy.observation_space,
  210. policy.action_space,
  211. logit_dim,
  212. policy.config["model"],
  213. name=TARGET_POLICY_SCOPE,
  214. framework=policy.framework,
  215. )
  216. policy.target_model_variables = policy.target_model.variables()
  217. # Return only the model (not the target model).
  218. return policy.model