offline_env_runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import logging
  2. from pathlib import Path
  3. from typing import List
  4. import ray
  5. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  6. from ray.rllib.core.columns import Columns
  7. from ray.rllib.env.env_runner import EnvRunner
  8. from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
  9. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  10. from ray.rllib.utils.annotations import (
  11. OverrideToImplementCustomLogic,
  12. OverrideToImplementCustomLogic_CallToSuperRecommended,
  13. override,
  14. )
  15. from ray.rllib.utils.compression import pack_if_needed
  16. from ray.rllib.utils.typing import EpisodeType
  17. from ray.util.annotations import PublicAPI
  18. from ray.util.debug import log_once
  19. logger = logging.Logger(__file__)
  20. # TODO (simon): This class can be agnostic to the episode type as it
  21. # calls only get_state.
  22. @PublicAPI(stability="alpha")
  23. class OfflineSingleAgentEnvRunner(SingleAgentEnvRunner):
  24. """The environment runner to record the single agent case."""
  25. @override(SingleAgentEnvRunner)
  26. @OverrideToImplementCustomLogic_CallToSuperRecommended
  27. def __init__(self, *, config: AlgorithmConfig, **kwargs):
  28. # Initialize the parent.
  29. super().__init__(config=config, **kwargs)
  30. # override SingleAgentEnvRunner
  31. self.episodes_to_numpy = False
  32. # Get the data context for this `EnvRunner`.
  33. data_context = ray.data.DataContext.get_current()
  34. # Limit the resources for Ray Data to the CPUs given to this `EnvRunner`.
  35. data_context.execution_options.resource_limits = (
  36. data_context.execution_options.resource_limits.copy(
  37. cpu=config.num_cpus_per_env_runner
  38. )
  39. )
  40. # Set the output write method.
  41. self.output_write_method = self.config.output_write_method
  42. self.output_write_method_kwargs = self.config.output_write_method_kwargs
  43. # Set the filesystem.
  44. self.filesystem = self.config.output_filesystem
  45. self.filesystem_kwargs = self.config.output_filesystem_kwargs
  46. self.filesystem_object = None
  47. # Set the output base path.
  48. self.output_path = self.config.output
  49. if self.env:
  50. # Set the subdir (environment specific).
  51. self.subdir_path = self._get_subdir_path()
  52. elif not self.env and (
  53. (self.config.create_env_on_local_worker and self.worker_index == 0)
  54. or self.worker_index > 0
  55. ):
  56. raise ValueError(
  57. "To set up the output path, the environment "
  58. "`env` must be provided when creating the "
  59. "`OfflineSingleAgentEnvRunner`."
  60. )
  61. # Set the worker-specific path name. Note, this is
  62. # specifically to enable multi-threaded writing into
  63. # the same directory.
  64. self.worker_path = "run-" + f"{self.worker_index}".zfill(6)
  65. # If a specific filesystem is given, set it up. Note, this could
  66. # be `gcsfs` for GCS, `pyarrow` for S3 or `adlfs` for Azure Blob Storage.
  67. # this filesystem is specifically needed, if a session has to be created
  68. # with the cloud provider.
  69. if self.filesystem == "gcs":
  70. import gcsfs
  71. self.filesystem_object = gcsfs.GCSFileSystem(**self.filesystem_kwargs)
  72. elif self.filesystem == "s3":
  73. from pyarrow import fs
  74. self.filesystem_object = fs.S3FileSystem(**self.filesystem_kwargs)
  75. elif self.filesystem == "abs":
  76. import adlfs
  77. self.filesystem_object = adlfs.AzureBlobFileSystem(**self.filesystem_kwargs)
  78. elif self.filesystem is not None:
  79. raise ValueError(
  80. f"Unknown filesystem: {self.filesystem}. Filesystems can be "
  81. "'gcs' for GCS, 's3' for S3, or 'abs'"
  82. )
  83. # Add the filesystem object to the write method kwargs.
  84. self.output_write_method_kwargs.update(
  85. {
  86. "filesystem": self.filesystem_object,
  87. }
  88. )
  89. # If we should store `SingleAgentEpisodes` or column data.
  90. self.output_write_episodes = self.config.output_write_episodes
  91. # Which columns should be compressed in the output data.
  92. self.output_compress_columns = self.config.output_compress_columns
  93. # Buffer these many rows before writing to file.
  94. self.output_max_rows_per_file = self.config.output_max_rows_per_file
  95. # If the user defines a maximum number of rows per file, set the
  96. # event to `False` and check during sampling.
  97. if self.output_max_rows_per_file:
  98. self.write_data_this_iter = False
  99. # Otherwise the event is always `True` and we write always sampled
  100. # data immediately to disk.
  101. else:
  102. self.write_data_this_iter = True
  103. # If the remaining data should be stored. Note, this is only
  104. # relevant in case `output_max_rows_per_file` is defined.
  105. self.write_remaining_data = self.config.output_write_remaining_data
  106. # Counts how often `sample` is called to define the output path for
  107. # each file.
  108. self._sample_counter = 0
  109. # Define the buffer for experiences stored until written to disk.
  110. self._samples = []
  111. def _get_subdir_path(self) -> str:
  112. """Returns the subdir path for storing data.
  113. Returns:
  114. The subdir path as a string.
  115. """
  116. # Set the subdir (environment specific).
  117. if isinstance(self.env, str):
  118. # `env` is a string.
  119. return self.env.lower()
  120. else:
  121. # `env` is a class or callable we use its class name.
  122. if self.config.gym_env_vectorize_mode == "sync":
  123. return self.env.unwrapped.envs[0].unwrapped.__class__.__name__.lower()
  124. elif self.config.gym_env_vectorize_mode == "async":
  125. return self.env.unwrapped.get_attr("unwrapped")[
  126. 0
  127. ].__class__.__name__.lower()
  128. elif self.config.gym_env_vectorize_mode == "vector_entry_point":
  129. return self.env.unwrapped.__class__.__name__.lower()
  130. else:
  131. raise ValueError(
  132. f"Unknown `gym_env_vectorize_mode`: "
  133. f"{self.config.gym_env_vectorize_mode}"
  134. )
  135. @override(SingleAgentEnvRunner)
  136. @OverrideToImplementCustomLogic
  137. def sample(
  138. self,
  139. *,
  140. num_timesteps: int = None,
  141. num_episodes: int = None,
  142. explore: bool = None,
  143. random_actions: bool = False,
  144. force_reset: bool = False,
  145. ) -> List[SingleAgentEpisode]:
  146. """Samples from environments and writes data to disk."""
  147. # Call the super sample method.
  148. samples = super().sample(
  149. num_timesteps=num_timesteps,
  150. num_episodes=num_episodes,
  151. explore=explore,
  152. random_actions=random_actions,
  153. force_reset=force_reset,
  154. )
  155. self._sample_counter += 1
  156. # Add data to the buffers.
  157. if self.output_write_episodes:
  158. import msgpack
  159. import msgpack_numpy as mnp
  160. if log_once("msgpack"):
  161. logger.info(
  162. "Packing episodes with `msgpack` and encode array with "
  163. "`msgpack_numpy` for serialization. This is needed for "
  164. "recording episodes."
  165. )
  166. # Note, we serialize episodes with `msgpack` and `msgpack_numpy` to
  167. # ensure version compatibility.
  168. assert all(eps.is_numpy is False for eps in samples)
  169. self._samples.extend(
  170. [msgpack.packb(eps.get_state(), default=mnp.encode) for eps in samples]
  171. )
  172. else:
  173. self._map_episodes_to_data(samples)
  174. # If the user defined the maximum number of rows to write.
  175. if self.output_max_rows_per_file:
  176. # Check, if this number is reached.
  177. if len(self._samples) >= self.output_max_rows_per_file:
  178. # Start the recording of data.
  179. self.write_data_this_iter = True
  180. if self.write_data_this_iter:
  181. # If the user wants a maximum number of experiences per file,
  182. # cut the samples to write to disk from the buffer.
  183. if self.output_max_rows_per_file:
  184. # Reset the event.
  185. self.write_data_this_iter = False
  186. # Ensure that all data ready to be written is released from
  187. # the buffer. Note, this is important in case we have many
  188. # episodes sampled and a relatively small `output_max_rows_per_file`.
  189. while len(self._samples) >= self.output_max_rows_per_file:
  190. # Extract the number of samples to be written to disk this
  191. # iteration.
  192. samples_to_write = self._samples[: self.output_max_rows_per_file]
  193. # Reset the buffer to the remaining data. This only makes sense, if
  194. # `rollout_fragment_length` is smaller `output_max_rows_per_file` or
  195. # a 2 x `output_max_rows_per_file`.
  196. self._samples = self._samples[self.output_max_rows_per_file :]
  197. samples_ds = ray.data.from_items(samples_to_write)
  198. # Otherwise, write the complete data.
  199. else:
  200. samples_ds = ray.data.from_items(self._samples)
  201. try:
  202. # Setup the path for writing data. Each run will be written to
  203. # its own file. A run is a writing event. The path will look
  204. # like. 'base_path/env-name/00000<WorkerID>-00000<RunID>'.
  205. path = (
  206. Path(self.output_path)
  207. .joinpath(self.subdir_path)
  208. .joinpath(self.worker_path + f"-{self._sample_counter}".zfill(6))
  209. )
  210. getattr(samples_ds, self.output_write_method)(
  211. path.as_posix(), **self.output_write_method_kwargs
  212. )
  213. logger.info(f"Wrote samples to storage at {path}.")
  214. except Exception as e:
  215. logger.error(e)
  216. self.metrics.log_value(
  217. key="recording_buffer_size",
  218. value=len(self._samples),
  219. )
  220. # Finally return the samples as usual.
  221. return samples
  222. @override(EnvRunner)
  223. @OverrideToImplementCustomLogic
  224. def stop(self) -> None:
  225. """Writes the reamining samples to disk
  226. Note, if the user defined `max_rows_per_file` the
  227. number of rows for the remaining samples could be
  228. less than the defined maximum row number by the user.
  229. """
  230. # If there are samples left over we have to write htem to disk. them
  231. # to a dataset.
  232. if self._samples and self.write_remaining_data:
  233. # Convert them to a `ray.data.Dataset`.
  234. samples_ds = ray.data.from_items(self._samples)
  235. # Increase the sample counter for the folder/file name.
  236. self._sample_counter += 1
  237. # Try to write the dataset to disk/cloud storage.
  238. try:
  239. # Setup the path for writing data. Each run will be written to
  240. # its own file. A run is a writing event. The path will look
  241. # like. 'base_path/env-name/00000<WorkerID>-00000<RunID>'.
  242. path = (
  243. Path(self.output_path)
  244. .joinpath(self.subdir_path)
  245. .joinpath(self.worker_path + f"-{self._sample_counter}".zfill(6))
  246. )
  247. getattr(samples_ds, self.output_write_method)(
  248. path.as_posix(), **self.output_write_method_kwargs
  249. )
  250. logger.info(
  251. f"Wrote final samples to storage at {path}. Note "
  252. "Note, final samples could be smaller in size than "
  253. f"`max_rows_per_file`, if defined."
  254. )
  255. except Exception as e:
  256. logger.error(e)
  257. logger.debug(f"Experience buffer length: {len(self._samples)}")
  258. @OverrideToImplementCustomLogic
  259. def _map_episodes_to_data(self, samples: List[EpisodeType]) -> None:
  260. """Converts list of episodes to list of single dict experiences.
  261. Note, this method also appends all sampled experiences to the
  262. buffer.
  263. Args:
  264. samples: List of episodes to be converted.
  265. """
  266. # Loop through all sampled episodes.
  267. for sample in samples:
  268. # Loop through all items of the episode.
  269. for i in range(len(sample)):
  270. sample_data = {
  271. Columns.EPS_ID: sample.id_,
  272. Columns.AGENT_ID: sample.agent_id,
  273. Columns.MODULE_ID: sample.module_id,
  274. # Compress observations, if requested.
  275. Columns.OBS: pack_if_needed(sample.get_observations(i))
  276. if Columns.OBS in self.output_compress_columns
  277. else sample.get_observations(i),
  278. # Compress actions, if requested.
  279. Columns.ACTIONS: pack_if_needed(sample.get_actions(i))
  280. if Columns.ACTIONS in self.output_compress_columns
  281. else sample.get_actions(i),
  282. Columns.REWARDS: sample.get_rewards(i),
  283. # Compress next observations, if requested.
  284. Columns.NEXT_OBS: pack_if_needed(sample.get_observations(i + 1))
  285. if Columns.OBS in self.output_compress_columns
  286. else sample.get_observations(i + 1),
  287. Columns.TERMINATEDS: False
  288. if i < len(sample) - 1
  289. else sample.is_terminated,
  290. Columns.TRUNCATEDS: False
  291. if i < len(sample) - 1
  292. else sample.is_truncated,
  293. **{
  294. # Compress any extra model output, if requested.
  295. k: pack_if_needed(sample.get_extra_model_outputs(k, i))
  296. if k in self.output_compress_columns
  297. else sample.get_extra_model_outputs(k, i)
  298. for k in sample.extra_model_outputs.keys()
  299. },
  300. }
  301. # Finally append to the data buffer.
  302. self._samples.append(sample_data)