output.py 32 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043
  1. import argparse
  2. import collections
  3. import datetime
  4. import logging
  5. import math
  6. import numbers
  7. import os
  8. import sys
  9. import textwrap
  10. import time
  11. from dataclasses import dataclass
  12. from enum import IntEnum
  13. from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union
  14. import numpy as np
  15. import pandas as pd
  16. import ray
  17. from ray._private.dict import flatten_dict, unflattened_lookup
  18. from ray._private.thirdparty.tabulate.tabulate import (
  19. DataRow,
  20. Line,
  21. TableFormat,
  22. tabulate,
  23. )
  24. from ray.air._internal.usage import AirEntrypoint
  25. from ray.air.constants import TRAINING_ITERATION
  26. from ray.tune import Checkpoint
  27. from ray.tune.callback import Callback
  28. from ray.tune.experiment.trial import Trial
  29. from ray.tune.result import (
  30. AUTO_RESULT_KEYS,
  31. EPISODE_REWARD_MEAN,
  32. MEAN_ACCURACY,
  33. MEAN_LOSS,
  34. TIME_TOTAL_S,
  35. TIMESTEPS_TOTAL,
  36. )
  37. from ray.tune.search.sample import Domain
  38. from ray.tune.utils.log import Verbosity
  39. try:
  40. import rich
  41. import rich.layout
  42. import rich.live
  43. except ImportError:
  44. rich = None
  45. logger = logging.getLogger(__name__)
  46. # defines the mapping of the key in result and the key to be printed in table.
  47. # Note this is ordered!
  48. DEFAULT_COLUMNS = collections.OrderedDict(
  49. {
  50. MEAN_ACCURACY: "acc",
  51. MEAN_LOSS: "loss",
  52. TRAINING_ITERATION: "iter",
  53. TIME_TOTAL_S: "total time (s)",
  54. TIMESTEPS_TOTAL: "ts",
  55. EPISODE_REWARD_MEAN: "reward",
  56. }
  57. )
  58. # These keys are blacklisted for printing out training/tuning intermediate/final result!
  59. BLACKLISTED_KEYS = {
  60. "config",
  61. "date",
  62. "done",
  63. "hostname",
  64. "iterations_since_restore",
  65. "node_ip",
  66. "pid",
  67. "time_since_restore",
  68. "timestamp",
  69. "trial_id",
  70. "experiment_tag",
  71. "should_checkpoint",
  72. "_report_on", # LIGHTNING_REPORT_STAGE_KEY
  73. }
  74. VALID_SUMMARY_TYPES = {
  75. int,
  76. float,
  77. np.float32,
  78. np.float64,
  79. np.int32,
  80. np.int64,
  81. type(None),
  82. }
  83. # The order of summarizing trials.
  84. ORDER = [
  85. Trial.RUNNING,
  86. Trial.TERMINATED,
  87. Trial.PAUSED,
  88. Trial.PENDING,
  89. Trial.ERROR,
  90. ]
  91. class AirVerbosity(IntEnum):
  92. SILENT = 0
  93. DEFAULT = 1
  94. VERBOSE = 2
  95. def __repr__(self):
  96. return str(self.value)
  97. IS_NOTEBOOK = ray.widgets.util.in_notebook()
  98. def get_air_verbosity(
  99. verbose: Union[int, AirVerbosity, Verbosity]
  100. ) -> Optional[AirVerbosity]:
  101. if os.environ.get("RAY_AIR_NEW_OUTPUT", "1") == "0":
  102. return None
  103. if isinstance(verbose, AirVerbosity):
  104. return verbose
  105. verbose_int = verbose if isinstance(verbose, int) else verbose.value
  106. # Verbosity 2 and 3 both map to AirVerbosity 2
  107. verbose_int = min(2, verbose_int)
  108. return AirVerbosity(verbose_int)
  109. def _infer_params(config: Dict[str, Any]) -> List[str]:
  110. params = []
  111. flat_config = flatten_dict(config)
  112. for key, val in flat_config.items():
  113. if isinstance(val, Domain):
  114. params.append(key)
  115. # Grid search is a special named field. Because we flattened
  116. # the whole config, we look it up per string
  117. if key.endswith("/grid_search"):
  118. # Truncate `/grid_search`
  119. params.append(key[:-12])
  120. return params
  121. def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]:
  122. """Get strings representing the current and elapsed time.
  123. Args:
  124. start_time: POSIX timestamp of the start of the tune run
  125. current_time: POSIX timestamp giving the current time
  126. Returns:
  127. Current time and elapsed time for the current run
  128. """
  129. current_time_dt = datetime.datetime.fromtimestamp(current_time)
  130. start_time_dt = datetime.datetime.fromtimestamp(start_time)
  131. delta: datetime.timedelta = current_time_dt - start_time_dt
  132. rest = delta.total_seconds()
  133. days = int(rest // (60 * 60 * 24))
  134. rest -= days * (60 * 60 * 24)
  135. hours = int(rest // (60 * 60))
  136. rest -= hours * (60 * 60)
  137. minutes = int(rest // 60)
  138. seconds = int(rest - minutes * 60)
  139. running_for_str = ""
  140. if days > 0:
  141. running_for_str += f"{days:d}d "
  142. if hours > 0 or running_for_str:
  143. running_for_str += f"{hours:d}hr "
  144. if minutes > 0 or running_for_str:
  145. running_for_str += f"{minutes:d}min "
  146. running_for_str += f"{seconds:d}s"
  147. return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str
  148. def _get_trials_by_state(trials: List[Trial]) -> Dict[str, List[Trial]]:
  149. trials_by_state = collections.defaultdict(list)
  150. for t in trials:
  151. trials_by_state[t.status].append(t)
  152. return trials_by_state
  153. def _get_trials_with_error(trials: List[Trial]) -> List[Trial]:
  154. return [t for t in trials if t.error_file]
  155. def _infer_user_metrics(trials: List[Trial], limit: int = 4) -> List[str]:
  156. """Try to infer the metrics to print out.
  157. By default, only the first 4 meaningful metrics in `last_result` will be
  158. inferred as user implied metrics.
  159. """
  160. # Using OrderedDict for OrderedSet.
  161. result = collections.OrderedDict()
  162. for t in trials:
  163. if not t.last_result:
  164. continue
  165. for metric, value in t.last_result.items():
  166. if metric not in DEFAULT_COLUMNS:
  167. if metric not in AUTO_RESULT_KEYS:
  168. if type(value) in VALID_SUMMARY_TYPES:
  169. result[metric] = "" # not important
  170. if len(result) >= limit:
  171. return list(result.keys())
  172. return list(result.keys())
  173. def _current_best_trial(
  174. trials: List[Trial], metric: Optional[str], mode: Optional[str]
  175. ) -> Tuple[Optional[Trial], Optional[str]]:
  176. """
  177. Returns the best trial and the metric key. If anything is empty or None,
  178. returns a trivial result of None, None.
  179. Args:
  180. trials: List of trials.
  181. metric: Metric that trials are being ranked.
  182. mode: One of "min" or "max".
  183. Returns:
  184. Best trial and the metric key.
  185. """
  186. if not trials or not metric or not mode:
  187. return None, None
  188. metric_op = 1.0 if mode == "max" else -1.0
  189. best_metric = float("-inf")
  190. best_trial = None
  191. for t in trials:
  192. if not t.last_result:
  193. continue
  194. metric_value = unflattened_lookup(metric, t.last_result, default=None)
  195. if pd.isnull(metric_value):
  196. continue
  197. if not best_trial or metric_value * metric_op > best_metric:
  198. best_metric = metric_value * metric_op
  199. best_trial = t
  200. return best_trial, metric
  201. @dataclass
  202. class _PerStatusTrialTableData:
  203. trial_infos: List[List[str]]
  204. more_info: str
  205. @dataclass
  206. class _TrialTableData:
  207. header: List[str]
  208. data: List[_PerStatusTrialTableData]
  209. def _max_len(value: Any, max_len: int = 20, wrap: bool = False) -> Any:
  210. """Abbreviate a string representation of an object to `max_len` characters.
  211. For numbers, booleans and None, the original value will be returned for
  212. correct rendering in the table formatting tool.
  213. Args:
  214. value: Object to be represented as a string.
  215. max_len: Maximum return string length.
  216. """
  217. if value is None or isinstance(value, (int, float, numbers.Number, bool)):
  218. return value
  219. string = str(value)
  220. if len(string) <= max_len:
  221. return string
  222. if wrap:
  223. # Maximum two rows.
  224. # Todo: Make this configurable in the refactor
  225. if len(value) > max_len * 2:
  226. value = "..." + string[(3 - (max_len * 2)) :]
  227. wrapped = textwrap.wrap(value, width=max_len)
  228. return "\n".join(wrapped)
  229. result = "..." + string[(3 - max_len) :]
  230. return result
  231. def _get_trial_info(
  232. trial: Trial, param_keys: List[str], metric_keys: List[str]
  233. ) -> List[str]:
  234. """Returns the following information about a trial:
  235. name | status | metrics...
  236. Args:
  237. trial: Trial to get information for.
  238. param_keys: Names of parameters to include.
  239. metric_keys: Names of metrics to include.
  240. """
  241. result = trial.last_result
  242. trial_info = [str(trial), trial.status]
  243. # params
  244. trial_info.extend(
  245. [
  246. _max_len(
  247. unflattened_lookup(param, trial.config, default=None),
  248. )
  249. for param in param_keys
  250. ]
  251. )
  252. # metrics
  253. trial_info.extend(
  254. [
  255. _max_len(
  256. unflattened_lookup(metric, result, default=None),
  257. )
  258. for metric in metric_keys
  259. ]
  260. )
  261. return trial_info
  262. def _get_trial_table_data_per_status(
  263. status: str,
  264. trials: List[Trial],
  265. param_keys: List[str],
  266. metric_keys: List[str],
  267. force_max_rows: bool = False,
  268. ) -> Optional[_PerStatusTrialTableData]:
  269. """Gather all information of trials pertained to one `status`.
  270. Args:
  271. status: The trial status of interest.
  272. trials: all the trials of that status.
  273. param_keys: *Ordered* list of parameters to be displayed in the table.
  274. metric_keys: *Ordered* list of metrics to be displayed in the table.
  275. Including both default and user defined.
  276. force_max_rows: Whether or not to enforce a max row number for this status.
  277. If True, only a max of `5` rows will be shown.
  278. Returns:
  279. All information of trials pertained to the `status`.
  280. """
  281. # TODO: configure it.
  282. max_row = 5 if force_max_rows else math.inf
  283. if not trials:
  284. return None
  285. trial_infos = list()
  286. more_info = None
  287. for t in trials:
  288. if len(trial_infos) >= max_row:
  289. remaining = len(trials) - max_row
  290. more_info = f"{remaining} more {status}"
  291. break
  292. trial_infos.append(_get_trial_info(t, param_keys, metric_keys))
  293. return _PerStatusTrialTableData(trial_infos, more_info)
  294. def _get_trial_table_data(
  295. trials: List[Trial],
  296. param_keys: List[str],
  297. metric_keys: List[str],
  298. all_rows: bool = False,
  299. wrap_headers: bool = False,
  300. ) -> _TrialTableData:
  301. """Generate a table showing the current progress of tuning trials.
  302. Args:
  303. trials: List of trials for which progress is to be shown.
  304. param_keys: Ordered list of parameters to be displayed in the table.
  305. metric_keys: Ordered list of metrics to be displayed in the table.
  306. Including both default and user defined.
  307. Will only be shown if at least one trial is having the key.
  308. all_rows: Force to show all rows.
  309. wrap_headers: If True, header columns can be wrapped with ``\n``.
  310. Returns:
  311. Trial table data, including header and trial table per each status.
  312. """
  313. # TODO: configure
  314. max_trial_num_to_show = 20
  315. max_column_length = 20
  316. trials_by_state = _get_trials_by_state(trials)
  317. # get the right metric to show.
  318. metric_keys = [
  319. k
  320. for k in metric_keys
  321. if any(
  322. unflattened_lookup(k, t.last_result, default=None) is not None
  323. for t in trials
  324. )
  325. ]
  326. # get header from metric keys
  327. formatted_metric_columns = [
  328. _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in metric_keys
  329. ]
  330. formatted_param_columns = [
  331. _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in param_keys
  332. ]
  333. metric_header = [
  334. DEFAULT_COLUMNS[metric] if metric in DEFAULT_COLUMNS else formatted
  335. for metric, formatted in zip(metric_keys, formatted_metric_columns)
  336. ]
  337. param_header = formatted_param_columns
  338. # Map to the abbreviated version if necessary.
  339. header = ["Trial name", "status"] + param_header + metric_header
  340. trial_data = list()
  341. for t_status in ORDER:
  342. trial_data_per_status = _get_trial_table_data_per_status(
  343. t_status,
  344. trials_by_state[t_status],
  345. param_keys=param_keys,
  346. metric_keys=metric_keys,
  347. force_max_rows=not all_rows and len(trials) > max_trial_num_to_show,
  348. )
  349. if trial_data_per_status:
  350. trial_data.append(trial_data_per_status)
  351. return _TrialTableData(header, trial_data)
  352. def _best_trial_str(
  353. trial: Trial,
  354. metric: str,
  355. ):
  356. """Returns a readable message stating the current best trial."""
  357. # returns something like
  358. # Current best trial: 18ae7_00005 with loss=0.5918508041056858 and params={'train_loop_config': {'lr': 0.059253447253394785}}. # noqa
  359. val = unflattened_lookup(metric, trial.last_result, default=None)
  360. config = trial.last_result.get("config", {})
  361. parameter_columns = list(config.keys())
  362. params = {p: unflattened_lookup(p, config) for p in parameter_columns}
  363. return (
  364. f"Current best trial: {trial.trial_id} with {metric}={val} and "
  365. f"params={params}"
  366. )
  367. def _render_table_item(
  368. key: str, item: Any, prefix: str = ""
  369. ) -> Iterable[Tuple[str, str]]:
  370. key = prefix + key
  371. if isinstance(item, argparse.Namespace):
  372. item = item.__dict__
  373. if isinstance(item, float):
  374. # tabulate does not work well with mixed-type columns, so we format
  375. # numbers ourselves.
  376. yield key, f"{item:.5f}".rstrip("0")
  377. elif isinstance(item, dict):
  378. flattened = flatten_dict(item)
  379. for k, v in sorted(flattened.items()):
  380. yield key + "/" + str(k), _max_len(v)
  381. else:
  382. yield key, _max_len(item, 20)
  383. def _get_dict_as_table_data(
  384. data: Dict,
  385. include: Optional[Collection] = None,
  386. exclude: Optional[Collection] = None,
  387. upper_keys: Optional[Collection] = None,
  388. ):
  389. """Get ``data`` dict as table rows.
  390. If specified, excluded keys are removed. Excluded keys can either be
  391. fully specified (e.g. ``foo/bar/baz``) or specify a top-level dictionary
  392. (e.g. ``foo``), but no intermediate levels (e.g. ``foo/bar``). If this is
  393. needed, we can revisit the logic at a later point.
  394. The same is true for included keys. If a top-level key is included (e.g. ``foo``)
  395. then all sub keys will be included, too, except if they are excluded.
  396. If keys are both excluded and included, exclusion takes precedence. Thus, if
  397. ``foo`` is excluded but ``foo/bar`` is included, it won't show up in the output.
  398. """
  399. include = include or set()
  400. exclude = exclude or set()
  401. upper_keys = upper_keys or set()
  402. upper = []
  403. lower = []
  404. for key, value in sorted(data.items()):
  405. # Exclude top-level keys
  406. if key in exclude:
  407. continue
  408. for k, v in _render_table_item(str(key), value):
  409. # k is now the full subkey, e.g. config/nested/key
  410. # We can exclude the full key
  411. if k in exclude:
  412. continue
  413. # If we specify includes, top-level includes should take precedence
  414. # (e.g. if `config` is in include, include config always).
  415. if include and key not in include and k not in include:
  416. continue
  417. if key in upper_keys:
  418. upper.append([k, v])
  419. else:
  420. lower.append([k, v])
  421. if not upper:
  422. return lower
  423. elif not lower:
  424. return upper
  425. else:
  426. return upper + lower
  427. if sys.stdout and sys.stdout.encoding and sys.stdout.encoding.startswith("utf"):
  428. # Copied/adjusted from tabulate
  429. AIR_TABULATE_TABLEFMT = TableFormat(
  430. lineabove=Line("╭", "─", "─", "╮"),
  431. linebelowheader=Line("├", "─", "─", "┤"),
  432. linebetweenrows=None,
  433. linebelow=Line("╰", "─", "─", "╯"),
  434. headerrow=DataRow("│", " ", "│"),
  435. datarow=DataRow("│", " ", "│"),
  436. padding=1,
  437. with_header_hide=None,
  438. )
  439. else:
  440. # For non-utf output, use ascii-compatible characters.
  441. # This prevents errors e.g. when legacy windows encoding is used.
  442. AIR_TABULATE_TABLEFMT = TableFormat(
  443. lineabove=Line("+", "-", "-", "+"),
  444. linebelowheader=Line("+", "-", "-", "+"),
  445. linebetweenrows=None,
  446. linebelow=Line("+", "-", "-", "+"),
  447. headerrow=DataRow("|", " ", "|"),
  448. datarow=DataRow("|", " ", "|"),
  449. padding=1,
  450. with_header_hide=None,
  451. )
  452. def _print_dict_as_table(
  453. data: Dict,
  454. header: Optional[str] = None,
  455. include: Optional[Collection[str]] = None,
  456. exclude: Optional[Collection[str]] = None,
  457. division: Optional[Collection[str]] = None,
  458. ):
  459. table_data = _get_dict_as_table_data(
  460. data=data, include=include, exclude=exclude, upper_keys=division
  461. )
  462. headers = [header, ""] if header else []
  463. if not table_data:
  464. return
  465. print(
  466. tabulate(
  467. table_data,
  468. headers=headers,
  469. colalign=("left", "right"),
  470. tablefmt=AIR_TABULATE_TABLEFMT,
  471. )
  472. )
  473. class ProgressReporter(Callback):
  474. """Periodically prints out status update."""
  475. # TODO: Make this configurable
  476. _heartbeat_freq = 30 # every 30 sec
  477. # to be updated by subclasses.
  478. _heartbeat_threshold = None
  479. _start_end_verbosity = None
  480. _intermediate_result_verbosity = None
  481. _addressing_tmpl = None
  482. def __init__(
  483. self,
  484. verbosity: AirVerbosity,
  485. progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
  486. ):
  487. """
  488. Args:
  489. verbosity: AirVerbosity level.
  490. """
  491. self._verbosity = verbosity
  492. self._start_time = time.time()
  493. self._last_heartbeat_time = float("-inf")
  494. self._start_time = time.time()
  495. self._progress_metrics = progress_metrics
  496. self._trial_last_printed_results = {}
  497. self._in_block = None
  498. @property
  499. def verbosity(self) -> AirVerbosity:
  500. return self._verbosity
  501. def setup(
  502. self,
  503. start_time: Optional[float] = None,
  504. **kwargs,
  505. ):
  506. self._start_time = start_time
  507. def _start_block(self, indicator: Any):
  508. if self._in_block != indicator:
  509. self._end_block()
  510. self._in_block = indicator
  511. def _end_block(self):
  512. if self._in_block:
  513. print("")
  514. self._in_block = None
  515. def on_experiment_end(self, trials: List["Trial"], **info):
  516. self._end_block()
  517. def experiment_started(
  518. self,
  519. experiment_name: str,
  520. experiment_path: str,
  521. searcher_str: str,
  522. scheduler_str: str,
  523. total_num_samples: int,
  524. tensorboard_path: Optional[str] = None,
  525. **kwargs,
  526. ):
  527. self._start_block("exp_start")
  528. print(f"\nView detailed results here: {experiment_path}")
  529. if tensorboard_path:
  530. print(
  531. f"To visualize your results with TensorBoard, run: "
  532. f"`tensorboard --logdir {tensorboard_path}`"
  533. )
  534. @property
  535. def _time_heartbeat_str(self):
  536. current_time_str, running_time_str = _get_time_str(
  537. self._start_time, time.time()
  538. )
  539. return (
  540. f"Current time: {current_time_str}. Total running time: " + running_time_str
  541. )
  542. def print_heartbeat(self, trials, *args, force: bool = False):
  543. if self._verbosity < self._heartbeat_threshold:
  544. return
  545. if force or time.time() - self._last_heartbeat_time >= self._heartbeat_freq:
  546. self._print_heartbeat(trials, *args, force=force)
  547. self._last_heartbeat_time = time.time()
  548. def _print_heartbeat(self, trials, *args, force: bool = False):
  549. raise NotImplementedError
  550. def _print_result(self, trial, result: Optional[Dict] = None, force: bool = False):
  551. """Only print result if a different result has been reported, or force=True"""
  552. result = result or trial.last_result
  553. last_result_iter = self._trial_last_printed_results.get(trial.trial_id, -1)
  554. this_iter = result.get(TRAINING_ITERATION, 0)
  555. if this_iter != last_result_iter or force:
  556. _print_dict_as_table(
  557. result,
  558. header=f"{self._addressing_tmpl.format(trial)} result",
  559. include=self._progress_metrics,
  560. exclude=BLACKLISTED_KEYS,
  561. division=AUTO_RESULT_KEYS,
  562. )
  563. self._trial_last_printed_results[trial.trial_id] = this_iter
  564. def _print_config(self, trial):
  565. _print_dict_as_table(
  566. trial.config, header=f"{self._addressing_tmpl.format(trial)} config"
  567. )
  568. def on_trial_result(
  569. self,
  570. iteration: int,
  571. trials: List[Trial],
  572. trial: Trial,
  573. result: Dict,
  574. **info,
  575. ):
  576. if self.verbosity < self._intermediate_result_verbosity:
  577. return
  578. self._start_block(f"trial_{trial}_result_{result[TRAINING_ITERATION]}")
  579. curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
  580. print(
  581. f"{self._addressing_tmpl.format(trial)} "
  582. f"finished iteration {result[TRAINING_ITERATION]} "
  583. f"at {curr_time_str}. Total running time: " + running_time_str
  584. )
  585. self._print_result(trial, result)
  586. def on_trial_complete(
  587. self, iteration: int, trials: List[Trial], trial: Trial, **info
  588. ):
  589. if self.verbosity < self._start_end_verbosity:
  590. return
  591. curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
  592. finished_iter = 0
  593. if trial.last_result and TRAINING_ITERATION in trial.last_result:
  594. finished_iter = trial.last_result[TRAINING_ITERATION]
  595. self._start_block(f"trial_{trial}_complete")
  596. print(
  597. f"{self._addressing_tmpl.format(trial)} "
  598. f"completed after {finished_iter} iterations "
  599. f"at {curr_time_str}. Total running time: " + running_time_str
  600. )
  601. self._print_result(trial)
  602. def on_trial_error(
  603. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  604. ):
  605. curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
  606. finished_iter = 0
  607. if trial.last_result and TRAINING_ITERATION in trial.last_result:
  608. finished_iter = trial.last_result[TRAINING_ITERATION]
  609. self._start_block(f"trial_{trial}_error")
  610. print(
  611. f"{self._addressing_tmpl.format(trial)} "
  612. f"errored after {finished_iter} iterations "
  613. f"at {curr_time_str}. Total running time: {running_time_str}\n"
  614. f"Error file: {trial.error_file}"
  615. )
  616. self._print_result(trial)
  617. def on_trial_recover(
  618. self, iteration: int, trials: List["Trial"], trial: "Trial", **info
  619. ):
  620. self.on_trial_error(iteration=iteration, trials=trials, trial=trial, **info)
  621. def on_checkpoint(
  622. self,
  623. iteration: int,
  624. trials: List[Trial],
  625. trial: Trial,
  626. checkpoint: Checkpoint,
  627. **info,
  628. ):
  629. if self._verbosity < self._intermediate_result_verbosity:
  630. return
  631. # don't think this is supposed to happen but just to be safe.
  632. saved_iter = "?"
  633. if trial.last_result and TRAINING_ITERATION in trial.last_result:
  634. saved_iter = trial.last_result[TRAINING_ITERATION]
  635. self._start_block(f"trial_{trial}_result_{saved_iter}")
  636. loc = f"({checkpoint.filesystem.type_name}){checkpoint.path}"
  637. print(
  638. f"{self._addressing_tmpl.format(trial)} "
  639. f"saved a checkpoint for iteration {saved_iter} "
  640. f"at: {loc}"
  641. )
  642. def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, **info):
  643. if self.verbosity < self._start_end_verbosity:
  644. return
  645. has_config = bool(trial.config)
  646. self._start_block(f"trial_{trial}_start")
  647. if has_config:
  648. print(
  649. f"{self._addressing_tmpl.format(trial)} " f"started with configuration:"
  650. )
  651. self._print_config(trial)
  652. else:
  653. print(
  654. f"{self._addressing_tmpl.format(trial)} "
  655. f"started without custom configuration."
  656. )
  657. def _detect_reporter(
  658. verbosity: AirVerbosity,
  659. num_samples: int,
  660. entrypoint: Optional[AirEntrypoint] = None,
  661. metric: Optional[str] = None,
  662. mode: Optional[str] = None,
  663. config: Optional[Dict] = None,
  664. progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
  665. ):
  666. if entrypoint in {
  667. AirEntrypoint.TUNE_RUN,
  668. AirEntrypoint.TUNE_RUN_EXPERIMENTS,
  669. AirEntrypoint.TUNER,
  670. }:
  671. reporter = TuneTerminalReporter(
  672. verbosity,
  673. num_samples=num_samples,
  674. metric=metric,
  675. mode=mode,
  676. config=config,
  677. progress_metrics=progress_metrics,
  678. )
  679. else:
  680. reporter = TrainReporter(verbosity, progress_metrics=progress_metrics)
  681. return reporter
  682. class TuneReporterBase(ProgressReporter):
  683. _heartbeat_threshold = AirVerbosity.DEFAULT
  684. _wrap_headers = False
  685. _intermediate_result_verbosity = AirVerbosity.VERBOSE
  686. _start_end_verbosity = AirVerbosity.DEFAULT
  687. _addressing_tmpl = "Trial {}"
  688. def __init__(
  689. self,
  690. verbosity: AirVerbosity,
  691. num_samples: int = 0,
  692. metric: Optional[str] = None,
  693. mode: Optional[str] = None,
  694. config: Optional[Dict] = None,
  695. progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
  696. ):
  697. self._num_samples = num_samples
  698. self._metric = metric
  699. self._mode = mode
  700. # will be populated when first result comes in.
  701. self._inferred_metric = None
  702. self._inferred_params = _infer_params(config or {})
  703. super(TuneReporterBase, self).__init__(
  704. verbosity=verbosity, progress_metrics=progress_metrics
  705. )
  706. def setup(
  707. self,
  708. start_time: Optional[float] = None,
  709. total_samples: Optional[int] = None,
  710. **kwargs,
  711. ):
  712. super().setup(start_time=start_time)
  713. self._num_samples = total_samples
  714. def _get_overall_trial_progress_str(self, trials):
  715. result = " | ".join(
  716. [
  717. f"{len(trials)} {status}"
  718. for status, trials in _get_trials_by_state(trials).items()
  719. ]
  720. )
  721. return f"Trial status: {result}"
  722. # TODO: Return a more structured type to share code with Jupyter flow.
  723. def _get_heartbeat(
  724. self, trials, *sys_args, force_full_output: bool = False
  725. ) -> Tuple[List[str], _TrialTableData]:
  726. result = list()
  727. # Trial status: 1 RUNNING | 7 PENDING
  728. result.append(self._get_overall_trial_progress_str(trials))
  729. # Current time: 2023-02-24 12:35:39 (running for 00:00:37.40)
  730. result.append(self._time_heartbeat_str)
  731. # Logical resource usage: 8.0/64 CPUs, 0/0 GPUs
  732. result.extend(sys_args)
  733. # Current best trial: TRIAL NAME, metrics: {...}, parameters: {...}
  734. current_best_trial, metric = _current_best_trial(
  735. trials, self._metric, self._mode
  736. )
  737. if current_best_trial:
  738. result.append(_best_trial_str(current_best_trial, metric))
  739. # Now populating the trial table data.
  740. if not self._inferred_metric:
  741. # try inferring again.
  742. self._inferred_metric = _infer_user_metrics(trials)
  743. all_metrics = list(DEFAULT_COLUMNS.keys()) + self._inferred_metric
  744. trial_table_data = _get_trial_table_data(
  745. trials,
  746. param_keys=self._inferred_params,
  747. metric_keys=all_metrics,
  748. all_rows=force_full_output,
  749. wrap_headers=self._wrap_headers,
  750. )
  751. return result, trial_table_data
  752. def _print_heartbeat(self, trials, *sys_args, force: bool = False):
  753. raise NotImplementedError
  754. class TuneTerminalReporter(TuneReporterBase):
  755. def experiment_started(
  756. self,
  757. experiment_name: str,
  758. experiment_path: str,
  759. searcher_str: str,
  760. scheduler_str: str,
  761. total_num_samples: int,
  762. tensorboard_path: Optional[str] = None,
  763. **kwargs,
  764. ):
  765. if total_num_samples > sys.maxsize:
  766. total_num_samples_str = "infinite"
  767. else:
  768. total_num_samples_str = str(total_num_samples)
  769. print(
  770. tabulate(
  771. [
  772. ["Search algorithm", searcher_str],
  773. ["Scheduler", scheduler_str],
  774. ["Number of trials", total_num_samples_str],
  775. ],
  776. headers=["Configuration for experiment", experiment_name],
  777. tablefmt=AIR_TABULATE_TABLEFMT,
  778. )
  779. )
  780. super().experiment_started(
  781. experiment_name=experiment_name,
  782. experiment_path=experiment_path,
  783. searcher_str=searcher_str,
  784. scheduler_str=scheduler_str,
  785. total_num_samples=total_num_samples,
  786. tensorboard_path=tensorboard_path,
  787. **kwargs,
  788. )
  789. def _print_heartbeat(self, trials, *sys_args, force: bool = False):
  790. if self._verbosity < self._heartbeat_threshold and not force:
  791. return
  792. heartbeat_strs, table_data = self._get_heartbeat(
  793. trials, *sys_args, force_full_output=force
  794. )
  795. self._start_block("heartbeat")
  796. for s in heartbeat_strs:
  797. print(s)
  798. # now print the table using Tabulate
  799. more_infos = []
  800. all_data = []
  801. fail_header = table_data.header
  802. for sub_table in table_data.data:
  803. all_data.extend(sub_table.trial_infos)
  804. if sub_table.more_info:
  805. more_infos.append(sub_table.more_info)
  806. print(
  807. tabulate(
  808. all_data,
  809. headers=fail_header,
  810. tablefmt=AIR_TABULATE_TABLEFMT,
  811. showindex=False,
  812. )
  813. )
  814. if more_infos:
  815. print(", ".join(more_infos))
  816. if not force:
  817. # Only print error table at end of training
  818. return
  819. trials_with_error = _get_trials_with_error(trials)
  820. if not trials_with_error:
  821. return
  822. self._start_block("status_errored")
  823. print(f"Number of errored trials: {len(trials_with_error)}")
  824. fail_header = ["Trial name", "# failures", "error file"]
  825. fail_table_data = [
  826. [
  827. str(trial),
  828. str(trial.run_metadata.num_failures)
  829. + ("" if trial.status == Trial.ERROR else "*"),
  830. trial.error_file,
  831. ]
  832. for trial in trials_with_error
  833. ]
  834. print(
  835. tabulate(
  836. fail_table_data,
  837. headers=fail_header,
  838. tablefmt=AIR_TABULATE_TABLEFMT,
  839. showindex=False,
  840. colalign=("left", "right", "left"),
  841. )
  842. )
  843. if any(trial.status == Trial.TERMINATED for trial in trials_with_error):
  844. print("* The trial terminated successfully after retrying.")
  845. class TrainReporter(ProgressReporter):
  846. # the minimal verbosity threshold at which heartbeat starts getting printed.
  847. _heartbeat_threshold = AirVerbosity.VERBOSE
  848. _intermediate_result_verbosity = AirVerbosity.DEFAULT
  849. _start_end_verbosity = AirVerbosity.DEFAULT
  850. _addressing_tmpl = "Training"
  851. def _get_heartbeat(self, trials: List[Trial], force_full_output: bool = False):
  852. # Training on iteration 1. Current time: 2023-03-22 15:29:25 (running for 00:00:03.24) # noqa
  853. if len(trials) == 0:
  854. return
  855. trial = trials[0]
  856. if trial.status != Trial.RUNNING:
  857. return " ".join(
  858. [f"Training is in {trial.status} status.", self._time_heartbeat_str]
  859. )
  860. if not trial.last_result or TRAINING_ITERATION not in trial.last_result:
  861. iter_num = 1
  862. else:
  863. iter_num = trial.last_result[TRAINING_ITERATION] + 1
  864. return " ".join(
  865. [f"Training on iteration {iter_num}.", self._time_heartbeat_str]
  866. )
  867. def _print_heartbeat(self, trials, *args, force: bool = False):
  868. print(self._get_heartbeat(trials, force_full_output=force))
  869. def on_trial_result(
  870. self,
  871. iteration: int,
  872. trials: List[Trial],
  873. trial: Trial,
  874. result: Dict,
  875. **info,
  876. ):
  877. self._last_heartbeat_time = time.time()
  878. super().on_trial_result(
  879. iteration=iteration, trials=trials, trial=trial, result=result, **info
  880. )