util.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649
  1. import copy
  2. import glob
  3. import inspect
  4. import logging
  5. import os
  6. import threading
  7. import time
  8. import uuid
  9. from collections import defaultdict
  10. from datetime import datetime
  11. from numbers import Number
  12. from threading import Thread
  13. from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
  14. import numpy as np
  15. import ray
  16. from ray._private.dict import ( # noqa: F401
  17. deep_update,
  18. flatten_dict,
  19. merge_dicts,
  20. unflatten_dict,
  21. unflatten_list_dict,
  22. unflattened_lookup,
  23. )
  24. from ray.air._internal.json import SafeFallbackEncoder # noqa
  25. from ray.air._internal.util import is_nan, is_nan_or_inf # noqa: F401
  26. from ray.util.annotations import DeveloperAPI, PublicAPI
  27. import psutil
  28. logger = logging.getLogger(__name__)
  29. def _import_gputil():
  30. try:
  31. import GPUtil
  32. except ImportError:
  33. GPUtil = None
  34. return GPUtil
  35. START_OF_TIME = time.time()
  36. @DeveloperAPI
  37. class UtilMonitor(Thread):
  38. """Class for system usage utilization monitoring.
  39. It keeps track of CPU, RAM, GPU, VRAM usage (each gpu separately) by
  40. pinging for information every x seconds in a separate thread.
  41. Requires psutil and GPUtil to be installed. Can be enabled with
  42. Tuner(param_space={"log_sys_usage": True}).
  43. """
  44. def __init__(self, start=True, delay=0.7):
  45. self.stopped = True
  46. GPUtil = _import_gputil()
  47. self.GPUtil = GPUtil
  48. if GPUtil is None and start:
  49. logger.warning("Install gputil for GPU system monitoring.")
  50. if psutil is None and start:
  51. logger.warning("Install psutil to monitor system performance.")
  52. if GPUtil is None and psutil is None:
  53. return
  54. super(UtilMonitor, self).__init__()
  55. self.delay = delay # Time between calls to GPUtil
  56. self.values = defaultdict(list)
  57. self.lock = threading.Lock()
  58. self.daemon = True
  59. if start:
  60. self.start()
  61. def _read_utilization(self):
  62. with self.lock:
  63. if psutil is not None:
  64. self.values["cpu_util_percent"].append(
  65. float(psutil.cpu_percent(interval=None))
  66. )
  67. self.values["ram_util_percent"].append(
  68. float(psutil.virtual_memory().percent)
  69. )
  70. if self.GPUtil is not None:
  71. gpu_list = []
  72. try:
  73. gpu_list = self.GPUtil.getGPUs()
  74. except Exception:
  75. logger.debug("GPUtil failed to retrieve GPUs.")
  76. for gpu in gpu_list:
  77. self.values["gpu_util_percent" + str(gpu.id)].append(
  78. float(gpu.load)
  79. )
  80. self.values["vram_util_percent" + str(gpu.id)].append(
  81. float(gpu.memoryUtil)
  82. )
  83. def get_data(self):
  84. if self.stopped:
  85. return {}
  86. with self.lock:
  87. ret_values = copy.deepcopy(self.values)
  88. for key, val in self.values.items():
  89. del val[:]
  90. return {"perf": {k: np.mean(v) for k, v in ret_values.items() if len(v) > 0}}
  91. def run(self):
  92. self.stopped = False
  93. while not self.stopped:
  94. self._read_utilization()
  95. time.sleep(self.delay)
  96. def stop(self):
  97. self.stopped = True
  98. @DeveloperAPI
  99. def retry_fn(
  100. fn: Callable[[], Any],
  101. exception_type: Union[Type[Exception], Sequence[Type[Exception]]] = Exception,
  102. num_retries: int = 3,
  103. sleep_time: int = 1,
  104. timeout: Optional[Number] = None,
  105. ) -> bool:
  106. errored = threading.Event()
  107. def _try_fn():
  108. try:
  109. fn()
  110. except exception_type as e:
  111. logger.warning(e)
  112. errored.set()
  113. for i in range(num_retries):
  114. errored.clear()
  115. proc = threading.Thread(target=_try_fn)
  116. proc.daemon = True
  117. proc.start()
  118. proc.join(timeout=timeout)
  119. if proc.is_alive():
  120. logger.debug(
  121. f"Process timed out (try {i+1}/{num_retries}): "
  122. f"{getattr(fn, '__name__', None)}"
  123. )
  124. elif not errored.is_set():
  125. return True
  126. # Timed out, sleep and try again
  127. time.sleep(sleep_time)
  128. # Timed out, so return False
  129. return False
  130. @DeveloperAPI
  131. class warn_if_slow:
  132. """Prints a warning if a given operation is slower than 500ms.
  133. Example:
  134. >>> from ray.tune.utils.util import warn_if_slow
  135. >>> something = ... # doctest: +SKIP
  136. >>> with warn_if_slow("some_operation"): # doctest: +SKIP
  137. ... ray.get(something) # doctest: +SKIP
  138. """
  139. DEFAULT_THRESHOLD = float(os.environ.get("TUNE_WARN_THRESHOLD_S", 0.5))
  140. DEFAULT_MESSAGE = (
  141. "The `{name}` operation took {duration:.3f} s, "
  142. "which may be a performance bottleneck."
  143. )
  144. def __init__(
  145. self,
  146. name: str,
  147. threshold: Optional[float] = None,
  148. message: Optional[str] = None,
  149. disable: bool = False,
  150. ):
  151. self.name = name
  152. self.threshold = threshold or self.DEFAULT_THRESHOLD
  153. self.message = message or self.DEFAULT_MESSAGE
  154. self.too_slow = False
  155. self.disable = disable
  156. def __enter__(self):
  157. self.start = time.time()
  158. return self
  159. def __exit__(self, type, value, traceback):
  160. now = time.time()
  161. if self.disable:
  162. return
  163. if now - self.start > self.threshold and now - START_OF_TIME > 60.0:
  164. self.too_slow = True
  165. duration = now - self.start
  166. logger.warning(self.message.format(name=self.name, duration=duration))
  167. @DeveloperAPI
  168. class Tee(object):
  169. def __init__(self, stream1, stream2):
  170. self.stream1 = stream1
  171. self.stream2 = stream2
  172. # If True, we are currently handling a warning.
  173. # We use this flag to avoid infinite recursion.
  174. self._handling_warning = False
  175. def _warn(self, op, s, args, kwargs):
  176. # If we are already handling a warning, this is because
  177. # `logger.warning` below triggered the same object again
  178. # (e.g. because stderr is redirected to this object).
  179. # In that case, exit early to avoid recursion.
  180. if self._handling_warning:
  181. return
  182. msg = f"ValueError when calling '{op}' on stream ({s}). "
  183. msg += f"args: {args} kwargs: {kwargs}"
  184. self._handling_warning = True
  185. logger.warning(msg)
  186. self._handling_warning = False
  187. def seek(self, *args, **kwargs):
  188. for s in [self.stream1, self.stream2]:
  189. try:
  190. s.seek(*args, **kwargs)
  191. except ValueError:
  192. self._warn("seek", s, args, kwargs)
  193. def write(self, *args, **kwargs):
  194. for s in [self.stream1, self.stream2]:
  195. try:
  196. s.write(*args, **kwargs)
  197. except ValueError:
  198. self._warn("write", s, args, kwargs)
  199. def flush(self, *args, **kwargs):
  200. for s in [self.stream1, self.stream2]:
  201. try:
  202. s.flush(*args, **kwargs)
  203. except ValueError:
  204. self._warn("flush", s, args, kwargs)
  205. @property
  206. def encoding(self):
  207. if hasattr(self.stream1, "encoding"):
  208. return self.stream1.encoding
  209. return self.stream2.encoding
  210. @property
  211. def error(self):
  212. if hasattr(self.stream1, "error"):
  213. return self.stream1.error
  214. return self.stream2.error
  215. @property
  216. def newlines(self):
  217. if hasattr(self.stream1, "newlines"):
  218. return self.stream1.newlines
  219. return self.stream2.newlines
  220. def detach(self):
  221. raise NotImplementedError
  222. def read(self, *args, **kwargs):
  223. raise NotImplementedError
  224. def readline(self, *args, **kwargs):
  225. raise NotImplementedError
  226. def tell(self, *args, **kwargs):
  227. raise NotImplementedError
  228. @DeveloperAPI
  229. def date_str():
  230. return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
  231. def _to_pinnable(obj):
  232. """Converts obj to a form that can be pinned in object store memory.
  233. Currently only numpy arrays are pinned in memory, if you have a strong
  234. reference to the array value.
  235. """
  236. return (obj, np.zeros(1))
  237. def _from_pinnable(obj):
  238. """Retrieve from _to_pinnable format."""
  239. return obj[0]
  240. @DeveloperAPI
  241. def diagnose_serialization(trainable: Callable):
  242. """Utility for detecting why your trainable function isn't serializing.
  243. Args:
  244. trainable: The trainable object passed to
  245. tune.Tuner(trainable). Currently only supports
  246. Function API.
  247. Returns:
  248. bool | set of unserializable objects.
  249. Example:
  250. .. code-block:: python
  251. import threading
  252. # this is not serializable
  253. e = threading.Event()
  254. def test():
  255. print(e)
  256. diagnose_serialization(test)
  257. # should help identify that 'e' should be moved into
  258. # the `test` scope.
  259. # correct implementation
  260. def test():
  261. e = threading.Event()
  262. print(e)
  263. assert diagnose_serialization(test) is True
  264. """
  265. from ray.tune.registry import _check_serializability, register_trainable
  266. def check_variables(objects, failure_set, printer):
  267. for var_name, variable in objects.items():
  268. msg = None
  269. try:
  270. _check_serializability(var_name, variable)
  271. status = "PASSED"
  272. except Exception as e:
  273. status = "FAILED"
  274. msg = f"{e.__class__.__name__}: {str(e)}"
  275. failure_set.add(var_name)
  276. printer(f"{str(variable)}[name='{var_name}'']... {status}")
  277. if msg:
  278. printer(msg)
  279. print(f"Trying to serialize {trainable}...")
  280. try:
  281. register_trainable("__test:" + str(trainable), trainable, warn=False)
  282. print("Serialization succeeded!")
  283. return True
  284. except Exception as e:
  285. print(f"Serialization failed: {e}")
  286. print(
  287. "Inspecting the scope of the trainable by running "
  288. f"`inspect.getclosurevars({str(trainable)})`..."
  289. )
  290. closure = inspect.getclosurevars(trainable)
  291. failure_set = set()
  292. if closure.globals:
  293. print(
  294. f"Detected {len(closure.globals)} global variables. "
  295. "Checking serializability..."
  296. )
  297. check_variables(closure.globals, failure_set, lambda s: print(" " + s))
  298. if closure.nonlocals:
  299. print(
  300. f"Detected {len(closure.nonlocals)} nonlocal variables. "
  301. "Checking serializability..."
  302. )
  303. check_variables(closure.nonlocals, failure_set, lambda s: print(" " + s))
  304. if not failure_set:
  305. print(
  306. "Nothing was found to have failed the diagnostic test, though "
  307. "serialization did not succeed. Feel free to raise an "
  308. "issue on github."
  309. )
  310. return failure_set
  311. else:
  312. print(
  313. f"Variable(s) {failure_set} was found to be non-serializable. "
  314. "Consider either removing the instantiation/imports "
  315. "of these objects or moving them into the scope of "
  316. "the trainable. "
  317. )
  318. return failure_set
  319. def _atomic_save(state: Dict, checkpoint_dir: str, file_name: str, tmp_file_name: str):
  320. """Atomically saves the state object to the checkpoint directory.
  321. This is automatically used by Tuner().fit during a Tune job.
  322. Args:
  323. state: Object state to be serialized.
  324. checkpoint_dir: Directory location for the checkpoint.
  325. file_name: Final name of file.
  326. tmp_file_name: Temporary name of file. We prepend a .uuid- prefix.
  327. """
  328. import ray.cloudpickle as cloudpickle
  329. tmp_search_ckpt_path = os.path.join(
  330. checkpoint_dir, f".{str(uuid.uuid4())}-{tmp_file_name}"
  331. )
  332. with open(tmp_search_ckpt_path, "wb") as f:
  333. cloudpickle.dump(state, f)
  334. os.replace(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
  335. def _load_newest_checkpoint(dirpath: str, ckpt_pattern: str) -> Optional[Dict]:
  336. """Returns the most recently modified checkpoint.
  337. Assumes files are saved with an ordered name, most likely by
  338. :obj:atomic_save.
  339. Args:
  340. dirpath: Directory in which to look for the checkpoint file.
  341. ckpt_pattern: File name pattern to match to find checkpoint
  342. files.
  343. Returns:
  344. (dict) Deserialized state dict.
  345. """
  346. import ray.cloudpickle as cloudpickle
  347. full_paths = glob.glob(os.path.join(dirpath, ckpt_pattern))
  348. if not full_paths:
  349. return
  350. most_recent_checkpoint = max(full_paths)
  351. with open(most_recent_checkpoint, "rb") as f:
  352. checkpoint_state = cloudpickle.load(f)
  353. return checkpoint_state
  354. @PublicAPI(stability="beta")
  355. def wait_for_gpu(
  356. gpu_id: Optional[Union[int, str]] = None,
  357. target_util: float = 0.01,
  358. retry: int = 20,
  359. delay_s: int = 5,
  360. gpu_memory_limit: Optional[float] = None,
  361. ):
  362. """Checks if a given GPU has freed memory.
  363. Requires ``gputil`` to be installed: ``pip install gputil``.
  364. Args:
  365. gpu_id: GPU id or uuid to check.
  366. Must be found within GPUtil.getGPUs(). If none, resorts to
  367. the first item returned from `ray.get_gpu_ids()`.
  368. target_util: The utilization threshold to reach to unblock.
  369. Set this to 0 to block until the GPU is completely free.
  370. retry: Number of times to check GPU limit. Sleeps `delay_s`
  371. seconds between checks.
  372. delay_s: Seconds to wait before check.
  373. Returns:
  374. bool: True if free.
  375. Raises:
  376. RuntimeError: If GPUtil is not found, if no GPUs are detected
  377. or if the check fails.
  378. Example:
  379. .. code-block:: python
  380. def tune_func(config):
  381. tune.utils.wait_for_gpu()
  382. train()
  383. tuner = tune.Tuner(
  384. tune.with_resources(
  385. tune_func,
  386. resources={"gpu": 1}
  387. ),
  388. tune_config=tune.TuneConfig(num_samples=10)
  389. )
  390. tuner.fit()
  391. """
  392. GPUtil = _import_gputil()
  393. if GPUtil is None:
  394. raise RuntimeError("GPUtil must be installed if calling `wait_for_gpu`.")
  395. if gpu_id is None:
  396. gpu_id_list = ray.get_gpu_ids()
  397. if not gpu_id_list:
  398. raise RuntimeError(
  399. "No GPU ids found from `ray.get_gpu_ids()`. "
  400. "Did you set Tune resources correctly?"
  401. )
  402. gpu_id = gpu_id_list[0]
  403. gpu_attr = "id"
  404. if isinstance(gpu_id, str):
  405. if gpu_id.isdigit():
  406. # GPU ID returned from `ray.get_gpu_ids()` is a str representation
  407. # of the int GPU ID
  408. gpu_id = int(gpu_id)
  409. else:
  410. # Could not coerce gpu_id to int, so assume UUID
  411. # and compare against `uuid` attribute e.g.,
  412. # 'GPU-04546190-b68d-65ac-101b-035f8faed77d'
  413. gpu_attr = "uuid"
  414. elif not isinstance(gpu_id, int):
  415. raise ValueError(f"gpu_id ({type(gpu_id)}) must be type str/int.")
  416. def gpu_id_fn(g):
  417. # Returns either `g.id` or `g.uuid` depending on
  418. # the format of the input `gpu_id`
  419. return getattr(g, gpu_attr)
  420. gpu_ids = {gpu_id_fn(g) for g in GPUtil.getGPUs()}
  421. if gpu_id not in gpu_ids:
  422. raise ValueError(
  423. f"{gpu_id} not found in set of available GPUs: {gpu_ids}. "
  424. "`wait_for_gpu` takes either GPU ordinal ID (e.g., '0') or "
  425. "UUID (e.g., 'GPU-04546190-b68d-65ac-101b-035f8faed77d')."
  426. )
  427. for i in range(int(retry)):
  428. gpu_object = next(g for g in GPUtil.getGPUs() if gpu_id_fn(g) == gpu_id)
  429. if gpu_object.memoryUtil > target_util:
  430. logger.info(
  431. f"Waiting for GPU util to reach {target_util}. "
  432. f"Util: {gpu_object.memoryUtil:0.3f}"
  433. )
  434. time.sleep(delay_s)
  435. else:
  436. return True
  437. raise RuntimeError("GPU memory was not freed.")
  438. @DeveloperAPI
  439. def validate_save_restore(
  440. trainable_cls: Type,
  441. config: Optional[Dict] = None,
  442. num_gpus: int = 0,
  443. ):
  444. """Helper method to check if your Trainable class will resume correctly.
  445. Args:
  446. trainable_cls: Trainable class for evaluation.
  447. config: Config to pass to Trainable when testing.
  448. num_gpus: GPU resources to allocate when testing.
  449. use_object_store: Whether to save and restore to Ray's object
  450. store. Recommended to set this to True if planning to use
  451. algorithms that pause training (i.e., PBT, HyperBand).
  452. """
  453. assert ray.is_initialized(), "Need Ray to be initialized."
  454. remote_cls = ray.remote(num_gpus=num_gpus)(trainable_cls)
  455. trainable_1 = remote_cls.remote(config=config)
  456. trainable_2 = remote_cls.remote(config=config)
  457. from ray.air.constants import TRAINING_ITERATION
  458. for _ in range(3):
  459. res = ray.get(trainable_1.train.remote())
  460. assert res.get(TRAINING_ITERATION), (
  461. "Validation will not pass because it requires `training_iteration` "
  462. "to be returned."
  463. )
  464. ray.get(trainable_2.restore.remote(trainable_1.save.remote()))
  465. res = ray.get(trainable_2.train.remote())
  466. assert res[TRAINING_ITERATION] == 4
  467. res = ray.get(trainable_2.train.remote())
  468. assert res[TRAINING_ITERATION] == 5
  469. return True
  470. def _detect_config_single(func):
  471. """Check if func({}) works."""
  472. func_sig = inspect.signature(func)
  473. use_config_single = True
  474. try:
  475. func_sig.bind({})
  476. except Exception as e:
  477. logger.debug(str(e))
  478. use_config_single = False
  479. return use_config_single
  480. @PublicAPI()
  481. def validate_warmstart(
  482. parameter_names: List[str],
  483. points_to_evaluate: List[Union[List, Dict]],
  484. evaluated_rewards: List,
  485. validate_point_name_lengths: bool = True,
  486. ):
  487. """Generic validation of a Searcher's warm start functionality.
  488. Raises exceptions in case of type and length mismatches between
  489. parameters.
  490. If ``validate_point_name_lengths`` is False, the equality of lengths
  491. between ``points_to_evaluate`` and ``parameter_names`` will not be
  492. validated.
  493. """
  494. if points_to_evaluate:
  495. if not isinstance(points_to_evaluate, list):
  496. raise TypeError(
  497. "points_to_evaluate expected to be a list, got {}.".format(
  498. type(points_to_evaluate)
  499. )
  500. )
  501. for point in points_to_evaluate:
  502. if not isinstance(point, (dict, list)):
  503. raise TypeError(
  504. f"points_to_evaluate expected to include list or dict, "
  505. f"got {point}."
  506. )
  507. if validate_point_name_lengths and (not len(point) == len(parameter_names)):
  508. raise ValueError(
  509. "Dim of point {}".format(point)
  510. + " and parameter_names {}".format(parameter_names)
  511. + " do not match."
  512. )
  513. if points_to_evaluate and evaluated_rewards:
  514. if not isinstance(evaluated_rewards, list):
  515. raise TypeError(
  516. "evaluated_rewards expected to be a list, got {}.".format(
  517. type(evaluated_rewards)
  518. )
  519. )
  520. if not len(evaluated_rewards) == len(points_to_evaluate):
  521. raise ValueError(
  522. "Dim of evaluated_rewards {}".format(evaluated_rewards)
  523. + " and points_to_evaluate {}".format(points_to_evaluate)
  524. + " do not match."
  525. )