| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- import logging
- import os
- import time
- from typing import Dict, List
- from ray import data
- from ray.rllib.offline.io_context import IOContext
- from ray.rllib.offline.json_writer import _to_json_dict
- from ray.rllib.offline.output_writer import OutputWriter
- from ray.rllib.utils.annotations import PublicAPI, override
- from ray.rllib.utils.typing import SampleBatchType
- logger = logging.getLogger(__name__)
- @PublicAPI
- class DatasetWriter(OutputWriter):
- """Writer object that saves experiences using Datasets."""
- @PublicAPI
- def __init__(
- self,
- ioctx: IOContext = None,
- compress_columns: List[str] = frozenset(["obs", "new_obs"]),
- ):
- """Initializes a DatasetWriter instance.
- Examples:
- config = {
- "output": "dataset",
- "output_config": {
- "format": "json",
- "path": "/tmp/test_samples/",
- "max_num_samples_per_file": 100000,
- }
- }
- Args:
- ioctx: current IO context object.
- compress_columns: list of sample batch columns to compress.
- """
- self.ioctx = ioctx or IOContext()
- output_config: Dict = ioctx.output_config
- assert (
- "format" in output_config
- ), "output_config.format must be specified when using Dataset output."
- assert (
- "path" in output_config
- ), "output_config.path must be specified when using Dataset output."
- self.format = output_config["format"]
- self.path = os.path.abspath(os.path.expanduser(output_config["path"]))
- self.max_num_samples_per_file = (
- output_config["max_num_samples_per_file"]
- if "max_num_samples_per_file" in output_config
- else 100000
- )
- self.compress_columns = compress_columns
- self.samples = []
- @override(OutputWriter)
- def write(self, sample_batch: SampleBatchType):
- start = time.time()
- # Make sure columns like obs are compressed and writable.
- d = _to_json_dict(sample_batch, self.compress_columns)
- self.samples.append(d)
- # Todo: We should flush at the end of sampling even if this
- # condition was not reached.
- if len(self.samples) >= self.max_num_samples_per_file:
- ds = data.from_items(self.samples).repartition(num_blocks=1, shuffle=False)
- if self.format == "json":
- ds.write_json(self.path, try_create_dir=True)
- elif self.format == "parquet":
- ds.write_parquet(self.path, try_create_dir=True)
- else:
- raise ValueError("Unknown output type: ", self.format)
- self.samples = []
- logger.debug("Wrote dataset in {}s".format(time.time() - start))
|