| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438 |
- import glob
- import json
- import logging
- import math
- import os
- import random
- import re
- import zipfile
- from pathlib import Path
- from typing import TYPE_CHECKING, List, Optional, Union
- from urllib.parse import urlparse
- import numpy as np
- import tree # pip install dm_tree
- try:
- from smart_open import smart_open
- except ImportError:
- smart_open = None
- from ray.rllib.offline.input_reader import InputReader
- from ray.rllib.offline.io_context import IOContext
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.sample_batch import (
- DEFAULT_POLICY_ID,
- MultiAgentBatch,
- SampleBatch,
- concat_samples,
- convert_ma_batch_to_sample_batch,
- )
- from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI, override
- from ray.rllib.utils.compression import unpack_if_needed
- from ray.rllib.utils.spaces.space_utils import clip_action, normalize_action
- from ray.rllib.utils.typing import Any, FileType, SampleBatchType
- if TYPE_CHECKING:
- from ray.rllib.evaluation import RolloutWorker
- logger = logging.getLogger(__name__)
- WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
- def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
- """Handle nested action/observation spaces for policies.
- Translates nested lists/dicts from the json into proper
- np.ndarrays, according to the (nested) observation- and action-
- spaces of the given policy.
- Providing nested lists w/o this preprocessing step would
- confuse a SampleBatch constructor.
- """
- for k, v in json_data.items():
- data_col = (
- policy.view_requirements[k].data_col
- if k in policy.view_requirements
- else ""
- )
- # No action flattening -> Process nested (leaf) action(s).
- if policy.config.get("_disable_action_flattening") and (
- k == SampleBatch.ACTIONS
- or data_col == SampleBatch.ACTIONS
- or k == SampleBatch.PREV_ACTIONS
- or data_col == SampleBatch.PREV_ACTIONS
- ):
- json_data[k] = tree.map_structure_up_to(
- policy.action_space_struct,
- lambda comp: np.array(comp),
- json_data[k],
- check_types=False,
- )
- # No preprocessing -> Process nested (leaf) observation(s).
- elif policy.config.get("_disable_preprocessor_api") and (
- k == SampleBatch.OBS
- or data_col == SampleBatch.OBS
- or k == SampleBatch.NEXT_OBS
- or data_col == SampleBatch.NEXT_OBS
- ):
- json_data[k] = tree.map_structure_up_to(
- policy.observation_space_struct,
- lambda comp: np.array(comp),
- json_data[k],
- check_types=False,
- )
- return json_data
- @DeveloperAPI
- def _adjust_dones(json_data: dict) -> dict:
- """Make sure DONES in json data is properly translated into TERMINATEDS."""
- new_json_data = {}
- for k, v in json_data.items():
- # Translate DONES into TERMINATEDS.
- if k == SampleBatch.DONES:
- new_json_data[SampleBatch.TERMINATEDS] = v
- # Leave everything else as-is.
- else:
- new_json_data[k] = v
- return new_json_data
- @DeveloperAPI
- def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType:
- # Clip actions (from any values into env's bounds), if necessary.
- cfg = ioctx.config
- # TODO(jungong): We should not clip_action in input reader.
- # Use connector to handle this.
- if cfg.get("clip_actions"):
- if ioctx.worker is None:
- raise ValueError(
- "clip_actions is True but cannot clip actions since no workers exist"
- )
- if isinstance(batch, SampleBatch):
- policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
- if policy is None:
- assert len(ioctx.worker.policy_map) == 1
- policy = next(iter(ioctx.worker.policy_map.values()))
- batch[SampleBatch.ACTIONS] = clip_action(
- batch[SampleBatch.ACTIONS], policy.action_space_struct
- )
- else:
- for pid, b in batch.policy_batches.items():
- b[SampleBatch.ACTIONS] = clip_action(
- b[SampleBatch.ACTIONS],
- ioctx.worker.policy_map[pid].action_space_struct,
- )
- # Re-normalize actions (from env's bounds to zero-centered), if
- # necessary.
- if (
- cfg.get("actions_in_input_normalized") is False
- and cfg.get("normalize_actions") is True
- ):
- if ioctx.worker is None:
- raise ValueError(
- "actions_in_input_normalized is False but"
- "cannot normalize actions since no workers exist"
- )
- # If we have a complex action space and actions were flattened
- # and we have to normalize -> Error.
- error_msg = (
- "Normalization of offline actions that are flattened is not "
- "supported! Make sure that you record actions into offline "
- "file with the `_disable_action_flattening=True` flag OR "
- "as already normalized (between -1.0 and 1.0) values. "
- "Also, when reading already normalized action values from "
- "offline files, make sure to set "
- "`actions_in_input_normalized=True` so that RLlib will not "
- "perform normalization on top."
- )
- if isinstance(batch, SampleBatch):
- policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
- if policy is None:
- assert len(ioctx.worker.policy_map) == 1
- policy = next(iter(ioctx.worker.policy_map.values()))
- if isinstance(
- policy.action_space_struct, (tuple, dict)
- ) and not policy.config.get("_disable_action_flattening"):
- raise ValueError(error_msg)
- batch[SampleBatch.ACTIONS] = normalize_action(
- batch[SampleBatch.ACTIONS], policy.action_space_struct
- )
- else:
- for pid, b in batch.policy_batches.items():
- policy = ioctx.worker.policy_map[pid]
- if isinstance(
- policy.action_space_struct, (tuple, dict)
- ) and not policy.config.get("_disable_action_flattening"):
- raise ValueError(error_msg)
- b[SampleBatch.ACTIONS] = normalize_action(
- b[SampleBatch.ACTIONS],
- ioctx.worker.policy_map[pid].action_space_struct,
- )
- return batch
- @DeveloperAPI
- def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
- # Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
- if "type" in json_data:
- data_type = json_data.pop("type")
- else:
- raise ValueError("JSON record missing 'type' field")
- if data_type == "SampleBatch":
- if worker is not None and len(worker.policy_map) != 1:
- raise ValueError(
- "Found single-agent SampleBatch in input file, but our "
- "PolicyMap contains more than 1 policy!"
- )
- for k, v in json_data.items():
- json_data[k] = unpack_if_needed(v)
- if worker is not None:
- policy = next(iter(worker.policy_map.values()))
- json_data = _adjust_obs_actions_for_policy(json_data, policy)
- json_data = _adjust_dones(json_data)
- return SampleBatch(json_data)
- elif data_type == "MultiAgentBatch":
- policy_batches = {}
- for policy_id, policy_batch in json_data["policy_batches"].items():
- inner = {}
- for k, v in policy_batch.items():
- # Translate DONES into TERMINATEDS.
- if k == SampleBatch.DONES:
- k = SampleBatch.TERMINATEDS
- inner[k] = unpack_if_needed(v)
- if worker is not None:
- policy = worker.policy_map[policy_id]
- inner = _adjust_obs_actions_for_policy(inner, policy)
- inner = _adjust_dones(inner)
- policy_batches[policy_id] = SampleBatch(inner)
- return MultiAgentBatch(policy_batches, json_data["count"])
- else:
- raise ValueError(
- "Type field must be one of ['SampleBatch', 'MultiAgentBatch']", data_type
- )
- # TODO(jungong) : use DatasetReader to back JsonReader, so we reduce
- # codebase complexity without losing existing functionality.
- @PublicAPI
- class JsonReader(InputReader):
- """Reader object that loads experiences from JSON file chunks.
- The input files will be read from in random order.
- """
- @PublicAPI
- def __init__(
- self, inputs: Union[str, List[str]], ioctx: Optional[IOContext] = None
- ):
- """Initializes a JsonReader instance.
- Args:
- inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`,
- or a list of single file paths or URIs, e.g.,
- ["s3://bucket/file.json", "s3://bucket/file2.json"].
- ioctx: Current IO context object or None.
- """
- logger.info(
- "You are using JSONReader. It is recommended to use "
- + "DatasetReader instead for better sharding support."
- )
- self.ioctx = ioctx or IOContext()
- self.default_policy = self.policy_map = None
- self.batch_size = 1
- if self.ioctx:
- self.batch_size = self.ioctx.config.get("train_batch_size", 1)
- num_workers = self.ioctx.config.get("num_env_runners", 0)
- if num_workers:
- self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
- if self.ioctx.worker is not None:
- self.policy_map = self.ioctx.worker.policy_map
- self.default_policy = self.policy_map.get(DEFAULT_POLICY_ID)
- if self.default_policy is None:
- assert len(self.policy_map) == 1
- self.default_policy = next(iter(self.policy_map.values()))
- if isinstance(inputs, str):
- inputs = os.path.abspath(os.path.expanduser(inputs))
- if os.path.isdir(inputs):
- inputs = [os.path.join(inputs, "*.json"), os.path.join(inputs, "*.zip")]
- logger.warning(f"Treating input directory as glob patterns: {inputs}")
- else:
- inputs = [inputs]
- if any(urlparse(i).scheme not in [""] + WINDOWS_DRIVES for i in inputs):
- raise ValueError(
- "Don't know how to glob over `{}`, ".format(inputs)
- + "please specify a list of files to read instead."
- )
- else:
- self.files = []
- for i in inputs:
- self.files.extend(glob.glob(i))
- elif isinstance(inputs, (list, tuple)):
- self.files = list(inputs)
- else:
- raise ValueError(
- "type of inputs must be list or str, not {}".format(inputs)
- )
- if self.files:
- logger.info("Found {} input files.".format(len(self.files)))
- else:
- raise ValueError("No files found matching {}".format(inputs))
- self.cur_file = None
- @override(InputReader)
- def next(self) -> SampleBatchType:
- ret = []
- count = 0
- while count < self.batch_size:
- batch = self._try_parse(self._next_line())
- tries = 0
- while not batch and tries < 100:
- tries += 1
- logger.debug("Skipping empty line in {}".format(self.cur_file))
- batch = self._try_parse(self._next_line())
- if not batch:
- raise ValueError(
- "Failed to read valid experience batch from file: {}".format(
- self.cur_file
- )
- )
- batch = self._postprocess_if_needed(batch)
- count += batch.count
- ret.append(batch)
- ret = concat_samples(ret)
- return ret
- def read_all_files(self) -> SampleBatchType:
- """Reads through all files and yields one SampleBatchType per line.
- When reaching the end of the last file, will start from the beginning
- again.
- Yields:
- One SampleBatch or MultiAgentBatch per line in all input files.
- """
- for path in self.files:
- file = self._try_open_file(path)
- while True:
- line = file.readline()
- if not line:
- break
- batch = self._try_parse(line)
- if batch is None:
- break
- yield batch
- def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
- if not self.ioctx.config.get("postprocess_inputs"):
- return batch
- batch = convert_ma_batch_to_sample_batch(batch)
- if isinstance(batch, SampleBatch):
- out = []
- for sub_batch in batch.split_by_episode():
- out.append(self.default_policy.postprocess_trajectory(sub_batch))
- return concat_samples(out)
- else:
- # TODO(ekl) this is trickier since the alignments between agent
- # trajectories in the episode are not available any more.
- raise NotImplementedError(
- "Postprocessing of multi-agent data not implemented yet."
- )
- def _try_open_file(self, path):
- if urlparse(path).scheme not in [""] + WINDOWS_DRIVES:
- if smart_open is None:
- raise ValueError(
- "You must install the `smart_open` module to read "
- "from URIs like {}".format(path)
- )
- ctx = smart_open
- else:
- # Allow shortcut for home directory ("~/" -> env[HOME]).
- if path.startswith("~/"):
- path = os.path.join(os.environ.get("HOME", ""), path[2:])
- # If path doesn't exist, try to interpret is as relative to the
- # rllib directory (located ../../ from this very module).
- path_orig = path
- if not os.path.exists(path):
- path = os.path.join(Path(__file__).parent.parent, path)
- if not os.path.exists(path):
- raise FileNotFoundError(f"Offline file {path_orig} not found!")
- # Unzip files, if necessary and re-point to extracted json file.
- if re.search("\\.zip$", path):
- with zipfile.ZipFile(path, "r") as zip_ref:
- zip_ref.extractall(Path(path).parent)
- path = re.sub("\\.zip$", ".json", path)
- assert os.path.exists(path)
- ctx = open
- file = ctx(path, "r")
- return file
- def _try_parse(self, line: str) -> Optional[SampleBatchType]:
- line = line.strip()
- if not line:
- return None
- try:
- batch = self._from_json(line)
- except Exception:
- logger.exception(
- "Ignoring corrupt json record in {}: {}".format(self.cur_file, line)
- )
- return None
- batch = postprocess_actions(batch, self.ioctx)
- return batch
- def _next_line(self) -> str:
- if not self.cur_file:
- self.cur_file = self._next_file()
- line = self.cur_file.readline()
- tries = 0
- while not line and tries < 100:
- tries += 1
- if hasattr(self.cur_file, "close"): # legacy smart_open impls
- self.cur_file.close()
- self.cur_file = self._next_file()
- line = self.cur_file.readline()
- if not line:
- logger.debug("Ignoring empty file {}".format(self.cur_file))
- if not line:
- raise ValueError(
- "Failed to read next line from files: {}".format(self.files)
- )
- return line
- def _next_file(self) -> FileType:
- # If this is the first time, we open a file, make sure all workers
- # start with a different one if possible.
- if self.cur_file is None and self.ioctx.worker is not None:
- idx = self.ioctx.worker.worker_index
- total = self.ioctx.worker.num_workers or 1
- path = self.files[round((len(self.files) - 1) * (idx / total))]
- # After the first file, pick all others randomly.
- else:
- path = random.choice(self.files)
- return self._try_open_file(path)
- def _from_json(self, data: str) -> SampleBatchType:
- if isinstance(data, bytes): # smart_open S3 doesn't respect "r"
- data = data.decode("utf-8")
- json_data = json.loads(data)
- return from_json_data(json_data, self.ioctx.worker)
|