| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043 |
- import argparse
- import collections
- import datetime
- import logging
- import math
- import numbers
- import os
- import sys
- import textwrap
- import time
- from dataclasses import dataclass
- from enum import IntEnum
- from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union
- import numpy as np
- import pandas as pd
- import ray
- from ray._private.dict import flatten_dict, unflattened_lookup
- from ray._private.thirdparty.tabulate.tabulate import (
- DataRow,
- Line,
- TableFormat,
- tabulate,
- )
- from ray.air._internal.usage import AirEntrypoint
- from ray.air.constants import TRAINING_ITERATION
- from ray.tune import Checkpoint
- from ray.tune.callback import Callback
- from ray.tune.experiment.trial import Trial
- from ray.tune.result import (
- AUTO_RESULT_KEYS,
- EPISODE_REWARD_MEAN,
- MEAN_ACCURACY,
- MEAN_LOSS,
- TIME_TOTAL_S,
- TIMESTEPS_TOTAL,
- )
- from ray.tune.search.sample import Domain
- from ray.tune.utils.log import Verbosity
- try:
- import rich
- import rich.layout
- import rich.live
- except ImportError:
- rich = None
- logger = logging.getLogger(__name__)
- # defines the mapping of the key in result and the key to be printed in table.
- # Note this is ordered!
- DEFAULT_COLUMNS = collections.OrderedDict(
- {
- MEAN_ACCURACY: "acc",
- MEAN_LOSS: "loss",
- TRAINING_ITERATION: "iter",
- TIME_TOTAL_S: "total time (s)",
- TIMESTEPS_TOTAL: "ts",
- EPISODE_REWARD_MEAN: "reward",
- }
- )
- # These keys are blacklisted for printing out training/tuning intermediate/final result!
- BLACKLISTED_KEYS = {
- "config",
- "date",
- "done",
- "hostname",
- "iterations_since_restore",
- "node_ip",
- "pid",
- "time_since_restore",
- "timestamp",
- "trial_id",
- "experiment_tag",
- "should_checkpoint",
- "_report_on", # LIGHTNING_REPORT_STAGE_KEY
- }
- VALID_SUMMARY_TYPES = {
- int,
- float,
- np.float32,
- np.float64,
- np.int32,
- np.int64,
- type(None),
- }
- # The order of summarizing trials.
- ORDER = [
- Trial.RUNNING,
- Trial.TERMINATED,
- Trial.PAUSED,
- Trial.PENDING,
- Trial.ERROR,
- ]
- class AirVerbosity(IntEnum):
- SILENT = 0
- DEFAULT = 1
- VERBOSE = 2
- def __repr__(self):
- return str(self.value)
- IS_NOTEBOOK = ray.widgets.util.in_notebook()
- def get_air_verbosity(
- verbose: Union[int, AirVerbosity, Verbosity]
- ) -> Optional[AirVerbosity]:
- if os.environ.get("RAY_AIR_NEW_OUTPUT", "1") == "0":
- return None
- if isinstance(verbose, AirVerbosity):
- return verbose
- verbose_int = verbose if isinstance(verbose, int) else verbose.value
- # Verbosity 2 and 3 both map to AirVerbosity 2
- verbose_int = min(2, verbose_int)
- return AirVerbosity(verbose_int)
- def _infer_params(config: Dict[str, Any]) -> List[str]:
- params = []
- flat_config = flatten_dict(config)
- for key, val in flat_config.items():
- if isinstance(val, Domain):
- params.append(key)
- # Grid search is a special named field. Because we flattened
- # the whole config, we look it up per string
- if key.endswith("/grid_search"):
- # Truncate `/grid_search`
- params.append(key[:-12])
- return params
- def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]:
- """Get strings representing the current and elapsed time.
- Args:
- start_time: POSIX timestamp of the start of the tune run
- current_time: POSIX timestamp giving the current time
- Returns:
- Current time and elapsed time for the current run
- """
- current_time_dt = datetime.datetime.fromtimestamp(current_time)
- start_time_dt = datetime.datetime.fromtimestamp(start_time)
- delta: datetime.timedelta = current_time_dt - start_time_dt
- rest = delta.total_seconds()
- days = int(rest // (60 * 60 * 24))
- rest -= days * (60 * 60 * 24)
- hours = int(rest // (60 * 60))
- rest -= hours * (60 * 60)
- minutes = int(rest // 60)
- seconds = int(rest - minutes * 60)
- running_for_str = ""
- if days > 0:
- running_for_str += f"{days:d}d "
- if hours > 0 or running_for_str:
- running_for_str += f"{hours:d}hr "
- if minutes > 0 or running_for_str:
- running_for_str += f"{minutes:d}min "
- running_for_str += f"{seconds:d}s"
- return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str
- def _get_trials_by_state(trials: List[Trial]) -> Dict[str, List[Trial]]:
- trials_by_state = collections.defaultdict(list)
- for t in trials:
- trials_by_state[t.status].append(t)
- return trials_by_state
- def _get_trials_with_error(trials: List[Trial]) -> List[Trial]:
- return [t for t in trials if t.error_file]
- def _infer_user_metrics(trials: List[Trial], limit: int = 4) -> List[str]:
- """Try to infer the metrics to print out.
- By default, only the first 4 meaningful metrics in `last_result` will be
- inferred as user implied metrics.
- """
- # Using OrderedDict for OrderedSet.
- result = collections.OrderedDict()
- for t in trials:
- if not t.last_result:
- continue
- for metric, value in t.last_result.items():
- if metric not in DEFAULT_COLUMNS:
- if metric not in AUTO_RESULT_KEYS:
- if type(value) in VALID_SUMMARY_TYPES:
- result[metric] = "" # not important
- if len(result) >= limit:
- return list(result.keys())
- return list(result.keys())
- def _current_best_trial(
- trials: List[Trial], metric: Optional[str], mode: Optional[str]
- ) -> Tuple[Optional[Trial], Optional[str]]:
- """
- Returns the best trial and the metric key. If anything is empty or None,
- returns a trivial result of None, None.
- Args:
- trials: List of trials.
- metric: Metric that trials are being ranked.
- mode: One of "min" or "max".
- Returns:
- Best trial and the metric key.
- """
- if not trials or not metric or not mode:
- return None, None
- metric_op = 1.0 if mode == "max" else -1.0
- best_metric = float("-inf")
- best_trial = None
- for t in trials:
- if not t.last_result:
- continue
- metric_value = unflattened_lookup(metric, t.last_result, default=None)
- if pd.isnull(metric_value):
- continue
- if not best_trial or metric_value * metric_op > best_metric:
- best_metric = metric_value * metric_op
- best_trial = t
- return best_trial, metric
- @dataclass
- class _PerStatusTrialTableData:
- trial_infos: List[List[str]]
- more_info: str
- @dataclass
- class _TrialTableData:
- header: List[str]
- data: List[_PerStatusTrialTableData]
- def _max_len(value: Any, max_len: int = 20, wrap: bool = False) -> Any:
- """Abbreviate a string representation of an object to `max_len` characters.
- For numbers, booleans and None, the original value will be returned for
- correct rendering in the table formatting tool.
- Args:
- value: Object to be represented as a string.
- max_len: Maximum return string length.
- """
- if value is None or isinstance(value, (int, float, numbers.Number, bool)):
- return value
- string = str(value)
- if len(string) <= max_len:
- return string
- if wrap:
- # Maximum two rows.
- # Todo: Make this configurable in the refactor
- if len(value) > max_len * 2:
- value = "..." + string[(3 - (max_len * 2)) :]
- wrapped = textwrap.wrap(value, width=max_len)
- return "\n".join(wrapped)
- result = "..." + string[(3 - max_len) :]
- return result
- def _get_trial_info(
- trial: Trial, param_keys: List[str], metric_keys: List[str]
- ) -> List[str]:
- """Returns the following information about a trial:
- name | status | metrics...
- Args:
- trial: Trial to get information for.
- param_keys: Names of parameters to include.
- metric_keys: Names of metrics to include.
- """
- result = trial.last_result
- trial_info = [str(trial), trial.status]
- # params
- trial_info.extend(
- [
- _max_len(
- unflattened_lookup(param, trial.config, default=None),
- )
- for param in param_keys
- ]
- )
- # metrics
- trial_info.extend(
- [
- _max_len(
- unflattened_lookup(metric, result, default=None),
- )
- for metric in metric_keys
- ]
- )
- return trial_info
- def _get_trial_table_data_per_status(
- status: str,
- trials: List[Trial],
- param_keys: List[str],
- metric_keys: List[str],
- force_max_rows: bool = False,
- ) -> Optional[_PerStatusTrialTableData]:
- """Gather all information of trials pertained to one `status`.
- Args:
- status: The trial status of interest.
- trials: all the trials of that status.
- param_keys: *Ordered* list of parameters to be displayed in the table.
- metric_keys: *Ordered* list of metrics to be displayed in the table.
- Including both default and user defined.
- force_max_rows: Whether or not to enforce a max row number for this status.
- If True, only a max of `5` rows will be shown.
- Returns:
- All information of trials pertained to the `status`.
- """
- # TODO: configure it.
- max_row = 5 if force_max_rows else math.inf
- if not trials:
- return None
- trial_infos = list()
- more_info = None
- for t in trials:
- if len(trial_infos) >= max_row:
- remaining = len(trials) - max_row
- more_info = f"{remaining} more {status}"
- break
- trial_infos.append(_get_trial_info(t, param_keys, metric_keys))
- return _PerStatusTrialTableData(trial_infos, more_info)
- def _get_trial_table_data(
- trials: List[Trial],
- param_keys: List[str],
- metric_keys: List[str],
- all_rows: bool = False,
- wrap_headers: bool = False,
- ) -> _TrialTableData:
- """Generate a table showing the current progress of tuning trials.
- Args:
- trials: List of trials for which progress is to be shown.
- param_keys: Ordered list of parameters to be displayed in the table.
- metric_keys: Ordered list of metrics to be displayed in the table.
- Including both default and user defined.
- Will only be shown if at least one trial is having the key.
- all_rows: Force to show all rows.
- wrap_headers: If True, header columns can be wrapped with ``\n``.
- Returns:
- Trial table data, including header and trial table per each status.
- """
- # TODO: configure
- max_trial_num_to_show = 20
- max_column_length = 20
- trials_by_state = _get_trials_by_state(trials)
- # get the right metric to show.
- metric_keys = [
- k
- for k in metric_keys
- if any(
- unflattened_lookup(k, t.last_result, default=None) is not None
- for t in trials
- )
- ]
- # get header from metric keys
- formatted_metric_columns = [
- _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in metric_keys
- ]
- formatted_param_columns = [
- _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in param_keys
- ]
- metric_header = [
- DEFAULT_COLUMNS[metric] if metric in DEFAULT_COLUMNS else formatted
- for metric, formatted in zip(metric_keys, formatted_metric_columns)
- ]
- param_header = formatted_param_columns
- # Map to the abbreviated version if necessary.
- header = ["Trial name", "status"] + param_header + metric_header
- trial_data = list()
- for t_status in ORDER:
- trial_data_per_status = _get_trial_table_data_per_status(
- t_status,
- trials_by_state[t_status],
- param_keys=param_keys,
- metric_keys=metric_keys,
- force_max_rows=not all_rows and len(trials) > max_trial_num_to_show,
- )
- if trial_data_per_status:
- trial_data.append(trial_data_per_status)
- return _TrialTableData(header, trial_data)
- def _best_trial_str(
- trial: Trial,
- metric: str,
- ):
- """Returns a readable message stating the current best trial."""
- # returns something like
- # Current best trial: 18ae7_00005 with loss=0.5918508041056858 and params={'train_loop_config': {'lr': 0.059253447253394785}}. # noqa
- val = unflattened_lookup(metric, trial.last_result, default=None)
- config = trial.last_result.get("config", {})
- parameter_columns = list(config.keys())
- params = {p: unflattened_lookup(p, config) for p in parameter_columns}
- return (
- f"Current best trial: {trial.trial_id} with {metric}={val} and "
- f"params={params}"
- )
- def _render_table_item(
- key: str, item: Any, prefix: str = ""
- ) -> Iterable[Tuple[str, str]]:
- key = prefix + key
- if isinstance(item, argparse.Namespace):
- item = item.__dict__
- if isinstance(item, float):
- # tabulate does not work well with mixed-type columns, so we format
- # numbers ourselves.
- yield key, f"{item:.5f}".rstrip("0")
- elif isinstance(item, dict):
- flattened = flatten_dict(item)
- for k, v in sorted(flattened.items()):
- yield key + "/" + str(k), _max_len(v)
- else:
- yield key, _max_len(item, 20)
- def _get_dict_as_table_data(
- data: Dict,
- include: Optional[Collection] = None,
- exclude: Optional[Collection] = None,
- upper_keys: Optional[Collection] = None,
- ):
- """Get ``data`` dict as table rows.
- If specified, excluded keys are removed. Excluded keys can either be
- fully specified (e.g. ``foo/bar/baz``) or specify a top-level dictionary
- (e.g. ``foo``), but no intermediate levels (e.g. ``foo/bar``). If this is
- needed, we can revisit the logic at a later point.
- The same is true for included keys. If a top-level key is included (e.g. ``foo``)
- then all sub keys will be included, too, except if they are excluded.
- If keys are both excluded and included, exclusion takes precedence. Thus, if
- ``foo`` is excluded but ``foo/bar`` is included, it won't show up in the output.
- """
- include = include or set()
- exclude = exclude or set()
- upper_keys = upper_keys or set()
- upper = []
- lower = []
- for key, value in sorted(data.items()):
- # Exclude top-level keys
- if key in exclude:
- continue
- for k, v in _render_table_item(str(key), value):
- # k is now the full subkey, e.g. config/nested/key
- # We can exclude the full key
- if k in exclude:
- continue
- # If we specify includes, top-level includes should take precedence
- # (e.g. if `config` is in include, include config always).
- if include and key not in include and k not in include:
- continue
- if key in upper_keys:
- upper.append([k, v])
- else:
- lower.append([k, v])
- if not upper:
- return lower
- elif not lower:
- return upper
- else:
- return upper + lower
- if sys.stdout and sys.stdout.encoding and sys.stdout.encoding.startswith("utf"):
- # Copied/adjusted from tabulate
- AIR_TABULATE_TABLEFMT = TableFormat(
- lineabove=Line("╭", "─", "─", "╮"),
- linebelowheader=Line("├", "─", "─", "┤"),
- linebetweenrows=None,
- linebelow=Line("╰", "─", "─", "╯"),
- headerrow=DataRow("│", " ", "│"),
- datarow=DataRow("│", " ", "│"),
- padding=1,
- with_header_hide=None,
- )
- else:
- # For non-utf output, use ascii-compatible characters.
- # This prevents errors e.g. when legacy windows encoding is used.
- AIR_TABULATE_TABLEFMT = TableFormat(
- lineabove=Line("+", "-", "-", "+"),
- linebelowheader=Line("+", "-", "-", "+"),
- linebetweenrows=None,
- linebelow=Line("+", "-", "-", "+"),
- headerrow=DataRow("|", " ", "|"),
- datarow=DataRow("|", " ", "|"),
- padding=1,
- with_header_hide=None,
- )
- def _print_dict_as_table(
- data: Dict,
- header: Optional[str] = None,
- include: Optional[Collection[str]] = None,
- exclude: Optional[Collection[str]] = None,
- division: Optional[Collection[str]] = None,
- ):
- table_data = _get_dict_as_table_data(
- data=data, include=include, exclude=exclude, upper_keys=division
- )
- headers = [header, ""] if header else []
- if not table_data:
- return
- print(
- tabulate(
- table_data,
- headers=headers,
- colalign=("left", "right"),
- tablefmt=AIR_TABULATE_TABLEFMT,
- )
- )
- class ProgressReporter(Callback):
- """Periodically prints out status update."""
- # TODO: Make this configurable
- _heartbeat_freq = 30 # every 30 sec
- # to be updated by subclasses.
- _heartbeat_threshold = None
- _start_end_verbosity = None
- _intermediate_result_verbosity = None
- _addressing_tmpl = None
- def __init__(
- self,
- verbosity: AirVerbosity,
- progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
- ):
- """
- Args:
- verbosity: AirVerbosity level.
- """
- self._verbosity = verbosity
- self._start_time = time.time()
- self._last_heartbeat_time = float("-inf")
- self._start_time = time.time()
- self._progress_metrics = progress_metrics
- self._trial_last_printed_results = {}
- self._in_block = None
- @property
- def verbosity(self) -> AirVerbosity:
- return self._verbosity
- def setup(
- self,
- start_time: Optional[float] = None,
- **kwargs,
- ):
- self._start_time = start_time
- def _start_block(self, indicator: Any):
- if self._in_block != indicator:
- self._end_block()
- self._in_block = indicator
- def _end_block(self):
- if self._in_block:
- print("")
- self._in_block = None
- def on_experiment_end(self, trials: List["Trial"], **info):
- self._end_block()
- def experiment_started(
- self,
- experiment_name: str,
- experiment_path: str,
- searcher_str: str,
- scheduler_str: str,
- total_num_samples: int,
- tensorboard_path: Optional[str] = None,
- **kwargs,
- ):
- self._start_block("exp_start")
- print(f"\nView detailed results here: {experiment_path}")
- if tensorboard_path:
- print(
- f"To visualize your results with TensorBoard, run: "
- f"`tensorboard --logdir {tensorboard_path}`"
- )
- @property
- def _time_heartbeat_str(self):
- current_time_str, running_time_str = _get_time_str(
- self._start_time, time.time()
- )
- return (
- f"Current time: {current_time_str}. Total running time: " + running_time_str
- )
- def print_heartbeat(self, trials, *args, force: bool = False):
- if self._verbosity < self._heartbeat_threshold:
- return
- if force or time.time() - self._last_heartbeat_time >= self._heartbeat_freq:
- self._print_heartbeat(trials, *args, force=force)
- self._last_heartbeat_time = time.time()
- def _print_heartbeat(self, trials, *args, force: bool = False):
- raise NotImplementedError
- def _print_result(self, trial, result: Optional[Dict] = None, force: bool = False):
- """Only print result if a different result has been reported, or force=True"""
- result = result or trial.last_result
- last_result_iter = self._trial_last_printed_results.get(trial.trial_id, -1)
- this_iter = result.get(TRAINING_ITERATION, 0)
- if this_iter != last_result_iter or force:
- _print_dict_as_table(
- result,
- header=f"{self._addressing_tmpl.format(trial)} result",
- include=self._progress_metrics,
- exclude=BLACKLISTED_KEYS,
- division=AUTO_RESULT_KEYS,
- )
- self._trial_last_printed_results[trial.trial_id] = this_iter
- def _print_config(self, trial):
- _print_dict_as_table(
- trial.config, header=f"{self._addressing_tmpl.format(trial)} config"
- )
- def on_trial_result(
- self,
- iteration: int,
- trials: List[Trial],
- trial: Trial,
- result: Dict,
- **info,
- ):
- if self.verbosity < self._intermediate_result_verbosity:
- return
- self._start_block(f"trial_{trial}_result_{result[TRAINING_ITERATION]}")
- curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
- print(
- f"{self._addressing_tmpl.format(trial)} "
- f"finished iteration {result[TRAINING_ITERATION]} "
- f"at {curr_time_str}. Total running time: " + running_time_str
- )
- self._print_result(trial, result)
- def on_trial_complete(
- self, iteration: int, trials: List[Trial], trial: Trial, **info
- ):
- if self.verbosity < self._start_end_verbosity:
- return
- curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
- finished_iter = 0
- if trial.last_result and TRAINING_ITERATION in trial.last_result:
- finished_iter = trial.last_result[TRAINING_ITERATION]
- self._start_block(f"trial_{trial}_complete")
- print(
- f"{self._addressing_tmpl.format(trial)} "
- f"completed after {finished_iter} iterations "
- f"at {curr_time_str}. Total running time: " + running_time_str
- )
- self._print_result(trial)
- def on_trial_error(
- self, iteration: int, trials: List["Trial"], trial: "Trial", **info
- ):
- curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
- finished_iter = 0
- if trial.last_result and TRAINING_ITERATION in trial.last_result:
- finished_iter = trial.last_result[TRAINING_ITERATION]
- self._start_block(f"trial_{trial}_error")
- print(
- f"{self._addressing_tmpl.format(trial)} "
- f"errored after {finished_iter} iterations "
- f"at {curr_time_str}. Total running time: {running_time_str}\n"
- f"Error file: {trial.error_file}"
- )
- self._print_result(trial)
- def on_trial_recover(
- self, iteration: int, trials: List["Trial"], trial: "Trial", **info
- ):
- self.on_trial_error(iteration=iteration, trials=trials, trial=trial, **info)
- def on_checkpoint(
- self,
- iteration: int,
- trials: List[Trial],
- trial: Trial,
- checkpoint: Checkpoint,
- **info,
- ):
- if self._verbosity < self._intermediate_result_verbosity:
- return
- # don't think this is supposed to happen but just to be safe.
- saved_iter = "?"
- if trial.last_result and TRAINING_ITERATION in trial.last_result:
- saved_iter = trial.last_result[TRAINING_ITERATION]
- self._start_block(f"trial_{trial}_result_{saved_iter}")
- loc = f"({checkpoint.filesystem.type_name}){checkpoint.path}"
- print(
- f"{self._addressing_tmpl.format(trial)} "
- f"saved a checkpoint for iteration {saved_iter} "
- f"at: {loc}"
- )
- def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, **info):
- if self.verbosity < self._start_end_verbosity:
- return
- has_config = bool(trial.config)
- self._start_block(f"trial_{trial}_start")
- if has_config:
- print(
- f"{self._addressing_tmpl.format(trial)} " f"started with configuration:"
- )
- self._print_config(trial)
- else:
- print(
- f"{self._addressing_tmpl.format(trial)} "
- f"started without custom configuration."
- )
- def _detect_reporter(
- verbosity: AirVerbosity,
- num_samples: int,
- entrypoint: Optional[AirEntrypoint] = None,
- metric: Optional[str] = None,
- mode: Optional[str] = None,
- config: Optional[Dict] = None,
- progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
- ):
- if entrypoint in {
- AirEntrypoint.TUNE_RUN,
- AirEntrypoint.TUNE_RUN_EXPERIMENTS,
- AirEntrypoint.TUNER,
- }:
- reporter = TuneTerminalReporter(
- verbosity,
- num_samples=num_samples,
- metric=metric,
- mode=mode,
- config=config,
- progress_metrics=progress_metrics,
- )
- else:
- reporter = TrainReporter(verbosity, progress_metrics=progress_metrics)
- return reporter
- class TuneReporterBase(ProgressReporter):
- _heartbeat_threshold = AirVerbosity.DEFAULT
- _wrap_headers = False
- _intermediate_result_verbosity = AirVerbosity.VERBOSE
- _start_end_verbosity = AirVerbosity.DEFAULT
- _addressing_tmpl = "Trial {}"
- def __init__(
- self,
- verbosity: AirVerbosity,
- num_samples: int = 0,
- metric: Optional[str] = None,
- mode: Optional[str] = None,
- config: Optional[Dict] = None,
- progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
- ):
- self._num_samples = num_samples
- self._metric = metric
- self._mode = mode
- # will be populated when first result comes in.
- self._inferred_metric = None
- self._inferred_params = _infer_params(config or {})
- super(TuneReporterBase, self).__init__(
- verbosity=verbosity, progress_metrics=progress_metrics
- )
- def setup(
- self,
- start_time: Optional[float] = None,
- total_samples: Optional[int] = None,
- **kwargs,
- ):
- super().setup(start_time=start_time)
- self._num_samples = total_samples
- def _get_overall_trial_progress_str(self, trials):
- result = " | ".join(
- [
- f"{len(trials)} {status}"
- for status, trials in _get_trials_by_state(trials).items()
- ]
- )
- return f"Trial status: {result}"
- # TODO: Return a more structured type to share code with Jupyter flow.
- def _get_heartbeat(
- self, trials, *sys_args, force_full_output: bool = False
- ) -> Tuple[List[str], _TrialTableData]:
- result = list()
- # Trial status: 1 RUNNING | 7 PENDING
- result.append(self._get_overall_trial_progress_str(trials))
- # Current time: 2023-02-24 12:35:39 (running for 00:00:37.40)
- result.append(self._time_heartbeat_str)
- # Logical resource usage: 8.0/64 CPUs, 0/0 GPUs
- result.extend(sys_args)
- # Current best trial: TRIAL NAME, metrics: {...}, parameters: {...}
- current_best_trial, metric = _current_best_trial(
- trials, self._metric, self._mode
- )
- if current_best_trial:
- result.append(_best_trial_str(current_best_trial, metric))
- # Now populating the trial table data.
- if not self._inferred_metric:
- # try inferring again.
- self._inferred_metric = _infer_user_metrics(trials)
- all_metrics = list(DEFAULT_COLUMNS.keys()) + self._inferred_metric
- trial_table_data = _get_trial_table_data(
- trials,
- param_keys=self._inferred_params,
- metric_keys=all_metrics,
- all_rows=force_full_output,
- wrap_headers=self._wrap_headers,
- )
- return result, trial_table_data
- def _print_heartbeat(self, trials, *sys_args, force: bool = False):
- raise NotImplementedError
- class TuneTerminalReporter(TuneReporterBase):
- def experiment_started(
- self,
- experiment_name: str,
- experiment_path: str,
- searcher_str: str,
- scheduler_str: str,
- total_num_samples: int,
- tensorboard_path: Optional[str] = None,
- **kwargs,
- ):
- if total_num_samples > sys.maxsize:
- total_num_samples_str = "infinite"
- else:
- total_num_samples_str = str(total_num_samples)
- print(
- tabulate(
- [
- ["Search algorithm", searcher_str],
- ["Scheduler", scheduler_str],
- ["Number of trials", total_num_samples_str],
- ],
- headers=["Configuration for experiment", experiment_name],
- tablefmt=AIR_TABULATE_TABLEFMT,
- )
- )
- super().experiment_started(
- experiment_name=experiment_name,
- experiment_path=experiment_path,
- searcher_str=searcher_str,
- scheduler_str=scheduler_str,
- total_num_samples=total_num_samples,
- tensorboard_path=tensorboard_path,
- **kwargs,
- )
- def _print_heartbeat(self, trials, *sys_args, force: bool = False):
- if self._verbosity < self._heartbeat_threshold and not force:
- return
- heartbeat_strs, table_data = self._get_heartbeat(
- trials, *sys_args, force_full_output=force
- )
- self._start_block("heartbeat")
- for s in heartbeat_strs:
- print(s)
- # now print the table using Tabulate
- more_infos = []
- all_data = []
- fail_header = table_data.header
- for sub_table in table_data.data:
- all_data.extend(sub_table.trial_infos)
- if sub_table.more_info:
- more_infos.append(sub_table.more_info)
- print(
- tabulate(
- all_data,
- headers=fail_header,
- tablefmt=AIR_TABULATE_TABLEFMT,
- showindex=False,
- )
- )
- if more_infos:
- print(", ".join(more_infos))
- if not force:
- # Only print error table at end of training
- return
- trials_with_error = _get_trials_with_error(trials)
- if not trials_with_error:
- return
- self._start_block("status_errored")
- print(f"Number of errored trials: {len(trials_with_error)}")
- fail_header = ["Trial name", "# failures", "error file"]
- fail_table_data = [
- [
- str(trial),
- str(trial.run_metadata.num_failures)
- + ("" if trial.status == Trial.ERROR else "*"),
- trial.error_file,
- ]
- for trial in trials_with_error
- ]
- print(
- tabulate(
- fail_table_data,
- headers=fail_header,
- tablefmt=AIR_TABULATE_TABLEFMT,
- showindex=False,
- colalign=("left", "right", "left"),
- )
- )
- if any(trial.status == Trial.TERMINATED for trial in trials_with_error):
- print("* The trial terminated successfully after retrying.")
- class TrainReporter(ProgressReporter):
- # the minimal verbosity threshold at which heartbeat starts getting printed.
- _heartbeat_threshold = AirVerbosity.VERBOSE
- _intermediate_result_verbosity = AirVerbosity.DEFAULT
- _start_end_verbosity = AirVerbosity.DEFAULT
- _addressing_tmpl = "Training"
- def _get_heartbeat(self, trials: List[Trial], force_full_output: bool = False):
- # Training on iteration 1. Current time: 2023-03-22 15:29:25 (running for 00:00:03.24) # noqa
- if len(trials) == 0:
- return
- trial = trials[0]
- if trial.status != Trial.RUNNING:
- return " ".join(
- [f"Training is in {trial.status} status.", self._time_heartbeat_str]
- )
- if not trial.last_result or TRAINING_ITERATION not in trial.last_result:
- iter_num = 1
- else:
- iter_num = trial.last_result[TRAINING_ITERATION] + 1
- return " ".join(
- [f"Training on iteration {iter_num}.", self._time_heartbeat_str]
- )
- def _print_heartbeat(self, trials, *args, force: bool = False):
- print(self._get_heartbeat(trials, force_full_output=force))
- def on_trial_result(
- self,
- iteration: int,
- trials: List[Trial],
- trial: Trial,
- result: Dict,
- **info,
- ):
- self._last_heartbeat_time = time.time()
- super().on_trial_result(
- iteration=iteration, trials=trials, trial=trial, result=result, **info
- )
|