| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- import logging
- import math
- from typing import Dict
- import numpy as np
- from ray._common.deprecation import deprecation_warning
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
- from ray.rllib.utils.annotations import OldAPIStack
- from ray.rllib.utils.framework import try_import_tf
- from ray.rllib.utils.metrics import (
- LEARN_ON_BATCH_TIMER,
- LOAD_BATCH_TIMER,
- NUM_AGENT_STEPS_TRAINED,
- NUM_ENV_STEPS_TRAINED,
- )
- from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
- from ray.rllib.utils.sgd import do_minibatch_sgd
- from ray.util import log_once
- tf1, tf, tfv = try_import_tf()
- logger = logging.getLogger(__name__)
- @OldAPIStack
- def train_one_step(algorithm, train_batch, policies_to_train=None) -> Dict:
- """Function that improves the all policies in `train_batch` on the local worker.
- .. testcode::
- :skipif: True
- from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
- algo = [...]
- train_batch = synchronous_parallel_sample(algo.env_runner_group)
- # This trains the policy on one batch.
- print(train_one_step(algo, train_batch)))
- .. testoutput::
- {"default_policy": ...}
- Updates the NUM_ENV_STEPS_TRAINED and NUM_AGENT_STEPS_TRAINED counters as well as
- the LEARN_ON_BATCH_TIMER timer of the `algorithm` object.
- """
- config = algorithm.config
- workers = algorithm.env_runner_group
- local_worker = workers.local_env_runner
- num_sgd_iter = config.get("num_epochs", config.get("num_sgd_iter", 1))
- minibatch_size = config.get("minibatch_size")
- if minibatch_size is None:
- minibatch_size = config.get("sgd_minibatch_size", 0)
- learn_timer = algorithm._timers[LEARN_ON_BATCH_TIMER]
- with learn_timer:
- # Subsample minibatches (size=`minibatch_size`) from the
- # train batch and loop through train batch `num_sgd_iter` times.
- if num_sgd_iter > 1 or minibatch_size > 0:
- info = do_minibatch_sgd(
- train_batch,
- {
- pid: local_worker.get_policy(pid)
- for pid in policies_to_train
- or local_worker.get_policies_to_train(train_batch)
- },
- local_worker,
- num_sgd_iter,
- minibatch_size,
- [],
- )
- # Single update step using train batch.
- else:
- info = local_worker.learn_on_batch(train_batch)
- learn_timer.push_units_processed(train_batch.count)
- algorithm._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
- algorithm._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
- if algorithm.reward_estimators:
- info[DEFAULT_POLICY_ID]["off_policy_estimation"] = {}
- for name, estimator in algorithm.reward_estimators.items():
- info[DEFAULT_POLICY_ID]["off_policy_estimation"][name] = estimator.train(
- train_batch
- )
- return info
- @OldAPIStack
- def multi_gpu_train_one_step(algorithm, train_batch) -> Dict:
- """Multi-GPU version of train_one_step.
- Uses the policies' `load_batch_into_buffer` and `learn_on_loaded_batch` methods
- to be more efficient wrt CPU/GPU data transfers. For example, when doing multiple
- passes through a train batch (e.g. for PPO) using `config.num_sgd_iter`, the
- actual train batch is only split once and loaded once into the GPU(s).
- .. testcode::
- :skipif: True
- from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
- algo = [...]
- train_batch = synchronous_parallel_sample(algo.env_runner_group)
- # This trains the policy on one batch.
- print(multi_gpu_train_one_step(algo, train_batch)))
- .. testoutput::
- {"default_policy": ...}
- Updates the NUM_ENV_STEPS_TRAINED and NUM_AGENT_STEPS_TRAINED counters as well as
- the LOAD_BATCH_TIMER and LEARN_ON_BATCH_TIMER timers of the Algorithm instance.
- """
- if log_once("mulit_gpu_train_one_step_deprecation_warning"):
- deprecation_warning(
- old=("ray.rllib.execution.train_ops.multi_gpu_train_one_step")
- )
- config = algorithm.config
- workers = algorithm.env_runner_group
- local_worker = workers.local_env_runner
- num_sgd_iter = config.get("num_epochs", config.get("num_sgd_iter", 1))
- minibatch_size = config.get("minibatch_size")
- if minibatch_size is None:
- minibatch_size = config["train_batch_size"]
- # Determine the number of devices (GPUs or 1 CPU) we use.
- num_devices = int(math.ceil(config["num_gpus"] or 1))
- # Make sure total batch size is dividable by the number of devices.
- # Batch size per tower.
- per_device_batch_size = minibatch_size // num_devices
- # Total batch size.
- batch_size = per_device_batch_size * num_devices
- assert batch_size % num_devices == 0
- assert batch_size >= num_devices, "Batch size too small!"
- # Handle everything as if multi-agent.
- train_batch = train_batch.as_multi_agent()
- # Load data into GPUs.
- load_timer = algorithm._timers[LOAD_BATCH_TIMER]
- with load_timer:
- num_loaded_samples = {}
- for policy_id, batch in train_batch.policy_batches.items():
- # Not a policy-to-train.
- if (
- local_worker.is_policy_to_train is not None
- and not local_worker.is_policy_to_train(policy_id, train_batch)
- ):
- continue
- # Decompress SampleBatch, in case some columns are compressed.
- batch.decompress_if_needed()
- # Load the entire train batch into the Policy's only buffer
- # (idx=0). Policies only have >1 buffers, if we are training
- # asynchronously.
- num_loaded_samples[policy_id] = local_worker.policy_map[
- policy_id
- ].load_batch_into_buffer(batch, buffer_index=0)
- # Execute minibatch SGD on loaded data.
- learn_timer = algorithm._timers[LEARN_ON_BATCH_TIMER]
- with learn_timer:
- # Use LearnerInfoBuilder as a unified way to build the final
- # results dict from `learn_on_loaded_batch` call(s).
- # This makes sure results dicts always have the same structure
- # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
- # tf vs torch).
- learner_info_builder = LearnerInfoBuilder(num_devices=num_devices)
- for policy_id, samples_per_device in num_loaded_samples.items():
- policy = local_worker.policy_map[policy_id]
- num_batches = max(1, int(samples_per_device) // int(per_device_batch_size))
- logger.debug("== sgd epochs for {} ==".format(policy_id))
- for _ in range(num_sgd_iter):
- permutation = np.random.permutation(num_batches)
- for batch_index in range(num_batches):
- # Learn on the pre-loaded data in the buffer.
- # Note: For minibatch SGD, the data is an offset into
- # the pre-loaded entire train batch.
- results = policy.learn_on_loaded_batch(
- permutation[batch_index] * per_device_batch_size, buffer_index=0
- )
- learner_info_builder.add_learn_on_batch_results(results, policy_id)
- # Tower reduce and finalize results.
- learner_info = learner_info_builder.finalize()
- load_timer.push_units_processed(train_batch.count)
- learn_timer.push_units_processed(train_batch.count)
- # TODO: Move this into Algorithm's `training_step` method for
- # better transparency.
- algorithm._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
- algorithm._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
- if algorithm.reward_estimators:
- learner_info[DEFAULT_POLICY_ID]["off_policy_estimation"] = {}
- for name, estimator in algorithm.reward_estimators.items():
- learner_info[DEFAULT_POLICY_ID]["off_policy_estimation"][
- name
- ] = estimator.train(train_batch)
- return learner_info
|