| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- from collections import defaultdict
- from typing import Dict
- import numpy as np
- import tree # pip install dm_tree
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
- from ray.rllib.utils.annotations import OldAPIStack
- from ray.rllib.utils.typing import PolicyID
- # Instant metrics (keys for metrics.info).
- LEARNER_INFO = "learner"
- # By convention, metrics from optimizing the loss can be reported in the
- # `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
- LEARNER_STATS_KEY = "learner_stats"
- @OldAPIStack
- class LearnerInfoBuilder:
- def __init__(self, num_devices: int = 1):
- self.num_devices = num_devices
- self.results_all_towers = defaultdict(list)
- self.is_finalized = False
- def add_learn_on_batch_results(
- self,
- results: Dict,
- policy_id: PolicyID = DEFAULT_POLICY_ID,
- ) -> None:
- """Adds a policy.learn_on_(loaded)?_batch() result to this builder.
- Args:
- results: The results returned by Policy.learn_on_batch or
- Policy.learn_on_loaded_batch.
- policy_id: The policy's ID, whose learn_on_(loaded)_batch method
- returned `results`.
- """
- assert (
- not self.is_finalized
- ), "LearnerInfo already finalized! Cannot add more results."
- # No towers: Single CPU.
- if "tower_0" not in results:
- self.results_all_towers[policy_id].append(results)
- # Multi-GPU case:
- else:
- self.results_all_towers[policy_id].append(
- tree.map_structure_with_path(
- lambda p, *s: _all_tower_reduce(p, *s),
- *(
- results.pop("tower_{}".format(tower_num))
- for tower_num in range(self.num_devices)
- )
- )
- )
- for k, v in results.items():
- if k == LEARNER_STATS_KEY:
- for k1, v1 in results[k].items():
- self.results_all_towers[policy_id][-1][LEARNER_STATS_KEY][
- k1
- ] = v1
- else:
- self.results_all_towers[policy_id][-1][k] = v
- def add_learn_on_batch_results_multi_agent(
- self,
- all_policies_results: Dict,
- ) -> None:
- """Adds multiple policy.learn_on_(loaded)?_batch() results to this builder.
- Args:
- all_policies_results: The results returned by all Policy.learn_on_batch or
- Policy.learn_on_loaded_batch wrapped as a dict mapping policy ID to
- results.
- """
- for pid, result in all_policies_results.items():
- if pid != "batch_count":
- self.add_learn_on_batch_results(result, policy_id=pid)
- def finalize(self):
- self.is_finalized = True
- info = {}
- for policy_id, results_all_towers in self.results_all_towers.items():
- # Reduce mean across all minibatch SGD steps (axis=0 to keep
- # all shapes as-is).
- info[policy_id] = tree.map_structure_with_path(
- _all_tower_reduce, *results_all_towers
- )
- return info
- @OldAPIStack
- def _all_tower_reduce(path, *tower_data):
- """Reduces stats across towers based on their stats-dict paths."""
- # TD-errors: Need to stay per batch item in order to be able to update
- # each item's weight in a prioritized replay buffer.
- if len(path) == 1 and path[0] == "td_error":
- return np.concatenate(tower_data, axis=0)
- elif tower_data[0] is None:
- return None
- if isinstance(path[-1], str):
- # TODO(sven): We need to fix this terrible dependency on `str.starts_with`
- # for determining, how to aggregate these stats! As "num_..." might
- # be a good indicator for summing, it will fail if the stats is e.g.
- # `num_samples_per_sec" :)
- # Counter stats: Reduce sum.
- # if path[-1].startswith("num_"):
- # return np.nansum(tower_data)
- # Min stats: Reduce min.
- if path[-1].startswith("min_"):
- return np.nanmin(tower_data)
- # Max stats: Reduce max.
- elif path[-1].startswith("max_"):
- return np.nanmax(tower_data)
- if np.isnan(tower_data).all():
- return np.nan
- # Everything else: Reduce mean.
- return np.nanmean(tower_data)
|