test_utils.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217
  1. import logging
  2. import os
  3. import pprint
  4. import random
  5. import time
  6. from typing import (
  7. TYPE_CHECKING,
  8. Any,
  9. Dict,
  10. Optional,
  11. Tuple,
  12. Type,
  13. )
  14. import gymnasium as gym
  15. import numpy as np
  16. import tree # pip install dm_tree
  17. from gymnasium.spaces import (
  18. Box,
  19. Dict as GymDict,
  20. Discrete,
  21. MultiBinary,
  22. MultiDiscrete,
  23. Tuple as GymTuple,
  24. )
  25. import ray
  26. from ray import tune
  27. from ray._common.deprecation import Deprecated
  28. from ray.rllib.core import DEFAULT_MODULE_ID, Columns
  29. from ray.rllib.env.wrappers.atari_wrappers import is_atari, wrap_deepmind
  30. from ray.rllib.utils.annotations import OldAPIStack
  31. from ray.rllib.utils.error import UnsupportedSpaceException
  32. from ray.rllib.utils.framework import try_import_jax, try_import_tf, try_import_torch
  33. from ray.rllib.utils.metrics import (
  34. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
  35. ENV_RUNNER_RESULTS,
  36. EVALUATION_RESULTS,
  37. NUM_ENV_STEPS_TRAINED,
  38. )
  39. from ray.rllib.utils.typing import ResultDict
  40. from ray.tune.result import TRAINING_ITERATION
  41. if TYPE_CHECKING:
  42. from ray.rllib.algorithms import Algorithm, AlgorithmConfig
  43. from ray.rllib.offline.dataset_reader import DatasetReader
  44. jax, _ = try_import_jax()
  45. tf1, tf, tfv = try_import_tf()
  46. torch, _ = try_import_torch()
  47. logger = logging.getLogger(__name__)
  48. @Deprecated(
  49. old="ray.rllib.utils.test_utils.add_rllib_example_script_args",
  50. new="ray.rllib.examples.utils.add_rllib_example_script_args",
  51. error=False,
  52. )
  53. def add_rllib_example_script_args(*args, **kwargs):
  54. from ray.rllib.examples.utils import add_rllib_example_script_args
  55. return add_rllib_example_script_args(*args, **kwargs)
  56. @Deprecated(
  57. old="ray.rllib.utils.test_utils.should_stop",
  58. new="ray.rllib.examples.utils.should_stop",
  59. error=False,
  60. )
  61. def should_stop(*args, **kwargs):
  62. from ray.rllib.examples.utils import should_stop
  63. return should_stop(*args, **kwargs)
  64. @Deprecated(
  65. old="ray.rllib.utils.test_utils.run_rllib_example_script_experiment",
  66. new="ray.rllib.examples.utils.run_rllib_example_script_experiment",
  67. error=False,
  68. )
  69. def run_rllib_example_script_experiment(*args, **kwargs):
  70. from ray.rllib.examples.utils import run_rllib_example_script_experiment
  71. return run_rllib_example_script_experiment(*args, **kwargs)
  72. def check(x, y, decimals=5, atol=None, rtol=None, false=False):
  73. """
  74. Checks two structures (dict, tuple, list,
  75. np.array, float, int, etc..) for (almost) numeric identity.
  76. All numbers in the two structures have to match up to `decimal` digits
  77. after the floating point. Uses assertions.
  78. Args:
  79. x: The value to be compared (to the expectation: `y`). This
  80. may be a Tensor.
  81. y: The expected value to be compared to `x`. This must not
  82. be a tf-Tensor, but may be a tf/torch-Tensor.
  83. decimals: The number of digits after the floating point up to
  84. which all numeric values have to match.
  85. atol: Absolute tolerance of the difference between x and y
  86. (overrides `decimals` if given).
  87. rtol: Relative tolerance of the difference between x and y
  88. (overrides `decimals` if given).
  89. false: Whether to check that x and y are NOT the same.
  90. """
  91. # A dict type.
  92. if isinstance(x, dict):
  93. assert isinstance(y, dict), "ERROR: If x is dict, y needs to be a dict as well!"
  94. y_keys = set(x.keys())
  95. for key, value in x.items():
  96. assert key in y, f"ERROR: y does not have x's key='{key}'! y={y}"
  97. check(value, y[key], decimals=decimals, atol=atol, rtol=rtol, false=false)
  98. y_keys.remove(key)
  99. assert not y_keys, "ERROR: y contains keys ({}) that are not in x! y={}".format(
  100. list(y_keys), y
  101. )
  102. # A tuple type.
  103. elif isinstance(x, (tuple, list)):
  104. assert isinstance(
  105. y, (tuple, list)
  106. ), "ERROR: If x is tuple/list, y needs to be a tuple/list as well!"
  107. assert len(y) == len(
  108. x
  109. ), "ERROR: y does not have the same length as x ({} vs {})!".format(
  110. len(y), len(x)
  111. )
  112. for i, value in enumerate(x):
  113. check(value, y[i], decimals=decimals, atol=atol, rtol=rtol, false=false)
  114. # Boolean comparison.
  115. elif isinstance(x, (np.bool_, bool)):
  116. if false is True:
  117. assert bool(x) is not bool(y), f"ERROR: x ({x}) is y ({y})!"
  118. else:
  119. assert bool(x) is bool(y), f"ERROR: x ({x}) is not y ({y})!"
  120. # Nones or primitives (excluding int vs float, which should be compared with
  121. # tolerance/decimals as well).
  122. elif (
  123. x is None
  124. or y is None
  125. or isinstance(x, str)
  126. or (isinstance(x, int) and isinstance(y, int))
  127. ):
  128. if false is True:
  129. assert x != y, f"ERROR: x ({x}) is the same as y ({y})!"
  130. else:
  131. assert x == y, f"ERROR: x ({x}) is not the same as y ({y})!"
  132. # String/byte comparisons.
  133. elif (
  134. hasattr(x, "dtype") and (x.dtype == object or str(x.dtype).startswith("<U"))
  135. ) or isinstance(x, bytes):
  136. try:
  137. np.testing.assert_array_equal(x, y)
  138. if false is True:
  139. assert False, f"ERROR: x ({x}) is the same as y ({y})!"
  140. except AssertionError as e:
  141. if false is False:
  142. raise e
  143. # Everything else (assume numeric or tf/torch.Tensor).
  144. # Also includes int vs float comparison, which is performed with tolerance/decimals.
  145. else:
  146. if tf1 is not None:
  147. # y should never be a Tensor (y=expected value).
  148. if isinstance(y, (tf1.Tensor, tf1.Variable)):
  149. # In eager mode, numpyize tensors.
  150. if tf.executing_eagerly():
  151. y = y.numpy()
  152. else:
  153. raise ValueError(
  154. "`y` (expected value) must not be a Tensor. "
  155. "Use numpy.ndarray instead"
  156. )
  157. if isinstance(x, (tf1.Tensor, tf1.Variable)):
  158. # In eager mode, numpyize tensors.
  159. if tf1.executing_eagerly():
  160. x = x.numpy()
  161. # Otherwise, use a new tf-session.
  162. else:
  163. with tf1.Session() as sess:
  164. x = sess.run(x)
  165. return check(
  166. x, y, decimals=decimals, atol=atol, rtol=rtol, false=false
  167. )
  168. if torch is not None:
  169. if isinstance(x, torch.Tensor):
  170. x = x.detach().cpu().numpy()
  171. if isinstance(y, torch.Tensor):
  172. y = y.detach().cpu().numpy()
  173. # Stats objects.
  174. from ray.rllib.utils.metrics.stats import StatsBase
  175. if isinstance(x, StatsBase):
  176. x = x.peek()
  177. if isinstance(y, StatsBase):
  178. y = y.peek()
  179. # Using decimals.
  180. if atol is None and rtol is None:
  181. # Assert equality of both values.
  182. try:
  183. np.testing.assert_almost_equal(x, y, decimal=decimals)
  184. # Both values are not equal.
  185. except AssertionError as e:
  186. # Raise error in normal case.
  187. if false is False:
  188. raise e
  189. # Both values are equal.
  190. else:
  191. # If false is set -> raise error (not expected to be equal).
  192. if false is True:
  193. assert False, f"ERROR: x ({x}) is the same as y ({y})!"
  194. # Using atol/rtol.
  195. else:
  196. # Provide defaults for either one of atol/rtol.
  197. if atol is None:
  198. atol = 0
  199. if rtol is None:
  200. rtol = 1e-7
  201. try:
  202. np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
  203. except AssertionError as e:
  204. if false is False:
  205. raise e
  206. else:
  207. if false is True:
  208. assert False, f"ERROR: x ({x}) is the same as y ({y})!"
  209. def check_compute_single_action(
  210. algorithm, include_state=False, include_prev_action_reward=False
  211. ):
  212. """Tests different combinations of args for algorithm.compute_single_action.
  213. Args:
  214. algorithm: The Algorithm object to test.
  215. include_state: Whether to include the initial state of the Policy's
  216. Model in the `compute_single_action` call.
  217. include_prev_action_reward: Whether to include the prev-action and
  218. -reward in the `compute_single_action` call.
  219. Raises:
  220. ValueError: If anything unexpected happens.
  221. """
  222. # Have to import this here to avoid circular dependency.
  223. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
  224. # Some Algorithms may not abide to the standard API.
  225. pid = DEFAULT_POLICY_ID
  226. try:
  227. # Multi-agent: Pick any learnable policy (or DEFAULT_POLICY if it's the only
  228. # one).
  229. pid = next(iter(algorithm.env_runner.get_policies_to_train()))
  230. pol = algorithm.get_policy(pid)
  231. except AttributeError:
  232. pol = algorithm.policy
  233. # Get the policy's model.
  234. model = pol.model
  235. action_space = pol.action_space
  236. def _test(
  237. what, method_to_test, obs_space, full_fetch, explore, timestep, unsquash, clip
  238. ):
  239. call_kwargs = {}
  240. if what is algorithm:
  241. call_kwargs["full_fetch"] = full_fetch
  242. call_kwargs["policy_id"] = pid
  243. obs = obs_space.sample()
  244. if isinstance(obs_space, Box):
  245. obs = np.clip(obs, -1.0, 1.0)
  246. state_in = None
  247. if include_state:
  248. state_in = model.get_initial_state()
  249. if not state_in:
  250. state_in = []
  251. i = 0
  252. while f"state_in_{i}" in model.view_requirements:
  253. state_in.append(
  254. model.view_requirements[f"state_in_{i}"].space.sample()
  255. )
  256. i += 1
  257. action_in = action_space.sample() if include_prev_action_reward else None
  258. reward_in = 1.0 if include_prev_action_reward else None
  259. if method_to_test == "input_dict":
  260. assert what is pol
  261. input_dict = {SampleBatch.OBS: obs}
  262. if include_prev_action_reward:
  263. input_dict[SampleBatch.PREV_ACTIONS] = action_in
  264. input_dict[SampleBatch.PREV_REWARDS] = reward_in
  265. if state_in:
  266. if what.config.get("enable_rl_module_and_learner", False):
  267. input_dict["state_in"] = state_in
  268. else:
  269. for i, s in enumerate(state_in):
  270. input_dict[f"state_in_{i}"] = s
  271. input_dict_batched = SampleBatch(
  272. tree.map_structure(lambda s: np.expand_dims(s, 0), input_dict)
  273. )
  274. action = pol.compute_actions_from_input_dict(
  275. input_dict=input_dict_batched,
  276. explore=explore,
  277. timestep=timestep,
  278. **call_kwargs,
  279. )
  280. # Unbatch everything to be able to compare against single
  281. # action below.
  282. # ARS and ES return action batches as lists.
  283. if isinstance(action[0], list):
  284. action = (np.array(action[0]), action[1], action[2])
  285. action = tree.map_structure(lambda s: s[0], action)
  286. try:
  287. action2 = pol.compute_single_action(
  288. input_dict=input_dict,
  289. explore=explore,
  290. timestep=timestep,
  291. **call_kwargs,
  292. )
  293. # Make sure these are the same, unless we have exploration
  294. # switched on (or noisy layers).
  295. if not explore and not pol.config.get("noisy"):
  296. check(action, action2)
  297. except TypeError:
  298. pass
  299. else:
  300. action = what.compute_single_action(
  301. obs,
  302. state_in,
  303. prev_action=action_in,
  304. prev_reward=reward_in,
  305. explore=explore,
  306. timestep=timestep,
  307. unsquash_action=unsquash,
  308. clip_action=clip,
  309. **call_kwargs,
  310. )
  311. state_out = None
  312. if state_in or full_fetch or what is pol:
  313. action, state_out, _ = action
  314. if state_out:
  315. for si, so in zip(tree.flatten(state_in), tree.flatten(state_out)):
  316. if tf.is_tensor(si):
  317. # If si is a tensor of Dimensions, we need to convert it
  318. # We expect this to be the case for TF RLModules who's initial
  319. # states are Tf Tensors.
  320. si_shape = si.shape.as_list()
  321. else:
  322. si_shape = list(si.shape)
  323. check(si_shape, so.shape)
  324. if unsquash is None:
  325. unsquash = what.config["normalize_actions"]
  326. if clip is None:
  327. clip = what.config["clip_actions"]
  328. # Test whether unsquash/clipping works on the Algorithm's
  329. # compute_single_action method: Both flags should force the action
  330. # to be within the space's bounds.
  331. if method_to_test == "single" and what == algorithm:
  332. if not action_space.contains(action) and (
  333. clip or unsquash or not isinstance(action_space, Box)
  334. ):
  335. raise ValueError(
  336. f"Returned action ({action}) of algorithm/policy {what} "
  337. f"not in Env's action_space {action_space}"
  338. )
  339. # We are operating in normalized space: Expect only smaller action
  340. # values.
  341. if (
  342. isinstance(action_space, Box)
  343. and not unsquash
  344. and what.config.get("normalize_actions")
  345. and np.any(np.abs(action) > 15.0)
  346. ):
  347. raise ValueError(
  348. f"Returned action ({action}) of algorithm/policy {what} "
  349. "should be in normalized space, but seems too large/small "
  350. "for that!"
  351. )
  352. # Loop through: Policy vs Algorithm; Different API methods to calculate
  353. # actions; unsquash option; clip option; full fetch or not.
  354. for what in [pol, algorithm]:
  355. if what is algorithm:
  356. # Get the obs-space from Workers.env (not Policy) due to possible
  357. # pre-processor up front.
  358. worker_set = getattr(algorithm, "env_runner_group", None)
  359. assert worker_set
  360. if not worker_set.local_env_runner:
  361. obs_space = algorithm.get_policy(pid).observation_space
  362. else:
  363. obs_space = worker_set.local_env_runner.for_policy(
  364. lambda p: p.observation_space, policy_id=pid
  365. )
  366. obs_space = getattr(obs_space, "original_space", obs_space)
  367. else:
  368. obs_space = pol.observation_space
  369. for method_to_test in ["single"] + (["input_dict"] if what is pol else []):
  370. for explore in [True, False]:
  371. for full_fetch in [False, True] if what is algorithm else [False]:
  372. timestep = random.randint(0, 100000)
  373. for unsquash in [True, False, None]:
  374. for clip in [False] if unsquash else [True, False, None]:
  375. print("-" * 80)
  376. print(f"what={what}")
  377. print(f"method_to_test={method_to_test}")
  378. print(f"explore={explore}")
  379. print(f"full_fetch={full_fetch}")
  380. print(f"unsquash={unsquash}")
  381. print(f"clip={clip}")
  382. _test(
  383. what,
  384. method_to_test,
  385. obs_space,
  386. full_fetch,
  387. explore,
  388. timestep,
  389. unsquash,
  390. clip,
  391. )
  392. def check_inference_w_connectors(policy, env_name, max_steps: int = 100):
  393. """Checks whether the given policy can infer actions from an env with connectors.
  394. Args:
  395. policy: The policy to check.
  396. env_name: Name of the environment to check
  397. max_steps: The maximum number of steps to run the environment for.
  398. Raises:
  399. ValueError: If the policy cannot infer actions from the environment.
  400. """
  401. # Avoids circular import
  402. from ray.rllib.utils.policy import local_policy_inference
  403. env = gym.make(env_name)
  404. # Potentially wrap the env like we do in RolloutWorker
  405. if is_atari(env):
  406. env = wrap_deepmind(
  407. env,
  408. dim=policy.config["model"]["dim"],
  409. framestack=policy.config["model"].get("framestack"),
  410. )
  411. obs, info = env.reset()
  412. reward, terminated, truncated = 0.0, False, False
  413. ts = 0
  414. while not terminated and not truncated and ts < max_steps:
  415. action_out = local_policy_inference(
  416. policy,
  417. env_id=0,
  418. agent_id=0,
  419. obs=obs,
  420. reward=reward,
  421. terminated=terminated,
  422. truncated=truncated,
  423. info=info,
  424. )
  425. obs, reward, terminated, truncated, info = env.step(action_out[0][0])
  426. ts += 1
  427. def check_learning_achieved(
  428. tune_results: "tune.ResultGrid",
  429. min_value: float,
  430. evaluation: Optional[bool] = None,
  431. metric: str = f"{ENV_RUNNER_RESULTS}/episode_return_mean",
  432. ):
  433. """Throws an error if `min_reward` is not reached within tune_results.
  434. Checks the last iteration found in tune_results for its
  435. "episode_return_mean" value and compares it to `min_reward`.
  436. Args:
  437. tune_results: The tune.Tuner().fit() returned results object.
  438. min_reward: The min reward that must be reached.
  439. evaluation: If True, use `evaluation/env_runners/[metric]`, if False, use
  440. `env_runners/[metric]`, if None, use evaluation sampler results if
  441. available otherwise, use train sampler results.
  442. Raises:
  443. ValueError: If `min_reward` not reached.
  444. """
  445. # Get maximum value of `metrics` over all trials
  446. # (check if at least one trial achieved some learning, not just the final one).
  447. recorded_values = []
  448. for _, row in tune_results.get_dataframe().iterrows():
  449. if evaluation or (
  450. evaluation is None and f"{EVALUATION_RESULTS}/{metric}" in row
  451. ):
  452. recorded_values.append(row[f"{EVALUATION_RESULTS}/{metric}"])
  453. else:
  454. recorded_values.append(row[metric])
  455. best_value = max(recorded_values)
  456. if best_value < min_value:
  457. raise ValueError(f"`{metric}` of {min_value} not reached!")
  458. print(f"`{metric}` of {min_value} reached! ok")
  459. def check_off_policyness(
  460. results: ResultDict,
  461. upper_limit: float,
  462. lower_limit: float = 0.0,
  463. ) -> Optional[float]:
  464. """Verifies that the off-policy'ness of some update is within some range.
  465. Off-policy'ness is defined as the average (across n workers) diff
  466. between the number of gradient updates performed on the policy used
  467. for sampling vs the number of gradient updates that have been performed
  468. on the trained policy (usually the one on the local worker).
  469. Uses the published DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY metric inside
  470. a training results dict and compares to the given bounds.
  471. Note: Only works with single-agent results thus far.
  472. Args:
  473. results: The training results dict.
  474. upper_limit: The upper limit to for the off_policy_ness value.
  475. lower_limit: The lower limit to for the off_policy_ness value.
  476. Returns:
  477. The off-policy'ness value (described above).
  478. Raises:
  479. AssertionError: If the value is out of bounds.
  480. """
  481. # Have to import this here to avoid circular dependency.
  482. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  483. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
  484. # Assert that the off-policy'ness is within the given bounds.
  485. learner_info = results["info"][LEARNER_INFO]
  486. if DEFAULT_POLICY_ID not in learner_info:
  487. return None
  488. off_policy_ness = learner_info[DEFAULT_POLICY_ID][
  489. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY
  490. ]
  491. # Roughly: Reaches up to 0.4 for 2 rollout workers and up to 0.2 for
  492. # 1 rollout worker.
  493. if not (lower_limit <= off_policy_ness <= upper_limit):
  494. raise AssertionError(
  495. f"`off_policy_ness` ({off_policy_ness}) is outside the given bounds "
  496. f"({lower_limit} - {upper_limit})!"
  497. )
  498. return off_policy_ness
  499. def check_train_results_new_api_stack(train_results: ResultDict) -> None:
  500. """Checks proper structure of a Algorithm.train() returned dict.
  501. Args:
  502. train_results: The train results dict to check.
  503. Raises:
  504. AssertionError: If `train_results` doesn't have the proper structure or
  505. data in it.
  506. """
  507. # Import these here to avoid circular dependencies.
  508. from ray.rllib.utils.metrics import (
  509. ENV_RUNNER_RESULTS,
  510. FAULT_TOLERANCE_STATS,
  511. LEARNER_RESULTS,
  512. TIMERS,
  513. )
  514. # Assert that some keys are where we would expect them.
  515. for key in [
  516. ENV_RUNNER_RESULTS,
  517. FAULT_TOLERANCE_STATS,
  518. LEARNER_RESULTS,
  519. TIMERS,
  520. TRAINING_ITERATION,
  521. "config",
  522. ]:
  523. assert (
  524. key in train_results
  525. ), f"'{key}' not found in `train_results` ({train_results})!"
  526. # Make sure, `config` is an actual dict, not an AlgorithmConfig object.
  527. assert isinstance(
  528. train_results["config"], dict
  529. ), "`config` in results not a python dict!"
  530. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  531. is_multi_agent = (
  532. AlgorithmConfig()
  533. .update_from_dict({"policies": train_results["config"]["policies"]})
  534. .is_multi_agent
  535. )
  536. # Check in particular the "info" dict.
  537. learner_results = train_results[LEARNER_RESULTS]
  538. # Make sure we have a `DEFAULT_MODULE_ID key if we are not in a
  539. # multi-agent setup.
  540. if not is_multi_agent:
  541. assert len(learner_results) == 0 or DEFAULT_MODULE_ID in learner_results, (
  542. f"'{DEFAULT_MODULE_ID}' not found in "
  543. f"train_results['{LEARNER_RESULTS}']!"
  544. )
  545. for module_id, module_metrics in learner_results.items():
  546. # The ModuleID can be __all_modules__ in multi-agent case when the new learner
  547. # stack is enabled.
  548. if module_id == "__all_modules__":
  549. continue
  550. # On the new API stack, policy has no LEARNER_STATS_KEY under it anymore.
  551. for key, value in module_metrics.items():
  552. # Min- and max-stats should be single values.
  553. if key.endswith("_min") or key.endswith("_max"):
  554. assert np.isscalar(value), f"'key' value not a scalar ({value})!"
  555. return train_results
  556. @OldAPIStack
  557. def check_train_results(train_results: ResultDict):
  558. """Checks proper structure of a Algorithm.train() returned dict.
  559. Args:
  560. train_results: The train results dict to check.
  561. Raises:
  562. AssertionError: If `train_results` doesn't have the proper structure or
  563. data in it.
  564. """
  565. # Import these here to avoid circular dependencies.
  566. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  567. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
  568. # Assert that some keys are where we would expect them.
  569. for key in [
  570. "config",
  571. "custom_metrics",
  572. ENV_RUNNER_RESULTS,
  573. "info",
  574. "iterations_since_restore",
  575. "num_healthy_workers",
  576. "perf",
  577. "time_since_restore",
  578. "time_this_iter_s",
  579. "timers",
  580. "time_total_s",
  581. TRAINING_ITERATION,
  582. ]:
  583. assert (
  584. key in train_results
  585. ), f"'{key}' not found in `train_results` ({train_results})!"
  586. for key in [
  587. "episode_len_mean",
  588. "episode_reward_max",
  589. "episode_reward_mean",
  590. "episode_reward_min",
  591. "hist_stats",
  592. "policy_reward_max",
  593. "policy_reward_mean",
  594. "policy_reward_min",
  595. "sampler_perf",
  596. ]:
  597. assert key in train_results[ENV_RUNNER_RESULTS], (
  598. f"'{key}' not found in `train_results[ENV_RUNNER_RESULTS]` "
  599. f"({train_results[ENV_RUNNER_RESULTS]})!"
  600. )
  601. # Make sure, `config` is an actual dict, not an AlgorithmConfig object.
  602. assert isinstance(
  603. train_results["config"], dict
  604. ), "`config` in results not a python dict!"
  605. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  606. is_multi_agent = (
  607. AlgorithmConfig()
  608. .update_from_dict({"policies": train_results["config"]["policies"]})
  609. .is_multi_agent
  610. )
  611. # Check in particular the "info" dict.
  612. info = train_results["info"]
  613. assert LEARNER_INFO in info, f"'learner' not in train_results['infos'] ({info})!"
  614. assert (
  615. "num_steps_trained" in info or NUM_ENV_STEPS_TRAINED in info
  616. ), f"'num_(env_)?steps_trained' not in train_results['infos'] ({info})!"
  617. learner_info = info[LEARNER_INFO]
  618. # Make sure we have a default_policy key if we are not in a
  619. # multi-agent setup.
  620. if not is_multi_agent:
  621. # APEX algos sometimes have an empty learner info dict (no metrics
  622. # collected yet).
  623. assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, (
  624. f"'{DEFAULT_POLICY_ID}' not found in "
  625. f"train_results['infos']['learner'] ({learner_info})!"
  626. )
  627. for pid, policy_stats in learner_info.items():
  628. if pid == "batch_count":
  629. continue
  630. # the pid can be __all__ in multi-agent case when the new learner stack is
  631. # enabled.
  632. if pid == "__all__":
  633. continue
  634. # On the new API stack, policy has no LEARNER_STATS_KEY under it anymore.
  635. if LEARNER_STATS_KEY in policy_stats:
  636. learner_stats = policy_stats[LEARNER_STATS_KEY]
  637. else:
  638. learner_stats = policy_stats
  639. for key, value in learner_stats.items():
  640. # Min- and max-stats should be single values.
  641. if key.startswith("min_") or key.startswith("max_"):
  642. assert np.isscalar(value), f"'key' value not a scalar ({value})!"
  643. return train_results
  644. def check_same_batch(batch1, batch2) -> None:
  645. """Check if both batches are (almost) identical.
  646. For MultiAgentBatches, the step count and individual policy's
  647. SampleBatches are checked for identity. For SampleBatches, identity is
  648. checked as the almost numerical key-value-pair identity between batches
  649. with ray.rllib.utils.test_utils.check(). unroll_id is compared only if
  650. both batches have an unroll_id.
  651. Args:
  652. batch1: Batch to compare against batch2
  653. batch2: Batch to compare against batch1
  654. """
  655. # Avoids circular import
  656. from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
  657. assert type(batch1) is type(
  658. batch2
  659. ), "Input batches are of different types {} and {}".format(
  660. str(type(batch1)), str(type(batch2))
  661. )
  662. def check_sample_batches(_batch1, _batch2, _policy_id=None):
  663. unroll_id_1 = _batch1.get("unroll_id", None)
  664. unroll_id_2 = _batch2.get("unroll_id", None)
  665. # unroll IDs only have to fit if both batches have them
  666. if unroll_id_1 is not None and unroll_id_2 is not None:
  667. assert unroll_id_1 == unroll_id_2
  668. batch1_keys = set()
  669. for k, v in _batch1.items():
  670. # unroll_id is compared above already
  671. if k == "unroll_id":
  672. continue
  673. check(v, _batch2[k])
  674. batch1_keys.add(k)
  675. batch2_keys = set(_batch2.keys())
  676. # unroll_id is compared above already
  677. batch2_keys.discard("unroll_id")
  678. _difference = batch1_keys.symmetric_difference(batch2_keys)
  679. # Cases where one batch has info and the other has not
  680. if _policy_id:
  681. assert not _difference, (
  682. "SampleBatches for policy with ID {} "
  683. "don't share information on the "
  684. "following information: \n{}"
  685. "".format(_policy_id, _difference)
  686. )
  687. else:
  688. assert not _difference, (
  689. "SampleBatches don't share information "
  690. "on the following information: \n{}"
  691. "".format(_difference)
  692. )
  693. if type(batch1) is SampleBatch:
  694. check_sample_batches(batch1, batch2)
  695. elif type(batch1) is MultiAgentBatch:
  696. assert batch1.count == batch2.count
  697. batch1_ids = set()
  698. for policy_id, policy_batch in batch1.policy_batches.items():
  699. check_sample_batches(
  700. policy_batch, batch2.policy_batches[policy_id], policy_id
  701. )
  702. batch1_ids.add(policy_id)
  703. # Case where one ma batch has info on a policy the other has not
  704. batch2_ids = set(batch2.policy_batches.keys())
  705. difference = batch1_ids.symmetric_difference(batch2_ids)
  706. assert (
  707. not difference
  708. ), f"MultiAgentBatches don't share the following information: \n{difference}."
  709. else:
  710. raise ValueError("Unsupported batch type " + str(type(batch1)))
  711. def check_reproducibilty(
  712. algo_class: Type["Algorithm"],
  713. algo_config: "AlgorithmConfig",
  714. *,
  715. fw_kwargs: Dict[str, Any],
  716. training_iteration: int = 1,
  717. ) -> None:
  718. # TODO @kourosh: we can get rid of examples/deterministic_training.py once
  719. # this is added to all algorithms
  720. """Check if the algorithm is reproducible across different testing conditions:
  721. frameworks: all input frameworks
  722. num_gpus: int(os.environ.get("RLLIB_NUM_GPUS", "0"))
  723. num_workers: 0 (only local workers) or
  724. 4 ((1) local workers + (4) remote workers)
  725. num_envs_per_env_runner: 2
  726. Args:
  727. algo_class: Algorithm class to test.
  728. algo_config: Base config to use for the algorithm.
  729. fw_kwargs: Framework iterator keyword arguments.
  730. training_iteration: Number of training iterations to run.
  731. Returns:
  732. None
  733. Raises:
  734. It raises an AssertionError if the algorithm is not reproducible.
  735. """
  736. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  737. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
  738. stop_dict = {TRAINING_ITERATION: training_iteration}
  739. # use 0 and 2 workers (for more that 4 workers we have to make sure the instance
  740. # type in ci build has enough resources)
  741. for num_workers in [0, 2]:
  742. algo_config = (
  743. algo_config.debugging(seed=42).env_runners(
  744. num_env_runners=num_workers, num_envs_per_env_runner=2
  745. )
  746. # new API
  747. .learners(
  748. num_gpus_per_learner=int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  749. )
  750. # old API
  751. .resources(
  752. num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  753. )
  754. )
  755. print(
  756. f"Testing reproducibility of {algo_class.__name__}"
  757. f" with {num_workers} workers"
  758. )
  759. print("/// config")
  760. pprint.pprint(algo_config.to_dict())
  761. # test tune.Tuner().fit() reproducibility
  762. results1 = tune.Tuner(
  763. algo_class,
  764. param_space=algo_config.to_dict(),
  765. run_config=tune.RunConfig(stop=stop_dict, verbose=1),
  766. ).fit()
  767. results1 = results1.get_best_result().metrics
  768. results2 = tune.Tuner(
  769. algo_class,
  770. param_space=algo_config.to_dict(),
  771. run_config=tune.RunConfig(stop=stop_dict, verbose=1),
  772. ).fit()
  773. results2 = results2.get_best_result().metrics
  774. # Test rollout behavior.
  775. check(
  776. results1[ENV_RUNNER_RESULTS]["hist_stats"],
  777. results2[ENV_RUNNER_RESULTS]["hist_stats"],
  778. )
  779. # As well as training behavior (minibatch sequence during SGD
  780. # iterations).
  781. # As well as training behavior (minibatch sequence during SGD
  782. # iterations).
  783. if algo_config.enable_rl_module_and_learner:
  784. check(
  785. results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID],
  786. results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID],
  787. )
  788. else:
  789. check(
  790. results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"],
  791. results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"],
  792. )
  793. def get_cartpole_dataset_reader(batch_size: int = 1) -> "DatasetReader":
  794. """Returns a DatasetReader for the cartpole dataset.
  795. Args:
  796. batch_size: The batch size to use for the reader.
  797. Returns:
  798. A rllib DatasetReader for the cartpole dataset.
  799. """
  800. from ray.rllib.algorithms import AlgorithmConfig
  801. from ray.rllib.offline import IOContext
  802. from ray.rllib.offline.dataset_reader import (
  803. DatasetReader,
  804. get_dataset_and_shards,
  805. )
  806. path = "offline/tests/data/cartpole/large.json"
  807. input_config = {"format": "json", "paths": path}
  808. dataset, _ = get_dataset_and_shards(
  809. AlgorithmConfig().offline_data(input_="dataset", input_config=input_config)
  810. )
  811. ioctx = IOContext(
  812. config=(
  813. AlgorithmConfig()
  814. .training(train_batch_size=batch_size)
  815. .offline_data(actions_in_input_normalized=True)
  816. ),
  817. worker_index=0,
  818. )
  819. reader = DatasetReader(dataset, ioctx)
  820. return reader
  821. class ModelChecker:
  822. """Helper class to compare architecturally identical Models across frameworks.
  823. Holds a ModelConfig, such that individual models can be added simply via their
  824. framework string (by building them with config.build(framework=...).
  825. A call to `check()` forces all added models to be compared in terms of their
  826. number of trainable and non-trainable parameters, as well as, their
  827. computation results given a common weights structure and values and identical
  828. inputs to the models.
  829. """
  830. def __init__(self, config):
  831. self.config = config
  832. # To compare number of params between frameworks.
  833. self.param_counts = {}
  834. # To compare computed outputs from fixed-weights-nets between frameworks.
  835. self.output_values = {}
  836. # We will pass an observation filled with this one random value through
  837. # all DL networks (after they have been set to fixed-weights) to compare
  838. # the computed outputs.
  839. self.random_fill_input_value = np.random.uniform(-0.01, 0.01)
  840. # Dict of models to check against each other.
  841. self.models = {}
  842. def add(self, framework: str = "torch", obs=True, state=False) -> Any:
  843. """Builds a new Model for the given framework."""
  844. model = self.models[framework] = self.config.build(framework=framework)
  845. # Pass a B=1 observation through the model.
  846. inputs = np.full(
  847. [1] + ([1] if state else []) + list(self.config.input_dims),
  848. self.random_fill_input_value,
  849. )
  850. if obs:
  851. inputs = {Columns.OBS: inputs}
  852. if state:
  853. inputs[Columns.STATE_IN] = tree.map_structure(
  854. lambda s: np.zeros(shape=[1] + list(s)), state
  855. )
  856. if framework == "torch":
  857. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  858. inputs = convert_to_torch_tensor(inputs)
  859. # w/ old specs: inputs = model.input_specs.fill(self.random_fill_input_value)
  860. outputs = model(inputs)
  861. # Bring model into a reproducible, comparable state (so we can compare
  862. # computations across frameworks). Use only a value-sequence of len=1 here
  863. # as it could possibly be that the layers are stored in different order
  864. # across the different frameworks.
  865. model._set_to_dummy_weights(value_sequence=(self.random_fill_input_value,))
  866. # Perform another forward pass.
  867. comparable_outputs = model(inputs)
  868. # Store the number of parameters for this framework's net.
  869. self.param_counts[framework] = model.get_num_parameters()
  870. # Store the fixed-weights-net outputs for this framework's net.
  871. if framework == "torch":
  872. self.output_values[framework] = tree.map_structure(
  873. lambda s: s.detach().numpy() if s is not None else None,
  874. comparable_outputs,
  875. )
  876. else:
  877. self.output_values[framework] = tree.map_structure(
  878. lambda s: s.numpy() if s is not None else None, comparable_outputs
  879. )
  880. return outputs
  881. def check(self):
  882. """Compares all added Models with each other and possibly raises errors."""
  883. main_key = next(iter(self.models.keys()))
  884. # Compare number of trainable and non-trainable params between all
  885. # frameworks.
  886. for c in self.param_counts.values():
  887. check(c, self.param_counts[main_key])
  888. # Compare dummy outputs by exact values given that all nets received the
  889. # same input and all nets have the same (dummy) weight values.
  890. for v in self.output_values.values():
  891. check(v, self.output_values[main_key], atol=0.0005)
  892. def _get_mean_action_from_algorithm(alg: "Algorithm", obs: np.ndarray) -> np.ndarray:
  893. """Returns the mean action computed by the given algorithm.
  894. Note: This makes calls to `Algorithm.compute_single_action`
  895. Args:
  896. alg: The constructed algorithm to run inference on.
  897. obs: The observation to compute the action for.
  898. Returns:
  899. The mean action computed by the algorithm over 5000 samples.
  900. """
  901. out = []
  902. for _ in range(5000):
  903. out.append(float(alg.compute_single_action(obs)))
  904. return np.mean(out)
  905. def check_supported_spaces(
  906. alg: str,
  907. config: "AlgorithmConfig",
  908. train: bool = True,
  909. check_bounds: bool = False,
  910. frameworks: Optional[Tuple[str, ...]] = None,
  911. use_gpu: bool = False,
  912. ):
  913. """Checks whether the given algorithm supports different action and obs spaces.
  914. Performs the checks by constructing an rllib algorithm from the config and
  915. checking to see that the model inside the policy is the correct one given
  916. the action and obs spaces. For example if the action space is discrete and
  917. the obs space is an image, then the model should be a vision network with
  918. a categorical action distribution.
  919. Args:
  920. alg: The name of the algorithm to test.
  921. config: The config to use for the algorithm.
  922. train: Whether to train the algorithm for a few iterations.
  923. check_bounds: Whether to check the bounds of the action space.
  924. frameworks: The frameworks to test the algorithm with.
  925. use_gpu: Whether to check support for training on a gpu.
  926. """
  927. # Do these imports here because otherwise we have circular imports.
  928. from ray.rllib.examples.envs.classes.random_env import RandomEnv
  929. from ray.rllib.models.torch.complex_input_net import (
  930. ComplexInputNetwork as TorchComplexNet,
  931. )
  932. from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNet
  933. from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVisionNet
  934. action_spaces_to_test = {
  935. # Test discrete twice here until we support multi_binary action spaces
  936. "discrete": Discrete(5),
  937. "continuous": Box(-1.0, 1.0, (5,), dtype=np.float32),
  938. "int_actions": Box(0, 3, (2, 3), dtype=np.int32),
  939. "multidiscrete": MultiDiscrete([1, 2, 3, 4]),
  940. "tuple": GymTuple(
  941. [Discrete(2), Discrete(3), Box(-1.0, 1.0, (5,), dtype=np.float32)]
  942. ),
  943. "dict": GymDict(
  944. {
  945. "action_choice": Discrete(3),
  946. "parameters": Box(-1.0, 1.0, (1,), dtype=np.float32),
  947. "yet_another_nested_dict": GymDict(
  948. {"a": GymTuple([Discrete(2), Discrete(3)])}
  949. ),
  950. }
  951. ),
  952. }
  953. observation_spaces_to_test = {
  954. "multi_binary": MultiBinary([3, 10, 10]),
  955. "discrete": Discrete(5),
  956. "continuous": Box(-1.0, 1.0, (5,), dtype=np.float32),
  957. "vector2d": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
  958. "image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32),
  959. "tuple": GymTuple([Discrete(10), Box(-1.0, 1.0, (5,), dtype=np.float32)]),
  960. "dict": GymDict(
  961. {
  962. "task": Discrete(10),
  963. "position": Box(-1.0, 1.0, (5,), dtype=np.float32),
  964. }
  965. ),
  966. }
  967. # The observation spaces that we test RLModules with
  968. rlmodule_supported_observation_spaces = [
  969. "multi_binary",
  970. "discrete",
  971. "continuous",
  972. "image",
  973. "tuple",
  974. "dict",
  975. ]
  976. # The action spaces that we test RLModules with
  977. rlmodule_supported_action_spaces = ["discrete", "continuous"]
  978. default_observation_space = default_action_space = "discrete"
  979. config["log_level"] = "ERROR"
  980. config["env"] = RandomEnv
  981. def _do_check(alg, config, a_name, o_name):
  982. # We need to copy here so that this validation does not affect the actual
  983. # validation method call further down the line.
  984. config_copy = config.copy()
  985. config_copy.validate()
  986. # If RLModules are enabled, we need to skip a few tests for now:
  987. if config_copy.enable_rl_module_and_learner:
  988. # Skip PPO cases in which RLModules don't support the given spaces yet.
  989. if o_name not in rlmodule_supported_observation_spaces:
  990. logger.warning(
  991. "Skipping PPO test with RLModules for obs space {}".format(o_name)
  992. )
  993. return
  994. if a_name not in rlmodule_supported_action_spaces:
  995. logger.warning(
  996. "Skipping PPO test with RLModules for action space {}".format(
  997. a_name
  998. )
  999. )
  1000. return
  1001. fw = config["framework"]
  1002. action_space = action_spaces_to_test[a_name]
  1003. obs_space = observation_spaces_to_test[o_name]
  1004. print(
  1005. "=== Testing {} (fw={}) action_space={} obs_space={} ===".format(
  1006. alg, fw, action_space, obs_space
  1007. )
  1008. )
  1009. t0 = time.time()
  1010. config.update_from_dict(
  1011. dict(
  1012. env_config=dict(
  1013. action_space=action_space,
  1014. observation_space=obs_space,
  1015. reward_space=Box(1.0, 1.0, shape=(), dtype=np.float32),
  1016. p_terminated=1.0,
  1017. check_action_bounds=check_bounds,
  1018. )
  1019. )
  1020. )
  1021. stat = "ok"
  1022. try:
  1023. algo = config.build()
  1024. except ray.exceptions.RayActorError as e:
  1025. if len(e.args) >= 2 and isinstance(e.args[2], UnsupportedSpaceException):
  1026. stat = "unsupported"
  1027. elif isinstance(e.args[0].args[2], UnsupportedSpaceException):
  1028. stat = "unsupported"
  1029. else:
  1030. raise
  1031. except UnsupportedSpaceException:
  1032. stat = "unsupported"
  1033. else:
  1034. if alg not in ["SAC", "PPO"]:
  1035. # 2D (image) input: Expect VisionNet.
  1036. if o_name in ["atari", "image"]:
  1037. assert isinstance(algo.get_policy().model, TorchVisionNet)
  1038. # 1D input: Expect FCNet.
  1039. elif o_name == "continuous":
  1040. assert isinstance(algo.get_policy().model, TorchFCNet)
  1041. # Could be either one: ComplexNet (if disabled Preprocessor)
  1042. # or FCNet (w/ Preprocessor).
  1043. elif o_name == "vector2d":
  1044. assert isinstance(
  1045. algo.get_policy().model, (TorchComplexNet, TorchFCNet)
  1046. )
  1047. if train:
  1048. algo.train()
  1049. algo.stop()
  1050. print("Test: {}, ran in {}s".format(stat, time.time() - t0))
  1051. if not frameworks:
  1052. frameworks = ("tf2", "tf", "torch")
  1053. _do_check_remote = ray.remote(_do_check)
  1054. _do_check_remote = _do_check_remote.options(num_gpus=1 if use_gpu else 0)
  1055. # Test all action spaces first.
  1056. for a_name in action_spaces_to_test.keys():
  1057. o_name = default_observation_space
  1058. ray.get(_do_check_remote.remote(alg, config, a_name, o_name))
  1059. # Now test all observation spaces.
  1060. for o_name in observation_spaces_to_test.keys():
  1061. a_name = default_action_space
  1062. ray.get(_do_check_remote.remote(alg, config, a_name, o_name))