| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440 |
- import logging
- from typing import Any, Dict, Optional
- import numpy as np
- from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning
- from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples
- from ray.rllib.utils.annotations import OldAPIStack
- from ray.rllib.utils.from_config import from_config
- from ray.rllib.utils.metrics import ALL_MODULES, TD_ERROR_KEY
- from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
- from ray.rllib.utils.replay_buffers import (
- EpisodeReplayBuffer,
- MultiAgentPrioritizedReplayBuffer,
- MultiAgentReplayBuffer,
- PrioritizedEpisodeReplayBuffer,
- ReplayBuffer,
- )
- from ray.rllib.utils.typing import (
- AlgorithmConfigDict,
- ModuleID,
- ResultDict,
- SampleBatchType,
- TensorType,
- )
- from ray.util import log_once
- from ray.util.annotations import DeveloperAPI
- import psutil
- logger = logging.getLogger(__name__)
- @DeveloperAPI
- def update_priorities_in_episode_replay_buffer(
- *,
- replay_buffer: EpisodeReplayBuffer,
- td_errors: Dict[ModuleID, TensorType],
- ) -> None:
- # Only update priorities, if the buffer supports them.
- if isinstance(replay_buffer, PrioritizedEpisodeReplayBuffer):
- # The `ResultDict` will be multi-agent.
- for module_id, td_error in td_errors.items():
- # Skip the `"__all__"` keys.
- if module_id in ["__all__", ALL_MODULES]:
- continue
- # Warn once, if we have no TD-errors to update priorities.
- if TD_ERROR_KEY not in td_error or td_error[TD_ERROR_KEY] is None:
- if log_once(
- "no_td_error_in_train_results_from_module_{}".format(module_id)
- ):
- logger.warning(
- "Trying to update priorities for module with ID "
- f"`{module_id}` in prioritized episode replay buffer without "
- "providing `td_errors` in train_results. Priority update for "
- "this policy is being skipped."
- )
- continue
- # TODO (simon): Implement multi-agent version. Remove, happens in buffer.
- # assert len(td_error[TD_ERROR_KEY]) == len(
- # replay_buffer._last_sampled_indices
- # )
- # TODO (simon): Implement for stateful modules.
- replay_buffer.update_priorities(td_error[TD_ERROR_KEY], module_id)
- @OldAPIStack
- def update_priorities_in_replay_buffer(
- replay_buffer: ReplayBuffer,
- config: AlgorithmConfigDict,
- train_batch: SampleBatchType,
- train_results: ResultDict,
- ) -> None:
- """Updates the priorities in a prioritized replay buffer, given training results.
- The `abs(TD-error)` from the loss (inside `train_results`) is used as new
- priorities for the row-indices that were sampled for the train batch.
- Don't do anything if the given buffer does not support prioritized replay.
- Args:
- replay_buffer: The replay buffer, whose priority values to update. This may also
- be a buffer that does not support priorities.
- config: The Algorithm's config dict.
- train_batch: The batch used for the training update.
- train_results: A train results dict, generated by e.g. the `train_one_step()`
- utility.
- """
- # Only update priorities if buffer supports them.
- if isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer):
- # Go through training results for the different policies (maybe multi-agent).
- prio_dict = {}
- for policy_id, info in train_results.items():
- # TODO(sven): This is currently structured differently for
- # torch/tf. Clean up these results/info dicts across
- # policies (note: fixing this in torch_policy.py will
- # break e.g. DDPPO!).
- td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error"))
- policy_batch = train_batch.policy_batches[policy_id]
- # Set the get_interceptor to None in order to be able to access the numpy
- # arrays directly (instead of e.g. a torch array).
- policy_batch.set_get_interceptor(None)
- # Get the replay buffer row indices that make up the `train_batch`.
- batch_indices = policy_batch.get("batch_indexes")
- if SampleBatch.SEQ_LENS in policy_batch:
- # Batch_indices are represented per column, in order to update
- # priorities, we need one index per td_error
- _batch_indices = []
- # Sequenced batches have been zero padded to max_seq_len.
- # Depending on how batches are split during learning, not all
- # sequences have an associated td_error (trailing ones missing).
- if policy_batch.zero_padded:
- seq_lens = len(td_error) * [policy_batch.max_seq_len]
- else:
- seq_lens = policy_batch[SampleBatch.SEQ_LENS][: len(td_error)]
- # Go through all indices by sequence that they represent and shrink
- # them to one index per sequences
- sequence_sum = 0
- for seq_len in seq_lens:
- _batch_indices.append(batch_indices[sequence_sum])
- sequence_sum += seq_len
- batch_indices = np.array(_batch_indices)
- if td_error is None:
- if log_once(
- "no_td_error_in_train_results_from_policy_{}".format(policy_id)
- ):
- logger.warning(
- "Trying to update priorities for policy with id `{}` in "
- "prioritized replay buffer without providing td_errors in "
- "train_results. Priority update for this policy is being "
- "skipped.".format(policy_id)
- )
- continue
- if batch_indices is None:
- if log_once(
- "no_batch_indices_in_train_result_for_policy_{}".format(policy_id)
- ):
- logger.warning(
- "Trying to update priorities for policy with id `{}` in "
- "prioritized replay buffer without providing batch_indices in "
- "train_batch. Priority update for this policy is being "
- "skipped.".format(policy_id)
- )
- continue
- # Try to transform batch_indices to td_error dimensions
- if len(batch_indices) != len(td_error):
- T = replay_buffer.replay_sequence_length
- assert (
- len(batch_indices) > len(td_error) and len(batch_indices) % T == 0
- )
- batch_indices = batch_indices.reshape([-1, T])[:, 0]
- assert len(batch_indices) == len(td_error)
- prio_dict[policy_id] = (batch_indices, td_error)
- # Make the actual buffer API call to update the priority weights on all
- # policies.
- replay_buffer.update_priorities(prio_dict)
- @DeveloperAPI
- def sample_min_n_steps_from_buffer(
- replay_buffer: ReplayBuffer, min_steps: int, count_by_agent_steps: bool
- ) -> Optional[SampleBatchType]:
- """Samples a minimum of n timesteps from a given replay buffer.
- This utility method is primarily used by the QMIX algorithm and helps with
- sampling a given number of time steps which has stored samples in units
- of sequences or complete episodes. Samples n batches from replay buffer
- until the total number of timesteps reaches `train_batch_size`.
- Args:
- replay_buffer: The replay buffer to sample from
- num_timesteps: The number of timesteps to sample
- count_by_agent_steps: Whether to count agent steps or env steps
- Returns:
- A concatenated SampleBatch or MultiAgentBatch with samples from the
- buffer.
- """
- train_batch_size = 0
- train_batches = []
- while train_batch_size < min_steps:
- batch = replay_buffer.sample(num_items=1)
- batch_len = batch.agent_steps() if count_by_agent_steps else batch.env_steps()
- if batch_len == 0:
- # Replay has not started, so we can't accumulate timesteps here
- return batch
- train_batches.append(batch)
- train_batch_size += batch_len
- # All batch types are the same type, hence we can use any concat_samples()
- train_batch = concat_samples(train_batches)
- return train_batch
- @DeveloperAPI
- def validate_buffer_config(config: dict) -> None:
- """Checks and fixes values in the replay buffer config.
- Checks the replay buffer config for common misconfigurations, warns or raises
- error in case validation fails. The type "key" is changed into the inferred
- replay buffer class.
- Args:
- config: The replay buffer config to be validated.
- Raises:
- ValueError: When detecting severe misconfiguration.
- """
- if config.get("replay_buffer_config", None) is None:
- config["replay_buffer_config"] = {}
- if config.get("worker_side_prioritization", DEPRECATED_VALUE) != DEPRECATED_VALUE:
- deprecation_warning(
- old="config['worker_side_prioritization']",
- new="config['replay_buffer_config']['worker_side_prioritization']",
- error=True,
- )
- prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE)
- if prioritized_replay != DEPRECATED_VALUE:
- deprecation_warning(
- old="config['prioritized_replay'] or config['replay_buffer_config']["
- "'prioritized_replay']",
- help="Replay prioritization specified by config key. RLlib's new replay "
- "buffer API requires setting `config["
- "'replay_buffer_config']['type']`, e.g. `config["
- "'replay_buffer_config']['type'] = "
- "'MultiAgentPrioritizedReplayBuffer'` to change the default "
- "behaviour.",
- error=True,
- )
- capacity = config.get("buffer_size", DEPRECATED_VALUE)
- if capacity == DEPRECATED_VALUE:
- capacity = config["replay_buffer_config"].get("buffer_size", DEPRECATED_VALUE)
- if capacity != DEPRECATED_VALUE:
- deprecation_warning(
- old="config['buffer_size'] or config['replay_buffer_config']["
- "'buffer_size']",
- new="config['replay_buffer_config']['capacity']",
- error=True,
- )
- replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
- if replay_burn_in != DEPRECATED_VALUE:
- config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
- deprecation_warning(
- old="config['burn_in']",
- help="config['replay_buffer_config']['replay_burn_in']",
- )
- replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
- if replay_batch_size == DEPRECATED_VALUE:
- replay_batch_size = config["replay_buffer_config"].get(
- "replay_batch_size", DEPRECATED_VALUE
- )
- if replay_batch_size != DEPRECATED_VALUE:
- deprecation_warning(
- old="config['replay_batch_size'] or config['replay_buffer_config']["
- "'replay_batch_size']",
- help="Specification of replay_batch_size is not supported anymore but is "
- "derived from `train_batch_size`. Specify the number of "
- "items you want to replay upon calling the sample() method of replay "
- "buffers if this does not work for you.",
- error=True,
- )
- # Deprecation of old-style replay buffer args
- # Warnings before checking of we need local buffer so that algorithms
- # Without local buffer also get warned
- keys_with_deprecated_positions = [
- "prioritized_replay_alpha",
- "prioritized_replay_beta",
- "prioritized_replay_eps",
- "no_local_replay_buffer",
- "replay_zero_init_states",
- "replay_buffer_shards_colocated_with_driver",
- ]
- for k in keys_with_deprecated_positions:
- if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
- deprecation_warning(
- old="config['{}']".format(k),
- help="config['replay_buffer_config']['{}']" "".format(k),
- error=False,
- )
- # Copy values over to new location in config to support new
- # and old configuration style.
- if config.get("replay_buffer_config") is not None:
- config["replay_buffer_config"][k] = config[k]
- learning_starts = config.get(
- "learning_starts",
- config.get("replay_buffer_config", {}).get("learning_starts", DEPRECATED_VALUE),
- )
- if learning_starts != DEPRECATED_VALUE:
- deprecation_warning(
- old="config['learning_starts'] or"
- "config['replay_buffer_config']['learning_starts']",
- help="config['num_steps_sampled_before_learning_starts']",
- error=True,
- )
- config["num_steps_sampled_before_learning_starts"] = learning_starts
- # Can't use DEPRECATED_VALUE here because this is also a deliberate
- # value set for some algorithms
- # TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
- replay_sequence_length = config.get("replay_sequence_length", None)
- if replay_sequence_length is not None:
- config["replay_buffer_config"][
- "replay_sequence_length"
- ] = replay_sequence_length
- deprecation_warning(
- old="config['replay_sequence_length']",
- help="Replay sequence length specified at new "
- "location config['replay_buffer_config']["
- "'replay_sequence_length'] will be overwritten.",
- error=True,
- )
- replay_buffer_config = config["replay_buffer_config"]
- assert (
- "type" in replay_buffer_config
- ), "Can not instantiate ReplayBuffer from config without 'type' key."
- # Check if old replay buffer should be instantiated
- buffer_type = config["replay_buffer_config"]["type"]
- if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
- # Create valid full [module].[class] string for from_config
- config["replay_buffer_config"]["type"] = (
- "ray.rllib.utils.replay_buffers." + buffer_type
- )
- # Instantiate a dummy buffer to fail early on misconfiguration and find out about
- # inferred buffer class
- dummy_buffer = from_config(buffer_type, config["replay_buffer_config"])
- config["replay_buffer_config"]["type"] = type(dummy_buffer)
- if hasattr(dummy_buffer, "update_priorities"):
- if (
- config["replay_buffer_config"].get("replay_mode", "independent")
- == "lockstep"
- ):
- raise ValueError(
- "Prioritized replay is not supported when replay_mode=lockstep."
- )
- elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1:
- raise ValueError(
- "Prioritized replay is not supported when "
- "replay_sequence_length > 1."
- )
- else:
- if config["replay_buffer_config"].get("worker_side_prioritization"):
- raise ValueError(
- "Worker side prioritization is not supported when "
- "prioritized_replay=False."
- )
- @DeveloperAPI
- def warn_replay_buffer_capacity(*, item: SampleBatchType, capacity: int) -> None:
- """Warn if the configured replay buffer capacity is too large for machine's memory.
- Args:
- item: A (example) item that's supposed to be added to the buffer.
- This is used to compute the overall memory footprint estimate for the
- buffer.
- capacity: The capacity value of the buffer. This is interpreted as the
- number of items (such as given `item`) that will eventually be stored in
- the buffer.
- Raises:
- ValueError: If computed memory footprint for the buffer exceeds the machine's
- RAM.
- """
- if log_once("warn_replay_buffer_capacity"):
- item_size = item.size_bytes()
- psutil_mem = psutil.virtual_memory()
- total_gb = psutil_mem.total / 1e9
- mem_size = capacity * item_size / 1e9
- msg = (
- "Estimated max memory usage for replay buffer is {} GB "
- "({} batches of size {}, {} bytes each), "
- "available system memory is {} GB".format(
- mem_size, capacity, item.count, item_size, total_gb
- )
- )
- if mem_size > total_gb:
- raise ValueError(msg)
- elif mem_size > 0.2 * total_gb:
- logger.warning(msg)
- else:
- logger.info(msg)
- def patch_buffer_with_fake_sampling_method(
- buffer: ReplayBuffer, fake_sample_output: SampleBatchType
- ) -> None:
- """Patch a ReplayBuffer such that we always sample fake_sample_output.
- Transforms fake_sample_output into a MultiAgentBatch if it is not a
- MultiAgentBatch and the buffer is a MultiAgentBuffer. This is useful for testing
- purposes if we need deterministic sampling.
- Args:
- buffer: The buffer to be patched
- fake_sample_output: The output to be sampled
- """
- if isinstance(buffer, MultiAgentReplayBuffer) and not isinstance(
- fake_sample_output, MultiAgentBatch
- ):
- fake_sample_output = SampleBatch(fake_sample_output).as_multi_agent()
- def fake_sample(_: Any = None, **kwargs) -> Optional[SampleBatchType]:
- """Always returns a predefined batch.
- Args:
- _: dummy arg to match signature of sample() method
- __: dummy arg to match signature of sample() method
- ``**kwargs``: dummy args to match signature of sample() method
- Returns:
- Predefined MultiAgentBatch fake_sample_output
- """
- return fake_sample_output
- buffer.sample = fake_sample
|