progress_reporter.py 57 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596
  1. from __future__ import print_function
  2. import collections
  3. import datetime
  4. import numbers
  5. import sys
  6. import textwrap
  7. import time
  8. import warnings
  9. from pathlib import Path
  10. from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Union
  11. import numpy as np
  12. import pandas as pd
  13. import ray
  14. from ray._private.dict import flatten_dict
  15. from ray._private.thirdparty.tabulate.tabulate import tabulate
  16. from ray.air.constants import EXPR_ERROR_FILE, TRAINING_ITERATION
  17. from ray.air.util.node import _force_on_current_node
  18. from ray.experimental.tqdm_ray import safe_print
  19. from ray.tune.callback import Callback
  20. from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location
  21. from ray.tune.logger import pretty_print
  22. from ray.tune.result import (
  23. AUTO_RESULT_KEYS,
  24. DEFAULT_METRIC,
  25. DONE,
  26. EPISODE_REWARD_MEAN,
  27. EXPERIMENT_TAG,
  28. MEAN_ACCURACY,
  29. MEAN_LOSS,
  30. NODE_IP,
  31. PID,
  32. TIME_TOTAL_S,
  33. TIMESTEPS_TOTAL,
  34. TRIAL_ID,
  35. )
  36. from ray.tune.trainable import Trainable
  37. from ray.tune.utils import unflattened_lookup
  38. from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
  39. from ray.util.annotations import DeveloperAPI, PublicAPI
  40. from ray.util.queue import Empty, Queue
  41. from ray.widgets import Template
  42. try:
  43. from collections.abc import Mapping, MutableMapping
  44. except ImportError:
  45. from collections import Mapping, MutableMapping
  46. IS_NOTEBOOK = ray.widgets.util.in_notebook()
  47. SKIP_RESULTS_IN_REPORT = {"config", TRIAL_ID, EXPERIMENT_TAG, DONE}
  48. @PublicAPI
  49. class ProgressReporter:
  50. """Abstract class for experiment progress reporting.
  51. `should_report()` is called to determine whether or not `report()` should
  52. be called. Tune will call these functions after trial state transitions,
  53. receiving training results, and so on.
  54. """
  55. def setup(
  56. self,
  57. start_time: Optional[float] = None,
  58. total_samples: Optional[int] = None,
  59. metric: Optional[str] = None,
  60. mode: Optional[str] = None,
  61. **kwargs,
  62. ):
  63. """Setup progress reporter for a new Ray Tune run.
  64. This function is used to initialize parameters that are set on runtime.
  65. It will be called before any of the other methods.
  66. Defaults to no-op.
  67. Args:
  68. start_time: Timestamp when the Ray Tune run is started.
  69. total_samples: Number of samples the Ray Tune run will run.
  70. metric: Metric to optimize.
  71. mode: Must be one of [min, max]. Determines whether objective is
  72. minimizing or maximizing the metric attribute.
  73. **kwargs: Keyword arguments for forward-compatibility.
  74. """
  75. pass
  76. def should_report(self, trials: List[Trial], done: bool = False):
  77. """Returns whether or not progress should be reported.
  78. Args:
  79. trials: Trials to report on.
  80. done: Whether this is the last progress report attempt.
  81. """
  82. raise NotImplementedError
  83. def report(self, trials: List[Trial], done: bool, *sys_info: Dict):
  84. """Reports progress across trials.
  85. Args:
  86. trials: Trials to report on.
  87. done: Whether this is the last progress report attempt.
  88. sys_info: System info.
  89. """
  90. raise NotImplementedError
  91. @DeveloperAPI
  92. class TuneReporterBase(ProgressReporter):
  93. """Abstract base class for the default Tune reporters.
  94. If metric_columns is not overridden, Tune will attempt to automatically
  95. infer the metrics being outputted, up to 'infer_limit' number of
  96. metrics.
  97. Args:
  98. metric_columns: Names of metrics to
  99. include in progress table. If this is a dict, the keys should
  100. be metric names and the values should be the displayed names.
  101. If this is a list, the metric name is used directly.
  102. parameter_columns: Names of parameters to
  103. include in progress table. If this is a dict, the keys should
  104. be parameter names and the values should be the displayed names.
  105. If this is a list, the parameter name is used directly. If empty,
  106. defaults to all available parameters.
  107. max_progress_rows: Maximum number of rows to print
  108. in the progress table. The progress table describes the
  109. progress of each trial. Defaults to 20.
  110. max_error_rows: Maximum number of rows to print in the
  111. error table. The error table lists the error file, if any,
  112. corresponding to each trial. Defaults to 20.
  113. max_column_length: Maximum column length (in characters). Column
  114. headers and values longer than this will be abbreviated.
  115. max_report_frequency: Maximum report frequency in seconds.
  116. Defaults to 5s.
  117. infer_limit: Maximum number of metrics to automatically infer
  118. from tune results.
  119. print_intermediate_tables: Print intermediate result
  120. tables. If None (default), will be set to True for verbosity
  121. levels above 3, otherwise False. If True, intermediate tables
  122. will be printed with experiment progress. If False, tables
  123. will only be printed at then end of the tuning run for verbosity
  124. levels greater than 2.
  125. metric: Metric used to determine best current trial.
  126. mode: One of [min, max]. Determines whether objective is
  127. minimizing or maximizing the metric attribute.
  128. sort_by_metric: Sort terminated trials by metric in the
  129. intermediate table. Defaults to False.
  130. """
  131. # Truncated representations of column names (to accommodate small screens).
  132. DEFAULT_COLUMNS = collections.OrderedDict(
  133. {
  134. MEAN_ACCURACY: "acc",
  135. MEAN_LOSS: "loss",
  136. TRAINING_ITERATION: "iter",
  137. TIME_TOTAL_S: "total time (s)",
  138. TIMESTEPS_TOTAL: "ts",
  139. EPISODE_REWARD_MEAN: "reward",
  140. }
  141. )
  142. VALID_SUMMARY_TYPES = {
  143. int,
  144. float,
  145. np.float32,
  146. np.float64,
  147. np.int32,
  148. np.int64,
  149. type(None),
  150. }
  151. def __init__(
  152. self,
  153. *,
  154. metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
  155. parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
  156. total_samples: Optional[int] = None,
  157. max_progress_rows: int = 20,
  158. max_error_rows: int = 20,
  159. max_column_length: int = 20,
  160. max_report_frequency: int = 5,
  161. infer_limit: int = 3,
  162. print_intermediate_tables: Optional[bool] = None,
  163. metric: Optional[str] = None,
  164. mode: Optional[str] = None,
  165. sort_by_metric: bool = False,
  166. ):
  167. self._total_samples = total_samples
  168. self._metrics_override = metric_columns is not None
  169. self._inferred_metrics = {}
  170. self._metric_columns = metric_columns or self.DEFAULT_COLUMNS.copy()
  171. self._parameter_columns = parameter_columns or []
  172. self._max_progress_rows = max_progress_rows
  173. self._max_error_rows = max_error_rows
  174. self._max_column_length = max_column_length
  175. self._infer_limit = infer_limit
  176. if print_intermediate_tables is None:
  177. self._print_intermediate_tables = has_verbosity(Verbosity.V3_TRIAL_DETAILS)
  178. else:
  179. self._print_intermediate_tables = print_intermediate_tables
  180. self._max_report_freqency = max_report_frequency
  181. self._last_report_time = 0
  182. self._start_time = time.time()
  183. self._metric = metric
  184. self._mode = mode
  185. self._sort_by_metric = sort_by_metric
  186. def setup(
  187. self,
  188. start_time: Optional[float] = None,
  189. total_samples: Optional[int] = None,
  190. metric: Optional[str] = None,
  191. mode: Optional[str] = None,
  192. **kwargs,
  193. ):
  194. self.set_start_time(start_time)
  195. self.set_total_samples(total_samples)
  196. self.set_search_properties(metric=metric, mode=mode)
  197. def set_search_properties(self, metric: Optional[str], mode: Optional[str]):
  198. if (self._metric and metric) or (self._mode and mode):
  199. raise ValueError(
  200. "You passed a `metric` or `mode` argument to `tune.TuneConfig()`, but "
  201. "the reporter you are using was already instantiated with their "
  202. "own `metric` and `mode` parameters. Either remove the arguments "
  203. "from your reporter or from your call to `tune.TuneConfig()`"
  204. )
  205. if metric:
  206. self._metric = metric
  207. if mode:
  208. self._mode = mode
  209. if self._metric is None and self._mode:
  210. # If only a mode was passed, use anonymous metric
  211. self._metric = DEFAULT_METRIC
  212. return True
  213. def set_total_samples(self, total_samples: int):
  214. self._total_samples = total_samples
  215. def set_start_time(self, timestamp: Optional[float] = None):
  216. if timestamp is not None:
  217. self._start_time = time.time()
  218. else:
  219. self._start_time = timestamp
  220. def should_report(self, trials: List[Trial], done: bool = False):
  221. if time.time() - self._last_report_time > self._max_report_freqency:
  222. self._last_report_time = time.time()
  223. return True
  224. return done
  225. def add_metric_column(self, metric: str, representation: Optional[str] = None):
  226. """Adds a metric to the existing columns.
  227. Args:
  228. metric: Metric to add. This must be a metric being returned
  229. in training step results.
  230. representation: Representation to use in table. Defaults to
  231. `metric`.
  232. """
  233. self._metrics_override = True
  234. if metric in self._metric_columns:
  235. raise ValueError("Column {} already exists.".format(metric))
  236. if isinstance(self._metric_columns, MutableMapping):
  237. representation = representation or metric
  238. self._metric_columns[metric] = representation
  239. else:
  240. if representation is not None and representation != metric:
  241. raise ValueError(
  242. "`representation` cannot differ from `metric` "
  243. "if this reporter was initialized with a list "
  244. "of metric columns."
  245. )
  246. self._metric_columns.append(metric)
  247. def add_parameter_column(
  248. self, parameter: str, representation: Optional[str] = None
  249. ):
  250. """Adds a parameter to the existing columns.
  251. Args:
  252. parameter: Parameter to add. This must be a parameter
  253. specified in the configuration.
  254. representation: Representation to use in table. Defaults to
  255. `parameter`.
  256. """
  257. if parameter in self._parameter_columns:
  258. raise ValueError("Column {} already exists.".format(parameter))
  259. if isinstance(self._parameter_columns, MutableMapping):
  260. representation = representation or parameter
  261. self._parameter_columns[parameter] = representation
  262. else:
  263. if representation is not None and representation != parameter:
  264. raise ValueError(
  265. "`representation` cannot differ from `parameter` "
  266. "if this reporter was initialized with a list "
  267. "of metric columns."
  268. )
  269. self._parameter_columns.append(parameter)
  270. def _progress_str(
  271. self,
  272. trials: List[Trial],
  273. done: bool,
  274. *sys_info: Dict,
  275. fmt: str = "psql",
  276. delim: str = "\n",
  277. ):
  278. """Returns full progress string.
  279. This string contains a progress table and error table. The progress
  280. table describes the progress of each trial. The error table lists
  281. the error file, if any, corresponding to each trial. The latter only
  282. exists if errors have occurred.
  283. Args:
  284. trials: Trials to report on.
  285. done: Whether this is the last progress report attempt.
  286. fmt: Table format. See `tablefmt` in tabulate API.
  287. delim: Delimiter between messages.
  288. """
  289. if self._sort_by_metric and (self._metric is None or self._mode is None):
  290. self._sort_by_metric = False
  291. warnings.warn(
  292. "Both 'metric' and 'mode' must be set to be able "
  293. "to sort by metric. No sorting is performed."
  294. )
  295. if not self._metrics_override:
  296. user_metrics = self._infer_user_metrics(trials, self._infer_limit)
  297. self._metric_columns.update(user_metrics)
  298. messages = [
  299. "== Status ==",
  300. _time_passed_str(self._start_time, time.time()),
  301. *sys_info,
  302. ]
  303. if done:
  304. max_progress = None
  305. max_error = None
  306. else:
  307. max_progress = self._max_progress_rows
  308. max_error = self._max_error_rows
  309. current_best_trial, metric = self._current_best_trial(trials)
  310. if current_best_trial:
  311. messages.append(
  312. _best_trial_str(current_best_trial, metric, self._parameter_columns)
  313. )
  314. if has_verbosity(Verbosity.V1_EXPERIMENT):
  315. # Will filter the table in `trial_progress_str`
  316. messages.append(
  317. _trial_progress_str(
  318. trials,
  319. metric_columns=self._metric_columns,
  320. parameter_columns=self._parameter_columns,
  321. total_samples=self._total_samples,
  322. force_table=self._print_intermediate_tables,
  323. fmt=fmt,
  324. max_rows=max_progress,
  325. max_column_length=self._max_column_length,
  326. done=done,
  327. metric=self._metric,
  328. mode=self._mode,
  329. sort_by_metric=self._sort_by_metric,
  330. )
  331. )
  332. messages.append(_trial_errors_str(trials, fmt=fmt, max_rows=max_error))
  333. return delim.join(messages) + delim
  334. def _infer_user_metrics(self, trials: List[Trial], limit: int = 4):
  335. """Try to infer the metrics to print out."""
  336. if len(self._inferred_metrics) >= limit:
  337. return self._inferred_metrics
  338. self._inferred_metrics = {}
  339. for t in trials:
  340. if not t.last_result:
  341. continue
  342. for metric, value in t.last_result.items():
  343. if metric not in self.DEFAULT_COLUMNS:
  344. if metric not in AUTO_RESULT_KEYS:
  345. if type(value) in self.VALID_SUMMARY_TYPES:
  346. self._inferred_metrics[metric] = metric
  347. if len(self._inferred_metrics) >= limit:
  348. return self._inferred_metrics
  349. return self._inferred_metrics
  350. def _current_best_trial(self, trials: List[Trial]):
  351. if not trials:
  352. return None, None
  353. metric, mode = self._metric, self._mode
  354. # If no metric has been set, see if exactly one has been reported
  355. # and use that one. `mode` must still be set.
  356. if not metric:
  357. if len(self._inferred_metrics) == 1:
  358. metric = list(self._inferred_metrics.keys())[0]
  359. if not metric or not mode:
  360. return None, metric
  361. metric_op = 1.0 if mode == "max" else -1.0
  362. best_metric = float("-inf")
  363. best_trial = None
  364. for t in trials:
  365. if not t.last_result:
  366. continue
  367. metric_value = unflattened_lookup(metric, t.last_result, default=None)
  368. if pd.isnull(metric_value):
  369. continue
  370. if not best_trial or metric_value * metric_op > best_metric:
  371. best_metric = metric_value * metric_op
  372. best_trial = t
  373. return best_trial, metric
  374. @DeveloperAPI
  375. class RemoteReporterMixin:
  376. """Remote reporter abstract mixin class.
  377. Subclasses of this class will use a Ray Queue to display output
  378. on the driver side when running Ray Client."""
  379. @property
  380. def output_queue(self) -> Queue:
  381. return getattr(self, "_output_queue", None)
  382. @output_queue.setter
  383. def output_queue(self, value: Queue):
  384. self._output_queue = value
  385. def display(self, string: str) -> None:
  386. """Display the progress string.
  387. Args:
  388. string: String to display.
  389. """
  390. raise NotImplementedError
  391. @PublicAPI
  392. class JupyterNotebookReporter(TuneReporterBase, RemoteReporterMixin):
  393. """Jupyter notebook-friendly Reporter that can update display in-place.
  394. Args:
  395. overwrite: Flag for overwriting the cell contents before initialization.
  396. metric_columns: Names of metrics to
  397. include in progress table. If this is a dict, the keys should
  398. be metric names and the values should be the displayed names.
  399. If this is a list, the metric name is used directly.
  400. parameter_columns: Names of parameters to
  401. include in progress table. If this is a dict, the keys should
  402. be parameter names and the values should be the displayed names.
  403. If this is a list, the parameter name is used directly. If empty,
  404. defaults to all available parameters.
  405. max_progress_rows: Maximum number of rows to print
  406. in the progress table. The progress table describes the
  407. progress of each trial. Defaults to 20.
  408. max_error_rows: Maximum number of rows to print in the
  409. error table. The error table lists the error file, if any,
  410. corresponding to each trial. Defaults to 20.
  411. max_column_length: Maximum column length (in characters). Column
  412. headers and values longer than this will be abbreviated.
  413. max_report_frequency: Maximum report frequency in seconds.
  414. Defaults to 5s.
  415. infer_limit: Maximum number of metrics to automatically infer
  416. from tune results.
  417. print_intermediate_tables: Print intermediate result
  418. tables. If None (default), will be set to True for verbosity
  419. levels above 3, otherwise False. If True, intermediate tables
  420. will be printed with experiment progress. If False, tables
  421. will only be printed at then end of the tuning run for verbosity
  422. levels greater than 2.
  423. metric: Metric used to determine best current trial.
  424. mode: One of [min, max]. Determines whether objective is
  425. minimizing or maximizing the metric attribute.
  426. sort_by_metric: Sort terminated trials by metric in the
  427. intermediate table. Defaults to False.
  428. """
  429. def __init__(
  430. self,
  431. *,
  432. overwrite: bool = True,
  433. metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
  434. parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
  435. total_samples: Optional[int] = None,
  436. max_progress_rows: int = 20,
  437. max_error_rows: int = 20,
  438. max_column_length: int = 20,
  439. max_report_frequency: int = 5,
  440. infer_limit: int = 3,
  441. print_intermediate_tables: Optional[bool] = None,
  442. metric: Optional[str] = None,
  443. mode: Optional[str] = None,
  444. sort_by_metric: bool = False,
  445. ):
  446. super(JupyterNotebookReporter, self).__init__(
  447. metric_columns=metric_columns,
  448. parameter_columns=parameter_columns,
  449. total_samples=total_samples,
  450. max_progress_rows=max_progress_rows,
  451. max_error_rows=max_error_rows,
  452. max_column_length=max_column_length,
  453. max_report_frequency=max_report_frequency,
  454. infer_limit=infer_limit,
  455. print_intermediate_tables=print_intermediate_tables,
  456. metric=metric,
  457. mode=mode,
  458. sort_by_metric=sort_by_metric,
  459. )
  460. if not IS_NOTEBOOK:
  461. warnings.warn(
  462. "You are using the `JupyterNotebookReporter`, but not "
  463. "IPython/Jupyter-compatible environment was detected. "
  464. "If this leads to unformatted output (e.g. like "
  465. "<IPython.core.display.HTML object>), consider passing "
  466. "a `CLIReporter` as the `progress_reporter` argument "
  467. "to `tune.RunConfig()` instead."
  468. )
  469. self._overwrite = overwrite
  470. self._display_handle = None
  471. self.display("") # initialize empty display to update later
  472. def report(self, trials: List[Trial], done: bool, *sys_info: Dict):
  473. progress = self._progress_html(trials, done, *sys_info)
  474. if self.output_queue is not None:
  475. # If an output queue is set, send string
  476. self.output_queue.put(progress)
  477. else:
  478. # Else, output directly
  479. self.display(progress)
  480. def display(self, string: str) -> None:
  481. from IPython.display import HTML, clear_output, display
  482. if not self._display_handle:
  483. if self._overwrite:
  484. clear_output(wait=True)
  485. self._display_handle = display(HTML(string), display_id=True)
  486. else:
  487. self._display_handle.update(HTML(string))
  488. def _progress_html(self, trials: List[Trial], done: bool, *sys_info) -> str:
  489. """Generate an HTML-formatted progress update.
  490. Args:
  491. trials: List of trials for which progress should be
  492. displayed
  493. done: True if the trials are finished, False otherwise
  494. *sys_info: System information to be displayed
  495. Returns:
  496. Progress update to be rendered in a notebook, including HTML
  497. tables and formatted error messages. Includes
  498. - Duration of the tune job
  499. - Memory consumption
  500. - Trial progress table, with information about each experiment
  501. """
  502. if not self._metrics_override:
  503. user_metrics = self._infer_user_metrics(trials, self._infer_limit)
  504. self._metric_columns.update(user_metrics)
  505. current_time, running_for = _get_time_str(self._start_time, time.time())
  506. used_gb, total_gb, memory_message = _get_memory_usage()
  507. status_table = tabulate(
  508. [
  509. ("Current time:", current_time),
  510. ("Running for:", running_for),
  511. ("Memory:", f"{used_gb}/{total_gb} GiB"),
  512. ],
  513. tablefmt="html",
  514. )
  515. trial_progress_data = _trial_progress_table(
  516. trials=trials,
  517. metric_columns=self._metric_columns,
  518. parameter_columns=self._parameter_columns,
  519. fmt="html",
  520. max_rows=None if done else self._max_progress_rows,
  521. metric=self._metric,
  522. mode=self._mode,
  523. sort_by_metric=self._sort_by_metric,
  524. max_column_length=self._max_column_length,
  525. )
  526. trial_progress = trial_progress_data[0]
  527. trial_progress_messages = trial_progress_data[1:]
  528. trial_errors = _trial_errors_str(
  529. trials, fmt="html", max_rows=None if done else self._max_error_rows
  530. )
  531. if any([memory_message, trial_progress_messages, trial_errors]):
  532. msg = Template("tune_status_messages.html.j2").render(
  533. memory_message=memory_message,
  534. trial_progress_messages=trial_progress_messages,
  535. trial_errors=trial_errors,
  536. )
  537. else:
  538. msg = None
  539. return Template("tune_status.html.j2").render(
  540. status_table=status_table,
  541. sys_info_message=_generate_sys_info_str(*sys_info),
  542. trial_progress=trial_progress,
  543. messages=msg,
  544. )
  545. @PublicAPI
  546. class CLIReporter(TuneReporterBase):
  547. """Command-line reporter
  548. Args:
  549. metric_columns: Names of metrics to
  550. include in progress table. If this is a dict, the keys should
  551. be metric names and the values should be the displayed names.
  552. If this is a list, the metric name is used directly.
  553. parameter_columns: Names of parameters to
  554. include in progress table. If this is a dict, the keys should
  555. be parameter names and the values should be the displayed names.
  556. If this is a list, the parameter name is used directly. If empty,
  557. defaults to all available parameters.
  558. max_progress_rows: Maximum number of rows to print
  559. in the progress table. The progress table describes the
  560. progress of each trial. Defaults to 20.
  561. max_error_rows: Maximum number of rows to print in the
  562. error table. The error table lists the error file, if any,
  563. corresponding to each trial. Defaults to 20.
  564. max_column_length: Maximum column length (in characters). Column
  565. headers and values longer than this will be abbreviated.
  566. max_report_frequency: Maximum report frequency in seconds.
  567. Defaults to 5s.
  568. infer_limit: Maximum number of metrics to automatically infer
  569. from tune results.
  570. print_intermediate_tables: Print intermediate result
  571. tables. If None (default), will be set to True for verbosity
  572. levels above 3, otherwise False. If True, intermediate tables
  573. will be printed with experiment progress. If False, tables
  574. will only be printed at then end of the tuning run for verbosity
  575. levels greater than 2.
  576. metric: Metric used to determine best current trial.
  577. mode: One of [min, max]. Determines whether objective is
  578. minimizing or maximizing the metric attribute.
  579. sort_by_metric: Sort terminated trials by metric in the
  580. intermediate table. Defaults to False.
  581. """
  582. def __init__(
  583. self,
  584. *,
  585. metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
  586. parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
  587. total_samples: Optional[int] = None,
  588. max_progress_rows: int = 20,
  589. max_error_rows: int = 20,
  590. max_column_length: int = 20,
  591. max_report_frequency: int = 5,
  592. infer_limit: int = 3,
  593. print_intermediate_tables: Optional[bool] = None,
  594. metric: Optional[str] = None,
  595. mode: Optional[str] = None,
  596. sort_by_metric: bool = False,
  597. ):
  598. super(CLIReporter, self).__init__(
  599. metric_columns=metric_columns,
  600. parameter_columns=parameter_columns,
  601. total_samples=total_samples,
  602. max_progress_rows=max_progress_rows,
  603. max_error_rows=max_error_rows,
  604. max_column_length=max_column_length,
  605. max_report_frequency=max_report_frequency,
  606. infer_limit=infer_limit,
  607. print_intermediate_tables=print_intermediate_tables,
  608. metric=metric,
  609. mode=mode,
  610. sort_by_metric=sort_by_metric,
  611. )
  612. def _print(self, msg: str):
  613. safe_print(msg)
  614. def report(self, trials: List[Trial], done: bool, *sys_info: Dict):
  615. self._print(self._progress_str(trials, done, *sys_info))
  616. def _get_memory_usage() -> Tuple[float, float, Optional[str]]:
  617. """Get the current memory consumption.
  618. Returns:
  619. Memory used, memory available, and optionally a warning
  620. message to be shown to the user when memory consumption is higher
  621. than 90% or if `psutil` is not installed
  622. """
  623. try:
  624. import ray # noqa F401
  625. import psutil
  626. total_gb = psutil.virtual_memory().total / (1024**3)
  627. used_gb = total_gb - psutil.virtual_memory().available / (1024**3)
  628. if used_gb > total_gb * 0.9:
  629. message = (
  630. ": ***LOW MEMORY*** less than 10% of the memory on "
  631. "this node is available for use. This can cause "
  632. "unexpected crashes. Consider "
  633. "reducing the memory used by your application "
  634. "or reducing the Ray object store size by setting "
  635. "`object_store_memory` when calling `ray.init`."
  636. )
  637. else:
  638. message = None
  639. return round(used_gb, 1), round(total_gb, 1), message
  640. except ImportError:
  641. return (
  642. np.nan,
  643. np.nan,
  644. "Unknown memory usage. Please run `pip install psutil` to resolve",
  645. )
  646. def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]:
  647. """Get strings representing the current and elapsed time.
  648. Args:
  649. start_time: POSIX timestamp of the start of the tune run
  650. current_time: POSIX timestamp giving the current time
  651. Returns:
  652. Current time and elapsed time for the current run
  653. """
  654. current_time_dt = datetime.datetime.fromtimestamp(current_time)
  655. start_time_dt = datetime.datetime.fromtimestamp(start_time)
  656. delta: datetime.timedelta = current_time_dt - start_time_dt
  657. rest = delta.total_seconds()
  658. days = rest // (60 * 60 * 24)
  659. rest -= days * (60 * 60 * 24)
  660. hours = rest // (60 * 60)
  661. rest -= hours * (60 * 60)
  662. minutes = rest // 60
  663. seconds = rest - minutes * 60
  664. if days > 0:
  665. running_for_str = f"{days:.0f} days, "
  666. else:
  667. running_for_str = ""
  668. running_for_str += f"{hours:02.0f}:{minutes:02.0f}:{seconds:05.2f}"
  669. return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str
  670. def _time_passed_str(start_time: float, current_time: float) -> str:
  671. """Generate a message describing the current and elapsed time in the run.
  672. Args:
  673. start_time: POSIX timestamp of the start of the tune run
  674. current_time: POSIX timestamp giving the current time
  675. Returns:
  676. Message with the current and elapsed time for the current tune run,
  677. formatted to be displayed to the user
  678. """
  679. current_time_str, running_for_str = _get_time_str(start_time, current_time)
  680. return f"Current time: {current_time_str} " f"(running for {running_for_str})"
  681. def _get_trials_by_state(trials: List[Trial]):
  682. trials_by_state = collections.defaultdict(list)
  683. for t in trials:
  684. trials_by_state[t.status].append(t)
  685. return trials_by_state
  686. def _trial_progress_str(
  687. trials: List[Trial],
  688. metric_columns: Union[List[str], Dict[str, str]],
  689. parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
  690. total_samples: int = 0,
  691. force_table: bool = False,
  692. fmt: str = "psql",
  693. max_rows: Optional[int] = None,
  694. max_column_length: int = 20,
  695. done: bool = False,
  696. metric: Optional[str] = None,
  697. mode: Optional[str] = None,
  698. sort_by_metric: bool = False,
  699. ):
  700. """Returns a human readable message for printing to the console.
  701. This contains a table where each row represents a trial, its parameters
  702. and the current values of its metrics.
  703. Args:
  704. trials: List of trials to get progress string for.
  705. metric_columns: Names of metrics to include.
  706. If this is a dict, the keys are metric names and the values are
  707. the names to use in the message. If this is a list, the metric
  708. name is used in the message directly.
  709. parameter_columns: Names of parameters to
  710. include. If this is a dict, the keys are parameter names and the
  711. values are the names to use in the message. If this is a list,
  712. the parameter name is used in the message directly. If this is
  713. empty, all parameters are used in the message.
  714. total_samples: Total number of trials that will be generated.
  715. force_table: Force printing a table. If False, a table will
  716. be printed only at the end of the training for verbosity levels
  717. above `Verbosity.V2_TRIAL_NORM`.
  718. fmt: Output format (see tablefmt in tabulate API).
  719. max_rows: Maximum number of rows in the trial table. Defaults to
  720. unlimited.
  721. max_column_length: Maximum column length (in characters).
  722. done: True indicates that the tuning run finished.
  723. metric: Metric used to sort trials.
  724. mode: One of [min, max]. Determines whether objective is
  725. minimizing or maximizing the metric attribute.
  726. sort_by_metric: Sort terminated trials by metric in the
  727. intermediate table. Defaults to False.
  728. """
  729. messages = []
  730. delim = "<br>" if fmt == "html" else "\n"
  731. if len(trials) < 1:
  732. return delim.join(messages)
  733. num_trials = len(trials)
  734. trials_by_state = _get_trials_by_state(trials)
  735. for local_dir in sorted({t.local_experiment_path for t in trials}):
  736. messages.append("Result logdir: {}".format(local_dir))
  737. num_trials_strs = [
  738. "{} {}".format(len(trials_by_state[state]), state)
  739. for state in sorted(trials_by_state)
  740. ]
  741. if total_samples and total_samples >= sys.maxsize:
  742. total_samples = "infinite"
  743. messages.append(
  744. "Number of trials: {}{} ({})".format(
  745. num_trials,
  746. f"/{total_samples}" if total_samples else "",
  747. ", ".join(num_trials_strs),
  748. )
  749. )
  750. if force_table or (has_verbosity(Verbosity.V2_TRIAL_NORM) and done):
  751. messages += _trial_progress_table(
  752. trials=trials,
  753. metric_columns=metric_columns,
  754. parameter_columns=parameter_columns,
  755. fmt=fmt,
  756. max_rows=max_rows,
  757. metric=metric,
  758. mode=mode,
  759. sort_by_metric=sort_by_metric,
  760. max_column_length=max_column_length,
  761. )
  762. return delim.join(messages)
  763. def _max_len(
  764. value: Any, max_len: int = 20, add_addr: bool = False, wrap: bool = False
  765. ) -> Any:
  766. """Abbreviate a string representation of an object to `max_len` characters.
  767. For numbers, booleans and None, the original value will be returned for
  768. correct rendering in the table formatting tool.
  769. Args:
  770. value: Object to be represented as a string.
  771. max_len: Maximum return string length.
  772. add_addr: If True, will add part of the object address to the end of the
  773. string, e.g. to identify different instances of the same class. If
  774. False, three dots (``...``) will be used instead.
  775. """
  776. if value is None or isinstance(value, (int, float, numbers.Number, bool)):
  777. return value
  778. string = str(value)
  779. if len(string) <= max_len:
  780. return string
  781. if wrap:
  782. # Maximum two rows.
  783. # Todo: Make this configurable in the refactor
  784. if len(value) > max_len * 2:
  785. value = "..." + string[(3 - (max_len * 2)) :]
  786. wrapped = textwrap.wrap(value, width=max_len)
  787. return "\n".join(wrapped)
  788. if add_addr and not isinstance(value, (int, float, bool)):
  789. result = f"{string[: (max_len - 5)]}_{hex(id(value))[-4:]}"
  790. return result
  791. result = "..." + string[(3 - max_len) :]
  792. return result
  793. def _get_progress_table_data(
  794. trials: List[Trial],
  795. metric_columns: Union[List[str], Dict[str, str]],
  796. parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
  797. max_rows: Optional[int] = None,
  798. metric: Optional[str] = None,
  799. mode: Optional[str] = None,
  800. sort_by_metric: bool = False,
  801. max_column_length: int = 20,
  802. ) -> Tuple[List, List[str], Tuple[bool, str]]:
  803. """Generate a table showing the current progress of tuning trials.
  804. Args:
  805. trials: List of trials for which progress is to be shown.
  806. metric_columns: Metrics to be displayed in the table.
  807. parameter_columns: List of parameters to be included in the data
  808. max_rows: Maximum number of rows to show. If there's overflow, a
  809. message will be shown to the user indicating that some rows
  810. are not displayed
  811. metric: Metric which is being tuned
  812. mode: Sort the table in descending order if mode is "max";
  813. ascending otherwise
  814. sort_by_metric: If true, the table will be sorted by the metric
  815. max_column_length: Max number of characters in each column
  816. Returns:
  817. - Trial data
  818. - List of column names
  819. - Overflow tuple:
  820. - boolean indicating whether the table has rows which are hidden
  821. - string with info about the overflowing rows
  822. """
  823. num_trials = len(trials)
  824. trials_by_state = _get_trials_by_state(trials)
  825. # Sort terminated trials by metric and mode, descending if mode is "max"
  826. if sort_by_metric:
  827. trials_by_state[Trial.TERMINATED] = sorted(
  828. trials_by_state[Trial.TERMINATED],
  829. reverse=(mode == "max"),
  830. key=lambda t: unflattened_lookup(metric, t.last_result, default=None),
  831. )
  832. state_tbl_order = [
  833. Trial.RUNNING,
  834. Trial.PAUSED,
  835. Trial.PENDING,
  836. Trial.TERMINATED,
  837. Trial.ERROR,
  838. ]
  839. max_rows = max_rows or float("inf")
  840. if num_trials > max_rows:
  841. # TODO(ujvl): suggestion for users to view more rows.
  842. trials_by_state_trunc = _fair_filter_trials(
  843. trials_by_state, max_rows, sort_by_metric
  844. )
  845. trials = []
  846. overflow_strs = []
  847. for state in state_tbl_order:
  848. if state not in trials_by_state:
  849. continue
  850. trials += trials_by_state_trunc[state]
  851. num = len(trials_by_state[state]) - len(trials_by_state_trunc[state])
  852. if num > 0:
  853. overflow_strs.append("{} {}".format(num, state))
  854. # Build overflow string.
  855. overflow = num_trials - max_rows
  856. overflow_str = ", ".join(overflow_strs)
  857. else:
  858. overflow = False
  859. overflow_str = ""
  860. trials = []
  861. for state in state_tbl_order:
  862. if state not in trials_by_state:
  863. continue
  864. trials += trials_by_state[state]
  865. # Pre-process trials to figure out what columns to show.
  866. if isinstance(metric_columns, Mapping):
  867. metric_keys = list(metric_columns.keys())
  868. else:
  869. metric_keys = metric_columns
  870. metric_keys = [
  871. k
  872. for k in metric_keys
  873. if any(
  874. unflattened_lookup(k, t.last_result, default=None) is not None
  875. for t in trials
  876. )
  877. ]
  878. if not parameter_columns:
  879. parameter_keys = sorted(set().union(*[t.evaluated_params for t in trials]))
  880. elif isinstance(parameter_columns, Mapping):
  881. parameter_keys = list(parameter_columns.keys())
  882. else:
  883. parameter_keys = parameter_columns
  884. # Build trial rows.
  885. trial_table = [
  886. _get_trial_info(
  887. trial, parameter_keys, metric_keys, max_column_length=max_column_length
  888. )
  889. for trial in trials
  890. ]
  891. # Format column headings
  892. if isinstance(metric_columns, Mapping):
  893. formatted_metric_columns = [
  894. _max_len(
  895. metric_columns[k], max_len=max_column_length, add_addr=False, wrap=True
  896. )
  897. for k in metric_keys
  898. ]
  899. else:
  900. formatted_metric_columns = [
  901. _max_len(k, max_len=max_column_length, add_addr=False, wrap=True)
  902. for k in metric_keys
  903. ]
  904. if isinstance(parameter_columns, Mapping):
  905. formatted_parameter_columns = [
  906. _max_len(
  907. parameter_columns[k],
  908. max_len=max_column_length,
  909. add_addr=False,
  910. wrap=True,
  911. )
  912. for k in parameter_keys
  913. ]
  914. else:
  915. formatted_parameter_columns = [
  916. _max_len(k, max_len=max_column_length, add_addr=False, wrap=True)
  917. for k in parameter_keys
  918. ]
  919. columns = (
  920. ["Trial name", "status", "loc"]
  921. + formatted_parameter_columns
  922. + formatted_metric_columns
  923. )
  924. return trial_table, columns, (overflow, overflow_str)
  925. def _trial_progress_table(
  926. trials: List[Trial],
  927. metric_columns: Union[List[str], Dict[str, str]],
  928. parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
  929. fmt: str = "psql",
  930. max_rows: Optional[int] = None,
  931. metric: Optional[str] = None,
  932. mode: Optional[str] = None,
  933. sort_by_metric: bool = False,
  934. max_column_length: int = 20,
  935. ) -> List[str]:
  936. """Generate a list of trial progress table messages.
  937. Args:
  938. trials: List of trials for which progress is to be shown.
  939. metric_columns: Metrics to be displayed in the table.
  940. parameter_columns: List of parameters to be included in the data
  941. fmt: Format of the table; passed to tabulate as the fmtstr argument
  942. max_rows: Maximum number of rows to show. If there's overflow, a
  943. message will be shown to the user indicating that some rows
  944. are not displayed
  945. metric: Metric which is being tuned
  946. mode: Sort the table in descenting order if mode is "max";
  947. ascending otherwise
  948. sort_by_metric: If true, the table will be sorted by the metric
  949. max_column_length: Max number of characters in each column
  950. Returns:
  951. Messages to be shown to the user containing progress tables
  952. """
  953. data, columns, (overflow, overflow_str) = _get_progress_table_data(
  954. trials,
  955. metric_columns,
  956. parameter_columns,
  957. max_rows,
  958. metric,
  959. mode,
  960. sort_by_metric,
  961. max_column_length,
  962. )
  963. messages = [tabulate(data, headers=columns, tablefmt=fmt, showindex=False)]
  964. if overflow:
  965. messages.append(f"... {overflow} more trials not shown ({overflow_str})")
  966. return messages
  967. def _generate_sys_info_str(*sys_info) -> str:
  968. """Format system info into a string.
  969. *sys_info: System info strings to be included.
  970. Returns:
  971. Formatted string containing system information.
  972. """
  973. if sys_info:
  974. return "<br>".join(sys_info).replace("\n", "<br>")
  975. return ""
  976. def _trial_errors_str(
  977. trials: List[Trial], fmt: str = "psql", max_rows: Optional[int] = None
  978. ):
  979. """Returns a readable message regarding trial errors.
  980. Args:
  981. trials: List of trials to get progress string for.
  982. fmt: Output format (see tablefmt in tabulate API).
  983. max_rows: Maximum number of rows in the error table. Defaults to
  984. unlimited.
  985. """
  986. messages = []
  987. failed = [t for t in trials if t.error_file]
  988. num_failed = len(failed)
  989. if num_failed > 0:
  990. messages.append("Number of errored trials: {}".format(num_failed))
  991. if num_failed > (max_rows or float("inf")):
  992. messages.append(
  993. "Table truncated to {} rows ({} overflow)".format(
  994. max_rows, num_failed - max_rows
  995. )
  996. )
  997. fail_header = ["Trial name", "# failures", "error file"]
  998. fail_table_data = [
  999. [
  1000. str(trial),
  1001. str(trial.run_metadata.num_failures)
  1002. + ("" if trial.status == Trial.ERROR else "*"),
  1003. trial.error_file,
  1004. ]
  1005. for trial in failed[:max_rows]
  1006. ]
  1007. messages.append(
  1008. tabulate(
  1009. fail_table_data,
  1010. headers=fail_header,
  1011. tablefmt=fmt,
  1012. showindex=False,
  1013. colalign=("left", "right", "left"),
  1014. )
  1015. )
  1016. if any(trial.status == Trial.TERMINATED for trial in failed[:max_rows]):
  1017. messages.append("* The trial terminated successfully after retrying.")
  1018. delim = "<br>" if fmt == "html" else "\n"
  1019. return delim.join(messages)
  1020. def _best_trial_str(
  1021. trial: Trial,
  1022. metric: str,
  1023. parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
  1024. ):
  1025. """Returns a readable message stating the current best trial."""
  1026. val = unflattened_lookup(metric, trial.last_result, default=None)
  1027. config = trial.last_result.get("config", {})
  1028. parameter_columns = parameter_columns or list(config.keys())
  1029. if isinstance(parameter_columns, Mapping):
  1030. parameter_columns = parameter_columns.keys()
  1031. params = {p: unflattened_lookup(p, config) for p in parameter_columns}
  1032. return (
  1033. f"Current best trial: {trial.trial_id} with {metric}={val} and "
  1034. f"parameters={params}"
  1035. )
  1036. def _fair_filter_trials(
  1037. trials_by_state: Dict[str, List[Trial]],
  1038. max_trials: int,
  1039. sort_by_metric: bool = False,
  1040. ):
  1041. """Filters trials such that each state is represented fairly.
  1042. The oldest trials are truncated if necessary.
  1043. Args:
  1044. trials_by_state: Maximum number of trials to return.
  1045. Returns:
  1046. Dict mapping state to List of fairly represented trials.
  1047. """
  1048. num_trials_by_state = collections.defaultdict(int)
  1049. no_change = False
  1050. # Determine number of trials to keep per state.
  1051. while max_trials > 0 and not no_change:
  1052. no_change = True
  1053. for state in sorted(trials_by_state):
  1054. if num_trials_by_state[state] < len(trials_by_state[state]):
  1055. no_change = False
  1056. max_trials -= 1
  1057. num_trials_by_state[state] += 1
  1058. # Sort by start time, descending if the trails is not sorted by metric.
  1059. sorted_trials_by_state = dict()
  1060. for state in sorted(trials_by_state):
  1061. if state == Trial.TERMINATED and sort_by_metric:
  1062. sorted_trials_by_state[state] = trials_by_state[state]
  1063. else:
  1064. sorted_trials_by_state[state] = sorted(
  1065. trials_by_state[state], reverse=False, key=lambda t: t.trial_id
  1066. )
  1067. # Truncate oldest trials.
  1068. filtered_trials = {
  1069. state: sorted_trials_by_state[state][: num_trials_by_state[state]]
  1070. for state in sorted(trials_by_state)
  1071. }
  1072. return filtered_trials
  1073. def _get_trial_location(trial: Trial, result: dict) -> _Location:
  1074. # we get the location from the result, as the one in trial will be
  1075. # reset when trial terminates
  1076. node_ip, pid = result.get(NODE_IP, None), result.get(PID, None)
  1077. if node_ip and pid:
  1078. location = _Location(node_ip, pid)
  1079. else:
  1080. # fallback to trial location if there hasn't been a report yet
  1081. location = trial.temporary_state.location
  1082. return location
  1083. def _get_trial_info(
  1084. trial: Trial, parameters: List[str], metrics: List[str], max_column_length: int = 20
  1085. ):
  1086. """Returns the following information about a trial:
  1087. name | status | loc | params... | metrics...
  1088. Args:
  1089. trial: Trial to get information for.
  1090. parameters: Names of trial parameters to include.
  1091. metrics: Names of metrics to include.
  1092. max_column_length: Maximum column length (in characters).
  1093. """
  1094. result = trial.last_result
  1095. config = trial.config
  1096. location = _get_trial_location(trial, result)
  1097. trial_info = [str(trial), trial.status, str(location)]
  1098. trial_info += [
  1099. _max_len(
  1100. unflattened_lookup(param, config, default=None),
  1101. max_len=max_column_length,
  1102. add_addr=True,
  1103. )
  1104. for param in parameters
  1105. ]
  1106. trial_info += [
  1107. _max_len(
  1108. unflattened_lookup(metric, result, default=None),
  1109. max_len=max_column_length,
  1110. add_addr=True,
  1111. )
  1112. for metric in metrics
  1113. ]
  1114. return trial_info
  1115. @DeveloperAPI
  1116. class TrialProgressCallback(Callback):
  1117. """Reports (prints) intermediate trial progress.
  1118. This callback is automatically added to the callback stack. When a
  1119. result is obtained, this callback will print the results according to
  1120. the specified verbosity level.
  1121. For ``Verbosity.V3_TRIAL_DETAILS``, a full result list is printed.
  1122. For ``Verbosity.V2_TRIAL_NORM``, only one line is printed per received
  1123. result.
  1124. All other verbosity levels do not print intermediate trial progress.
  1125. Result printing is throttled on a per-trial basis. Per default, results are
  1126. printed only once every 30 seconds. Results are always printed when a trial
  1127. finished or errored.
  1128. """
  1129. def __init__(
  1130. self, metric: Optional[str] = None, progress_metrics: Optional[List[str]] = None
  1131. ):
  1132. self._last_print = collections.defaultdict(float)
  1133. self._last_print_iteration = collections.defaultdict(int)
  1134. self._completed_trials = set()
  1135. self._last_result_str = {}
  1136. self._metric = metric
  1137. self._progress_metrics = set(progress_metrics or [])
  1138. # Only use progress metrics if at least two metrics are in there
  1139. if self._metric and self._progress_metrics:
  1140. self._progress_metrics.add(self._metric)
  1141. self._last_result = {}
  1142. self._display_handle = None
  1143. def _print(self, msg: str):
  1144. safe_print(msg)
  1145. def on_trial_result(
  1146. self,
  1147. iteration: int,
  1148. trials: List["Trial"],
  1149. trial: "Trial",
  1150. result: Dict,
  1151. **info,
  1152. ):
  1153. self.log_result(trial, result, error=False)
  1154. def on_trial_error(
  1155. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  1156. ):
  1157. self.log_result(trial, trial.last_result, error=True)
  1158. def on_trial_complete(
  1159. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  1160. ):
  1161. # Only log when we never logged that a trial was completed
  1162. if trial not in self._completed_trials:
  1163. self._completed_trials.add(trial)
  1164. print_result_str = self._print_result(trial.last_result)
  1165. last_result_str = self._last_result_str.get(trial, "")
  1166. # If this is a new result, print full result string
  1167. if print_result_str != last_result_str:
  1168. self.log_result(trial, trial.last_result, error=False)
  1169. else:
  1170. self._print(f"Trial {trial} completed. Last result: {print_result_str}")
  1171. def log_result(self, trial: "Trial", result: Dict, error: bool = False):
  1172. done = result.get("done", False) is True
  1173. last_print = self._last_print[trial]
  1174. should_print = done or error or time.time() - last_print > DEBUG_PRINT_INTERVAL
  1175. if done and trial not in self._completed_trials:
  1176. self._completed_trials.add(trial)
  1177. if should_print:
  1178. if IS_NOTEBOOK:
  1179. self.display_result(trial, result, error, done)
  1180. else:
  1181. self.print_result(trial, result, error, done)
  1182. self._last_print[trial] = time.time()
  1183. if TRAINING_ITERATION in result:
  1184. self._last_print_iteration[trial] = result[TRAINING_ITERATION]
  1185. def print_result(self, trial: Trial, result: Dict, error: bool, done: bool):
  1186. """Print the most recent results for the given trial to stdout.
  1187. Args:
  1188. trial: Trial for which results are to be printed
  1189. result: Result to be printed
  1190. error: True if an error has occurred, False otherwise
  1191. done: True if the trial is finished, False otherwise
  1192. """
  1193. last_print_iteration = self._last_print_iteration[trial]
  1194. if has_verbosity(Verbosity.V3_TRIAL_DETAILS):
  1195. if result.get(TRAINING_ITERATION) != last_print_iteration:
  1196. self._print(f"Result for {trial}:")
  1197. self._print(" {}".format(pretty_print(result).replace("\n", "\n ")))
  1198. if done:
  1199. self._print(f"Trial {trial} completed.")
  1200. elif has_verbosity(Verbosity.V2_TRIAL_NORM):
  1201. metric_name = self._metric or "_metric"
  1202. metric_value = result.get(metric_name, -99.0)
  1203. error_file = Path(trial.local_path, EXPR_ERROR_FILE).as_posix()
  1204. info = ""
  1205. if done:
  1206. info = " This trial completed."
  1207. print_result_str = self._print_result(result)
  1208. self._last_result_str[trial] = print_result_str
  1209. if error:
  1210. message = (
  1211. f"The trial {trial} errored with "
  1212. f"parameters={trial.config}. "
  1213. f"Error file: {error_file}"
  1214. )
  1215. elif self._metric:
  1216. message = (
  1217. f"Trial {trial} reported "
  1218. f"{metric_name}={metric_value:.2f} "
  1219. f"with parameters={trial.config}.{info}"
  1220. )
  1221. else:
  1222. message = (
  1223. f"Trial {trial} reported "
  1224. f"{print_result_str} "
  1225. f"with parameters={trial.config}.{info}"
  1226. )
  1227. self._print(message)
  1228. def generate_trial_table(
  1229. self, trials: Dict[Trial, Dict], columns: List[str]
  1230. ) -> str:
  1231. """Generate an HTML table of trial progress info.
  1232. Trials (rows) are sorted by name; progress stats (columns) are sorted
  1233. as well.
  1234. Args:
  1235. trials: Trials and their associated latest results
  1236. columns: Columns to show in the table; must be a list of valid
  1237. keys for each Trial result
  1238. Returns:
  1239. HTML template containing a rendered table of progress info
  1240. """
  1241. data = []
  1242. columns = sorted(columns)
  1243. sorted_trials = collections.OrderedDict(
  1244. sorted(self._last_result.items(), key=lambda item: str(item[0]))
  1245. )
  1246. for trial, result in sorted_trials.items():
  1247. data.append([str(trial)] + [result.get(col, "") for col in columns])
  1248. return Template("trial_progress.html.j2").render(
  1249. table=tabulate(
  1250. data, tablefmt="html", headers=["Trial name"] + columns, showindex=False
  1251. )
  1252. )
  1253. def display_result(self, trial: Trial, result: Dict, error: bool, done: bool):
  1254. """Display a formatted HTML table of trial progress results.
  1255. Trial progress is only shown if verbosity is set to level 2 or 3.
  1256. Args:
  1257. trial: Trial for which results are to be printed
  1258. result: Result to be printed
  1259. error: True if an error has occurred, False otherwise
  1260. done: True if the trial is finished, False otherwise
  1261. """
  1262. from IPython.display import HTML, display
  1263. self._last_result[trial] = result
  1264. if has_verbosity(Verbosity.V3_TRIAL_DETAILS):
  1265. ignored_keys = {
  1266. "config",
  1267. "hist_stats",
  1268. }
  1269. elif has_verbosity(Verbosity.V2_TRIAL_NORM):
  1270. ignored_keys = {
  1271. "config",
  1272. "hist_stats",
  1273. "trial_id",
  1274. "experiment_tag",
  1275. "done",
  1276. } | set(AUTO_RESULT_KEYS)
  1277. else:
  1278. return
  1279. table = self.generate_trial_table(
  1280. self._last_result, set(result.keys()) - ignored_keys
  1281. )
  1282. if not self._display_handle:
  1283. self._display_handle = display(HTML(table), display_id=True)
  1284. else:
  1285. self._display_handle.update(HTML(table))
  1286. def _print_result(self, result: Dict):
  1287. if self._progress_metrics:
  1288. # If progress metrics are given, only report these
  1289. flat_result = flatten_dict(result)
  1290. print_result = {}
  1291. for metric in self._progress_metrics:
  1292. print_result[metric] = flat_result.get(metric)
  1293. else:
  1294. # Else, skip auto populated results
  1295. print_result = result.copy()
  1296. for skip_result in SKIP_RESULTS_IN_REPORT:
  1297. print_result.pop(skip_result, None)
  1298. for auto_result in AUTO_RESULT_KEYS:
  1299. print_result.pop(auto_result, None)
  1300. print_result_str = ",".join(
  1301. [f"{k}={v}" for k, v in print_result.items() if v is not None]
  1302. )
  1303. return print_result_str
  1304. def _detect_reporter(_trainer_api: bool = False, **kwargs) -> TuneReporterBase:
  1305. """Detect progress reporter class.
  1306. Will return a :class:`JupyterNotebookReporter` if a IPython/Jupyter-like
  1307. session was detected, and a :class:`CLIReporter` otherwise.
  1308. Keyword arguments are passed on to the reporter class.
  1309. """
  1310. if IS_NOTEBOOK and not _trainer_api:
  1311. kwargs.setdefault("overwrite", not has_verbosity(Verbosity.V2_TRIAL_NORM))
  1312. progress_reporter = JupyterNotebookReporter(**kwargs)
  1313. else:
  1314. progress_reporter = CLIReporter(**kwargs)
  1315. return progress_reporter
  1316. def _detect_progress_metrics(
  1317. trainable: Optional[Union["Trainable", Callable]]
  1318. ) -> Optional[Collection[str]]:
  1319. """Detect progress metrics to report."""
  1320. if not trainable:
  1321. return None
  1322. return getattr(trainable, "_progress_metrics", None)
  1323. def _prepare_progress_reporter_for_ray_client(
  1324. progress_reporter: ProgressReporter,
  1325. verbosity: Union[int, Verbosity],
  1326. string_queue: Optional[Queue] = None,
  1327. ) -> Tuple[ProgressReporter, Queue]:
  1328. """Prepares progress reported for Ray Client by setting the string queue.
  1329. The string queue will be created if it's None."""
  1330. set_verbosity(verbosity)
  1331. progress_reporter = progress_reporter or _detect_reporter()
  1332. # JupyterNotebooks don't work with remote tune runs out of the box
  1333. # (e.g. via Ray client) as they don't have access to the main
  1334. # process stdout. So we introduce a queue here that accepts
  1335. # strings, which will then be displayed on the driver side.
  1336. if isinstance(progress_reporter, RemoteReporterMixin):
  1337. if string_queue is None:
  1338. string_queue = Queue(
  1339. actor_options={"num_cpus": 0, **_force_on_current_node(None)}
  1340. )
  1341. progress_reporter.output_queue = string_queue
  1342. return progress_reporter, string_queue
  1343. def _stream_client_output(
  1344. remote_future: ray.ObjectRef,
  1345. progress_reporter: ProgressReporter,
  1346. string_queue: Queue,
  1347. ) -> Any:
  1348. """
  1349. Stream items from string queue to progress_reporter until remote_future resolves
  1350. """
  1351. if string_queue is None:
  1352. return
  1353. def get_next_queue_item():
  1354. try:
  1355. return string_queue.get(block=False)
  1356. except Empty:
  1357. return None
  1358. def _handle_string_queue():
  1359. string_item = get_next_queue_item()
  1360. while string_item is not None:
  1361. # This happens on the driver side
  1362. progress_reporter.display(string_item)
  1363. string_item = get_next_queue_item()
  1364. # ray.wait(...)[1] returns futures that are not ready, yet
  1365. while ray.wait([remote_future], timeout=0.2)[1]:
  1366. # Check if we have items to execute
  1367. _handle_string_queue()
  1368. # Handle queue one last time
  1369. _handle_string_queue()