dataset_reader.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import logging
  2. import math
  3. import re
  4. import zipfile
  5. from pathlib import Path
  6. from typing import TYPE_CHECKING, List, Optional, Tuple
  7. import numpy as np
  8. import ray.data
  9. from ray.rllib.offline.input_reader import InputReader
  10. from ray.rllib.offline.io_context import IOContext
  11. from ray.rllib.offline.json_reader import from_json_data, postprocess_actions
  12. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, concat_samples
  13. from ray.rllib.utils.annotations import PublicAPI, override
  14. from ray.rllib.utils.typing import SampleBatchType
  15. if TYPE_CHECKING:
  16. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  17. DEFAULT_NUM_CPUS_PER_TASK = 0.5
  18. logger = logging.getLogger(__name__)
  19. def _unzip_this_path(fpath: Path, extract_path: str):
  20. with zipfile.ZipFile(str(fpath), "r") as zip_ref:
  21. zip_ref.extractall(extract_path)
  22. def _unzip_if_needed(paths: List[str], format: str):
  23. """If a path in paths is a zip file, unzip it and use path of the unzipped file"""
  24. ret_paths = []
  25. for path in paths:
  26. if re.search("\\.zip$", str(path)):
  27. # TODO: We need to add unzip support for s3
  28. if str(path).startswith("s3://"):
  29. raise ValueError(
  30. "unzip_if_needed currently does not support remote paths from s3"
  31. )
  32. extract_path = "./"
  33. try:
  34. _unzip_this_path(str(path), extract_path)
  35. except FileNotFoundError:
  36. # intrepreted as a relative path to rllib folder
  37. try:
  38. # TODO: remove this later when we replace all tests with s3 paths
  39. _unzip_this_path(Path(__file__).parent.parent / path, extract_path)
  40. except FileNotFoundError:
  41. raise FileNotFoundError(f"File not found: {path}")
  42. unzipped_path = str(
  43. Path(extract_path).absolute() / f"{Path(path).stem}.{format}"
  44. )
  45. ret_paths.append(unzipped_path)
  46. else:
  47. # TODO: We can get rid of this logic when we replace all tests with s3 paths
  48. if str(path).startswith("s3://"):
  49. ret_paths.append(path)
  50. else:
  51. if not Path(path).exists():
  52. relative_path = str(Path(__file__).parent.parent / path)
  53. if not Path(relative_path).exists():
  54. raise FileNotFoundError(f"File not found: {relative_path}")
  55. path = relative_path
  56. ret_paths.append(path)
  57. return ret_paths
  58. @PublicAPI
  59. def get_dataset_and_shards(
  60. config: "AlgorithmConfig", num_workers: int = 0
  61. ) -> Tuple[ray.data.Dataset, List[ray.data.Dataset]]:
  62. """Returns a dataset and a list of shards.
  63. This function uses algorithm configs to create a dataset and a list of shards.
  64. The following config keys are used to create the dataset:
  65. input: The input type should be "dataset".
  66. input_config: A dict containing the following key and values:
  67. `format`: str, speciifies the format of the input data. This will be the
  68. format that ray dataset supports. See ray.data.Dataset for
  69. supported formats. Only "parquet" or "json" are supported for now.
  70. `paths`: str, a single string or a list of strings. Each string is a path
  71. to a file or a directory holding the dataset. It can be either a local path
  72. or a remote path (e.g. to an s3 bucket).
  73. `loader_fn`: Callable[None, ray.data.Dataset], Instead of
  74. specifying paths and format, you can specify a function to load the dataset.
  75. `parallelism`: int, The number of tasks to use for loading the dataset.
  76. If not specified, it will be set to the number of workers.
  77. `num_cpus_per_read_task`: float, The number of CPUs to use for each read
  78. task. If not specified, it will be set to 0.5.
  79. Args:
  80. config: The config dict for the algorithm.
  81. num_workers: The number of shards to create for remote workers.
  82. Returns:
  83. dataset: The dataset object.
  84. shards: A list of dataset shards. For num_workers > 0 the first returned
  85. shared would be a dummy None shard for local_worker.
  86. """
  87. # check input and input config keys
  88. assert config.input_ == "dataset", (
  89. f"Must specify config.input_ as 'dataset' if"
  90. f" calling `get_dataset_and_shards`. Got {config.input_}"
  91. )
  92. # check input config format
  93. input_config = config.input_config
  94. format = input_config.get("format")
  95. supported_fmts = ["json", "parquet"]
  96. if format is not None and format not in supported_fmts:
  97. raise ValueError(
  98. f"Unsupported format {format}. Supported formats are {supported_fmts}"
  99. )
  100. # check paths and loader_fn since only one of them is required.
  101. paths = input_config.get("paths")
  102. loader_fn = input_config.get("loader_fn")
  103. if loader_fn and (format or paths):
  104. raise ValueError(
  105. "When using a `loader_fn`, you cannot specify a `format` or `path`."
  106. )
  107. # check if at least loader_fn or format + path is specified.
  108. if not (format and paths) and not loader_fn:
  109. raise ValueError(
  110. "Must specify either a `loader_fn` or a `format` and `path` in "
  111. "`input_config`."
  112. )
  113. # check paths to be a str or list[str] if not None
  114. if paths is not None:
  115. if isinstance(paths, str):
  116. paths = [paths]
  117. elif isinstance(paths, list):
  118. assert isinstance(paths[0], str), "Paths must be a list of path strings."
  119. else:
  120. raise ValueError("Paths must be a path string or a list of path strings.")
  121. paths = _unzip_if_needed(paths, format)
  122. # TODO (Kourosh): num_workers is not necessary since we can use parallelism for
  123. # everything. Having two parameters is confusing here. Remove num_workers later.
  124. parallelism = input_config.get("parallelism", num_workers or 1)
  125. cpus_per_task = input_config.get(
  126. "num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
  127. )
  128. if loader_fn:
  129. dataset = loader_fn()
  130. elif format == "json":
  131. dataset = ray.data.read_json(
  132. paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
  133. )
  134. elif format == "parquet":
  135. dataset = ray.data.read_parquet(
  136. paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
  137. )
  138. else:
  139. raise ValueError("Un-supported Ray dataset format: ", format)
  140. # Local worker will be responsible for sampling.
  141. if num_workers == 0:
  142. # Dataset is the only shard we need.
  143. return dataset, [dataset]
  144. # Remote workers are responsible for sampling:
  145. else:
  146. # Each remote worker gets 1 shard.
  147. remote_shards = dataset.repartition(
  148. num_blocks=num_workers, shuffle=False
  149. ).split(num_workers)
  150. # The first None shard is for the local worker, which
  151. # shouldn't be doing rollout work anyways.
  152. return dataset, [None] + remote_shards
  153. @PublicAPI
  154. class DatasetReader(InputReader):
  155. """Reader object that loads data from Ray Dataset.
  156. Examples:
  157. config = {
  158. "input": "dataset",
  159. "input_config": {
  160. "format": "json",
  161. # A single data file, a directory, or anything
  162. # that ray.data.dataset recognizes.
  163. "paths": "/tmp/sample_batches/",
  164. # By default, parallelism=num_workers.
  165. "parallelism": 3,
  166. # Dataset allocates 0.5 CPU for each reader by default.
  167. # Adjust this value based on the size of your offline dataset.
  168. "num_cpus_per_read_task": 0.5,
  169. }
  170. }
  171. """
  172. @PublicAPI
  173. def __init__(self, ds: ray.data.Dataset, ioctx: Optional[IOContext] = None):
  174. """Initializes a DatasetReader instance.
  175. Args:
  176. ds: Ray dataset to sample from.
  177. """
  178. self._ioctx = ioctx or IOContext()
  179. self._default_policy = self.policy_map = None
  180. self.preprocessor = None
  181. self._dataset = ds
  182. self.count = None if not self._dataset else self._dataset.count()
  183. # do this to disable the ray data stdout logging
  184. ray.data.DataContext.get_current().enable_progress_bars = False
  185. # the number of steps to return per call to next()
  186. self.batch_size = self._ioctx.config.get("train_batch_size", 1)
  187. num_workers = self._ioctx.config.get("num_env_runners", 0)
  188. seed = self._ioctx.config.get("seed", None)
  189. if num_workers:
  190. self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
  191. # We allow the creation of a non-functioning None DatasetReader.
  192. # It's useful for example for a non-rollout local worker.
  193. if ds:
  194. if self._ioctx.worker is not None:
  195. self._policy_map = self._ioctx.worker.policy_map
  196. self._default_policy = self._policy_map.get(DEFAULT_POLICY_ID)
  197. self.preprocessor = (
  198. self._ioctx.worker.preprocessors.get(DEFAULT_POLICY_ID)
  199. if not self._ioctx.config.get("_disable_preprocessors", False)
  200. else None
  201. )
  202. print(
  203. f"DatasetReader {self._ioctx.worker_index} has {ds.count()}, samples."
  204. )
  205. def iterator():
  206. while True:
  207. ds = self._dataset.random_shuffle(seed=seed)
  208. yield from ds.iter_rows()
  209. self._iter = iterator()
  210. else:
  211. self._iter = None
  212. @override(InputReader)
  213. def next(self) -> SampleBatchType:
  214. # next() should not get called on None DatasetReader.
  215. assert self._iter is not None
  216. ret = []
  217. count = 0
  218. while count < self.batch_size:
  219. d = next(self._iter)
  220. # Columns like obs are compressed when written by DatasetWriter.
  221. d = from_json_data(d, self._ioctx.worker)
  222. count += d.count
  223. d = self._preprocess_if_needed(d)
  224. d = postprocess_actions(d, self._ioctx)
  225. d = self._postprocess_if_needed(d)
  226. ret.append(d)
  227. ret = concat_samples(ret)
  228. return ret
  229. def _preprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
  230. # TODO: @kourosh, preprocessor is only supported for single agent case.
  231. if self.preprocessor:
  232. for key in (SampleBatch.CUR_OBS, SampleBatch.NEXT_OBS):
  233. if key in batch:
  234. batch[key] = np.stack(
  235. [self.preprocessor.transform(s) for s in batch[key]]
  236. )
  237. return batch
  238. def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
  239. if not self._ioctx.config.get("postprocess_inputs"):
  240. return batch
  241. if isinstance(batch, SampleBatch):
  242. out = []
  243. for sub_batch in batch.split_by_episode():
  244. if self._default_policy is not None:
  245. out.append(self._default_policy.postprocess_trajectory(sub_batch))
  246. else:
  247. out.append(sub_batch)
  248. return concat_samples(out)
  249. else:
  250. # TODO(ekl) this is trickier since the alignments between agent
  251. # trajectories in the episode are not available any more.
  252. raise NotImplementedError(
  253. "Postprocessing of multi-agent data not implemented yet."
  254. )