d4rl_reader.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import logging
  2. from typing import Dict
  3. import gymnasium as gym
  4. from ray.rllib.offline.input_reader import InputReader
  5. from ray.rllib.offline.io_context import IOContext
  6. from ray.rllib.policy.sample_batch import SampleBatch
  7. from ray.rllib.utils.annotations import PublicAPI, override
  8. from ray.rllib.utils.typing import SampleBatchType
  9. logger = logging.getLogger(__name__)
  10. @PublicAPI
  11. class D4RLReader(InputReader):
  12. """Reader object that loads the dataset from the D4RL dataset."""
  13. @PublicAPI
  14. def __init__(self, inputs: str, ioctx: IOContext = None):
  15. """Initializes a D4RLReader instance.
  16. Args:
  17. inputs: String corresponding to the D4RL environment name.
  18. ioctx: Current IO context object.
  19. """
  20. import d4rl
  21. self.env = gym.make(inputs)
  22. self.dataset = _convert_to_batch(d4rl.qlearning_dataset(self.env))
  23. assert self.dataset.count >= 1
  24. self.counter = 0
  25. @override(InputReader)
  26. def next(self) -> SampleBatchType:
  27. if self.counter >= self.dataset.count:
  28. self.counter = 0
  29. self.counter += 1
  30. return self.dataset.slice(start=self.counter, end=self.counter + 1)
  31. def _convert_to_batch(dataset: Dict) -> SampleBatchType:
  32. # Converts D4RL dataset to SampleBatch
  33. d = {}
  34. d[SampleBatch.OBS] = dataset["observations"]
  35. d[SampleBatch.ACTIONS] = dataset["actions"]
  36. d[SampleBatch.NEXT_OBS] = dataset["next_observations"]
  37. d[SampleBatch.REWARDS] = dataset["rewards"]
  38. d[SampleBatch.TERMINATEDS] = dataset["terminals"]
  39. return SampleBatch(d)