json_reader.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. import glob
  2. import json
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import re
  8. import zipfile
  9. from pathlib import Path
  10. from typing import TYPE_CHECKING, List, Optional, Union
  11. from urllib.parse import urlparse
  12. import numpy as np
  13. import tree # pip install dm_tree
  14. try:
  15. from smart_open import smart_open
  16. except ImportError:
  17. smart_open = None
  18. from ray.rllib.offline.input_reader import InputReader
  19. from ray.rllib.offline.io_context import IOContext
  20. from ray.rllib.policy.policy import Policy
  21. from ray.rllib.policy.sample_batch import (
  22. DEFAULT_POLICY_ID,
  23. MultiAgentBatch,
  24. SampleBatch,
  25. concat_samples,
  26. convert_ma_batch_to_sample_batch,
  27. )
  28. from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI, override
  29. from ray.rllib.utils.compression import unpack_if_needed
  30. from ray.rllib.utils.spaces.space_utils import clip_action, normalize_action
  31. from ray.rllib.utils.typing import Any, FileType, SampleBatchType
  32. if TYPE_CHECKING:
  33. from ray.rllib.evaluation import RolloutWorker
  34. logger = logging.getLogger(__name__)
  35. WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
  36. def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
  37. """Handle nested action/observation spaces for policies.
  38. Translates nested lists/dicts from the json into proper
  39. np.ndarrays, according to the (nested) observation- and action-
  40. spaces of the given policy.
  41. Providing nested lists w/o this preprocessing step would
  42. confuse a SampleBatch constructor.
  43. """
  44. for k, v in json_data.items():
  45. data_col = (
  46. policy.view_requirements[k].data_col
  47. if k in policy.view_requirements
  48. else ""
  49. )
  50. # No action flattening -> Process nested (leaf) action(s).
  51. if policy.config.get("_disable_action_flattening") and (
  52. k == SampleBatch.ACTIONS
  53. or data_col == SampleBatch.ACTIONS
  54. or k == SampleBatch.PREV_ACTIONS
  55. or data_col == SampleBatch.PREV_ACTIONS
  56. ):
  57. json_data[k] = tree.map_structure_up_to(
  58. policy.action_space_struct,
  59. lambda comp: np.array(comp),
  60. json_data[k],
  61. check_types=False,
  62. )
  63. # No preprocessing -> Process nested (leaf) observation(s).
  64. elif policy.config.get("_disable_preprocessor_api") and (
  65. k == SampleBatch.OBS
  66. or data_col == SampleBatch.OBS
  67. or k == SampleBatch.NEXT_OBS
  68. or data_col == SampleBatch.NEXT_OBS
  69. ):
  70. json_data[k] = tree.map_structure_up_to(
  71. policy.observation_space_struct,
  72. lambda comp: np.array(comp),
  73. json_data[k],
  74. check_types=False,
  75. )
  76. return json_data
  77. @DeveloperAPI
  78. def _adjust_dones(json_data: dict) -> dict:
  79. """Make sure DONES in json data is properly translated into TERMINATEDS."""
  80. new_json_data = {}
  81. for k, v in json_data.items():
  82. # Translate DONES into TERMINATEDS.
  83. if k == SampleBatch.DONES:
  84. new_json_data[SampleBatch.TERMINATEDS] = v
  85. # Leave everything else as-is.
  86. else:
  87. new_json_data[k] = v
  88. return new_json_data
  89. @DeveloperAPI
  90. def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType:
  91. # Clip actions (from any values into env's bounds), if necessary.
  92. cfg = ioctx.config
  93. # TODO(jungong): We should not clip_action in input reader.
  94. # Use connector to handle this.
  95. if cfg.get("clip_actions"):
  96. if ioctx.worker is None:
  97. raise ValueError(
  98. "clip_actions is True but cannot clip actions since no workers exist"
  99. )
  100. if isinstance(batch, SampleBatch):
  101. policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
  102. if policy is None:
  103. assert len(ioctx.worker.policy_map) == 1
  104. policy = next(iter(ioctx.worker.policy_map.values()))
  105. batch[SampleBatch.ACTIONS] = clip_action(
  106. batch[SampleBatch.ACTIONS], policy.action_space_struct
  107. )
  108. else:
  109. for pid, b in batch.policy_batches.items():
  110. b[SampleBatch.ACTIONS] = clip_action(
  111. b[SampleBatch.ACTIONS],
  112. ioctx.worker.policy_map[pid].action_space_struct,
  113. )
  114. # Re-normalize actions (from env's bounds to zero-centered), if
  115. # necessary.
  116. if (
  117. cfg.get("actions_in_input_normalized") is False
  118. and cfg.get("normalize_actions") is True
  119. ):
  120. if ioctx.worker is None:
  121. raise ValueError(
  122. "actions_in_input_normalized is False but"
  123. "cannot normalize actions since no workers exist"
  124. )
  125. # If we have a complex action space and actions were flattened
  126. # and we have to normalize -> Error.
  127. error_msg = (
  128. "Normalization of offline actions that are flattened is not "
  129. "supported! Make sure that you record actions into offline "
  130. "file with the `_disable_action_flattening=True` flag OR "
  131. "as already normalized (between -1.0 and 1.0) values. "
  132. "Also, when reading already normalized action values from "
  133. "offline files, make sure to set "
  134. "`actions_in_input_normalized=True` so that RLlib will not "
  135. "perform normalization on top."
  136. )
  137. if isinstance(batch, SampleBatch):
  138. policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
  139. if policy is None:
  140. assert len(ioctx.worker.policy_map) == 1
  141. policy = next(iter(ioctx.worker.policy_map.values()))
  142. if isinstance(
  143. policy.action_space_struct, (tuple, dict)
  144. ) and not policy.config.get("_disable_action_flattening"):
  145. raise ValueError(error_msg)
  146. batch[SampleBatch.ACTIONS] = normalize_action(
  147. batch[SampleBatch.ACTIONS], policy.action_space_struct
  148. )
  149. else:
  150. for pid, b in batch.policy_batches.items():
  151. policy = ioctx.worker.policy_map[pid]
  152. if isinstance(
  153. policy.action_space_struct, (tuple, dict)
  154. ) and not policy.config.get("_disable_action_flattening"):
  155. raise ValueError(error_msg)
  156. b[SampleBatch.ACTIONS] = normalize_action(
  157. b[SampleBatch.ACTIONS],
  158. ioctx.worker.policy_map[pid].action_space_struct,
  159. )
  160. return batch
  161. @DeveloperAPI
  162. def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
  163. # Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
  164. if "type" in json_data:
  165. data_type = json_data.pop("type")
  166. else:
  167. raise ValueError("JSON record missing 'type' field")
  168. if data_type == "SampleBatch":
  169. if worker is not None and len(worker.policy_map) != 1:
  170. raise ValueError(
  171. "Found single-agent SampleBatch in input file, but our "
  172. "PolicyMap contains more than 1 policy!"
  173. )
  174. for k, v in json_data.items():
  175. json_data[k] = unpack_if_needed(v)
  176. if worker is not None:
  177. policy = next(iter(worker.policy_map.values()))
  178. json_data = _adjust_obs_actions_for_policy(json_data, policy)
  179. json_data = _adjust_dones(json_data)
  180. return SampleBatch(json_data)
  181. elif data_type == "MultiAgentBatch":
  182. policy_batches = {}
  183. for policy_id, policy_batch in json_data["policy_batches"].items():
  184. inner = {}
  185. for k, v in policy_batch.items():
  186. # Translate DONES into TERMINATEDS.
  187. if k == SampleBatch.DONES:
  188. k = SampleBatch.TERMINATEDS
  189. inner[k] = unpack_if_needed(v)
  190. if worker is not None:
  191. policy = worker.policy_map[policy_id]
  192. inner = _adjust_obs_actions_for_policy(inner, policy)
  193. inner = _adjust_dones(inner)
  194. policy_batches[policy_id] = SampleBatch(inner)
  195. return MultiAgentBatch(policy_batches, json_data["count"])
  196. else:
  197. raise ValueError(
  198. "Type field must be one of ['SampleBatch', 'MultiAgentBatch']", data_type
  199. )
  200. # TODO(jungong) : use DatasetReader to back JsonReader, so we reduce
  201. # codebase complexity without losing existing functionality.
  202. @PublicAPI
  203. class JsonReader(InputReader):
  204. """Reader object that loads experiences from JSON file chunks.
  205. The input files will be read from in random order.
  206. """
  207. @PublicAPI
  208. def __init__(
  209. self, inputs: Union[str, List[str]], ioctx: Optional[IOContext] = None
  210. ):
  211. """Initializes a JsonReader instance.
  212. Args:
  213. inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`,
  214. or a list of single file paths or URIs, e.g.,
  215. ["s3://bucket/file.json", "s3://bucket/file2.json"].
  216. ioctx: Current IO context object or None.
  217. """
  218. logger.info(
  219. "You are using JSONReader. It is recommended to use "
  220. + "DatasetReader instead for better sharding support."
  221. )
  222. self.ioctx = ioctx or IOContext()
  223. self.default_policy = self.policy_map = None
  224. self.batch_size = 1
  225. if self.ioctx:
  226. self.batch_size = self.ioctx.config.get("train_batch_size", 1)
  227. num_workers = self.ioctx.config.get("num_env_runners", 0)
  228. if num_workers:
  229. self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
  230. if self.ioctx.worker is not None:
  231. self.policy_map = self.ioctx.worker.policy_map
  232. self.default_policy = self.policy_map.get(DEFAULT_POLICY_ID)
  233. if self.default_policy is None:
  234. assert len(self.policy_map) == 1
  235. self.default_policy = next(iter(self.policy_map.values()))
  236. if isinstance(inputs, str):
  237. inputs = os.path.abspath(os.path.expanduser(inputs))
  238. if os.path.isdir(inputs):
  239. inputs = [os.path.join(inputs, "*.json"), os.path.join(inputs, "*.zip")]
  240. logger.warning(f"Treating input directory as glob patterns: {inputs}")
  241. else:
  242. inputs = [inputs]
  243. if any(urlparse(i).scheme not in [""] + WINDOWS_DRIVES for i in inputs):
  244. raise ValueError(
  245. "Don't know how to glob over `{}`, ".format(inputs)
  246. + "please specify a list of files to read instead."
  247. )
  248. else:
  249. self.files = []
  250. for i in inputs:
  251. self.files.extend(glob.glob(i))
  252. elif isinstance(inputs, (list, tuple)):
  253. self.files = list(inputs)
  254. else:
  255. raise ValueError(
  256. "type of inputs must be list or str, not {}".format(inputs)
  257. )
  258. if self.files:
  259. logger.info("Found {} input files.".format(len(self.files)))
  260. else:
  261. raise ValueError("No files found matching {}".format(inputs))
  262. self.cur_file = None
  263. @override(InputReader)
  264. def next(self) -> SampleBatchType:
  265. ret = []
  266. count = 0
  267. while count < self.batch_size:
  268. batch = self._try_parse(self._next_line())
  269. tries = 0
  270. while not batch and tries < 100:
  271. tries += 1
  272. logger.debug("Skipping empty line in {}".format(self.cur_file))
  273. batch = self._try_parse(self._next_line())
  274. if not batch:
  275. raise ValueError(
  276. "Failed to read valid experience batch from file: {}".format(
  277. self.cur_file
  278. )
  279. )
  280. batch = self._postprocess_if_needed(batch)
  281. count += batch.count
  282. ret.append(batch)
  283. ret = concat_samples(ret)
  284. return ret
  285. def read_all_files(self) -> SampleBatchType:
  286. """Reads through all files and yields one SampleBatchType per line.
  287. When reaching the end of the last file, will start from the beginning
  288. again.
  289. Yields:
  290. One SampleBatch or MultiAgentBatch per line in all input files.
  291. """
  292. for path in self.files:
  293. file = self._try_open_file(path)
  294. while True:
  295. line = file.readline()
  296. if not line:
  297. break
  298. batch = self._try_parse(line)
  299. if batch is None:
  300. break
  301. yield batch
  302. def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
  303. if not self.ioctx.config.get("postprocess_inputs"):
  304. return batch
  305. batch = convert_ma_batch_to_sample_batch(batch)
  306. if isinstance(batch, SampleBatch):
  307. out = []
  308. for sub_batch in batch.split_by_episode():
  309. out.append(self.default_policy.postprocess_trajectory(sub_batch))
  310. return concat_samples(out)
  311. else:
  312. # TODO(ekl) this is trickier since the alignments between agent
  313. # trajectories in the episode are not available any more.
  314. raise NotImplementedError(
  315. "Postprocessing of multi-agent data not implemented yet."
  316. )
  317. def _try_open_file(self, path):
  318. if urlparse(path).scheme not in [""] + WINDOWS_DRIVES:
  319. if smart_open is None:
  320. raise ValueError(
  321. "You must install the `smart_open` module to read "
  322. "from URIs like {}".format(path)
  323. )
  324. ctx = smart_open
  325. else:
  326. # Allow shortcut for home directory ("~/" -> env[HOME]).
  327. if path.startswith("~/"):
  328. path = os.path.join(os.environ.get("HOME", ""), path[2:])
  329. # If path doesn't exist, try to interpret is as relative to the
  330. # rllib directory (located ../../ from this very module).
  331. path_orig = path
  332. if not os.path.exists(path):
  333. path = os.path.join(Path(__file__).parent.parent, path)
  334. if not os.path.exists(path):
  335. raise FileNotFoundError(f"Offline file {path_orig} not found!")
  336. # Unzip files, if necessary and re-point to extracted json file.
  337. if re.search("\\.zip$", path):
  338. with zipfile.ZipFile(path, "r") as zip_ref:
  339. zip_ref.extractall(Path(path).parent)
  340. path = re.sub("\\.zip$", ".json", path)
  341. assert os.path.exists(path)
  342. ctx = open
  343. file = ctx(path, "r")
  344. return file
  345. def _try_parse(self, line: str) -> Optional[SampleBatchType]:
  346. line = line.strip()
  347. if not line:
  348. return None
  349. try:
  350. batch = self._from_json(line)
  351. except Exception:
  352. logger.exception(
  353. "Ignoring corrupt json record in {}: {}".format(self.cur_file, line)
  354. )
  355. return None
  356. batch = postprocess_actions(batch, self.ioctx)
  357. return batch
  358. def _next_line(self) -> str:
  359. if not self.cur_file:
  360. self.cur_file = self._next_file()
  361. line = self.cur_file.readline()
  362. tries = 0
  363. while not line and tries < 100:
  364. tries += 1
  365. if hasattr(self.cur_file, "close"): # legacy smart_open impls
  366. self.cur_file.close()
  367. self.cur_file = self._next_file()
  368. line = self.cur_file.readline()
  369. if not line:
  370. logger.debug("Ignoring empty file {}".format(self.cur_file))
  371. if not line:
  372. raise ValueError(
  373. "Failed to read next line from files: {}".format(self.files)
  374. )
  375. return line
  376. def _next_file(self) -> FileType:
  377. # If this is the first time, we open a file, make sure all workers
  378. # start with a different one if possible.
  379. if self.cur_file is None and self.ioctx.worker is not None:
  380. idx = self.ioctx.worker.worker_index
  381. total = self.ioctx.worker.num_workers or 1
  382. path = self.files[round((len(self.files) - 1) * (idx / total))]
  383. # After the first file, pick all others randomly.
  384. else:
  385. path = random.choice(self.files)
  386. return self._try_open_file(path)
  387. def _from_json(self, data: str) -> SampleBatchType:
  388. if isinstance(data, bytes): # smart_open S3 doesn't respect "r"
  389. data = data.decode("utf-8")
  390. json_data = json.loads(data)
  391. return from_json_data(json_data, self.ioctx.worker)