memory.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from collections import defaultdict
  2. from typing import DefaultDict, List, Optional, Set
  3. import numpy as np
  4. import tree # pip install dm_tree
  5. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
  6. from ray.rllib.utils.annotations import DeveloperAPI
  7. from ray.util.debug import Suspect, _test_some_code_for_memory_leaks
  8. @DeveloperAPI
  9. def check_memory_leaks(
  10. algorithm,
  11. to_check: Optional[Set[str]] = None,
  12. repeats: Optional[int] = None,
  13. max_num_trials: int = 3,
  14. ) -> DefaultDict[str, List[Suspect]]:
  15. """Diagnoses the given Algorithm for possible memory leaks.
  16. Isolates single components inside the Algorithm's local worker, e.g. the env,
  17. policy, etc.. and calls some of their methods repeatedly, while checking
  18. the memory footprints and keeping track of which lines in the code add
  19. un-GC'd items to memory.
  20. Args:
  21. algorithm: The Algorithm instance to test.
  22. to_check: Set of strings to indentify components to test. Allowed strings
  23. are: "env", "policy", "model", "rollout_worker". By default, check all
  24. of these.
  25. repeats: Number of times the test code block should get executed (per trial).
  26. If a trial fails, a new trial may get started with a larger number of
  27. repeats: actual_repeats = `repeats` * (trial + 1) (1st trial == 0).
  28. max_num_trials: The maximum number of trials to run each check for.
  29. Raises:
  30. A defaultdict(list) with keys being the `to_check` strings and values being
  31. lists of Suspect instances that were found.
  32. """
  33. local_worker = algorithm.env_runner
  34. # Which components should we test?
  35. to_check = to_check or {"env", "model", "policy", "rollout_worker"}
  36. results_per_category = defaultdict(list)
  37. # Test a single sub-env (first in the VectorEnv)?
  38. if "env" in to_check:
  39. assert local_worker.async_env is not None, (
  40. "ERROR: Cannot test 'env' since given Algorithm does not have one "
  41. "in its local worker. Try setting `create_local_env_runner=True`."
  42. )
  43. # Isolate the first sub-env in the vectorized setup and test it.
  44. env = local_worker.async_env.get_sub_environments()[0]
  45. action_space = env.action_space
  46. # Always use same action to avoid numpy random caused memory leaks.
  47. action_sample = action_space.sample()
  48. def code():
  49. ts = 0
  50. env.reset()
  51. while True:
  52. # If masking is used, try something like this:
  53. # np.random.choice(
  54. # action_space.n, p=(obs["action_mask"] / sum(obs["action_mask"])))
  55. _, _, done, _, _ = env.step(action_sample)
  56. ts += 1
  57. if done:
  58. break
  59. test = _test_some_code_for_memory_leaks(
  60. desc="Looking for leaks in env, running through episodes.",
  61. init=None,
  62. code=code,
  63. # How many times to repeat the function call?
  64. repeats=repeats or 200,
  65. max_num_trials=max_num_trials,
  66. )
  67. if test:
  68. results_per_category["env"].extend(test)
  69. # Test the policy (single-agent case only so far).
  70. if "policy" in to_check:
  71. policy = local_worker.policy_map[DEFAULT_POLICY_ID]
  72. # Get a fixed obs (B=10).
  73. obs = tree.map_structure(
  74. lambda s: np.stack([s] * 10, axis=0), policy.observation_space.sample()
  75. )
  76. print("Looking for leaks in Policy")
  77. def code():
  78. policy.compute_actions_from_input_dict(
  79. {
  80. "obs": obs,
  81. }
  82. )
  83. # Call `compute_actions_from_input_dict()` n times.
  84. test = _test_some_code_for_memory_leaks(
  85. desc="Calling `compute_actions_from_input_dict()`.",
  86. init=None,
  87. code=code,
  88. # How many times to repeat the function call?
  89. repeats=repeats or 400,
  90. # How many times to re-try if we find a suspicious memory
  91. # allocation?
  92. max_num_trials=max_num_trials,
  93. )
  94. if test:
  95. results_per_category["policy"].extend(test)
  96. # Testing this only makes sense if the learner API is disabled.
  97. if not policy.config.get("enable_rl_module_and_learner", False):
  98. # Call `learn_on_batch()` n times.
  99. dummy_batch = policy._get_dummy_batch_from_view_requirements(batch_size=16)
  100. test = _test_some_code_for_memory_leaks(
  101. desc="Calling `learn_on_batch()`.",
  102. init=None,
  103. code=lambda: policy.learn_on_batch(dummy_batch),
  104. # How many times to repeat the function call?
  105. repeats=repeats or 100,
  106. max_num_trials=max_num_trials,
  107. )
  108. if test:
  109. results_per_category["policy"].extend(test)
  110. # Test only the model.
  111. if "model" in to_check:
  112. policy = local_worker.policy_map[DEFAULT_POLICY_ID]
  113. # Get a fixed obs.
  114. obs = tree.map_structure(lambda s: s[None], policy.observation_space.sample())
  115. print("Looking for leaks in Model")
  116. # Call `compute_actions_from_input_dict()` n times.
  117. test = _test_some_code_for_memory_leaks(
  118. desc="Calling `[model]()`.",
  119. init=None,
  120. code=lambda: policy.model({SampleBatch.OBS: obs}),
  121. # How many times to repeat the function call?
  122. repeats=repeats or 400,
  123. # How many times to re-try if we find a suspicious memory
  124. # allocation?
  125. max_num_trials=max_num_trials,
  126. )
  127. if test:
  128. results_per_category["model"].extend(test)
  129. # Test the RolloutWorker.
  130. if "rollout_worker" in to_check:
  131. print("Looking for leaks in local RolloutWorker")
  132. def code():
  133. local_worker.sample()
  134. local_worker.get_metrics()
  135. # Call `compute_actions_from_input_dict()` n times.
  136. test = _test_some_code_for_memory_leaks(
  137. desc="Calling `sample()` and `get_metrics()`.",
  138. init=None,
  139. code=code,
  140. # How many times to repeat the function call?
  141. repeats=repeats or 50,
  142. # How many times to re-try if we find a suspicious memory
  143. # allocation?
  144. max_num_trials=max_num_trials,
  145. )
  146. if test:
  147. results_per_category["rollout_worker"].extend(test)
  148. if "learner" in to_check and algorithm.config.get(
  149. "enable_rl_module_and_learner", False
  150. ):
  151. learner_group = algorithm.learner_group
  152. assert learner_group._is_local, (
  153. "This test will miss leaks hidden in remote "
  154. "workers. Please make sure that there is a "
  155. "local learner inside the learner group for "
  156. "this test."
  157. )
  158. dummy_batch = (
  159. algorithm.get_policy()
  160. ._get_dummy_batch_from_view_requirements(batch_size=16)
  161. .as_multi_agent()
  162. )
  163. print("Looking for leaks in Learner")
  164. def code():
  165. learner_group.update(dummy_batch)
  166. # Call `compute_actions_from_input_dict()` n times.
  167. test = _test_some_code_for_memory_leaks(
  168. desc="Calling `LearnerGroup.update()`.",
  169. init=None,
  170. code=code,
  171. # How many times to repeat the function call?
  172. repeats=repeats or 400,
  173. # How many times to re-try if we find a suspicious memory
  174. # allocation?
  175. max_num_trials=max_num_trials,
  176. )
  177. if test:
  178. results_per_category["learner"].extend(test)
  179. return results_per_category