sgd.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. """Utils for minibatch SGD across multiple RLlib policies."""
  2. import logging
  3. import random
  4. import numpy as np
  5. from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
  6. from ray.rllib.utils.annotations import OldAPIStack
  7. from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
  8. logger = logging.getLogger(__name__)
  9. @OldAPIStack
  10. def standardized(array: np.ndarray):
  11. """Normalize the values in an array.
  12. Args:
  13. array (np.ndarray): Array of values to normalize.
  14. Returns:
  15. array with zero mean and unit standard deviation.
  16. """
  17. return (array - array.mean()) / max(1e-4, array.std())
  18. @OldAPIStack
  19. def minibatches(samples: SampleBatch, sgd_minibatch_size: int, shuffle: bool = True):
  20. """Return a generator yielding minibatches from a sample batch.
  21. Args:
  22. samples: SampleBatch to split up.
  23. sgd_minibatch_size: Size of minibatches to return.
  24. shuffle: Whether to shuffle the order of the generated minibatches.
  25. Note that in case of a non-recurrent policy, the incoming batch
  26. is globally shuffled first regardless of this setting, before
  27. the minibatches are generated from it!
  28. Yields:
  29. SampleBatch: Each of size `sgd_minibatch_size`.
  30. """
  31. if not sgd_minibatch_size:
  32. yield samples
  33. return
  34. if isinstance(samples, MultiAgentBatch):
  35. raise NotImplementedError(
  36. "Minibatching not implemented for multi-agent in simple mode"
  37. )
  38. if "state_in_0" not in samples and "state_out_0" not in samples:
  39. samples.shuffle()
  40. all_slices = samples._get_slice_indices(sgd_minibatch_size)
  41. data_slices, state_slices = all_slices
  42. if len(state_slices) == 0:
  43. if shuffle:
  44. random.shuffle(data_slices)
  45. for i, j in data_slices:
  46. yield samples[i:j]
  47. else:
  48. all_slices = list(zip(data_slices, state_slices))
  49. if shuffle:
  50. # Make sure to shuffle data and states while linked together.
  51. random.shuffle(all_slices)
  52. for (i, j), (si, sj) in all_slices:
  53. yield samples.slice(i, j, si, sj)
  54. @OldAPIStack
  55. def do_minibatch_sgd(
  56. samples,
  57. policies,
  58. local_worker,
  59. num_sgd_iter,
  60. sgd_minibatch_size,
  61. standardize_fields,
  62. ):
  63. """Execute minibatch SGD.
  64. Args:
  65. samples: Batch of samples to optimize.
  66. policies: Dictionary of policies to optimize.
  67. local_worker: Master rollout worker instance.
  68. num_sgd_iter: Number of epochs of optimization to take.
  69. sgd_minibatch_size: Size of minibatches to use for optimization.
  70. standardize_fields: List of sample field names that should be
  71. normalized prior to optimization.
  72. Returns:
  73. averaged info fetches over the last SGD epoch taken.
  74. """
  75. # Handle everything as if multi-agent.
  76. samples = samples.as_multi_agent()
  77. # Use LearnerInfoBuilder as a unified way to build the final
  78. # results dict from `learn_on_loaded_batch` call(s).
  79. # This makes sure results dicts always have the same structure
  80. # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
  81. # tf vs torch).
  82. learner_info_builder = LearnerInfoBuilder(num_devices=1)
  83. for policy_id, policy in policies.items():
  84. if policy_id not in samples.policy_batches:
  85. continue
  86. batch = samples.policy_batches[policy_id]
  87. for field in standardize_fields:
  88. batch[field] = standardized(batch[field])
  89. # Check to make sure that the sgd_minibatch_size is not smaller
  90. # than max_seq_len otherwise this will cause indexing errors while
  91. # performing sgd when using a RNN or Attention model
  92. if (
  93. policy.is_recurrent()
  94. and policy.config["model"]["max_seq_len"] > sgd_minibatch_size
  95. ):
  96. raise ValueError(
  97. "`sgd_minibatch_size` ({}) cannot be smaller than"
  98. "`max_seq_len` ({}).".format(
  99. sgd_minibatch_size, policy.config["model"]["max_seq_len"]
  100. )
  101. )
  102. for i in range(num_sgd_iter):
  103. for minibatch in minibatches(batch, sgd_minibatch_size):
  104. results = (
  105. local_worker.learn_on_batch(
  106. MultiAgentBatch({policy_id: minibatch}, minibatch.count)
  107. )
  108. )[policy_id]
  109. learner_info_builder.add_learn_on_batch_results(results, policy_id)
  110. learner_info = learner_info_builder.finalize()
  111. return learner_info