base.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import platform
  2. from abc import ABCMeta, abstractmethod
  3. from typing import Any, Dict, Optional
  4. from ray.util.annotations import DeveloperAPI
  5. @DeveloperAPI
  6. class ReplayBufferInterface(metaclass=ABCMeta):
  7. """Abstract base class for all of RLlib's replay buffers.
  8. Mainly defines the `add()` and `sample()` methods that every buffer class
  9. must implement to be usable by an Algorithm.
  10. Buffers may determine on all the implementation details themselves, e.g.
  11. whether to store single timesteps, episodes, or episode fragments or whether
  12. to return fixed batch sizes or per-call defined ones.
  13. """
  14. @abstractmethod
  15. @DeveloperAPI
  16. def __len__(self) -> int:
  17. """Returns the number of items currently stored in this buffer."""
  18. @abstractmethod
  19. @DeveloperAPI
  20. def add(self, batch: Any, **kwargs) -> None:
  21. """Adds a batch of experiences or other data to this buffer.
  22. Args:
  23. batch: Batch or data to add.
  24. ``**kwargs``: Forward compatibility kwargs.
  25. """
  26. @abstractmethod
  27. @DeveloperAPI
  28. def sample(self, num_items: Optional[int] = None, **kwargs) -> Any:
  29. """Samples `num_items` items from this buffer.
  30. The exact shape of the returned data depends on the buffer's implementation.
  31. Args:
  32. num_items: Number of items to sample from this buffer.
  33. ``**kwargs``: Forward compatibility kwargs.
  34. Returns:
  35. A batch of items.
  36. """
  37. @abstractmethod
  38. @DeveloperAPI
  39. def get_state(self) -> Dict[str, Any]:
  40. """Returns all local state in a dict.
  41. Returns:
  42. The serializable local state.
  43. """
  44. @abstractmethod
  45. @DeveloperAPI
  46. def set_state(self, state: Dict[str, Any]) -> None:
  47. """Restores all local state to the provided `state`.
  48. Args:
  49. state: The new state to set this buffer. Can be obtained by calling
  50. `self.get_state()`.
  51. """
  52. @DeveloperAPI
  53. def get_host(self) -> str:
  54. """Returns the computer's network name.
  55. Returns:
  56. The computer's networks name or an empty string, if the network
  57. name could not be determined.
  58. """
  59. return platform.node()