| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- import logging
- import math
- import re
- import zipfile
- from pathlib import Path
- from typing import TYPE_CHECKING, List, Optional, Tuple
- import numpy as np
- import ray.data
- from ray.rllib.offline.input_reader import InputReader
- from ray.rllib.offline.io_context import IOContext
- from ray.rllib.offline.json_reader import from_json_data, postprocess_actions
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, concat_samples
- from ray.rllib.utils.annotations import PublicAPI, override
- from ray.rllib.utils.typing import SampleBatchType
- if TYPE_CHECKING:
- from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
- DEFAULT_NUM_CPUS_PER_TASK = 0.5
- logger = logging.getLogger(__name__)
- def _unzip_this_path(fpath: Path, extract_path: str):
- with zipfile.ZipFile(str(fpath), "r") as zip_ref:
- zip_ref.extractall(extract_path)
- def _unzip_if_needed(paths: List[str], format: str):
- """If a path in paths is a zip file, unzip it and use path of the unzipped file"""
- ret_paths = []
- for path in paths:
- if re.search("\\.zip$", str(path)):
- # TODO: We need to add unzip support for s3
- if str(path).startswith("s3://"):
- raise ValueError(
- "unzip_if_needed currently does not support remote paths from s3"
- )
- extract_path = "./"
- try:
- _unzip_this_path(str(path), extract_path)
- except FileNotFoundError:
- # intrepreted as a relative path to rllib folder
- try:
- # TODO: remove this later when we replace all tests with s3 paths
- _unzip_this_path(Path(__file__).parent.parent / path, extract_path)
- except FileNotFoundError:
- raise FileNotFoundError(f"File not found: {path}")
- unzipped_path = str(
- Path(extract_path).absolute() / f"{Path(path).stem}.{format}"
- )
- ret_paths.append(unzipped_path)
- else:
- # TODO: We can get rid of this logic when we replace all tests with s3 paths
- if str(path).startswith("s3://"):
- ret_paths.append(path)
- else:
- if not Path(path).exists():
- relative_path = str(Path(__file__).parent.parent / path)
- if not Path(relative_path).exists():
- raise FileNotFoundError(f"File not found: {relative_path}")
- path = relative_path
- ret_paths.append(path)
- return ret_paths
- @PublicAPI
- def get_dataset_and_shards(
- config: "AlgorithmConfig", num_workers: int = 0
- ) -> Tuple[ray.data.Dataset, List[ray.data.Dataset]]:
- """Returns a dataset and a list of shards.
- This function uses algorithm configs to create a dataset and a list of shards.
- The following config keys are used to create the dataset:
- input: The input type should be "dataset".
- input_config: A dict containing the following key and values:
- `format`: str, speciifies the format of the input data. This will be the
- format that ray dataset supports. See ray.data.Dataset for
- supported formats. Only "parquet" or "json" are supported for now.
- `paths`: str, a single string or a list of strings. Each string is a path
- to a file or a directory holding the dataset. It can be either a local path
- or a remote path (e.g. to an s3 bucket).
- `loader_fn`: Callable[None, ray.data.Dataset], Instead of
- specifying paths and format, you can specify a function to load the dataset.
- `parallelism`: int, The number of tasks to use for loading the dataset.
- If not specified, it will be set to the number of workers.
- `num_cpus_per_read_task`: float, The number of CPUs to use for each read
- task. If not specified, it will be set to 0.5.
- Args:
- config: The config dict for the algorithm.
- num_workers: The number of shards to create for remote workers.
- Returns:
- dataset: The dataset object.
- shards: A list of dataset shards. For num_workers > 0 the first returned
- shared would be a dummy None shard for local_worker.
- """
- # check input and input config keys
- assert config.input_ == "dataset", (
- f"Must specify config.input_ as 'dataset' if"
- f" calling `get_dataset_and_shards`. Got {config.input_}"
- )
- # check input config format
- input_config = config.input_config
- format = input_config.get("format")
- supported_fmts = ["json", "parquet"]
- if format is not None and format not in supported_fmts:
- raise ValueError(
- f"Unsupported format {format}. Supported formats are {supported_fmts}"
- )
- # check paths and loader_fn since only one of them is required.
- paths = input_config.get("paths")
- loader_fn = input_config.get("loader_fn")
- if loader_fn and (format or paths):
- raise ValueError(
- "When using a `loader_fn`, you cannot specify a `format` or `path`."
- )
- # check if at least loader_fn or format + path is specified.
- if not (format and paths) and not loader_fn:
- raise ValueError(
- "Must specify either a `loader_fn` or a `format` and `path` in "
- "`input_config`."
- )
- # check paths to be a str or list[str] if not None
- if paths is not None:
- if isinstance(paths, str):
- paths = [paths]
- elif isinstance(paths, list):
- assert isinstance(paths[0], str), "Paths must be a list of path strings."
- else:
- raise ValueError("Paths must be a path string or a list of path strings.")
- paths = _unzip_if_needed(paths, format)
- # TODO (Kourosh): num_workers is not necessary since we can use parallelism for
- # everything. Having two parameters is confusing here. Remove num_workers later.
- parallelism = input_config.get("parallelism", num_workers or 1)
- cpus_per_task = input_config.get(
- "num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
- )
- if loader_fn:
- dataset = loader_fn()
- elif format == "json":
- dataset = ray.data.read_json(
- paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
- )
- elif format == "parquet":
- dataset = ray.data.read_parquet(
- paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
- )
- else:
- raise ValueError("Un-supported Ray dataset format: ", format)
- # Local worker will be responsible for sampling.
- if num_workers == 0:
- # Dataset is the only shard we need.
- return dataset, [dataset]
- # Remote workers are responsible for sampling:
- else:
- # Each remote worker gets 1 shard.
- remote_shards = dataset.repartition(
- num_blocks=num_workers, shuffle=False
- ).split(num_workers)
- # The first None shard is for the local worker, which
- # shouldn't be doing rollout work anyways.
- return dataset, [None] + remote_shards
- @PublicAPI
- class DatasetReader(InputReader):
- """Reader object that loads data from Ray Dataset.
- Examples:
- config = {
- "input": "dataset",
- "input_config": {
- "format": "json",
- # A single data file, a directory, or anything
- # that ray.data.dataset recognizes.
- "paths": "/tmp/sample_batches/",
- # By default, parallelism=num_workers.
- "parallelism": 3,
- # Dataset allocates 0.5 CPU for each reader by default.
- # Adjust this value based on the size of your offline dataset.
- "num_cpus_per_read_task": 0.5,
- }
- }
- """
- @PublicAPI
- def __init__(self, ds: ray.data.Dataset, ioctx: Optional[IOContext] = None):
- """Initializes a DatasetReader instance.
- Args:
- ds: Ray dataset to sample from.
- """
- self._ioctx = ioctx or IOContext()
- self._default_policy = self.policy_map = None
- self.preprocessor = None
- self._dataset = ds
- self.count = None if not self._dataset else self._dataset.count()
- # do this to disable the ray data stdout logging
- ray.data.DataContext.get_current().enable_progress_bars = False
- # the number of steps to return per call to next()
- self.batch_size = self._ioctx.config.get("train_batch_size", 1)
- num_workers = self._ioctx.config.get("num_env_runners", 0)
- seed = self._ioctx.config.get("seed", None)
- if num_workers:
- self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
- # We allow the creation of a non-functioning None DatasetReader.
- # It's useful for example for a non-rollout local worker.
- if ds:
- if self._ioctx.worker is not None:
- self._policy_map = self._ioctx.worker.policy_map
- self._default_policy = self._policy_map.get(DEFAULT_POLICY_ID)
- self.preprocessor = (
- self._ioctx.worker.preprocessors.get(DEFAULT_POLICY_ID)
- if not self._ioctx.config.get("_disable_preprocessors", False)
- else None
- )
- print(
- f"DatasetReader {self._ioctx.worker_index} has {ds.count()}, samples."
- )
- def iterator():
- while True:
- ds = self._dataset.random_shuffle(seed=seed)
- yield from ds.iter_rows()
- self._iter = iterator()
- else:
- self._iter = None
- @override(InputReader)
- def next(self) -> SampleBatchType:
- # next() should not get called on None DatasetReader.
- assert self._iter is not None
- ret = []
- count = 0
- while count < self.batch_size:
- d = next(self._iter)
- # Columns like obs are compressed when written by DatasetWriter.
- d = from_json_data(d, self._ioctx.worker)
- count += d.count
- d = self._preprocess_if_needed(d)
- d = postprocess_actions(d, self._ioctx)
- d = self._postprocess_if_needed(d)
- ret.append(d)
- ret = concat_samples(ret)
- return ret
- def _preprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
- # TODO: @kourosh, preprocessor is only supported for single agent case.
- if self.preprocessor:
- for key in (SampleBatch.CUR_OBS, SampleBatch.NEXT_OBS):
- if key in batch:
- batch[key] = np.stack(
- [self.preprocessor.transform(s) for s in batch[key]]
- )
- return batch
- def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
- if not self._ioctx.config.get("postprocess_inputs"):
- return batch
- if isinstance(batch, SampleBatch):
- out = []
- for sub_batch in batch.split_by_episode():
- if self._default_policy is not None:
- out.append(self._default_policy.postprocess_trajectory(sub_batch))
- else:
- out.append(sub_batch)
- return concat_samples(out)
- else:
- # TODO(ekl) this is trickier since the alignments between agent
- # trajectories in the episode are not available any more.
- raise NotImplementedError(
- "Postprocessing of multi-agent data not implemented yet."
- )
|