import argparse import collections import datetime import logging import math import numbers import os import sys import textwrap import time from dataclasses import dataclass from enum import IntEnum from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import pandas as pd import ray from ray._private.dict import flatten_dict, unflattened_lookup from ray._private.thirdparty.tabulate.tabulate import ( DataRow, Line, TableFormat, tabulate, ) from ray.air._internal.usage import AirEntrypoint from ray.air.constants import TRAINING_ITERATION from ray.tune import Checkpoint from ray.tune.callback import Callback from ray.tune.experiment.trial import Trial from ray.tune.result import ( AUTO_RESULT_KEYS, EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS, TIME_TOTAL_S, TIMESTEPS_TOTAL, ) from ray.tune.search.sample import Domain from ray.tune.utils.log import Verbosity try: import rich import rich.layout import rich.live except ImportError: rich = None logger = logging.getLogger(__name__) # defines the mapping of the key in result and the key to be printed in table. # Note this is ordered! DEFAULT_COLUMNS = collections.OrderedDict( { MEAN_ACCURACY: "acc", MEAN_LOSS: "loss", TRAINING_ITERATION: "iter", TIME_TOTAL_S: "total time (s)", TIMESTEPS_TOTAL: "ts", EPISODE_REWARD_MEAN: "reward", } ) # These keys are blacklisted for printing out training/tuning intermediate/final result! BLACKLISTED_KEYS = { "config", "date", "done", "hostname", "iterations_since_restore", "node_ip", "pid", "time_since_restore", "timestamp", "trial_id", "experiment_tag", "should_checkpoint", "_report_on", # LIGHTNING_REPORT_STAGE_KEY } VALID_SUMMARY_TYPES = { int, float, np.float32, np.float64, np.int32, np.int64, type(None), } # The order of summarizing trials. ORDER = [ Trial.RUNNING, Trial.TERMINATED, Trial.PAUSED, Trial.PENDING, Trial.ERROR, ] class AirVerbosity(IntEnum): SILENT = 0 DEFAULT = 1 VERBOSE = 2 def __repr__(self): return str(self.value) IS_NOTEBOOK = ray.widgets.util.in_notebook() def get_air_verbosity( verbose: Union[int, AirVerbosity, Verbosity] ) -> Optional[AirVerbosity]: if os.environ.get("RAY_AIR_NEW_OUTPUT", "1") == "0": return None if isinstance(verbose, AirVerbosity): return verbose verbose_int = verbose if isinstance(verbose, int) else verbose.value # Verbosity 2 and 3 both map to AirVerbosity 2 verbose_int = min(2, verbose_int) return AirVerbosity(verbose_int) def _infer_params(config: Dict[str, Any]) -> List[str]: params = [] flat_config = flatten_dict(config) for key, val in flat_config.items(): if isinstance(val, Domain): params.append(key) # Grid search is a special named field. Because we flattened # the whole config, we look it up per string if key.endswith("/grid_search"): # Truncate `/grid_search` params.append(key[:-12]) return params def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]: """Get strings representing the current and elapsed time. Args: start_time: POSIX timestamp of the start of the tune run current_time: POSIX timestamp giving the current time Returns: Current time and elapsed time for the current run """ current_time_dt = datetime.datetime.fromtimestamp(current_time) start_time_dt = datetime.datetime.fromtimestamp(start_time) delta: datetime.timedelta = current_time_dt - start_time_dt rest = delta.total_seconds() days = int(rest // (60 * 60 * 24)) rest -= days * (60 * 60 * 24) hours = int(rest // (60 * 60)) rest -= hours * (60 * 60) minutes = int(rest // 60) seconds = int(rest - minutes * 60) running_for_str = "" if days > 0: running_for_str += f"{days:d}d " if hours > 0 or running_for_str: running_for_str += f"{hours:d}hr " if minutes > 0 or running_for_str: running_for_str += f"{minutes:d}min " running_for_str += f"{seconds:d}s" return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str def _get_trials_by_state(trials: List[Trial]) -> Dict[str, List[Trial]]: trials_by_state = collections.defaultdict(list) for t in trials: trials_by_state[t.status].append(t) return trials_by_state def _get_trials_with_error(trials: List[Trial]) -> List[Trial]: return [t for t in trials if t.error_file] def _infer_user_metrics(trials: List[Trial], limit: int = 4) -> List[str]: """Try to infer the metrics to print out. By default, only the first 4 meaningful metrics in `last_result` will be inferred as user implied metrics. """ # Using OrderedDict for OrderedSet. result = collections.OrderedDict() for t in trials: if not t.last_result: continue for metric, value in t.last_result.items(): if metric not in DEFAULT_COLUMNS: if metric not in AUTO_RESULT_KEYS: if type(value) in VALID_SUMMARY_TYPES: result[metric] = "" # not important if len(result) >= limit: return list(result.keys()) return list(result.keys()) def _current_best_trial( trials: List[Trial], metric: Optional[str], mode: Optional[str] ) -> Tuple[Optional[Trial], Optional[str]]: """ Returns the best trial and the metric key. If anything is empty or None, returns a trivial result of None, None. Args: trials: List of trials. metric: Metric that trials are being ranked. mode: One of "min" or "max". Returns: Best trial and the metric key. """ if not trials or not metric or not mode: return None, None metric_op = 1.0 if mode == "max" else -1.0 best_metric = float("-inf") best_trial = None for t in trials: if not t.last_result: continue metric_value = unflattened_lookup(metric, t.last_result, default=None) if pd.isnull(metric_value): continue if not best_trial or metric_value * metric_op > best_metric: best_metric = metric_value * metric_op best_trial = t return best_trial, metric @dataclass class _PerStatusTrialTableData: trial_infos: List[List[str]] more_info: str @dataclass class _TrialTableData: header: List[str] data: List[_PerStatusTrialTableData] def _max_len(value: Any, max_len: int = 20, wrap: bool = False) -> Any: """Abbreviate a string representation of an object to `max_len` characters. For numbers, booleans and None, the original value will be returned for correct rendering in the table formatting tool. Args: value: Object to be represented as a string. max_len: Maximum return string length. """ if value is None or isinstance(value, (int, float, numbers.Number, bool)): return value string = str(value) if len(string) <= max_len: return string if wrap: # Maximum two rows. # Todo: Make this configurable in the refactor if len(value) > max_len * 2: value = "..." + string[(3 - (max_len * 2)) :] wrapped = textwrap.wrap(value, width=max_len) return "\n".join(wrapped) result = "..." + string[(3 - max_len) :] return result def _get_trial_info( trial: Trial, param_keys: List[str], metric_keys: List[str] ) -> List[str]: """Returns the following information about a trial: name | status | metrics... Args: trial: Trial to get information for. param_keys: Names of parameters to include. metric_keys: Names of metrics to include. """ result = trial.last_result trial_info = [str(trial), trial.status] # params trial_info.extend( [ _max_len( unflattened_lookup(param, trial.config, default=None), ) for param in param_keys ] ) # metrics trial_info.extend( [ _max_len( unflattened_lookup(metric, result, default=None), ) for metric in metric_keys ] ) return trial_info def _get_trial_table_data_per_status( status: str, trials: List[Trial], param_keys: List[str], metric_keys: List[str], force_max_rows: bool = False, ) -> Optional[_PerStatusTrialTableData]: """Gather all information of trials pertained to one `status`. Args: status: The trial status of interest. trials: all the trials of that status. param_keys: *Ordered* list of parameters to be displayed in the table. metric_keys: *Ordered* list of metrics to be displayed in the table. Including both default and user defined. force_max_rows: Whether or not to enforce a max row number for this status. If True, only a max of `5` rows will be shown. Returns: All information of trials pertained to the `status`. """ # TODO: configure it. max_row = 5 if force_max_rows else math.inf if not trials: return None trial_infos = list() more_info = None for t in trials: if len(trial_infos) >= max_row: remaining = len(trials) - max_row more_info = f"{remaining} more {status}" break trial_infos.append(_get_trial_info(t, param_keys, metric_keys)) return _PerStatusTrialTableData(trial_infos, more_info) def _get_trial_table_data( trials: List[Trial], param_keys: List[str], metric_keys: List[str], all_rows: bool = False, wrap_headers: bool = False, ) -> _TrialTableData: """Generate a table showing the current progress of tuning trials. Args: trials: List of trials for which progress is to be shown. param_keys: Ordered list of parameters to be displayed in the table. metric_keys: Ordered list of metrics to be displayed in the table. Including both default and user defined. Will only be shown if at least one trial is having the key. all_rows: Force to show all rows. wrap_headers: If True, header columns can be wrapped with ``\n``. Returns: Trial table data, including header and trial table per each status. """ # TODO: configure max_trial_num_to_show = 20 max_column_length = 20 trials_by_state = _get_trials_by_state(trials) # get the right metric to show. metric_keys = [ k for k in metric_keys if any( unflattened_lookup(k, t.last_result, default=None) is not None for t in trials ) ] # get header from metric keys formatted_metric_columns = [ _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in metric_keys ] formatted_param_columns = [ _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in param_keys ] metric_header = [ DEFAULT_COLUMNS[metric] if metric in DEFAULT_COLUMNS else formatted for metric, formatted in zip(metric_keys, formatted_metric_columns) ] param_header = formatted_param_columns # Map to the abbreviated version if necessary. header = ["Trial name", "status"] + param_header + metric_header trial_data = list() for t_status in ORDER: trial_data_per_status = _get_trial_table_data_per_status( t_status, trials_by_state[t_status], param_keys=param_keys, metric_keys=metric_keys, force_max_rows=not all_rows and len(trials) > max_trial_num_to_show, ) if trial_data_per_status: trial_data.append(trial_data_per_status) return _TrialTableData(header, trial_data) def _best_trial_str( trial: Trial, metric: str, ): """Returns a readable message stating the current best trial.""" # returns something like # Current best trial: 18ae7_00005 with loss=0.5918508041056858 and params={'train_loop_config': {'lr': 0.059253447253394785}}. # noqa val = unflattened_lookup(metric, trial.last_result, default=None) config = trial.last_result.get("config", {}) parameter_columns = list(config.keys()) params = {p: unflattened_lookup(p, config) for p in parameter_columns} return ( f"Current best trial: {trial.trial_id} with {metric}={val} and " f"params={params}" ) def _render_table_item( key: str, item: Any, prefix: str = "" ) -> Iterable[Tuple[str, str]]: key = prefix + key if isinstance(item, argparse.Namespace): item = item.__dict__ if isinstance(item, float): # tabulate does not work well with mixed-type columns, so we format # numbers ourselves. yield key, f"{item:.5f}".rstrip("0") elif isinstance(item, dict): flattened = flatten_dict(item) for k, v in sorted(flattened.items()): yield key + "/" + str(k), _max_len(v) else: yield key, _max_len(item, 20) def _get_dict_as_table_data( data: Dict, include: Optional[Collection] = None, exclude: Optional[Collection] = None, upper_keys: Optional[Collection] = None, ): """Get ``data`` dict as table rows. If specified, excluded keys are removed. Excluded keys can either be fully specified (e.g. ``foo/bar/baz``) or specify a top-level dictionary (e.g. ``foo``), but no intermediate levels (e.g. ``foo/bar``). If this is needed, we can revisit the logic at a later point. The same is true for included keys. If a top-level key is included (e.g. ``foo``) then all sub keys will be included, too, except if they are excluded. If keys are both excluded and included, exclusion takes precedence. Thus, if ``foo`` is excluded but ``foo/bar`` is included, it won't show up in the output. """ include = include or set() exclude = exclude or set() upper_keys = upper_keys or set() upper = [] lower = [] for key, value in sorted(data.items()): # Exclude top-level keys if key in exclude: continue for k, v in _render_table_item(str(key), value): # k is now the full subkey, e.g. config/nested/key # We can exclude the full key if k in exclude: continue # If we specify includes, top-level includes should take precedence # (e.g. if `config` is in include, include config always). if include and key not in include and k not in include: continue if key in upper_keys: upper.append([k, v]) else: lower.append([k, v]) if not upper: return lower elif not lower: return upper else: return upper + lower if sys.stdout and sys.stdout.encoding and sys.stdout.encoding.startswith("utf"): # Copied/adjusted from tabulate AIR_TABULATE_TABLEFMT = TableFormat( lineabove=Line("╭", "─", "─", "╮"), linebelowheader=Line("├", "─", "─", "┤"), linebetweenrows=None, linebelow=Line("╰", "─", "─", "╯"), headerrow=DataRow("│", " ", "│"), datarow=DataRow("│", " ", "│"), padding=1, with_header_hide=None, ) else: # For non-utf output, use ascii-compatible characters. # This prevents errors e.g. when legacy windows encoding is used. AIR_TABULATE_TABLEFMT = TableFormat( lineabove=Line("+", "-", "-", "+"), linebelowheader=Line("+", "-", "-", "+"), linebetweenrows=None, linebelow=Line("+", "-", "-", "+"), headerrow=DataRow("|", " ", "|"), datarow=DataRow("|", " ", "|"), padding=1, with_header_hide=None, ) def _print_dict_as_table( data: Dict, header: Optional[str] = None, include: Optional[Collection[str]] = None, exclude: Optional[Collection[str]] = None, division: Optional[Collection[str]] = None, ): table_data = _get_dict_as_table_data( data=data, include=include, exclude=exclude, upper_keys=division ) headers = [header, ""] if header else [] if not table_data: return print( tabulate( table_data, headers=headers, colalign=("left", "right"), tablefmt=AIR_TABULATE_TABLEFMT, ) ) class ProgressReporter(Callback): """Periodically prints out status update.""" # TODO: Make this configurable _heartbeat_freq = 30 # every 30 sec # to be updated by subclasses. _heartbeat_threshold = None _start_end_verbosity = None _intermediate_result_verbosity = None _addressing_tmpl = None def __init__( self, verbosity: AirVerbosity, progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None, ): """ Args: verbosity: AirVerbosity level. """ self._verbosity = verbosity self._start_time = time.time() self._last_heartbeat_time = float("-inf") self._start_time = time.time() self._progress_metrics = progress_metrics self._trial_last_printed_results = {} self._in_block = None @property def verbosity(self) -> AirVerbosity: return self._verbosity def setup( self, start_time: Optional[float] = None, **kwargs, ): self._start_time = start_time def _start_block(self, indicator: Any): if self._in_block != indicator: self._end_block() self._in_block = indicator def _end_block(self): if self._in_block: print("") self._in_block = None def on_experiment_end(self, trials: List["Trial"], **info): self._end_block() def experiment_started( self, experiment_name: str, experiment_path: str, searcher_str: str, scheduler_str: str, total_num_samples: int, tensorboard_path: Optional[str] = None, **kwargs, ): self._start_block("exp_start") print(f"\nView detailed results here: {experiment_path}") if tensorboard_path: print( f"To visualize your results with TensorBoard, run: " f"`tensorboard --logdir {tensorboard_path}`" ) @property def _time_heartbeat_str(self): current_time_str, running_time_str = _get_time_str( self._start_time, time.time() ) return ( f"Current time: {current_time_str}. Total running time: " + running_time_str ) def print_heartbeat(self, trials, *args, force: bool = False): if self._verbosity < self._heartbeat_threshold: return if force or time.time() - self._last_heartbeat_time >= self._heartbeat_freq: self._print_heartbeat(trials, *args, force=force) self._last_heartbeat_time = time.time() def _print_heartbeat(self, trials, *args, force: bool = False): raise NotImplementedError def _print_result(self, trial, result: Optional[Dict] = None, force: bool = False): """Only print result if a different result has been reported, or force=True""" result = result or trial.last_result last_result_iter = self._trial_last_printed_results.get(trial.trial_id, -1) this_iter = result.get(TRAINING_ITERATION, 0) if this_iter != last_result_iter or force: _print_dict_as_table( result, header=f"{self._addressing_tmpl.format(trial)} result", include=self._progress_metrics, exclude=BLACKLISTED_KEYS, division=AUTO_RESULT_KEYS, ) self._trial_last_printed_results[trial.trial_id] = this_iter def _print_config(self, trial): _print_dict_as_table( trial.config, header=f"{self._addressing_tmpl.format(trial)} config" ) def on_trial_result( self, iteration: int, trials: List[Trial], trial: Trial, result: Dict, **info, ): if self.verbosity < self._intermediate_result_verbosity: return self._start_block(f"trial_{trial}_result_{result[TRAINING_ITERATION]}") curr_time_str, running_time_str = _get_time_str(self._start_time, time.time()) print( f"{self._addressing_tmpl.format(trial)} " f"finished iteration {result[TRAINING_ITERATION]} " f"at {curr_time_str}. Total running time: " + running_time_str ) self._print_result(trial, result) def on_trial_complete( self, iteration: int, trials: List[Trial], trial: Trial, **info ): if self.verbosity < self._start_end_verbosity: return curr_time_str, running_time_str = _get_time_str(self._start_time, time.time()) finished_iter = 0 if trial.last_result and TRAINING_ITERATION in trial.last_result: finished_iter = trial.last_result[TRAINING_ITERATION] self._start_block(f"trial_{trial}_complete") print( f"{self._addressing_tmpl.format(trial)} " f"completed after {finished_iter} iterations " f"at {curr_time_str}. Total running time: " + running_time_str ) self._print_result(trial) def on_trial_error( self, iteration: int, trials: List["Trial"], trial: "Trial", **info ): curr_time_str, running_time_str = _get_time_str(self._start_time, time.time()) finished_iter = 0 if trial.last_result and TRAINING_ITERATION in trial.last_result: finished_iter = trial.last_result[TRAINING_ITERATION] self._start_block(f"trial_{trial}_error") print( f"{self._addressing_tmpl.format(trial)} " f"errored after {finished_iter} iterations " f"at {curr_time_str}. Total running time: {running_time_str}\n" f"Error file: {trial.error_file}" ) self._print_result(trial) def on_trial_recover( self, iteration: int, trials: List["Trial"], trial: "Trial", **info ): self.on_trial_error(iteration=iteration, trials=trials, trial=trial, **info) def on_checkpoint( self, iteration: int, trials: List[Trial], trial: Trial, checkpoint: Checkpoint, **info, ): if self._verbosity < self._intermediate_result_verbosity: return # don't think this is supposed to happen but just to be safe. saved_iter = "?" if trial.last_result and TRAINING_ITERATION in trial.last_result: saved_iter = trial.last_result[TRAINING_ITERATION] self._start_block(f"trial_{trial}_result_{saved_iter}") loc = f"({checkpoint.filesystem.type_name}){checkpoint.path}" print( f"{self._addressing_tmpl.format(trial)} " f"saved a checkpoint for iteration {saved_iter} " f"at: {loc}" ) def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, **info): if self.verbosity < self._start_end_verbosity: return has_config = bool(trial.config) self._start_block(f"trial_{trial}_start") if has_config: print( f"{self._addressing_tmpl.format(trial)} " f"started with configuration:" ) self._print_config(trial) else: print( f"{self._addressing_tmpl.format(trial)} " f"started without custom configuration." ) def _detect_reporter( verbosity: AirVerbosity, num_samples: int, entrypoint: Optional[AirEntrypoint] = None, metric: Optional[str] = None, mode: Optional[str] = None, config: Optional[Dict] = None, progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None, ): if entrypoint in { AirEntrypoint.TUNE_RUN, AirEntrypoint.TUNE_RUN_EXPERIMENTS, AirEntrypoint.TUNER, }: reporter = TuneTerminalReporter( verbosity, num_samples=num_samples, metric=metric, mode=mode, config=config, progress_metrics=progress_metrics, ) else: reporter = TrainReporter(verbosity, progress_metrics=progress_metrics) return reporter class TuneReporterBase(ProgressReporter): _heartbeat_threshold = AirVerbosity.DEFAULT _wrap_headers = False _intermediate_result_verbosity = AirVerbosity.VERBOSE _start_end_verbosity = AirVerbosity.DEFAULT _addressing_tmpl = "Trial {}" def __init__( self, verbosity: AirVerbosity, num_samples: int = 0, metric: Optional[str] = None, mode: Optional[str] = None, config: Optional[Dict] = None, progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None, ): self._num_samples = num_samples self._metric = metric self._mode = mode # will be populated when first result comes in. self._inferred_metric = None self._inferred_params = _infer_params(config or {}) super(TuneReporterBase, self).__init__( verbosity=verbosity, progress_metrics=progress_metrics ) def setup( self, start_time: Optional[float] = None, total_samples: Optional[int] = None, **kwargs, ): super().setup(start_time=start_time) self._num_samples = total_samples def _get_overall_trial_progress_str(self, trials): result = " | ".join( [ f"{len(trials)} {status}" for status, trials in _get_trials_by_state(trials).items() ] ) return f"Trial status: {result}" # TODO: Return a more structured type to share code with Jupyter flow. def _get_heartbeat( self, trials, *sys_args, force_full_output: bool = False ) -> Tuple[List[str], _TrialTableData]: result = list() # Trial status: 1 RUNNING | 7 PENDING result.append(self._get_overall_trial_progress_str(trials)) # Current time: 2023-02-24 12:35:39 (running for 00:00:37.40) result.append(self._time_heartbeat_str) # Logical resource usage: 8.0/64 CPUs, 0/0 GPUs result.extend(sys_args) # Current best trial: TRIAL NAME, metrics: {...}, parameters: {...} current_best_trial, metric = _current_best_trial( trials, self._metric, self._mode ) if current_best_trial: result.append(_best_trial_str(current_best_trial, metric)) # Now populating the trial table data. if not self._inferred_metric: # try inferring again. self._inferred_metric = _infer_user_metrics(trials) all_metrics = list(DEFAULT_COLUMNS.keys()) + self._inferred_metric trial_table_data = _get_trial_table_data( trials, param_keys=self._inferred_params, metric_keys=all_metrics, all_rows=force_full_output, wrap_headers=self._wrap_headers, ) return result, trial_table_data def _print_heartbeat(self, trials, *sys_args, force: bool = False): raise NotImplementedError class TuneTerminalReporter(TuneReporterBase): def experiment_started( self, experiment_name: str, experiment_path: str, searcher_str: str, scheduler_str: str, total_num_samples: int, tensorboard_path: Optional[str] = None, **kwargs, ): if total_num_samples > sys.maxsize: total_num_samples_str = "infinite" else: total_num_samples_str = str(total_num_samples) print( tabulate( [ ["Search algorithm", searcher_str], ["Scheduler", scheduler_str], ["Number of trials", total_num_samples_str], ], headers=["Configuration for experiment", experiment_name], tablefmt=AIR_TABULATE_TABLEFMT, ) ) super().experiment_started( experiment_name=experiment_name, experiment_path=experiment_path, searcher_str=searcher_str, scheduler_str=scheduler_str, total_num_samples=total_num_samples, tensorboard_path=tensorboard_path, **kwargs, ) def _print_heartbeat(self, trials, *sys_args, force: bool = False): if self._verbosity < self._heartbeat_threshold and not force: return heartbeat_strs, table_data = self._get_heartbeat( trials, *sys_args, force_full_output=force ) self._start_block("heartbeat") for s in heartbeat_strs: print(s) # now print the table using Tabulate more_infos = [] all_data = [] fail_header = table_data.header for sub_table in table_data.data: all_data.extend(sub_table.trial_infos) if sub_table.more_info: more_infos.append(sub_table.more_info) print( tabulate( all_data, headers=fail_header, tablefmt=AIR_TABULATE_TABLEFMT, showindex=False, ) ) if more_infos: print(", ".join(more_infos)) if not force: # Only print error table at end of training return trials_with_error = _get_trials_with_error(trials) if not trials_with_error: return self._start_block("status_errored") print(f"Number of errored trials: {len(trials_with_error)}") fail_header = ["Trial name", "# failures", "error file"] fail_table_data = [ [ str(trial), str(trial.run_metadata.num_failures) + ("" if trial.status == Trial.ERROR else "*"), trial.error_file, ] for trial in trials_with_error ] print( tabulate( fail_table_data, headers=fail_header, tablefmt=AIR_TABULATE_TABLEFMT, showindex=False, colalign=("left", "right", "left"), ) ) if any(trial.status == Trial.TERMINATED for trial in trials_with_error): print("* The trial terminated successfully after retrying.") class TrainReporter(ProgressReporter): # the minimal verbosity threshold at which heartbeat starts getting printed. _heartbeat_threshold = AirVerbosity.VERBOSE _intermediate_result_verbosity = AirVerbosity.DEFAULT _start_end_verbosity = AirVerbosity.DEFAULT _addressing_tmpl = "Training" def _get_heartbeat(self, trials: List[Trial], force_full_output: bool = False): # Training on iteration 1. Current time: 2023-03-22 15:29:25 (running for 00:00:03.24) # noqa if len(trials) == 0: return trial = trials[0] if trial.status != Trial.RUNNING: return " ".join( [f"Training is in {trial.status} status.", self._time_heartbeat_str] ) if not trial.last_result or TRAINING_ITERATION not in trial.last_result: iter_num = 1 else: iter_num = trial.last_result[TRAINING_ITERATION] + 1 return " ".join( [f"Training on iteration {iter_num}.", self._time_heartbeat_str] ) def _print_heartbeat(self, trials, *args, force: bool = False): print(self._get_heartbeat(trials, force_full_output=force)) def on_trial_result( self, iteration: int, trials: List[Trial], trial: Trial, result: Dict, **info, ): self._last_heartbeat_time = time.time() super().on_trial_result( iteration=iteration, trials=trials, trial=trial, result=result, **info )