minibatch_buffer.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import queue
  2. from typing import Any, Tuple
  3. from ray.rllib.utils.annotations import OldAPIStack
  4. @OldAPIStack
  5. class MinibatchBuffer:
  6. """Ring buffer of recent data batches for minibatch SGD.
  7. This is for use with AsyncSamplesOptimizer.
  8. """
  9. def __init__(
  10. self,
  11. inqueue: queue.Queue,
  12. size: int,
  13. timeout: float,
  14. num_passes: int,
  15. init_num_passes: int = 1,
  16. ):
  17. """Initialize a minibatch buffer.
  18. Args:
  19. inqueue (queue.Queue): Queue to populate the internal ring buffer
  20. from.
  21. size: Max number of data items to buffer.
  22. timeout: Queue timeout
  23. num_passes: Max num times each data item should be emitted.
  24. init_num_passes: Initial passes for each data item.
  25. Maxiumum number of passes per item are increased to num_passes over
  26. time.
  27. """
  28. self.inqueue = inqueue
  29. self.size = size
  30. self.timeout = timeout
  31. self.max_initial_ttl = num_passes
  32. self.cur_initial_ttl = init_num_passes
  33. self.buffers = [None] * size
  34. self.ttl = [0] * size
  35. self.idx = 0
  36. def get(self) -> Tuple[Any, bool]:
  37. """Get a new batch from the internal ring buffer.
  38. Returns:
  39. buf: Data item saved from inqueue.
  40. released: True if the item is now removed from the ring buffer.
  41. """
  42. if self.ttl[self.idx] <= 0:
  43. self.buffers[self.idx] = self.inqueue.get(timeout=self.timeout)
  44. self.ttl[self.idx] = self.cur_initial_ttl
  45. if self.cur_initial_ttl < self.max_initial_ttl:
  46. self.cur_initial_ttl += 1
  47. buf = self.buffers[self.idx]
  48. self.ttl[self.idx] -= 1
  49. released = self.ttl[self.idx] <= 0
  50. if released:
  51. self.buffers[self.idx] = None
  52. self.idx = (self.idx + 1) % len(self.buffers)
  53. return buf, released