offline_data.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import logging
  2. import time
  3. import types
  4. from pathlib import Path
  5. from typing import TYPE_CHECKING, Any, Dict
  6. import numpy as np
  7. import pyarrow.fs
  8. import ray
  9. from ray.rllib.core import COMPONENT_RL_MODULE
  10. from ray.rllib.env import INPUT_ENV_SPACES
  11. from ray.rllib.offline.offline_prelearner import OfflinePreLearner
  12. from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
  13. from ray.rllib.utils import force_list, unflatten_dict
  14. from ray.rllib.utils.annotations import (
  15. OverrideToImplementCustomLogic,
  16. OverrideToImplementCustomLogic_CallToSuperRecommended,
  17. )
  18. from ray.util.annotations import PublicAPI
  19. if TYPE_CHECKING:
  20. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  21. logger = logging.getLogger(__name__)
  22. @PublicAPI(stability="alpha")
  23. class OfflineData:
  24. @OverrideToImplementCustomLogic_CallToSuperRecommended
  25. def __init__(self, config: "AlgorithmConfig"):
  26. # TODO (simon): Define self.spaces here.
  27. self.config = config
  28. self.is_multi_agent = self.config.is_multi_agent
  29. self.path = (
  30. self.config.input_
  31. if isinstance(config.input_, list)
  32. else Path(config.input_)
  33. )
  34. # Use `read_parquet` as default data read method.
  35. self.data_read_method = self.config.input_read_method
  36. # Override default arguments for the data read method.
  37. self.data_read_method_kwargs = self.config.input_read_method_kwargs
  38. # In case `EpisodeType` or `BatchType` batches are read the size
  39. # could differ from the final `train_batch_size_per_learner`.
  40. self.data_read_batch_size = self.config.input_read_batch_size
  41. # If data should be materialized.
  42. self.materialize_data = config.materialize_data
  43. # If mapped data should be materialized.
  44. self.materialize_mapped_data = config.materialize_mapped_data
  45. # Flag to identify, if data has already been mapped with the
  46. # `OfflinePreLearner`.
  47. self.data_is_mapped = False
  48. # Set the filesystem.
  49. self.filesystem = self.config.input_filesystem
  50. self.filesystem_kwargs = self.config.input_filesystem_kwargs
  51. self.filesystem_object = None
  52. # If a specific filesystem is given, set it up. Note, this could
  53. # be `gcsfs` for GCS, `pyarrow` for S3 or `adlfs` for Azure Blob Storage.
  54. # this filesystem is specifically needed, if a session has to be created
  55. # with the cloud provider.
  56. if self.filesystem == "gcs":
  57. import gcsfs
  58. self.filesystem_object = gcsfs.GCSFileSystem(**self.filesystem_kwargs)
  59. elif self.filesystem == "s3":
  60. self.filesystem_object = pyarrow.fs.S3FileSystem(**self.filesystem_kwargs)
  61. elif self.filesystem == "abs":
  62. import adlfs
  63. self.filesystem_object = adlfs.AzureBlobFileSystem(**self.filesystem_kwargs)
  64. elif isinstance(self.filesystem, pyarrow.fs.FileSystem):
  65. self.filesystem_object = self.filesystem
  66. elif self.filesystem is not None:
  67. raise ValueError(
  68. f"Unknown `config.input_filesystem` {self.filesystem}! Filesystems "
  69. "can be None for local, any instance of `pyarrow.fs.FileSystem`, "
  70. "'gcs' for GCS, 's3' for S3, or 'abs' for adlfs.AzureBlobFileSystem."
  71. )
  72. # Add the filesystem object to the write method kwargs.
  73. if self.filesystem_object:
  74. self.data_read_method_kwargs.update(
  75. {
  76. "filesystem": self.filesystem_object,
  77. }
  78. )
  79. # Load the dataset.
  80. start_time = time.perf_counter()
  81. self.data = getattr(ray.data, self.data_read_method)(
  82. self.path, **self.data_read_method_kwargs
  83. )
  84. if self.materialize_data:
  85. self.data = self.data.materialize()
  86. stop_time = time.perf_counter()
  87. logger.debug(
  88. f"Time to load offline data from {self.path}: {stop_time - start_time:.2f}s."
  89. )
  90. # Avoids reinstantiating the batch iterator each time we sample.
  91. self.batch_iterators = None
  92. self.map_batches_kwargs = (
  93. self.default_map_batches_kwargs | self.config.map_batches_kwargs
  94. )
  95. self.iter_batches_kwargs = (
  96. self.default_iter_batches_kwargs | self.config.iter_batches_kwargs
  97. )
  98. self.returned_streaming_split = False
  99. # Defines the prelearner class. Note, this could be user-defined.
  100. self.prelearner_class = self.config.prelearner_class or OfflinePreLearner
  101. # For remote learner setups.
  102. self.locality_hints = None
  103. self.learner_handles = None
  104. self.module_spec = None
  105. @OverrideToImplementCustomLogic
  106. def sample(
  107. self,
  108. num_samples: int,
  109. return_iterator: bool = False,
  110. num_shards: int = 1,
  111. module_state: Dict[str, Any] = None,
  112. ):
  113. # Materialize the mapped data, if necessary. This runs for all the
  114. # data the `OfflinePreLearner` logic and maps them to `MultiAgentBatch`es.
  115. # TODO (simon, sven): This would never update the module nor the
  116. # the connectors. If this is needed we have to check, if we give
  117. # (a) only an iterator and let the learner and OfflinePreLearner
  118. # communicate through the object storage. This only works when
  119. # not materializing.
  120. # (b) Rematerialize the data every couple of iterations. This is
  121. # is costly.
  122. if not self.data_is_mapped:
  123. if not module_state:
  124. # Get the RLModule state from learners.
  125. if num_shards >= 1:
  126. # Call here the learner to get an up-to-date module state.
  127. # TODO (simon): This is a workaround as along as learners cannot
  128. # receive any calls from another actor.
  129. module_state = ray.get(
  130. self.learner_handles[0].get_state.remote(
  131. component=COMPONENT_RL_MODULE,
  132. )
  133. )[COMPONENT_RL_MODULE]
  134. # Provide the `Learner`(s) GPU devices, if needed.
  135. # if not self.map_batches_uses_gpus(self.config) and self.config._validate_config:
  136. # devices = ray.get(self.learner_handles[0].get_device.remote())
  137. # devices = [devices] if not isinstance(devices, list) else devices
  138. # device_strings = [
  139. # f"{device.type}:{str(device.index)}"
  140. # if device.type == "cuda"
  141. # else device.type
  142. # for device in devices
  143. # ]
  144. # # Otherwise, set the GPU strings to `None`.
  145. # # TODO (simon): Check inside 'OfflinePreLearner'.
  146. # else:
  147. # device_strings = None
  148. else:
  149. # Get the module state from the `Learner`(S).
  150. module_state = self.learner_handles[0].get_state(
  151. component=COMPONENT_RL_MODULE,
  152. )[COMPONENT_RL_MODULE]
  153. # Provide the `Learner`(s) GPU devices, if needed.
  154. # if not self.map_batches_uses_gpus(self.config) and self.config._validate_config:
  155. # device = self.learner_handles[0].get_device()
  156. # device_strings = [
  157. # f"{device.type}:{str(device.index)}"
  158. # if device.type == "cuda"
  159. # else device.type
  160. # ]
  161. # else:
  162. # device_strings = None
  163. # Constructor `kwargs` for the `OfflinePreLearner`.
  164. fn_constructor_kwargs = {
  165. "config": self.config,
  166. "spaces": self.spaces[INPUT_ENV_SPACES],
  167. "module_spec": self.module_spec,
  168. "module_state": module_state,
  169. # "device_strings": self.get_devices(),
  170. }
  171. # Map the data to run the `OfflinePreLearner`s in the data pipeline
  172. # for training.
  173. self.data = self.data.map_batches(
  174. self.prelearner_class,
  175. fn_constructor_kwargs=fn_constructor_kwargs,
  176. batch_size=self.data_read_batch_size or num_samples,
  177. **self.map_batches_kwargs,
  178. )
  179. # Set the flag to `True`.
  180. self.data_is_mapped = True
  181. # If the user wants to materialize the data in memory.
  182. if self.materialize_mapped_data:
  183. self.data = self.data.materialize()
  184. # Build an iterator, if necessary. Note, in case that an iterator should be
  185. # returned now and we have already generated from the iterator, i.e.
  186. # `isinstance(self.batch_iterators, types.GeneratorType) == True`, we need
  187. # to create here a new iterator.
  188. if not self.batch_iterators or (
  189. return_iterator and isinstance(self.batch_iterators, types.GeneratorType)
  190. ):
  191. # If we have more than one learner create an iterator for each of them
  192. # by splitting the data stream.
  193. if num_shards > 1:
  194. # In case of multiple shards, we return multiple
  195. # `StreamingSplitIterator` instances.
  196. self.batch_iterators = self.data.streaming_split(
  197. n=num_shards,
  198. # Note, `equal` must be `True`, i.e. the batch size must
  199. # be the same for all batches b/c otherwise remote learners
  200. # could block each others.
  201. equal=True,
  202. locality_hints=self.locality_hints,
  203. )
  204. # Otherwise we create a simple iterator and - if necessary - initialize
  205. # it here.
  206. else:
  207. # Should an iterator be returned?
  208. if return_iterator:
  209. self.batch_iterators = self.data.iterator()
  210. # Otherwise, the user wants batches returned.
  211. else:
  212. # Define a collate (last-mile) transformation that maps batches
  213. # to RLlib's `MultiAgentBatch`.
  214. def _collate_fn(_batch: Dict[str, np.ndarray]) -> MultiAgentBatch:
  215. _batch = unflatten_dict(_batch)
  216. return MultiAgentBatch(
  217. {
  218. module_id: SampleBatch(module_data)
  219. for module_id, module_data in _batch.items()
  220. },
  221. env_steps=sum(
  222. len(next(iter(module_data.values())))
  223. for module_data in _batch.values()
  224. ),
  225. )
  226. # If no iterator should be returned, or if we want to return a single
  227. # batch iterator, we instantiate the batch iterator once, here.
  228. self.batch_iterators = self.data.iter_batches(
  229. batch_size=num_samples,
  230. _collate_fn=_collate_fn,
  231. **self.iter_batches_kwargs,
  232. )
  233. self.batch_iterators = iter(self.batch_iterators)
  234. # Do we want to return an iterator or a single batch?
  235. if return_iterator:
  236. return force_list(self.batch_iterators)
  237. else:
  238. # Return a single batch from the iterator.
  239. try:
  240. return next(self.batch_iterators)
  241. except StopIteration:
  242. # If the batch iterator is exhausted, reinitiate a new one.
  243. logger.debug("Batch iterator exhausted. Reinitiating ...")
  244. self.batch_iterators = None
  245. return self.sample(
  246. num_samples=num_samples,
  247. return_iterator=return_iterator,
  248. num_shards=num_shards,
  249. )
  250. @property
  251. def default_map_batches_kwargs(self):
  252. return {
  253. "concurrency": max(2, self.config.num_learners),
  254. "zero_copy_batch": True,
  255. }
  256. @property
  257. def default_iter_batches_kwargs(self):
  258. return {
  259. "prefetch_batches": 2,
  260. }