metrics.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import collections
  2. import logging
  3. from typing import TYPE_CHECKING, List, Optional
  4. import numpy as np
  5. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  6. from ray.rllib.utils.annotations import OldAPIStack
  7. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  8. from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict
  9. if TYPE_CHECKING:
  10. from ray.rllib.env.env_runner_group import EnvRunnerGroup
  11. logger = logging.getLogger(__name__)
  12. RolloutMetrics = OldAPIStack(
  13. collections.namedtuple(
  14. "RolloutMetrics",
  15. [
  16. "episode_length",
  17. "episode_reward",
  18. "agent_rewards",
  19. "custom_metrics",
  20. "perf_stats",
  21. "hist_data",
  22. "media",
  23. "episode_faulty",
  24. "connector_metrics",
  25. ],
  26. )
  27. )
  28. RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {}, False, {})
  29. @OldAPIStack
  30. def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict:
  31. """Return optimization stats reported from the policy.
  32. .. testcode::
  33. :skipif: True
  34. grad_info = worker.learn_on_batch(samples)
  35. # {"td_error": [...], "learner_stats": {"vf_loss": ..., ...}}
  36. print(get_stats(grad_info))
  37. .. testoutput::
  38. {"vf_loss": ..., "policy_loss": ...}
  39. """
  40. if LEARNER_STATS_KEY in grad_info:
  41. return grad_info[LEARNER_STATS_KEY]
  42. multiagent_stats = {}
  43. for k, v in grad_info.items():
  44. if type(v) is dict:
  45. if LEARNER_STATS_KEY in v:
  46. multiagent_stats[k] = v[LEARNER_STATS_KEY]
  47. return multiagent_stats
  48. @OldAPIStack
  49. def collect_metrics(
  50. workers: "EnvRunnerGroup",
  51. remote_worker_ids: Optional[List[int]] = None,
  52. timeout_seconds: int = 180,
  53. keep_custom_metrics: bool = False,
  54. ) -> ResultDict:
  55. """Gathers episode metrics from rollout worker set.
  56. Args:
  57. workers: EnvRunnerGroup.
  58. remote_worker_ids: Optional list of IDs of remote workers to collect
  59. metrics from.
  60. timeout_seconds: Timeout in seconds for collecting metrics from remote workers.
  61. keep_custom_metrics: Whether to keep custom metrics in the result dict as
  62. they are (True) or to aggregate them (False).
  63. Returns:
  64. A result dict of metrics.
  65. """
  66. episodes = collect_episodes(
  67. workers, remote_worker_ids, timeout_seconds=timeout_seconds
  68. )
  69. metrics = summarize_episodes(
  70. episodes, episodes, keep_custom_metrics=keep_custom_metrics
  71. )
  72. return metrics
  73. @OldAPIStack
  74. def collect_episodes(
  75. workers: "EnvRunnerGroup",
  76. remote_worker_ids: Optional[List[int]] = None,
  77. timeout_seconds: int = 180,
  78. ) -> List[RolloutMetrics]:
  79. """Gathers new episodes metrics tuples from the given RolloutWorkers.
  80. Args:
  81. workers: EnvRunnerGroup.
  82. remote_worker_ids: Optional list of IDs of remote workers to collect
  83. metrics from.
  84. timeout_seconds: Timeout in seconds for collecting metrics from remote workers.
  85. Returns:
  86. List of RolloutMetrics.
  87. """
  88. # This will drop get_metrics() calls that are too slow.
  89. # We can potentially make this an asynchronous call if this turns
  90. # out to be a problem.
  91. metric_lists = workers.foreach_env_runner(
  92. lambda w: w.get_metrics(),
  93. local_env_runner=True,
  94. remote_worker_ids=remote_worker_ids,
  95. timeout_seconds=timeout_seconds,
  96. )
  97. if len(metric_lists) == 0:
  98. logger.warning("WARNING: collected no metrics.")
  99. episodes = []
  100. for metrics in metric_lists:
  101. episodes.extend(metrics)
  102. return episodes
  103. @OldAPIStack
  104. def summarize_episodes(
  105. episodes: List[RolloutMetrics],
  106. new_episodes: List[RolloutMetrics] = None,
  107. keep_custom_metrics: bool = False,
  108. ) -> ResultDict:
  109. """Summarizes a set of episode metrics tuples.
  110. Args:
  111. episodes: List of most recent n episodes. This may include historical ones
  112. (not newly collected in this iteration) in order to achieve the size of
  113. the smoothing window.
  114. new_episodes: All the episodes that were completed in this iteration.
  115. keep_custom_metrics: Whether to keep custom metrics in the result dict as
  116. they are (True) or to aggregate them (False).
  117. Returns:
  118. A result dict of metrics.
  119. """
  120. if new_episodes is None:
  121. new_episodes = episodes
  122. episode_rewards = []
  123. episode_lengths = []
  124. policy_rewards = collections.defaultdict(list)
  125. custom_metrics = collections.defaultdict(list)
  126. perf_stats = collections.defaultdict(list)
  127. hist_stats = collections.defaultdict(list)
  128. episode_media = collections.defaultdict(list)
  129. connector_metrics = collections.defaultdict(list)
  130. num_faulty_episodes = 0
  131. for episode in episodes:
  132. # Faulty episodes may still carry perf_stats data.
  133. for k, v in episode.perf_stats.items():
  134. perf_stats[k].append(v)
  135. # Continue if this is a faulty episode.
  136. # There should be other meaningful stats to be collected.
  137. if episode.episode_faulty:
  138. num_faulty_episodes += 1
  139. continue
  140. episode_lengths.append(episode.episode_length)
  141. episode_rewards.append(episode.episode_reward)
  142. for k, v in episode.custom_metrics.items():
  143. custom_metrics[k].append(v)
  144. is_multi_agent = (
  145. len(episode.agent_rewards) > 1
  146. or DEFAULT_POLICY_ID not in episode.agent_rewards
  147. )
  148. if is_multi_agent:
  149. for (_, policy_id), reward in episode.agent_rewards.items():
  150. policy_rewards[policy_id].append(reward)
  151. for k, v in episode.hist_data.items():
  152. hist_stats[k] += v
  153. for k, v in episode.media.items():
  154. episode_media[k].append(v)
  155. if hasattr(episode, "connector_metrics"):
  156. # Group connector metrics by connector_metric name for all policies
  157. for per_pipeline_metrics in episode.connector_metrics.values():
  158. for per_connector_metrics in per_pipeline_metrics.values():
  159. for connector_metric_name, val in per_connector_metrics.items():
  160. connector_metrics[connector_metric_name].append(val)
  161. if episode_rewards:
  162. min_reward = min(episode_rewards)
  163. max_reward = max(episode_rewards)
  164. avg_reward = np.mean(episode_rewards)
  165. else:
  166. min_reward = float("nan")
  167. max_reward = float("nan")
  168. avg_reward = float("nan")
  169. if episode_lengths:
  170. avg_length = np.mean(episode_lengths)
  171. else:
  172. avg_length = float("nan")
  173. # Show as histogram distributions.
  174. hist_stats["episode_reward"] = episode_rewards
  175. hist_stats["episode_lengths"] = episode_lengths
  176. policy_reward_min = {}
  177. policy_reward_mean = {}
  178. policy_reward_max = {}
  179. for policy_id, rewards in policy_rewards.copy().items():
  180. policy_reward_min[policy_id] = np.min(rewards)
  181. policy_reward_mean[policy_id] = np.mean(rewards)
  182. policy_reward_max[policy_id] = np.max(rewards)
  183. # Show as histogram distributions.
  184. hist_stats["policy_{}_reward".format(policy_id)] = rewards
  185. for k, v_list in custom_metrics.copy().items():
  186. filt = [v for v in v_list if not np.any(np.isnan(v))]
  187. if keep_custom_metrics:
  188. custom_metrics[k] = filt
  189. else:
  190. custom_metrics[k + "_mean"] = np.mean(filt)
  191. if filt:
  192. custom_metrics[k + "_min"] = np.min(filt)
  193. custom_metrics[k + "_max"] = np.max(filt)
  194. else:
  195. custom_metrics[k + "_min"] = float("nan")
  196. custom_metrics[k + "_max"] = float("nan")
  197. del custom_metrics[k]
  198. for k, v_list in perf_stats.copy().items():
  199. perf_stats[k] = np.mean(v_list)
  200. mean_connector_metrics = dict()
  201. for k, v_list in connector_metrics.items():
  202. mean_connector_metrics[k] = np.mean(v_list)
  203. return dict(
  204. episode_reward_max=max_reward,
  205. episode_reward_min=min_reward,
  206. episode_reward_mean=avg_reward,
  207. episode_len_mean=avg_length,
  208. episode_media=dict(episode_media),
  209. episodes_timesteps_total=sum(episode_lengths),
  210. policy_reward_min=policy_reward_min,
  211. policy_reward_max=policy_reward_max,
  212. policy_reward_mean=policy_reward_mean,
  213. custom_metrics=dict(custom_metrics),
  214. hist_stats=dict(hist_stats),
  215. sampler_perf=dict(perf_stats),
  216. num_faulty_episodes=num_faulty_episodes,
  217. connector_metrics=mean_connector_metrics,
  218. # Added these (duplicate) values here for forward compatibility with the new API
  219. # stack's metrics structure. This allows us to unify our test cases and keeping
  220. # the new API stack clean of backward-compatible keys.
  221. num_episodes=len(new_episodes),
  222. episode_return_max=max_reward,
  223. episode_return_min=min_reward,
  224. episode_return_mean=avg_reward,
  225. episodes_this_iter=len(new_episodes), # deprecate in favor of `num_epsodes_...`
  226. )