learner_info.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from collections import defaultdict
  2. from typing import Dict
  3. import numpy as np
  4. import tree # pip install dm_tree
  5. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  6. from ray.rllib.utils.annotations import OldAPIStack
  7. from ray.rllib.utils.typing import PolicyID
  8. # Instant metrics (keys for metrics.info).
  9. LEARNER_INFO = "learner"
  10. # By convention, metrics from optimizing the loss can be reported in the
  11. # `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
  12. LEARNER_STATS_KEY = "learner_stats"
  13. @OldAPIStack
  14. class LearnerInfoBuilder:
  15. def __init__(self, num_devices: int = 1):
  16. self.num_devices = num_devices
  17. self.results_all_towers = defaultdict(list)
  18. self.is_finalized = False
  19. def add_learn_on_batch_results(
  20. self,
  21. results: Dict,
  22. policy_id: PolicyID = DEFAULT_POLICY_ID,
  23. ) -> None:
  24. """Adds a policy.learn_on_(loaded)?_batch() result to this builder.
  25. Args:
  26. results: The results returned by Policy.learn_on_batch or
  27. Policy.learn_on_loaded_batch.
  28. policy_id: The policy's ID, whose learn_on_(loaded)_batch method
  29. returned `results`.
  30. """
  31. assert (
  32. not self.is_finalized
  33. ), "LearnerInfo already finalized! Cannot add more results."
  34. # No towers: Single CPU.
  35. if "tower_0" not in results:
  36. self.results_all_towers[policy_id].append(results)
  37. # Multi-GPU case:
  38. else:
  39. self.results_all_towers[policy_id].append(
  40. tree.map_structure_with_path(
  41. lambda p, *s: _all_tower_reduce(p, *s),
  42. *(
  43. results.pop("tower_{}".format(tower_num))
  44. for tower_num in range(self.num_devices)
  45. )
  46. )
  47. )
  48. for k, v in results.items():
  49. if k == LEARNER_STATS_KEY:
  50. for k1, v1 in results[k].items():
  51. self.results_all_towers[policy_id][-1][LEARNER_STATS_KEY][
  52. k1
  53. ] = v1
  54. else:
  55. self.results_all_towers[policy_id][-1][k] = v
  56. def add_learn_on_batch_results_multi_agent(
  57. self,
  58. all_policies_results: Dict,
  59. ) -> None:
  60. """Adds multiple policy.learn_on_(loaded)?_batch() results to this builder.
  61. Args:
  62. all_policies_results: The results returned by all Policy.learn_on_batch or
  63. Policy.learn_on_loaded_batch wrapped as a dict mapping policy ID to
  64. results.
  65. """
  66. for pid, result in all_policies_results.items():
  67. if pid != "batch_count":
  68. self.add_learn_on_batch_results(result, policy_id=pid)
  69. def finalize(self):
  70. self.is_finalized = True
  71. info = {}
  72. for policy_id, results_all_towers in self.results_all_towers.items():
  73. # Reduce mean across all minibatch SGD steps (axis=0 to keep
  74. # all shapes as-is).
  75. info[policy_id] = tree.map_structure_with_path(
  76. _all_tower_reduce, *results_all_towers
  77. )
  78. return info
  79. @OldAPIStack
  80. def _all_tower_reduce(path, *tower_data):
  81. """Reduces stats across towers based on their stats-dict paths."""
  82. # TD-errors: Need to stay per batch item in order to be able to update
  83. # each item's weight in a prioritized replay buffer.
  84. if len(path) == 1 and path[0] == "td_error":
  85. return np.concatenate(tower_data, axis=0)
  86. elif tower_data[0] is None:
  87. return None
  88. if isinstance(path[-1], str):
  89. # TODO(sven): We need to fix this terrible dependency on `str.starts_with`
  90. # for determining, how to aggregate these stats! As "num_..." might
  91. # be a good indicator for summing, it will fail if the stats is e.g.
  92. # `num_samples_per_sec" :)
  93. # Counter stats: Reduce sum.
  94. # if path[-1].startswith("num_"):
  95. # return np.nansum(tower_data)
  96. # Min stats: Reduce min.
  97. if path[-1].startswith("min_"):
  98. return np.nanmin(tower_data)
  99. # Max stats: Reduce max.
  100. elif path[-1].startswith("max_"):
  101. return np.nanmax(tower_data)
  102. if np.isnan(tower_data).all():
  103. return np.nan
  104. # Everything else: Reduce mean.
  105. return np.nanmean(tower_data)