dataset_writer.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import logging
  2. import os
  3. import time
  4. from typing import Dict, List
  5. from ray import data
  6. from ray.rllib.offline.io_context import IOContext
  7. from ray.rllib.offline.json_writer import _to_json_dict
  8. from ray.rllib.offline.output_writer import OutputWriter
  9. from ray.rllib.utils.annotations import PublicAPI, override
  10. from ray.rllib.utils.typing import SampleBatchType
  11. logger = logging.getLogger(__name__)
  12. @PublicAPI
  13. class DatasetWriter(OutputWriter):
  14. """Writer object that saves experiences using Datasets."""
  15. @PublicAPI
  16. def __init__(
  17. self,
  18. ioctx: IOContext = None,
  19. compress_columns: List[str] = frozenset(["obs", "new_obs"]),
  20. ):
  21. """Initializes a DatasetWriter instance.
  22. Examples:
  23. config = {
  24. "output": "dataset",
  25. "output_config": {
  26. "format": "json",
  27. "path": "/tmp/test_samples/",
  28. "max_num_samples_per_file": 100000,
  29. }
  30. }
  31. Args:
  32. ioctx: current IO context object.
  33. compress_columns: list of sample batch columns to compress.
  34. """
  35. self.ioctx = ioctx or IOContext()
  36. output_config: Dict = ioctx.output_config
  37. assert (
  38. "format" in output_config
  39. ), "output_config.format must be specified when using Dataset output."
  40. assert (
  41. "path" in output_config
  42. ), "output_config.path must be specified when using Dataset output."
  43. self.format = output_config["format"]
  44. self.path = os.path.abspath(os.path.expanduser(output_config["path"]))
  45. self.max_num_samples_per_file = (
  46. output_config["max_num_samples_per_file"]
  47. if "max_num_samples_per_file" in output_config
  48. else 100000
  49. )
  50. self.compress_columns = compress_columns
  51. self.samples = []
  52. @override(OutputWriter)
  53. def write(self, sample_batch: SampleBatchType):
  54. start = time.time()
  55. # Make sure columns like obs are compressed and writable.
  56. d = _to_json_dict(sample_batch, self.compress_columns)
  57. self.samples.append(d)
  58. # Todo: We should flush at the end of sampling even if this
  59. # condition was not reached.
  60. if len(self.samples) >= self.max_num_samples_per_file:
  61. ds = data.from_items(self.samples).repartition(num_blocks=1, shuffle=False)
  62. if self.format == "json":
  63. ds.write_json(self.path, try_create_dir=True)
  64. elif self.format == "parquet":
  65. ds.write_parquet(self.path, try_create_dir=True)
  66. else:
  67. raise ValueError("Unknown output type: ", self.format)
  68. self.samples = []
  69. logger.debug("Wrote dataset in {}s".format(time.time() - start))