replay_ops.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import random
  2. from typing import Optional
  3. from ray.rllib.utils.annotations import OldAPIStack
  4. from ray.rllib.utils.replay_buffers.replay_buffer import warn_replay_capacity
  5. from ray.rllib.utils.typing import SampleBatchType
  6. @OldAPIStack
  7. class SimpleReplayBuffer:
  8. """Simple replay buffer that operates over batches."""
  9. def __init__(self, num_slots: int, replay_proportion: Optional[float] = None):
  10. """Initialize SimpleReplayBuffer.
  11. Args:
  12. num_slots: Number of batches to store in total.
  13. """
  14. self.num_slots = num_slots
  15. self.replay_batches = []
  16. self.replay_index = 0
  17. def add_batch(self, sample_batch: SampleBatchType) -> None:
  18. warn_replay_capacity(item=sample_batch, num_items=self.num_slots)
  19. if self.num_slots > 0:
  20. if len(self.replay_batches) < self.num_slots:
  21. self.replay_batches.append(sample_batch)
  22. else:
  23. self.replay_batches[self.replay_index] = sample_batch
  24. self.replay_index += 1
  25. self.replay_index %= self.num_slots
  26. def replay(self) -> SampleBatchType:
  27. return random.choice(self.replay_batches)
  28. def __len__(self):
  29. return len(self.replay_batches)