train_ops.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import logging
  2. import math
  3. from typing import Dict
  4. import numpy as np
  5. from ray._common.deprecation import deprecation_warning
  6. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  7. from ray.rllib.utils.annotations import OldAPIStack
  8. from ray.rllib.utils.framework import try_import_tf
  9. from ray.rllib.utils.metrics import (
  10. LEARN_ON_BATCH_TIMER,
  11. LOAD_BATCH_TIMER,
  12. NUM_AGENT_STEPS_TRAINED,
  13. NUM_ENV_STEPS_TRAINED,
  14. )
  15. from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
  16. from ray.rllib.utils.sgd import do_minibatch_sgd
  17. from ray.util import log_once
  18. tf1, tf, tfv = try_import_tf()
  19. logger = logging.getLogger(__name__)
  20. @OldAPIStack
  21. def train_one_step(algorithm, train_batch, policies_to_train=None) -> Dict:
  22. """Function that improves the all policies in `train_batch` on the local worker.
  23. .. testcode::
  24. :skipif: True
  25. from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
  26. algo = [...]
  27. train_batch = synchronous_parallel_sample(algo.env_runner_group)
  28. # This trains the policy on one batch.
  29. print(train_one_step(algo, train_batch)))
  30. .. testoutput::
  31. {"default_policy": ...}
  32. Updates the NUM_ENV_STEPS_TRAINED and NUM_AGENT_STEPS_TRAINED counters as well as
  33. the LEARN_ON_BATCH_TIMER timer of the `algorithm` object.
  34. """
  35. config = algorithm.config
  36. workers = algorithm.env_runner_group
  37. local_worker = workers.local_env_runner
  38. num_sgd_iter = config.get("num_epochs", config.get("num_sgd_iter", 1))
  39. minibatch_size = config.get("minibatch_size")
  40. if minibatch_size is None:
  41. minibatch_size = config.get("sgd_minibatch_size", 0)
  42. learn_timer = algorithm._timers[LEARN_ON_BATCH_TIMER]
  43. with learn_timer:
  44. # Subsample minibatches (size=`minibatch_size`) from the
  45. # train batch and loop through train batch `num_sgd_iter` times.
  46. if num_sgd_iter > 1 or minibatch_size > 0:
  47. info = do_minibatch_sgd(
  48. train_batch,
  49. {
  50. pid: local_worker.get_policy(pid)
  51. for pid in policies_to_train
  52. or local_worker.get_policies_to_train(train_batch)
  53. },
  54. local_worker,
  55. num_sgd_iter,
  56. minibatch_size,
  57. [],
  58. )
  59. # Single update step using train batch.
  60. else:
  61. info = local_worker.learn_on_batch(train_batch)
  62. learn_timer.push_units_processed(train_batch.count)
  63. algorithm._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
  64. algorithm._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
  65. if algorithm.reward_estimators:
  66. info[DEFAULT_POLICY_ID]["off_policy_estimation"] = {}
  67. for name, estimator in algorithm.reward_estimators.items():
  68. info[DEFAULT_POLICY_ID]["off_policy_estimation"][name] = estimator.train(
  69. train_batch
  70. )
  71. return info
  72. @OldAPIStack
  73. def multi_gpu_train_one_step(algorithm, train_batch) -> Dict:
  74. """Multi-GPU version of train_one_step.
  75. Uses the policies' `load_batch_into_buffer` and `learn_on_loaded_batch` methods
  76. to be more efficient wrt CPU/GPU data transfers. For example, when doing multiple
  77. passes through a train batch (e.g. for PPO) using `config.num_sgd_iter`, the
  78. actual train batch is only split once and loaded once into the GPU(s).
  79. .. testcode::
  80. :skipif: True
  81. from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
  82. algo = [...]
  83. train_batch = synchronous_parallel_sample(algo.env_runner_group)
  84. # This trains the policy on one batch.
  85. print(multi_gpu_train_one_step(algo, train_batch)))
  86. .. testoutput::
  87. {"default_policy": ...}
  88. Updates the NUM_ENV_STEPS_TRAINED and NUM_AGENT_STEPS_TRAINED counters as well as
  89. the LOAD_BATCH_TIMER and LEARN_ON_BATCH_TIMER timers of the Algorithm instance.
  90. """
  91. if log_once("mulit_gpu_train_one_step_deprecation_warning"):
  92. deprecation_warning(
  93. old=("ray.rllib.execution.train_ops.multi_gpu_train_one_step")
  94. )
  95. config = algorithm.config
  96. workers = algorithm.env_runner_group
  97. local_worker = workers.local_env_runner
  98. num_sgd_iter = config.get("num_epochs", config.get("num_sgd_iter", 1))
  99. minibatch_size = config.get("minibatch_size")
  100. if minibatch_size is None:
  101. minibatch_size = config["train_batch_size"]
  102. # Determine the number of devices (GPUs or 1 CPU) we use.
  103. num_devices = int(math.ceil(config["num_gpus"] or 1))
  104. # Make sure total batch size is dividable by the number of devices.
  105. # Batch size per tower.
  106. per_device_batch_size = minibatch_size // num_devices
  107. # Total batch size.
  108. batch_size = per_device_batch_size * num_devices
  109. assert batch_size % num_devices == 0
  110. assert batch_size >= num_devices, "Batch size too small!"
  111. # Handle everything as if multi-agent.
  112. train_batch = train_batch.as_multi_agent()
  113. # Load data into GPUs.
  114. load_timer = algorithm._timers[LOAD_BATCH_TIMER]
  115. with load_timer:
  116. num_loaded_samples = {}
  117. for policy_id, batch in train_batch.policy_batches.items():
  118. # Not a policy-to-train.
  119. if (
  120. local_worker.is_policy_to_train is not None
  121. and not local_worker.is_policy_to_train(policy_id, train_batch)
  122. ):
  123. continue
  124. # Decompress SampleBatch, in case some columns are compressed.
  125. batch.decompress_if_needed()
  126. # Load the entire train batch into the Policy's only buffer
  127. # (idx=0). Policies only have >1 buffers, if we are training
  128. # asynchronously.
  129. num_loaded_samples[policy_id] = local_worker.policy_map[
  130. policy_id
  131. ].load_batch_into_buffer(batch, buffer_index=0)
  132. # Execute minibatch SGD on loaded data.
  133. learn_timer = algorithm._timers[LEARN_ON_BATCH_TIMER]
  134. with learn_timer:
  135. # Use LearnerInfoBuilder as a unified way to build the final
  136. # results dict from `learn_on_loaded_batch` call(s).
  137. # This makes sure results dicts always have the same structure
  138. # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
  139. # tf vs torch).
  140. learner_info_builder = LearnerInfoBuilder(num_devices=num_devices)
  141. for policy_id, samples_per_device in num_loaded_samples.items():
  142. policy = local_worker.policy_map[policy_id]
  143. num_batches = max(1, int(samples_per_device) // int(per_device_batch_size))
  144. logger.debug("== sgd epochs for {} ==".format(policy_id))
  145. for _ in range(num_sgd_iter):
  146. permutation = np.random.permutation(num_batches)
  147. for batch_index in range(num_batches):
  148. # Learn on the pre-loaded data in the buffer.
  149. # Note: For minibatch SGD, the data is an offset into
  150. # the pre-loaded entire train batch.
  151. results = policy.learn_on_loaded_batch(
  152. permutation[batch_index] * per_device_batch_size, buffer_index=0
  153. )
  154. learner_info_builder.add_learn_on_batch_results(results, policy_id)
  155. # Tower reduce and finalize results.
  156. learner_info = learner_info_builder.finalize()
  157. load_timer.push_units_processed(train_batch.count)
  158. learn_timer.push_units_processed(train_batch.count)
  159. # TODO: Move this into Algorithm's `training_step` method for
  160. # better transparency.
  161. algorithm._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
  162. algorithm._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
  163. if algorithm.reward_estimators:
  164. learner_info[DEFAULT_POLICY_ID]["off_policy_estimation"] = {}
  165. for name, estimator in algorithm.reward_estimators.items():
  166. learner_info[DEFAULT_POLICY_ID]["off_policy_estimation"][
  167. name
  168. ] = estimator.train(train_batch)
  169. return learner_info