offline_evaluation_runner.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. import types
  2. from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union
  3. import ray
  4. from ray.data.iterator import DataIterator
  5. from ray.rllib.core import (
  6. ALL_MODULES,
  7. COMPONENT_RL_MODULE,
  8. )
  9. from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI
  10. from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
  11. from ray.rllib.policy.sample_batch import MultiAgentBatch
  12. from ray.rllib.utils.annotations import override
  13. from ray.rllib.utils.checkpoints import Checkpointable
  14. from ray.rllib.utils.framework import get_device, try_import_torch
  15. from ray.rllib.utils.metrics import (
  16. DATASET_NUM_ITERS_EVALUATED,
  17. DATASET_NUM_ITERS_EVALUATED_LIFETIME,
  18. MODULE_SAMPLE_BATCH_SIZE_MEAN,
  19. NUM_ENV_STEPS_SAMPLED,
  20. NUM_ENV_STEPS_SAMPLED_LIFETIME,
  21. NUM_MODULE_STEPS_SAMPLED,
  22. NUM_MODULE_STEPS_SAMPLED_LIFETIME,
  23. OFFLINE_SAMPLING_TIMER,
  24. WEIGHTS_SEQ_NO,
  25. )
  26. from ray.rllib.utils.minibatch_utils import MiniBatchRayDataIterator
  27. from ray.rllib.utils.numpy import convert_to_numpy
  28. from ray.rllib.utils.runners.runner import Runner
  29. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  30. from ray.rllib.utils.typing import DeviceType, ModuleID, StateDict, TensorType
  31. if TYPE_CHECKING:
  32. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  33. torch, _ = try_import_torch()
  34. TOTAL_EVAL_LOSS_KEY = "total_eval_loss"
  35. class OfflineEvaluationRunner(Runner, Checkpointable):
  36. def __init__(
  37. self,
  38. config: "AlgorithmConfig",
  39. module_spec: Optional[MultiRLModuleSpec] = None,
  40. **kwargs,
  41. ):
  42. # This needs to be defined before we call the `Runner.__init__`
  43. # b/c the latter calls the `make_module` and then needs the spec.
  44. # TODO (simon): Check, if we make this a generic attribute.
  45. self.__module_spec: MultiRLModuleSpec = module_spec
  46. self.__dataset_iterator = None
  47. self.__batch_iterator = None
  48. Runner.__init__(self, config=config, **kwargs)
  49. Checkpointable.__init__(self)
  50. # This has to be defined after we have a `self.config`.
  51. self._loss_for_module_fn = types.MethodType(self.get_loss_for_module_fn(), self)
  52. @override(Runner)
  53. def run(
  54. self,
  55. explore: bool = False,
  56. train: bool = True,
  57. **kwargs,
  58. ) -> None:
  59. if self.__dataset_iterator is None:
  60. raise ValueError(
  61. f"{self} doesn't have a data iterator. Can't call `run` on "
  62. "`OfflineEvaluationRunner`."
  63. )
  64. if not self._batch_iterator:
  65. self.__batch_iterator = self._create_batch_iterator(
  66. **self.config.iter_batches_kwargs
  67. )
  68. # Log current weight seq no.
  69. self.metrics.log_value(
  70. key=WEIGHTS_SEQ_NO,
  71. value=self._weights_seq_no,
  72. window=1,
  73. )
  74. with self.metrics.log_time(OFFLINE_SAMPLING_TIMER):
  75. if explore is None:
  76. explore = self.config.explore
  77. # Evaluate on offline data.
  78. return self._evaluate(
  79. explore=explore,
  80. train=train,
  81. )
  82. def _create_batch_iterator(self, **kwargs) -> Iterable:
  83. # Return a minibatch iterator.
  84. return MiniBatchRayDataIterator(
  85. iterator=self._dataset_iterator,
  86. device=self._device,
  87. minibatch_size=self.config.offline_eval_batch_size_per_runner,
  88. num_iters=self.config.dataset_num_iters_per_eval_runner,
  89. **kwargs,
  90. )
  91. def _evaluate(
  92. self,
  93. explore: bool,
  94. train: bool,
  95. ) -> None:
  96. for iteration, tensor_minibatch in enumerate(self._batch_iterator):
  97. # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs
  98. # found in this batch. If not, throw an error.
  99. unknown_module_ids = set(tensor_minibatch.policy_batches.keys()) - set(
  100. self.module.keys()
  101. )
  102. if unknown_module_ids:
  103. raise ValueError(
  104. f"Batch contains one or more ModuleIDs ({unknown_module_ids}) that "
  105. f"are not in this Learner!"
  106. )
  107. if explore:
  108. fwd_out = self.module.forward_exploration(
  109. tensor_minibatch.policy_batches
  110. )
  111. elif train:
  112. fwd_out = self.module.forward_train(tensor_minibatch.policy_batches)
  113. else:
  114. fwd_out = self.module.forward_inference(tensor_minibatch.policy_batches)
  115. eval_loss_per_module = self.compute_eval_losses(
  116. fwd_out=fwd_out, batch=tensor_minibatch.policy_batches
  117. )
  118. self._log_steps_evaluated_metrics(tensor_minibatch)
  119. # Record the number of batches pulled from the dataset.
  120. self.metrics.log_value(
  121. # TODO (simon): Create extra eval metrics.
  122. (ALL_MODULES, DATASET_NUM_ITERS_EVALUATED),
  123. iteration + 1,
  124. reduce="sum",
  125. )
  126. self.metrics.log_value(
  127. (ALL_MODULES, DATASET_NUM_ITERS_EVALUATED_LIFETIME),
  128. iteration + 1,
  129. reduce="lifetime_sum",
  130. )
  131. # Log all individual RLModules' loss terms
  132. # Note: We do this only once for the last of the minibatch updates, b/c the
  133. # window is only 1 anyways.
  134. for mid, loss in convert_to_numpy(eval_loss_per_module).items():
  135. self.metrics.log_value(
  136. key=(mid, TOTAL_EVAL_LOSS_KEY),
  137. value=loss,
  138. window=1,
  139. )
  140. return self.metrics.reduce()
  141. @override(Checkpointable)
  142. def get_ctor_args_and_kwargs(self):
  143. return (
  144. (), # *args
  145. {"config": self.config}, # **kwargs
  146. )
  147. @override(Checkpointable)
  148. def get_state(
  149. self,
  150. components: Optional[Union[str, Collection[str]]] = None,
  151. *,
  152. not_components: Optional[Union[str, Collection[str]]] = None,
  153. **kwargs,
  154. ) -> StateDict:
  155. state = {}
  156. if self._check_component(COMPONENT_RL_MODULE, components, not_components):
  157. state[COMPONENT_RL_MODULE] = self.module.get_state(
  158. components=self._get_subcomponents(COMPONENT_RL_MODULE, components),
  159. not_components=self._get_subcomponents(
  160. COMPONENT_RL_MODULE, not_components
  161. ),
  162. **kwargs,
  163. )
  164. state[WEIGHTS_SEQ_NO] = self._weights_seq_no
  165. return state
  166. def _convert_to_tensor(self, struct) -> TensorType:
  167. """Converts structs to a framework-specific tensor."""
  168. return convert_to_torch_tensor(struct)
  169. @override(Runner)
  170. def stop(self) -> None:
  171. """Releases all resources used by this EnvRunner.
  172. For example, when using a gym.Env in this EnvRunner, you should make sure
  173. that its `close()` method is called.
  174. """
  175. pass
  176. @override(Runner)
  177. def __del__(self) -> None:
  178. """If this Actor is deleted, clears all resources used by it."""
  179. pass
  180. @override(Runner)
  181. def assert_healthy(self):
  182. """Checks that self.__init__() has been completed properly.
  183. Ensures that the instances has a `MultiRLModule` and an
  184. environment defined.
  185. Raises:
  186. AssertionError: If the EnvRunner Actor has NOT been properly initialized.
  187. """
  188. # Make sure, we have built our RLModule properly and assigned a dataset iterator.
  189. assert self._dataset_iterator and hasattr(self, "module")
  190. @override(Runner)
  191. def get_metrics(self):
  192. return self.metrics.reduce()
  193. def _convert_batch_type(
  194. self,
  195. batch: MultiAgentBatch,
  196. to_device: bool = True,
  197. pin_memory: bool = False,
  198. use_stream: bool = False,
  199. ) -> MultiAgentBatch:
  200. batch = convert_to_torch_tensor(
  201. batch.policy_batches,
  202. device=self._device if to_device else None,
  203. pin_memory=pin_memory,
  204. use_stream=use_stream,
  205. )
  206. # TODO (sven): This computation of `env_steps` is not accurate!
  207. length = max(len(b) for b in batch.values())
  208. batch = MultiAgentBatch(batch, env_steps=length)
  209. return batch
  210. def compute_eval_losses(
  211. self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any]
  212. ) -> Dict[str, Any]:
  213. loss_per_module = {}
  214. for module_id in fwd_out:
  215. module_batch = batch[module_id]
  216. module_fwd_out = fwd_out[module_id]
  217. module = self.module[module_id].unwrapped()
  218. if isinstance(module, SelfSupervisedLossAPI):
  219. loss = module.compute_self_supervised_loss(
  220. learner=self,
  221. module_id=module_id,
  222. config=self.config.get_config_for_module(module_id),
  223. batch=module_batch,
  224. fwd_out=module_fwd_out,
  225. )
  226. else:
  227. loss = self.compute_eval_loss_for_module(
  228. module_id=module_id,
  229. config=self.config.get_config_for_module(module_id),
  230. batch=module_batch,
  231. fwd_out=module_fwd_out,
  232. )
  233. loss_per_module[module_id] = loss
  234. return loss_per_module
  235. def compute_eval_loss_for_module(
  236. self,
  237. *,
  238. module_id: ModuleID,
  239. config: "AlgorithmConfig",
  240. batch: Dict[str, Any],
  241. fwd_out: Dict[str, TensorType],
  242. ) -> TensorType:
  243. return self._loss_for_module_fn(
  244. module_id=module_id,
  245. config=config,
  246. batch=batch,
  247. fwd_out=fwd_out,
  248. )
  249. @override(Checkpointable)
  250. def set_state(self, state: StateDict) -> None:
  251. # Update the RLModule state.
  252. if COMPONENT_RL_MODULE in state:
  253. # A missing value for WEIGHTS_SEQ_NO or a value of 0 means: Force the
  254. # update.
  255. weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
  256. # Only update the weigths, if this is the first synchronization or
  257. # if the weights of this `EnvRunner` lacks behind the actual ones.
  258. if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
  259. rl_module_state = state[COMPONENT_RL_MODULE]
  260. if isinstance(rl_module_state, ray.ObjectRef):
  261. rl_module_state = ray.get(rl_module_state)
  262. self.module.set_state(rl_module_state)
  263. # Update our weights_seq_no, if the new one is > 0.
  264. if weights_seq_no > 0:
  265. self._weights_seq_no = weights_seq_no
  266. def _log_steps_evaluated_metrics(self, batch: MultiAgentBatch) -> None:
  267. for mid, module_batch in batch.policy_batches.items():
  268. # Log weights seq no for this batch.
  269. self.metrics.log_value(
  270. (mid, WEIGHTS_SEQ_NO),
  271. self._weights_seq_no,
  272. window=1,
  273. )
  274. module_batch_size = len(module_batch)
  275. # Log average batch size (for each module).
  276. self.metrics.log_value(
  277. key=(mid, MODULE_SAMPLE_BATCH_SIZE_MEAN),
  278. value=module_batch_size,
  279. )
  280. # Log module steps (for each module).
  281. self.metrics.log_value(
  282. key=(mid, NUM_MODULE_STEPS_SAMPLED),
  283. value=module_batch_size,
  284. reduce="sum",
  285. )
  286. self.metrics.log_value(
  287. key=(mid, NUM_MODULE_STEPS_SAMPLED_LIFETIME),
  288. value=module_batch_size,
  289. reduce="lifetime_sum",
  290. )
  291. # Log module steps (sum of all modules).
  292. self.metrics.log_value(
  293. key=(ALL_MODULES, NUM_MODULE_STEPS_SAMPLED),
  294. value=module_batch_size,
  295. reduce="sum",
  296. )
  297. self.metrics.log_value(
  298. key=(ALL_MODULES, NUM_MODULE_STEPS_SAMPLED_LIFETIME),
  299. value=module_batch_size,
  300. reduce="lifetime_sum",
  301. )
  302. # Log env steps (all modules).
  303. self.metrics.log_value(
  304. (ALL_MODULES, NUM_ENV_STEPS_SAMPLED),
  305. batch.env_steps(),
  306. reduce="sum",
  307. )
  308. self.metrics.log_value(
  309. (ALL_MODULES, NUM_ENV_STEPS_SAMPLED_LIFETIME),
  310. batch.env_steps(),
  311. reduce="lifetime_sum",
  312. with_throughput=True,
  313. )
  314. @override(Runner)
  315. def set_device(self):
  316. try:
  317. self.__device = get_device(
  318. self.config,
  319. (
  320. 0
  321. if not self.worker_index
  322. else self.config.num_gpus_per_offline_eval_runner
  323. ),
  324. )
  325. except NotImplementedError:
  326. self.__device = None
  327. @override(Runner)
  328. def make_module(self):
  329. try:
  330. from ray.rllib.env import INPUT_ENV_SPACES
  331. if not self._module_spec:
  332. self.__module_spec = self.config.get_multi_rl_module_spec(
  333. # Note, usually we have no environemnt in case of offline evaluation.
  334. env=self.config.env,
  335. spaces={
  336. INPUT_ENV_SPACES: (
  337. self.config.observation_space,
  338. self.config.action_space,
  339. )
  340. },
  341. inference_only=self.config.offline_eval_rl_module_inference_only,
  342. )
  343. # Build the module from its spec.
  344. self.module = self._module_spec.build()
  345. # TODO (simon): Implement GPU inference.
  346. # Move the RLModule to our device.
  347. # TODO (sven): In order to make this framework-agnostic, we should maybe
  348. # make the MultiRLModule.build() method accept a device OR create an
  349. # additional `(Multi)RLModule.to()` override.
  350. self.module.foreach_module(
  351. lambda mid, mod: (
  352. mod.to(self._device) if isinstance(mod, torch.nn.Module) else mod
  353. )
  354. )
  355. # If `AlgorithmConfig.get_multi_rl_module_spec()` is not implemented, this env runner
  356. # will not have an RLModule, but might still be usable with random actions.
  357. except NotImplementedError:
  358. self.module = None
  359. def get_loss_for_module_fn(self):
  360. # Either the user has provided a loss-for-module function, or we take
  361. # the loss function from the default `Learner` class.
  362. return (
  363. self.config.offline_loss_for_module_fn
  364. or self.config.get_default_learner_class().__dict__[
  365. "compute_loss_for_module"
  366. ]
  367. )
  368. @property
  369. def _dataset_iterator(self) -> DataIterator:
  370. """Returns the dataset iterator."""
  371. return self.__dataset_iterator
  372. def set_dataset_iterator(self, iterator):
  373. """Sets the dataset iterator."""
  374. self.__dataset_iterator = iterator
  375. @property
  376. def _batch_iterator(self) -> MiniBatchRayDataIterator:
  377. return self.__batch_iterator
  378. @property
  379. def _device(self) -> Union[DeviceType, None]:
  380. return self.__device
  381. @property
  382. def _module_spec(self) -> MultiRLModuleSpec:
  383. """Returns the `MultiRLModuleSpec` of this `Runner`."""
  384. return self.__module_spec