utils.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788
  1. import argparse
  2. import json
  3. import logging
  4. import os
  5. import re
  6. import time
  7. from typing import (
  8. TYPE_CHECKING,
  9. Any,
  10. Dict,
  11. List,
  12. Optional,
  13. Type,
  14. Union,
  15. )
  16. import numpy as np
  17. import ray
  18. from ray import tune
  19. from ray.air.integrations.wandb import WANDB_ENV_VAR, WandbLoggerCallback
  20. from ray.rllib.utils.metrics import (
  21. ENV_RUNNER_RESULTS,
  22. EPISODE_RETURN_MEAN,
  23. EVALUATION_RESULTS,
  24. NUM_ENV_STEPS_SAMPLED_LIFETIME,
  25. )
  26. from ray.rllib.utils.serialization import convert_numpy_to_python_primitives
  27. from ray.rllib.utils.typing import ResultDict
  28. from ray.tune import CLIReporter
  29. from ray.tune.result import TRAINING_ITERATION
  30. if TYPE_CHECKING:
  31. from ray.rllib.algorithms import AlgorithmConfig
  32. logger = logging.getLogger(__name__)
  33. def add_rllib_example_script_args(
  34. parser: Optional[argparse.ArgumentParser] = None,
  35. default_reward: float = 100.0,
  36. default_iters: int = 200,
  37. default_timesteps: int = 100000,
  38. ) -> argparse.ArgumentParser:
  39. """Adds RLlib-typical (and common) examples scripts command line args to a parser.
  40. TODO (sven): This function should be used by most of our examples scripts, which
  41. already mostly have this logic in them (but written out).
  42. Args:
  43. parser: The parser to add the arguments to. If None, create a new one.
  44. default_reward: The default value for the --stop-reward option.
  45. default_iters: The default value for the --stop-iters option.
  46. default_timesteps: The default value for the --stop-timesteps option.
  47. Returns:
  48. The altered (or newly created) parser object.
  49. """
  50. if parser is None:
  51. parser = argparse.ArgumentParser()
  52. # Algo and Algo config options.
  53. parser.add_argument(
  54. "--algo", type=str, default="PPO", help="The RLlib-registered algorithm to use."
  55. )
  56. parser.add_argument(
  57. "--framework",
  58. choices=["tf", "tf2", "torch"],
  59. default="torch",
  60. help="The DL framework specifier.",
  61. )
  62. parser.add_argument(
  63. "--env",
  64. type=str,
  65. default=None,
  66. help="The gym.Env identifier to run the experiment with.",
  67. )
  68. parser.add_argument(
  69. "--num-env-runners",
  70. type=int,
  71. default=None,
  72. help="The number of (remote) EnvRunners to use for the experiment.",
  73. )
  74. parser.add_argument(
  75. "--num-envs-per-env-runner",
  76. type=int,
  77. default=None,
  78. help="The number of (vectorized) environments per EnvRunner. Note that "
  79. "this is identical to the batch size for (inference) action computations.",
  80. )
  81. parser.add_argument(
  82. "--num-agents",
  83. type=int,
  84. default=0,
  85. help="If 0 (default), will run as single-agent. If > 0, will run as "
  86. "multi-agent with the environment simply cloned n times and each agent acting "
  87. "independently at every single timestep. The overall reward for this "
  88. "experiment is then the sum over all individual agents' rewards.",
  89. )
  90. # Evaluation options.
  91. parser.add_argument(
  92. "--evaluation-num-env-runners",
  93. type=int,
  94. default=0,
  95. help="The number of evaluation (remote) EnvRunners to use for the experiment.",
  96. )
  97. parser.add_argument(
  98. "--evaluation-interval",
  99. type=int,
  100. default=0,
  101. help="Every how many iterations to run one round of evaluation. "
  102. "Use 0 (default) to disable evaluation.",
  103. )
  104. parser.add_argument(
  105. "--evaluation-duration",
  106. type=lambda v: v if v == "auto" else int(v),
  107. default=10,
  108. help="The number of evaluation units to run each evaluation round. "
  109. "Use `--evaluation-duration-unit` to count either in 'episodes' "
  110. "or 'timesteps'. If 'auto', will run as many as possible during train pass ("
  111. "`--evaluation-parallel-to-training` must be set then).",
  112. )
  113. parser.add_argument(
  114. "--evaluation-duration-unit",
  115. type=str,
  116. default="episodes",
  117. choices=["episodes", "timesteps"],
  118. help="The evaluation duration unit to count by. One of 'episodes' or "
  119. "'timesteps'. This unit will be run `--evaluation-duration` times in each "
  120. "evaluation round. If `--evaluation-duration=auto`, this setting does not "
  121. "matter.",
  122. )
  123. parser.add_argument(
  124. "--evaluation-parallel-to-training",
  125. action="store_true",
  126. help="Whether to run evaluation parallel to training. This might help speed up "
  127. "your overall iteration time. Be aware that when using this option, your "
  128. "reported evaluation results are referring to one iteration before the current "
  129. "one.",
  130. )
  131. # RLlib logging options.
  132. parser.add_argument(
  133. "--output",
  134. type=str,
  135. default=None,
  136. help="The output directory to write trajectories to, which are collected by "
  137. "the algo's EnvRunners.",
  138. )
  139. parser.add_argument(
  140. "--log-level",
  141. type=str,
  142. default=None, # None -> use default
  143. choices=["INFO", "DEBUG", "WARN", "ERROR"],
  144. help="The log-level to be used by the RLlib logger.",
  145. )
  146. # tune.Tuner options.
  147. parser.add_argument(
  148. "--no-tune",
  149. action="store_true",
  150. help="Whether to NOT use tune.Tuner(), but rather a simple for-loop calling "
  151. "`algo.train()` repeatedly until one of the stop criteria is met.",
  152. )
  153. parser.add_argument(
  154. "--num-samples",
  155. type=int,
  156. default=1,
  157. help="How many (tune.Tuner.fit()) experiments to execute - if possible in "
  158. "parallel.",
  159. )
  160. parser.add_argument(
  161. "--max-concurrent-trials",
  162. type=int,
  163. default=None,
  164. help="How many (tune.Tuner) trials to run concurrently.",
  165. )
  166. parser.add_argument(
  167. "--verbose",
  168. type=int,
  169. default=2,
  170. help="The verbosity level for the `tune.Tuner()` running the experiment.",
  171. )
  172. parser.add_argument(
  173. "--checkpoint-freq",
  174. type=int,
  175. default=0,
  176. help=(
  177. "The frequency (in training iterations) with which to create checkpoints. "
  178. "Note that if --wandb-key is provided, all checkpoints will "
  179. "automatically be uploaded to WandB."
  180. ),
  181. )
  182. parser.add_argument(
  183. "--checkpoint-at-end",
  184. action="store_true",
  185. help=(
  186. "Whether to create a checkpoint at the very end of the experiment. "
  187. "Note that if --wandb-key is provided, all checkpoints will "
  188. "automatically be uploaded to WandB."
  189. ),
  190. )
  191. parser.add_argument(
  192. "--tune-max-report-freq",
  193. type=int,
  194. default=5, # tune default to 5
  195. help="The frequency (in seconds) at which to log the training performance.",
  196. )
  197. # WandB logging options.
  198. parser.add_argument(
  199. "--wandb-key",
  200. type=str,
  201. default=None,
  202. help="The WandB API key to use for uploading results.",
  203. )
  204. parser.add_argument(
  205. "--wandb-project",
  206. type=str,
  207. default=None,
  208. help="The WandB project name to use.",
  209. )
  210. parser.add_argument(
  211. "--wandb-run-name",
  212. type=str,
  213. default=None,
  214. help="The WandB run name to use.",
  215. )
  216. # Experiment stopping and testing criteria.
  217. parser.add_argument(
  218. "--stop-reward",
  219. type=float,
  220. default=default_reward,
  221. help="Reward at which the script should stop training.",
  222. )
  223. parser.add_argument(
  224. "--stop-iters",
  225. type=int,
  226. default=default_iters,
  227. help="The number of iterations to train.",
  228. )
  229. parser.add_argument(
  230. "--stop-timesteps",
  231. type=int,
  232. default=default_timesteps,
  233. help="The number of (environment sampling) timesteps to train.",
  234. )
  235. parser.add_argument(
  236. "--as-test",
  237. action="store_true",
  238. help="Whether this script should be run as a test. If set, --stop-reward must "
  239. "be achieved within --stop-timesteps AND --stop-iters, otherwise this "
  240. "script will throw an exception at the end.",
  241. )
  242. parser.add_argument(
  243. "--as-release-test",
  244. action="store_true",
  245. help="Whether this script should be run as a release test. If set, "
  246. "all that applies to the --as-test option is true, plus, a short JSON summary "
  247. "will be written into a results file whose location is given by the ENV "
  248. "variable `TEST_OUTPUT_JSON`.",
  249. )
  250. # Learner scaling options.
  251. parser.add_argument(
  252. "--num-learners",
  253. type=int,
  254. default=None,
  255. help="The number of Learners to use. If `None`, use the algorithm's default "
  256. "value.",
  257. )
  258. parser.add_argument(
  259. "--num-cpus-per-learner",
  260. type=float,
  261. default=None,
  262. help="The number of CPUs per Learner to use. If `None`, use the algorithm's "
  263. "default value.",
  264. )
  265. parser.add_argument(
  266. "--num-gpus-per-learner",
  267. type=float,
  268. default=None,
  269. help="The number of GPUs per Learner to use. If `None` and there are enough "
  270. "GPUs for all required Learners (--num-learners), use a value of 1, "
  271. "otherwise 0.",
  272. )
  273. parser.add_argument(
  274. "--num-aggregator-actors-per-learner",
  275. type=int,
  276. default=None,
  277. help="The number of Aggregator actors to use per Learner. If `None`, use the "
  278. "algorithm's default value.",
  279. )
  280. # Ray init options.
  281. parser.add_argument("--num-cpus", type=int, default=0)
  282. parser.add_argument(
  283. "--local-mode",
  284. action="store_true",
  285. help="Init Ray in local mode for easier debugging.",
  286. )
  287. # Old API stack: config.num_gpus.
  288. parser.add_argument(
  289. "--num-gpus",
  290. type=int,
  291. default=None,
  292. help="The number of GPUs to use (only on the old API stack).",
  293. )
  294. parser.add_argument(
  295. "--old-api-stack",
  296. action="store_true",
  297. help="Run this script on the old API stack of RLlib.",
  298. )
  299. # Deprecated options. Throws error when still used. Use `--old-api-stack` for
  300. # disabling the new API stack.
  301. parser.add_argument(
  302. "--enable-new-api-stack",
  303. action="store_true",
  304. )
  305. return parser
  306. # TODO (simon): Use this function in the `run_rllib_example_experiment` when
  307. # `no_tune` is `True`.
  308. def should_stop(
  309. stop: Dict[str, Any], results: ResultDict, keep_ray_up: bool = False
  310. ) -> bool:
  311. """Checks stopping criteria on `ResultDict`
  312. Args:
  313. stop: Dictionary of stopping criteria. Each criterium is a mapping of
  314. a metric in the `ResultDict` of the algorithm to a certain criterium.
  315. results: An RLlib `ResultDict` containing all results from a training step.
  316. keep_ray_up: Optionally shutting down the runnin Ray instance.
  317. Returns: True, if any stopping criterium is fulfilled. Otherwise, False.
  318. """
  319. for key, threshold in stop.items():
  320. val = results
  321. for k in key.split("/"):
  322. k = k.strip()
  323. # If k exists in the current level, continue down;
  324. # otherwise, set val to None and break out of this inner loop.
  325. if isinstance(val, dict) and k in val:
  326. val = val[k]
  327. else:
  328. val = None
  329. break
  330. # If the key was not found, simply skip to the next criterion.
  331. if val is None:
  332. continue
  333. try:
  334. # Check that val is numeric and meets the threshold.
  335. if not np.isnan(val) and val >= threshold:
  336. print(f"Stop criterion ({key}={threshold}) fulfilled!")
  337. if not keep_ray_up:
  338. ray.shutdown()
  339. return True
  340. except TypeError:
  341. # If val isn't numeric, skip this criterion.
  342. continue
  343. # If none of the criteria are fulfilled, return False.
  344. return False
  345. # TODO (sven): Make this the de-facto, well documented, and unified utility for most of
  346. # our tests:
  347. # - CI (label: "learning_tests")
  348. # - release tests (benchmarks)
  349. # - example scripts
  350. def run_rllib_example_script_experiment(
  351. base_config: "AlgorithmConfig",
  352. args: Optional[argparse.Namespace] = None,
  353. *,
  354. stop: Optional[Dict] = None,
  355. success_metric: Optional[Dict] = None,
  356. trainable: Optional[Type] = None,
  357. tune_callbacks: Optional[List] = None,
  358. keep_config: bool = False,
  359. keep_ray_up: bool = False,
  360. scheduler=None,
  361. progress_reporter=None,
  362. ) -> Union[ResultDict, tune.result_grid.ResultGrid]:
  363. """Given an algorithm config and some command line args, runs an experiment.
  364. There are some constraints on what properties must be defined in `args`.
  365. It should ideally be generated via calling
  366. `args = add_rllib_example_script_args()`, which can be found in this very module
  367. here.
  368. The function sets up an Algorithm object from the given config (altered by the
  369. contents of `args`), then runs the Algorithm via Tune (or manually, if
  370. `args.no_tune` is set to True) using the stopping criteria in `stop`.
  371. At the end of the experiment, if `args.as_test` is True, checks, whether the
  372. Algorithm reached the `success_metric` (if None, use `env_runners/
  373. episode_return_mean` with a minimum value of `args.stop_reward`).
  374. See https://github.com/ray-project/ray/tree/master/rllib/examples for an overview
  375. of all supported command line options.
  376. Args:
  377. base_config: The AlgorithmConfig object to use for this experiment. This base
  378. config will be automatically "extended" based on some of the provided
  379. `args`. For example, `args.num_env_runners` is used to set
  380. `config.num_env_runners`, etc..
  381. args: A argparse.Namespace object, ideally returned by calling
  382. `args = add_rllib_example_script_args()`. It must have the following
  383. properties defined: `stop_iters`, `stop_reward`, `stop_timesteps`,
  384. `no_tune`, `verbose`, `checkpoint_freq`, `as_test`. Optionally, for WandB
  385. logging: `wandb_key`, `wandb_project`, `wandb_run_name`.
  386. stop: An optional dict mapping ResultDict key strings (using "/" in case of
  387. nesting, e.g. "env_runners/episode_return_mean" for referring to
  388. `result_dict['env_runners']['episode_return_mean']` to minimum
  389. values, reaching of which will stop the experiment). Default is:
  390. {
  391. "env_runners/episode_return_mean": args.stop_reward,
  392. "training_iteration": args.stop_iters,
  393. "num_env_steps_sampled_lifetime": args.stop_timesteps,
  394. }
  395. success_metric: Only relevant if `args.as_test` is True.
  396. A dict mapping a single(!) ResultDict key string (using "/" in
  397. case of nesting, e.g. "env_runners/episode_return_mean" for referring
  398. to `result_dict['env_runners']['episode_return_mean']` to a single(!)
  399. minimum value to be reached in order for the experiment to count as
  400. successful. If `args.as_test` is True AND this `success_metric` is not
  401. reached with the bounds defined by `stop`, will raise an Exception.
  402. trainable: The Trainable sub-class to run in the tune.Tuner. If None (default),
  403. use the registered RLlib Algorithm class specified by args.algo.
  404. tune_callbacks: A list of Tune callbacks to configure with the tune.Tuner.
  405. In case `args.wandb_key` is provided, appends a WandB logger to this
  406. list.
  407. keep_config: Set this to True, if you don't want this utility to change the
  408. given `base_config` in any way and leave it as-is. This is helpful
  409. for those example scripts which demonstrate how to set config settings
  410. that are otherwise taken care of automatically in this function (e.g.
  411. `num_env_runners`).
  412. Returns:
  413. The last ResultDict from a --no-tune run OR the tune.Tuner.fit()
  414. results.
  415. """
  416. if args is None:
  417. parser = add_rllib_example_script_args()
  418. args = parser.parse_args()
  419. # Deprecated args.
  420. if args.enable_new_api_stack:
  421. raise ValueError(
  422. "`--enable-new-api-stack` flag no longer supported (it's the default "
  423. "behavior now)! To switch back to the old API stack on your scripts, use "
  424. "the `--old-api-stack` flag."
  425. )
  426. # If run --as-release-test, --as-test must also be set.
  427. if args.as_release_test:
  428. args.as_test = True
  429. if args.as_test:
  430. args.verbose = 1
  431. args.tune_max_report_freq = 30
  432. # Initialize Ray.
  433. ray.init(
  434. num_cpus=args.num_cpus or None,
  435. local_mode=args.local_mode,
  436. ignore_reinit_error=True,
  437. )
  438. # Define one or more stopping criteria.
  439. if stop is None:
  440. stop = {
  441. f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward,
  442. f"{ENV_RUNNER_RESULTS}/{NUM_ENV_STEPS_SAMPLED_LIFETIME}": (
  443. args.stop_timesteps
  444. ),
  445. TRAINING_ITERATION: args.stop_iters,
  446. }
  447. config = base_config
  448. # Enhance the `base_config`, based on provided `args`.
  449. if not keep_config:
  450. # Set the framework.
  451. config.framework(args.framework)
  452. # Add an env specifier (only if not already set in config)?
  453. if args.env is not None and config.env is None:
  454. config.environment(args.env)
  455. # Disable the new API stack?
  456. if args.old_api_stack:
  457. config.api_stack(
  458. enable_rl_module_and_learner=False,
  459. enable_env_runner_and_connector_v2=False,
  460. )
  461. # Define EnvRunner scaling and behavior.
  462. if args.num_env_runners is not None:
  463. config.env_runners(num_env_runners=args.num_env_runners)
  464. if args.num_envs_per_env_runner is not None:
  465. config.env_runners(num_envs_per_env_runner=args.num_envs_per_env_runner)
  466. # Define compute resources used automatically (only using the --num-learners
  467. # and --num-gpus-per-learner args).
  468. # New stack.
  469. if config.enable_rl_module_and_learner:
  470. if args.num_gpus is not None and args.num_gpus > 0:
  471. raise ValueError(
  472. "--num-gpus is not supported on the new API stack! To train on "
  473. "GPUs, use the command line options `--num-gpus-per-learner=1` and "
  474. "`--num-learners=[your number of available GPUs]`, instead."
  475. )
  476. # Do we have GPUs available in the cluster?
  477. num_gpus_available = ray.cluster_resources().get("GPU", 0)
  478. # Number of actual Learner instances (including the local Learner if
  479. # `num_learners=0`).
  480. num_actual_learners = (
  481. args.num_learners
  482. if args.num_learners is not None
  483. else config.num_learners
  484. ) or 1 # 1: There is always a local Learner, if num_learners=0.
  485. # How many were hard-requested by the user
  486. # (through explicit `--num-gpus-per-learner >= 1`).
  487. num_gpus_requested = (args.num_gpus_per_learner or 0) * num_actual_learners
  488. # Number of GPUs needed, if `num_gpus_per_learner=None` (auto).
  489. num_gpus_needed_if_available = (
  490. args.num_gpus_per_learner
  491. if args.num_gpus_per_learner is not None
  492. else 1
  493. ) * num_actual_learners
  494. # Define compute resources used.
  495. config.resources(num_gpus=0) # @OldAPIStack
  496. if args.num_learners is not None:
  497. config.learners(num_learners=args.num_learners)
  498. # User wants to use aggregator actors per Learner.
  499. if args.num_aggregator_actors_per_learner is not None:
  500. config.learners(
  501. num_aggregator_actors_per_learner=(
  502. args.num_aggregator_actors_per_learner
  503. )
  504. )
  505. # User wants to use GPUs if available, but doesn't hard-require them.
  506. if args.num_gpus_per_learner is None:
  507. if num_gpus_available >= num_gpus_needed_if_available:
  508. config.learners(num_gpus_per_learner=1)
  509. else:
  510. config.learners(num_gpus_per_learner=0)
  511. # User hard-requires n GPUs, but they are not available -> Error.
  512. elif num_gpus_available < num_gpus_requested:
  513. raise ValueError(
  514. "You are running your script with --num-learners="
  515. f"{args.num_learners} and --num-gpus-per-learner="
  516. f"{args.num_gpus_per_learner}, but your cluster only has "
  517. f"{num_gpus_available} GPUs!"
  518. )
  519. # All required GPUs are available -> Use them.
  520. else:
  521. config.learners(num_gpus_per_learner=args.num_gpus_per_learner)
  522. # Set CPUs per Learner.
  523. if args.num_cpus_per_learner is not None:
  524. config.learners(num_cpus_per_learner=args.num_cpus_per_learner)
  525. # Old stack (override only if arg was provided by user).
  526. elif args.num_gpus is not None:
  527. config.resources(num_gpus=args.num_gpus)
  528. # Evaluation setup.
  529. if args.evaluation_interval > 0:
  530. config.evaluation(
  531. evaluation_num_env_runners=args.evaluation_num_env_runners,
  532. evaluation_interval=args.evaluation_interval,
  533. evaluation_duration=args.evaluation_duration,
  534. evaluation_duration_unit=args.evaluation_duration_unit,
  535. evaluation_parallel_to_training=args.evaluation_parallel_to_training,
  536. )
  537. # Set the log-level (if applicable).
  538. if args.log_level is not None:
  539. config.debugging(log_level=args.log_level)
  540. # Set the output dir (if applicable).
  541. if args.output is not None:
  542. config.offline_data(output=args.output)
  543. # Run the experiment w/o Tune (directly operate on the RLlib Algorithm object).
  544. if args.no_tune:
  545. assert not args.as_test and not args.as_release_test
  546. algo = config.build()
  547. for i in range(stop.get(TRAINING_ITERATION, args.stop_iters)):
  548. results = algo.train()
  549. if ENV_RUNNER_RESULTS in results:
  550. mean_return = results[ENV_RUNNER_RESULTS].get(
  551. EPISODE_RETURN_MEAN, np.nan
  552. )
  553. print(f"iter={i} R={mean_return}", end="")
  554. if (
  555. EVALUATION_RESULTS in results
  556. and ENV_RUNNER_RESULTS in results[EVALUATION_RESULTS]
  557. ):
  558. Reval = results[EVALUATION_RESULTS][ENV_RUNNER_RESULTS][
  559. EPISODE_RETURN_MEAN
  560. ]
  561. print(f" R(eval)={Reval}", end="")
  562. print()
  563. for key, threshold in stop.items():
  564. val = results
  565. for k in key.split("/"):
  566. try:
  567. val = val[k]
  568. except KeyError:
  569. val = None
  570. break
  571. if val is not None and not np.isnan(val) and val >= threshold:
  572. print(f"Stop criterium ({key}={threshold}) fulfilled!")
  573. if not keep_ray_up:
  574. ray.shutdown()
  575. return results
  576. if not keep_ray_up:
  577. ray.shutdown()
  578. return results
  579. # Run the experiment using Ray Tune.
  580. # Log results using WandB.
  581. tune_callbacks = tune_callbacks or []
  582. if hasattr(args, "wandb_key") and (
  583. args.wandb_key is not None or WANDB_ENV_VAR in os.environ
  584. ):
  585. wandb_key = args.wandb_key or os.environ[WANDB_ENV_VAR]
  586. project = args.wandb_project or (
  587. args.algo.lower() + "-" + re.sub("\\W+", "-", str(config.env).lower())
  588. )
  589. tune_callbacks.append(
  590. WandbLoggerCallback(
  591. api_key=wandb_key,
  592. project=project,
  593. upload_checkpoints=True,
  594. **({"name": args.wandb_run_name} if args.wandb_run_name else {}),
  595. )
  596. )
  597. # Auto-configure a CLIReporter (to log the results to the console).
  598. # Use better ProgressReporter for multi-agent cases: List individual policy rewards.
  599. if progress_reporter is None:
  600. if args.num_agents == 0:
  601. progress_reporter = CLIReporter(
  602. metric_columns={
  603. TRAINING_ITERATION: "iter",
  604. "time_total_s": "total time (s)",
  605. NUM_ENV_STEPS_SAMPLED_LIFETIME: "ts",
  606. f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": "episode return mean",
  607. },
  608. max_report_frequency=args.tune_max_report_freq,
  609. )
  610. else:
  611. progress_reporter = CLIReporter(
  612. metric_columns={
  613. **{
  614. TRAINING_ITERATION: "iter",
  615. "time_total_s": "total time (s)",
  616. NUM_ENV_STEPS_SAMPLED_LIFETIME: "ts",
  617. f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": "combined return",
  618. },
  619. **{
  620. (
  621. f"{ENV_RUNNER_RESULTS}/module_episode_returns_mean/{pid}"
  622. ): f"return {pid}"
  623. for pid in config.policies
  624. },
  625. },
  626. max_report_frequency=args.tune_max_report_freq,
  627. )
  628. # Force Tuner to use old progress output as the new one silently ignores our custom
  629. # `CLIReporter`.
  630. os.environ["RAY_AIR_NEW_OUTPUT"] = "0"
  631. # Run the actual experiment (using Tune).
  632. start_time = time.time()
  633. results = tune.Tuner(
  634. trainable or config.algo_class,
  635. param_space=config,
  636. run_config=tune.RunConfig(
  637. stop=stop,
  638. verbose=args.verbose,
  639. callbacks=tune_callbacks,
  640. checkpoint_config=tune.CheckpointConfig(
  641. checkpoint_frequency=args.checkpoint_freq,
  642. checkpoint_at_end=args.checkpoint_at_end,
  643. ),
  644. progress_reporter=progress_reporter,
  645. ),
  646. tune_config=tune.TuneConfig(
  647. num_samples=args.num_samples,
  648. max_concurrent_trials=args.max_concurrent_trials,
  649. scheduler=scheduler,
  650. ),
  651. ).fit()
  652. time_taken = time.time() - start_time
  653. if not keep_ray_up:
  654. ray.shutdown()
  655. # Error out, if Tuner.fit() failed to run. Otherwise, erroneous examples might pass
  656. # the CI tests w/o us knowing that they are broken (b/c some examples do not have
  657. # a --as-test flag and/or any passing criteria).
  658. if results.errors:
  659. # Might cause an IndexError if the tuple is not long enough; in that case, use repr(e).
  660. errors = [
  661. e.args[0].args[2]
  662. if e.args and hasattr(e.args[0], "args") and len(e.args[0].args) > 2
  663. else repr(e)
  664. for e in results.errors
  665. ]
  666. raise RuntimeError(
  667. f"Running the example script resulted in one or more errors! {errors}"
  668. )
  669. # If run as a test, check whether we reached the specified success criteria.
  670. test_passed = False
  671. if args.as_test:
  672. # Success metric not provided, try extracting it from `stop`.
  673. if success_metric is None:
  674. for try_it in [
  675. f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}",
  676. f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}",
  677. ]:
  678. if try_it in stop:
  679. success_metric = {try_it: stop[try_it]}
  680. break
  681. if success_metric is None:
  682. success_metric = {
  683. f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward,
  684. }
  685. # TODO (sven): Make this work for more than one metric (AND-logic?).
  686. # Get maximum value of `metric` over all trials
  687. # (check if at least one trial achieved some learning, not just the final one).
  688. success_metric_key, success_metric_value = next(iter(success_metric.items()))
  689. best_value = max(
  690. row[success_metric_key] for _, row in results.get_dataframe().iterrows()
  691. )
  692. if best_value >= success_metric_value:
  693. test_passed = True
  694. print(f"`{success_metric_key}` of {success_metric_value} reached! ok")
  695. if args.as_release_test:
  696. trial = results._experiment_analysis.trials[0]
  697. stats = trial.last_result
  698. stats.pop("config", None)
  699. json_summary = {
  700. "time_taken": float(time_taken),
  701. "trial_states": [trial.status],
  702. "last_update": float(time.time()),
  703. "stats": convert_numpy_to_python_primitives(stats),
  704. "passed": [test_passed],
  705. "not_passed": [not test_passed],
  706. "failures": {str(trial): 1} if not test_passed else {},
  707. }
  708. filename = os.environ.get("TEST_OUTPUT_JSON", "/tmp/learning_test.json")
  709. with open(filename, "wt") as f:
  710. json.dump(json_summary, f)
  711. if not test_passed:
  712. if args.as_release_test:
  713. print(
  714. f"`{success_metric_key}` of {success_metric_value} not reached! Best value reached is {best_value}"
  715. )
  716. else:
  717. raise ValueError(
  718. f"`{success_metric_key}` of {success_metric_value} not reached! Best value reached is {best_value}"
  719. )
  720. return results