utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. import logging
  2. from typing import Any, Dict, Optional
  3. import numpy as np
  4. from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning
  5. from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples
  6. from ray.rllib.utils.annotations import OldAPIStack
  7. from ray.rllib.utils.from_config import from_config
  8. from ray.rllib.utils.metrics import ALL_MODULES, TD_ERROR_KEY
  9. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  10. from ray.rllib.utils.replay_buffers import (
  11. EpisodeReplayBuffer,
  12. MultiAgentPrioritizedReplayBuffer,
  13. MultiAgentReplayBuffer,
  14. PrioritizedEpisodeReplayBuffer,
  15. ReplayBuffer,
  16. )
  17. from ray.rllib.utils.typing import (
  18. AlgorithmConfigDict,
  19. ModuleID,
  20. ResultDict,
  21. SampleBatchType,
  22. TensorType,
  23. )
  24. from ray.util import log_once
  25. from ray.util.annotations import DeveloperAPI
  26. import psutil
  27. logger = logging.getLogger(__name__)
  28. @DeveloperAPI
  29. def update_priorities_in_episode_replay_buffer(
  30. *,
  31. replay_buffer: EpisodeReplayBuffer,
  32. td_errors: Dict[ModuleID, TensorType],
  33. ) -> None:
  34. # Only update priorities, if the buffer supports them.
  35. if isinstance(replay_buffer, PrioritizedEpisodeReplayBuffer):
  36. # The `ResultDict` will be multi-agent.
  37. for module_id, td_error in td_errors.items():
  38. # Skip the `"__all__"` keys.
  39. if module_id in ["__all__", ALL_MODULES]:
  40. continue
  41. # Warn once, if we have no TD-errors to update priorities.
  42. if TD_ERROR_KEY not in td_error or td_error[TD_ERROR_KEY] is None:
  43. if log_once(
  44. "no_td_error_in_train_results_from_module_{}".format(module_id)
  45. ):
  46. logger.warning(
  47. "Trying to update priorities for module with ID "
  48. f"`{module_id}` in prioritized episode replay buffer without "
  49. "providing `td_errors` in train_results. Priority update for "
  50. "this policy is being skipped."
  51. )
  52. continue
  53. # TODO (simon): Implement multi-agent version. Remove, happens in buffer.
  54. # assert len(td_error[TD_ERROR_KEY]) == len(
  55. # replay_buffer._last_sampled_indices
  56. # )
  57. # TODO (simon): Implement for stateful modules.
  58. replay_buffer.update_priorities(td_error[TD_ERROR_KEY], module_id)
  59. @OldAPIStack
  60. def update_priorities_in_replay_buffer(
  61. replay_buffer: ReplayBuffer,
  62. config: AlgorithmConfigDict,
  63. train_batch: SampleBatchType,
  64. train_results: ResultDict,
  65. ) -> None:
  66. """Updates the priorities in a prioritized replay buffer, given training results.
  67. The `abs(TD-error)` from the loss (inside `train_results`) is used as new
  68. priorities for the row-indices that were sampled for the train batch.
  69. Don't do anything if the given buffer does not support prioritized replay.
  70. Args:
  71. replay_buffer: The replay buffer, whose priority values to update. This may also
  72. be a buffer that does not support priorities.
  73. config: The Algorithm's config dict.
  74. train_batch: The batch used for the training update.
  75. train_results: A train results dict, generated by e.g. the `train_one_step()`
  76. utility.
  77. """
  78. # Only update priorities if buffer supports them.
  79. if isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer):
  80. # Go through training results for the different policies (maybe multi-agent).
  81. prio_dict = {}
  82. for policy_id, info in train_results.items():
  83. # TODO(sven): This is currently structured differently for
  84. # torch/tf. Clean up these results/info dicts across
  85. # policies (note: fixing this in torch_policy.py will
  86. # break e.g. DDPPO!).
  87. td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error"))
  88. policy_batch = train_batch.policy_batches[policy_id]
  89. # Set the get_interceptor to None in order to be able to access the numpy
  90. # arrays directly (instead of e.g. a torch array).
  91. policy_batch.set_get_interceptor(None)
  92. # Get the replay buffer row indices that make up the `train_batch`.
  93. batch_indices = policy_batch.get("batch_indexes")
  94. if SampleBatch.SEQ_LENS in policy_batch:
  95. # Batch_indices are represented per column, in order to update
  96. # priorities, we need one index per td_error
  97. _batch_indices = []
  98. # Sequenced batches have been zero padded to max_seq_len.
  99. # Depending on how batches are split during learning, not all
  100. # sequences have an associated td_error (trailing ones missing).
  101. if policy_batch.zero_padded:
  102. seq_lens = len(td_error) * [policy_batch.max_seq_len]
  103. else:
  104. seq_lens = policy_batch[SampleBatch.SEQ_LENS][: len(td_error)]
  105. # Go through all indices by sequence that they represent and shrink
  106. # them to one index per sequences
  107. sequence_sum = 0
  108. for seq_len in seq_lens:
  109. _batch_indices.append(batch_indices[sequence_sum])
  110. sequence_sum += seq_len
  111. batch_indices = np.array(_batch_indices)
  112. if td_error is None:
  113. if log_once(
  114. "no_td_error_in_train_results_from_policy_{}".format(policy_id)
  115. ):
  116. logger.warning(
  117. "Trying to update priorities for policy with id `{}` in "
  118. "prioritized replay buffer without providing td_errors in "
  119. "train_results. Priority update for this policy is being "
  120. "skipped.".format(policy_id)
  121. )
  122. continue
  123. if batch_indices is None:
  124. if log_once(
  125. "no_batch_indices_in_train_result_for_policy_{}".format(policy_id)
  126. ):
  127. logger.warning(
  128. "Trying to update priorities for policy with id `{}` in "
  129. "prioritized replay buffer without providing batch_indices in "
  130. "train_batch. Priority update for this policy is being "
  131. "skipped.".format(policy_id)
  132. )
  133. continue
  134. # Try to transform batch_indices to td_error dimensions
  135. if len(batch_indices) != len(td_error):
  136. T = replay_buffer.replay_sequence_length
  137. assert (
  138. len(batch_indices) > len(td_error) and len(batch_indices) % T == 0
  139. )
  140. batch_indices = batch_indices.reshape([-1, T])[:, 0]
  141. assert len(batch_indices) == len(td_error)
  142. prio_dict[policy_id] = (batch_indices, td_error)
  143. # Make the actual buffer API call to update the priority weights on all
  144. # policies.
  145. replay_buffer.update_priorities(prio_dict)
  146. @DeveloperAPI
  147. def sample_min_n_steps_from_buffer(
  148. replay_buffer: ReplayBuffer, min_steps: int, count_by_agent_steps: bool
  149. ) -> Optional[SampleBatchType]:
  150. """Samples a minimum of n timesteps from a given replay buffer.
  151. This utility method is primarily used by the QMIX algorithm and helps with
  152. sampling a given number of time steps which has stored samples in units
  153. of sequences or complete episodes. Samples n batches from replay buffer
  154. until the total number of timesteps reaches `train_batch_size`.
  155. Args:
  156. replay_buffer: The replay buffer to sample from
  157. num_timesteps: The number of timesteps to sample
  158. count_by_agent_steps: Whether to count agent steps or env steps
  159. Returns:
  160. A concatenated SampleBatch or MultiAgentBatch with samples from the
  161. buffer.
  162. """
  163. train_batch_size = 0
  164. train_batches = []
  165. while train_batch_size < min_steps:
  166. batch = replay_buffer.sample(num_items=1)
  167. batch_len = batch.agent_steps() if count_by_agent_steps else batch.env_steps()
  168. if batch_len == 0:
  169. # Replay has not started, so we can't accumulate timesteps here
  170. return batch
  171. train_batches.append(batch)
  172. train_batch_size += batch_len
  173. # All batch types are the same type, hence we can use any concat_samples()
  174. train_batch = concat_samples(train_batches)
  175. return train_batch
  176. @DeveloperAPI
  177. def validate_buffer_config(config: dict) -> None:
  178. """Checks and fixes values in the replay buffer config.
  179. Checks the replay buffer config for common misconfigurations, warns or raises
  180. error in case validation fails. The type "key" is changed into the inferred
  181. replay buffer class.
  182. Args:
  183. config: The replay buffer config to be validated.
  184. Raises:
  185. ValueError: When detecting severe misconfiguration.
  186. """
  187. if config.get("replay_buffer_config", None) is None:
  188. config["replay_buffer_config"] = {}
  189. if config.get("worker_side_prioritization", DEPRECATED_VALUE) != DEPRECATED_VALUE:
  190. deprecation_warning(
  191. old="config['worker_side_prioritization']",
  192. new="config['replay_buffer_config']['worker_side_prioritization']",
  193. error=True,
  194. )
  195. prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE)
  196. if prioritized_replay != DEPRECATED_VALUE:
  197. deprecation_warning(
  198. old="config['prioritized_replay'] or config['replay_buffer_config']["
  199. "'prioritized_replay']",
  200. help="Replay prioritization specified by config key. RLlib's new replay "
  201. "buffer API requires setting `config["
  202. "'replay_buffer_config']['type']`, e.g. `config["
  203. "'replay_buffer_config']['type'] = "
  204. "'MultiAgentPrioritizedReplayBuffer'` to change the default "
  205. "behaviour.",
  206. error=True,
  207. )
  208. capacity = config.get("buffer_size", DEPRECATED_VALUE)
  209. if capacity == DEPRECATED_VALUE:
  210. capacity = config["replay_buffer_config"].get("buffer_size", DEPRECATED_VALUE)
  211. if capacity != DEPRECATED_VALUE:
  212. deprecation_warning(
  213. old="config['buffer_size'] or config['replay_buffer_config']["
  214. "'buffer_size']",
  215. new="config['replay_buffer_config']['capacity']",
  216. error=True,
  217. )
  218. replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
  219. if replay_burn_in != DEPRECATED_VALUE:
  220. config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
  221. deprecation_warning(
  222. old="config['burn_in']",
  223. help="config['replay_buffer_config']['replay_burn_in']",
  224. )
  225. replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
  226. if replay_batch_size == DEPRECATED_VALUE:
  227. replay_batch_size = config["replay_buffer_config"].get(
  228. "replay_batch_size", DEPRECATED_VALUE
  229. )
  230. if replay_batch_size != DEPRECATED_VALUE:
  231. deprecation_warning(
  232. old="config['replay_batch_size'] or config['replay_buffer_config']["
  233. "'replay_batch_size']",
  234. help="Specification of replay_batch_size is not supported anymore but is "
  235. "derived from `train_batch_size`. Specify the number of "
  236. "items you want to replay upon calling the sample() method of replay "
  237. "buffers if this does not work for you.",
  238. error=True,
  239. )
  240. # Deprecation of old-style replay buffer args
  241. # Warnings before checking of we need local buffer so that algorithms
  242. # Without local buffer also get warned
  243. keys_with_deprecated_positions = [
  244. "prioritized_replay_alpha",
  245. "prioritized_replay_beta",
  246. "prioritized_replay_eps",
  247. "no_local_replay_buffer",
  248. "replay_zero_init_states",
  249. "replay_buffer_shards_colocated_with_driver",
  250. ]
  251. for k in keys_with_deprecated_positions:
  252. if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
  253. deprecation_warning(
  254. old="config['{}']".format(k),
  255. help="config['replay_buffer_config']['{}']" "".format(k),
  256. error=False,
  257. )
  258. # Copy values over to new location in config to support new
  259. # and old configuration style.
  260. if config.get("replay_buffer_config") is not None:
  261. config["replay_buffer_config"][k] = config[k]
  262. learning_starts = config.get(
  263. "learning_starts",
  264. config.get("replay_buffer_config", {}).get("learning_starts", DEPRECATED_VALUE),
  265. )
  266. if learning_starts != DEPRECATED_VALUE:
  267. deprecation_warning(
  268. old="config['learning_starts'] or"
  269. "config['replay_buffer_config']['learning_starts']",
  270. help="config['num_steps_sampled_before_learning_starts']",
  271. error=True,
  272. )
  273. config["num_steps_sampled_before_learning_starts"] = learning_starts
  274. # Can't use DEPRECATED_VALUE here because this is also a deliberate
  275. # value set for some algorithms
  276. # TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
  277. replay_sequence_length = config.get("replay_sequence_length", None)
  278. if replay_sequence_length is not None:
  279. config["replay_buffer_config"][
  280. "replay_sequence_length"
  281. ] = replay_sequence_length
  282. deprecation_warning(
  283. old="config['replay_sequence_length']",
  284. help="Replay sequence length specified at new "
  285. "location config['replay_buffer_config']["
  286. "'replay_sequence_length'] will be overwritten.",
  287. error=True,
  288. )
  289. replay_buffer_config = config["replay_buffer_config"]
  290. assert (
  291. "type" in replay_buffer_config
  292. ), "Can not instantiate ReplayBuffer from config without 'type' key."
  293. # Check if old replay buffer should be instantiated
  294. buffer_type = config["replay_buffer_config"]["type"]
  295. if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
  296. # Create valid full [module].[class] string for from_config
  297. config["replay_buffer_config"]["type"] = (
  298. "ray.rllib.utils.replay_buffers." + buffer_type
  299. )
  300. # Instantiate a dummy buffer to fail early on misconfiguration and find out about
  301. # inferred buffer class
  302. dummy_buffer = from_config(buffer_type, config["replay_buffer_config"])
  303. config["replay_buffer_config"]["type"] = type(dummy_buffer)
  304. if hasattr(dummy_buffer, "update_priorities"):
  305. if (
  306. config["replay_buffer_config"].get("replay_mode", "independent")
  307. == "lockstep"
  308. ):
  309. raise ValueError(
  310. "Prioritized replay is not supported when replay_mode=lockstep."
  311. )
  312. elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1:
  313. raise ValueError(
  314. "Prioritized replay is not supported when "
  315. "replay_sequence_length > 1."
  316. )
  317. else:
  318. if config["replay_buffer_config"].get("worker_side_prioritization"):
  319. raise ValueError(
  320. "Worker side prioritization is not supported when "
  321. "prioritized_replay=False."
  322. )
  323. @DeveloperAPI
  324. def warn_replay_buffer_capacity(*, item: SampleBatchType, capacity: int) -> None:
  325. """Warn if the configured replay buffer capacity is too large for machine's memory.
  326. Args:
  327. item: A (example) item that's supposed to be added to the buffer.
  328. This is used to compute the overall memory footprint estimate for the
  329. buffer.
  330. capacity: The capacity value of the buffer. This is interpreted as the
  331. number of items (such as given `item`) that will eventually be stored in
  332. the buffer.
  333. Raises:
  334. ValueError: If computed memory footprint for the buffer exceeds the machine's
  335. RAM.
  336. """
  337. if log_once("warn_replay_buffer_capacity"):
  338. item_size = item.size_bytes()
  339. psutil_mem = psutil.virtual_memory()
  340. total_gb = psutil_mem.total / 1e9
  341. mem_size = capacity * item_size / 1e9
  342. msg = (
  343. "Estimated max memory usage for replay buffer is {} GB "
  344. "({} batches of size {}, {} bytes each), "
  345. "available system memory is {} GB".format(
  346. mem_size, capacity, item.count, item_size, total_gb
  347. )
  348. )
  349. if mem_size > total_gb:
  350. raise ValueError(msg)
  351. elif mem_size > 0.2 * total_gb:
  352. logger.warning(msg)
  353. else:
  354. logger.info(msg)
  355. def patch_buffer_with_fake_sampling_method(
  356. buffer: ReplayBuffer, fake_sample_output: SampleBatchType
  357. ) -> None:
  358. """Patch a ReplayBuffer such that we always sample fake_sample_output.
  359. Transforms fake_sample_output into a MultiAgentBatch if it is not a
  360. MultiAgentBatch and the buffer is a MultiAgentBuffer. This is useful for testing
  361. purposes if we need deterministic sampling.
  362. Args:
  363. buffer: The buffer to be patched
  364. fake_sample_output: The output to be sampled
  365. """
  366. if isinstance(buffer, MultiAgentReplayBuffer) and not isinstance(
  367. fake_sample_output, MultiAgentBatch
  368. ):
  369. fake_sample_output = SampleBatch(fake_sample_output).as_multi_agent()
  370. def fake_sample(_: Any = None, **kwargs) -> Optional[SampleBatchType]:
  371. """Always returns a predefined batch.
  372. Args:
  373. _: dummy arg to match signature of sample() method
  374. __: dummy arg to match signature of sample() method
  375. ``**kwargs``: dummy args to match signature of sample() method
  376. Returns:
  377. Predefined MultiAgentBatch fake_sample_output
  378. """
  379. return fake_sample_output
  380. buffer.sample = fake_sample