| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596 |
- from __future__ import print_function
- import collections
- import datetime
- import numbers
- import sys
- import textwrap
- import time
- import warnings
- from pathlib import Path
- from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Union
- import numpy as np
- import pandas as pd
- import ray
- from ray._private.dict import flatten_dict
- from ray._private.thirdparty.tabulate.tabulate import tabulate
- from ray.air.constants import EXPR_ERROR_FILE, TRAINING_ITERATION
- from ray.air.util.node import _force_on_current_node
- from ray.experimental.tqdm_ray import safe_print
- from ray.tune.callback import Callback
- from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location
- from ray.tune.logger import pretty_print
- from ray.tune.result import (
- AUTO_RESULT_KEYS,
- DEFAULT_METRIC,
- DONE,
- EPISODE_REWARD_MEAN,
- EXPERIMENT_TAG,
- MEAN_ACCURACY,
- MEAN_LOSS,
- NODE_IP,
- PID,
- TIME_TOTAL_S,
- TIMESTEPS_TOTAL,
- TRIAL_ID,
- )
- from ray.tune.trainable import Trainable
- from ray.tune.utils import unflattened_lookup
- from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
- from ray.util.annotations import DeveloperAPI, PublicAPI
- from ray.util.queue import Empty, Queue
- from ray.widgets import Template
- try:
- from collections.abc import Mapping, MutableMapping
- except ImportError:
- from collections import Mapping, MutableMapping
- IS_NOTEBOOK = ray.widgets.util.in_notebook()
- SKIP_RESULTS_IN_REPORT = {"config", TRIAL_ID, EXPERIMENT_TAG, DONE}
- @PublicAPI
- class ProgressReporter:
- """Abstract class for experiment progress reporting.
- `should_report()` is called to determine whether or not `report()` should
- be called. Tune will call these functions after trial state transitions,
- receiving training results, and so on.
- """
- def setup(
- self,
- start_time: Optional[float] = None,
- total_samples: Optional[int] = None,
- metric: Optional[str] = None,
- mode: Optional[str] = None,
- **kwargs,
- ):
- """Setup progress reporter for a new Ray Tune run.
- This function is used to initialize parameters that are set on runtime.
- It will be called before any of the other methods.
- Defaults to no-op.
- Args:
- start_time: Timestamp when the Ray Tune run is started.
- total_samples: Number of samples the Ray Tune run will run.
- metric: Metric to optimize.
- mode: Must be one of [min, max]. Determines whether objective is
- minimizing or maximizing the metric attribute.
- **kwargs: Keyword arguments for forward-compatibility.
- """
- pass
- def should_report(self, trials: List[Trial], done: bool = False):
- """Returns whether or not progress should be reported.
- Args:
- trials: Trials to report on.
- done: Whether this is the last progress report attempt.
- """
- raise NotImplementedError
- def report(self, trials: List[Trial], done: bool, *sys_info: Dict):
- """Reports progress across trials.
- Args:
- trials: Trials to report on.
- done: Whether this is the last progress report attempt.
- sys_info: System info.
- """
- raise NotImplementedError
- @DeveloperAPI
- class TuneReporterBase(ProgressReporter):
- """Abstract base class for the default Tune reporters.
- If metric_columns is not overridden, Tune will attempt to automatically
- infer the metrics being outputted, up to 'infer_limit' number of
- metrics.
- Args:
- metric_columns: Names of metrics to
- include in progress table. If this is a dict, the keys should
- be metric names and the values should be the displayed names.
- If this is a list, the metric name is used directly.
- parameter_columns: Names of parameters to
- include in progress table. If this is a dict, the keys should
- be parameter names and the values should be the displayed names.
- If this is a list, the parameter name is used directly. If empty,
- defaults to all available parameters.
- max_progress_rows: Maximum number of rows to print
- in the progress table. The progress table describes the
- progress of each trial. Defaults to 20.
- max_error_rows: Maximum number of rows to print in the
- error table. The error table lists the error file, if any,
- corresponding to each trial. Defaults to 20.
- max_column_length: Maximum column length (in characters). Column
- headers and values longer than this will be abbreviated.
- max_report_frequency: Maximum report frequency in seconds.
- Defaults to 5s.
- infer_limit: Maximum number of metrics to automatically infer
- from tune results.
- print_intermediate_tables: Print intermediate result
- tables. If None (default), will be set to True for verbosity
- levels above 3, otherwise False. If True, intermediate tables
- will be printed with experiment progress. If False, tables
- will only be printed at then end of the tuning run for verbosity
- levels greater than 2.
- metric: Metric used to determine best current trial.
- mode: One of [min, max]. Determines whether objective is
- minimizing or maximizing the metric attribute.
- sort_by_metric: Sort terminated trials by metric in the
- intermediate table. Defaults to False.
- """
- # Truncated representations of column names (to accommodate small screens).
- DEFAULT_COLUMNS = collections.OrderedDict(
- {
- MEAN_ACCURACY: "acc",
- MEAN_LOSS: "loss",
- TRAINING_ITERATION: "iter",
- TIME_TOTAL_S: "total time (s)",
- TIMESTEPS_TOTAL: "ts",
- EPISODE_REWARD_MEAN: "reward",
- }
- )
- VALID_SUMMARY_TYPES = {
- int,
- float,
- np.float32,
- np.float64,
- np.int32,
- np.int64,
- type(None),
- }
- def __init__(
- self,
- *,
- metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
- parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
- total_samples: Optional[int] = None,
- max_progress_rows: int = 20,
- max_error_rows: int = 20,
- max_column_length: int = 20,
- max_report_frequency: int = 5,
- infer_limit: int = 3,
- print_intermediate_tables: Optional[bool] = None,
- metric: Optional[str] = None,
- mode: Optional[str] = None,
- sort_by_metric: bool = False,
- ):
- self._total_samples = total_samples
- self._metrics_override = metric_columns is not None
- self._inferred_metrics = {}
- self._metric_columns = metric_columns or self.DEFAULT_COLUMNS.copy()
- self._parameter_columns = parameter_columns or []
- self._max_progress_rows = max_progress_rows
- self._max_error_rows = max_error_rows
- self._max_column_length = max_column_length
- self._infer_limit = infer_limit
- if print_intermediate_tables is None:
- self._print_intermediate_tables = has_verbosity(Verbosity.V3_TRIAL_DETAILS)
- else:
- self._print_intermediate_tables = print_intermediate_tables
- self._max_report_freqency = max_report_frequency
- self._last_report_time = 0
- self._start_time = time.time()
- self._metric = metric
- self._mode = mode
- self._sort_by_metric = sort_by_metric
- def setup(
- self,
- start_time: Optional[float] = None,
- total_samples: Optional[int] = None,
- metric: Optional[str] = None,
- mode: Optional[str] = None,
- **kwargs,
- ):
- self.set_start_time(start_time)
- self.set_total_samples(total_samples)
- self.set_search_properties(metric=metric, mode=mode)
- def set_search_properties(self, metric: Optional[str], mode: Optional[str]):
- if (self._metric and metric) or (self._mode and mode):
- raise ValueError(
- "You passed a `metric` or `mode` argument to `tune.TuneConfig()`, but "
- "the reporter you are using was already instantiated with their "
- "own `metric` and `mode` parameters. Either remove the arguments "
- "from your reporter or from your call to `tune.TuneConfig()`"
- )
- if metric:
- self._metric = metric
- if mode:
- self._mode = mode
- if self._metric is None and self._mode:
- # If only a mode was passed, use anonymous metric
- self._metric = DEFAULT_METRIC
- return True
- def set_total_samples(self, total_samples: int):
- self._total_samples = total_samples
- def set_start_time(self, timestamp: Optional[float] = None):
- if timestamp is not None:
- self._start_time = time.time()
- else:
- self._start_time = timestamp
- def should_report(self, trials: List[Trial], done: bool = False):
- if time.time() - self._last_report_time > self._max_report_freqency:
- self._last_report_time = time.time()
- return True
- return done
- def add_metric_column(self, metric: str, representation: Optional[str] = None):
- """Adds a metric to the existing columns.
- Args:
- metric: Metric to add. This must be a metric being returned
- in training step results.
- representation: Representation to use in table. Defaults to
- `metric`.
- """
- self._metrics_override = True
- if metric in self._metric_columns:
- raise ValueError("Column {} already exists.".format(metric))
- if isinstance(self._metric_columns, MutableMapping):
- representation = representation or metric
- self._metric_columns[metric] = representation
- else:
- if representation is not None and representation != metric:
- raise ValueError(
- "`representation` cannot differ from `metric` "
- "if this reporter was initialized with a list "
- "of metric columns."
- )
- self._metric_columns.append(metric)
- def add_parameter_column(
- self, parameter: str, representation: Optional[str] = None
- ):
- """Adds a parameter to the existing columns.
- Args:
- parameter: Parameter to add. This must be a parameter
- specified in the configuration.
- representation: Representation to use in table. Defaults to
- `parameter`.
- """
- if parameter in self._parameter_columns:
- raise ValueError("Column {} already exists.".format(parameter))
- if isinstance(self._parameter_columns, MutableMapping):
- representation = representation or parameter
- self._parameter_columns[parameter] = representation
- else:
- if representation is not None and representation != parameter:
- raise ValueError(
- "`representation` cannot differ from `parameter` "
- "if this reporter was initialized with a list "
- "of metric columns."
- )
- self._parameter_columns.append(parameter)
- def _progress_str(
- self,
- trials: List[Trial],
- done: bool,
- *sys_info: Dict,
- fmt: str = "psql",
- delim: str = "\n",
- ):
- """Returns full progress string.
- This string contains a progress table and error table. The progress
- table describes the progress of each trial. The error table lists
- the error file, if any, corresponding to each trial. The latter only
- exists if errors have occurred.
- Args:
- trials: Trials to report on.
- done: Whether this is the last progress report attempt.
- fmt: Table format. See `tablefmt` in tabulate API.
- delim: Delimiter between messages.
- """
- if self._sort_by_metric and (self._metric is None or self._mode is None):
- self._sort_by_metric = False
- warnings.warn(
- "Both 'metric' and 'mode' must be set to be able "
- "to sort by metric. No sorting is performed."
- )
- if not self._metrics_override:
- user_metrics = self._infer_user_metrics(trials, self._infer_limit)
- self._metric_columns.update(user_metrics)
- messages = [
- "== Status ==",
- _time_passed_str(self._start_time, time.time()),
- *sys_info,
- ]
- if done:
- max_progress = None
- max_error = None
- else:
- max_progress = self._max_progress_rows
- max_error = self._max_error_rows
- current_best_trial, metric = self._current_best_trial(trials)
- if current_best_trial:
- messages.append(
- _best_trial_str(current_best_trial, metric, self._parameter_columns)
- )
- if has_verbosity(Verbosity.V1_EXPERIMENT):
- # Will filter the table in `trial_progress_str`
- messages.append(
- _trial_progress_str(
- trials,
- metric_columns=self._metric_columns,
- parameter_columns=self._parameter_columns,
- total_samples=self._total_samples,
- force_table=self._print_intermediate_tables,
- fmt=fmt,
- max_rows=max_progress,
- max_column_length=self._max_column_length,
- done=done,
- metric=self._metric,
- mode=self._mode,
- sort_by_metric=self._sort_by_metric,
- )
- )
- messages.append(_trial_errors_str(trials, fmt=fmt, max_rows=max_error))
- return delim.join(messages) + delim
- def _infer_user_metrics(self, trials: List[Trial], limit: int = 4):
- """Try to infer the metrics to print out."""
- if len(self._inferred_metrics) >= limit:
- return self._inferred_metrics
- self._inferred_metrics = {}
- for t in trials:
- if not t.last_result:
- continue
- for metric, value in t.last_result.items():
- if metric not in self.DEFAULT_COLUMNS:
- if metric not in AUTO_RESULT_KEYS:
- if type(value) in self.VALID_SUMMARY_TYPES:
- self._inferred_metrics[metric] = metric
- if len(self._inferred_metrics) >= limit:
- return self._inferred_metrics
- return self._inferred_metrics
- def _current_best_trial(self, trials: List[Trial]):
- if not trials:
- return None, None
- metric, mode = self._metric, self._mode
- # If no metric has been set, see if exactly one has been reported
- # and use that one. `mode` must still be set.
- if not metric:
- if len(self._inferred_metrics) == 1:
- metric = list(self._inferred_metrics.keys())[0]
- if not metric or not mode:
- return None, metric
- metric_op = 1.0 if mode == "max" else -1.0
- best_metric = float("-inf")
- best_trial = None
- for t in trials:
- if not t.last_result:
- continue
- metric_value = unflattened_lookup(metric, t.last_result, default=None)
- if pd.isnull(metric_value):
- continue
- if not best_trial or metric_value * metric_op > best_metric:
- best_metric = metric_value * metric_op
- best_trial = t
- return best_trial, metric
- @DeveloperAPI
- class RemoteReporterMixin:
- """Remote reporter abstract mixin class.
- Subclasses of this class will use a Ray Queue to display output
- on the driver side when running Ray Client."""
- @property
- def output_queue(self) -> Queue:
- return getattr(self, "_output_queue", None)
- @output_queue.setter
- def output_queue(self, value: Queue):
- self._output_queue = value
- def display(self, string: str) -> None:
- """Display the progress string.
- Args:
- string: String to display.
- """
- raise NotImplementedError
- @PublicAPI
- class JupyterNotebookReporter(TuneReporterBase, RemoteReporterMixin):
- """Jupyter notebook-friendly Reporter that can update display in-place.
- Args:
- overwrite: Flag for overwriting the cell contents before initialization.
- metric_columns: Names of metrics to
- include in progress table. If this is a dict, the keys should
- be metric names and the values should be the displayed names.
- If this is a list, the metric name is used directly.
- parameter_columns: Names of parameters to
- include in progress table. If this is a dict, the keys should
- be parameter names and the values should be the displayed names.
- If this is a list, the parameter name is used directly. If empty,
- defaults to all available parameters.
- max_progress_rows: Maximum number of rows to print
- in the progress table. The progress table describes the
- progress of each trial. Defaults to 20.
- max_error_rows: Maximum number of rows to print in the
- error table. The error table lists the error file, if any,
- corresponding to each trial. Defaults to 20.
- max_column_length: Maximum column length (in characters). Column
- headers and values longer than this will be abbreviated.
- max_report_frequency: Maximum report frequency in seconds.
- Defaults to 5s.
- infer_limit: Maximum number of metrics to automatically infer
- from tune results.
- print_intermediate_tables: Print intermediate result
- tables. If None (default), will be set to True for verbosity
- levels above 3, otherwise False. If True, intermediate tables
- will be printed with experiment progress. If False, tables
- will only be printed at then end of the tuning run for verbosity
- levels greater than 2.
- metric: Metric used to determine best current trial.
- mode: One of [min, max]. Determines whether objective is
- minimizing or maximizing the metric attribute.
- sort_by_metric: Sort terminated trials by metric in the
- intermediate table. Defaults to False.
- """
- def __init__(
- self,
- *,
- overwrite: bool = True,
- metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
- parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
- total_samples: Optional[int] = None,
- max_progress_rows: int = 20,
- max_error_rows: int = 20,
- max_column_length: int = 20,
- max_report_frequency: int = 5,
- infer_limit: int = 3,
- print_intermediate_tables: Optional[bool] = None,
- metric: Optional[str] = None,
- mode: Optional[str] = None,
- sort_by_metric: bool = False,
- ):
- super(JupyterNotebookReporter, self).__init__(
- metric_columns=metric_columns,
- parameter_columns=parameter_columns,
- total_samples=total_samples,
- max_progress_rows=max_progress_rows,
- max_error_rows=max_error_rows,
- max_column_length=max_column_length,
- max_report_frequency=max_report_frequency,
- infer_limit=infer_limit,
- print_intermediate_tables=print_intermediate_tables,
- metric=metric,
- mode=mode,
- sort_by_metric=sort_by_metric,
- )
- if not IS_NOTEBOOK:
- warnings.warn(
- "You are using the `JupyterNotebookReporter`, but not "
- "IPython/Jupyter-compatible environment was detected. "
- "If this leads to unformatted output (e.g. like "
- "<IPython.core.display.HTML object>), consider passing "
- "a `CLIReporter` as the `progress_reporter` argument "
- "to `tune.RunConfig()` instead."
- )
- self._overwrite = overwrite
- self._display_handle = None
- self.display("") # initialize empty display to update later
- def report(self, trials: List[Trial], done: bool, *sys_info: Dict):
- progress = self._progress_html(trials, done, *sys_info)
- if self.output_queue is not None:
- # If an output queue is set, send string
- self.output_queue.put(progress)
- else:
- # Else, output directly
- self.display(progress)
- def display(self, string: str) -> None:
- from IPython.display import HTML, clear_output, display
- if not self._display_handle:
- if self._overwrite:
- clear_output(wait=True)
- self._display_handle = display(HTML(string), display_id=True)
- else:
- self._display_handle.update(HTML(string))
- def _progress_html(self, trials: List[Trial], done: bool, *sys_info) -> str:
- """Generate an HTML-formatted progress update.
- Args:
- trials: List of trials for which progress should be
- displayed
- done: True if the trials are finished, False otherwise
- *sys_info: System information to be displayed
- Returns:
- Progress update to be rendered in a notebook, including HTML
- tables and formatted error messages. Includes
- - Duration of the tune job
- - Memory consumption
- - Trial progress table, with information about each experiment
- """
- if not self._metrics_override:
- user_metrics = self._infer_user_metrics(trials, self._infer_limit)
- self._metric_columns.update(user_metrics)
- current_time, running_for = _get_time_str(self._start_time, time.time())
- used_gb, total_gb, memory_message = _get_memory_usage()
- status_table = tabulate(
- [
- ("Current time:", current_time),
- ("Running for:", running_for),
- ("Memory:", f"{used_gb}/{total_gb} GiB"),
- ],
- tablefmt="html",
- )
- trial_progress_data = _trial_progress_table(
- trials=trials,
- metric_columns=self._metric_columns,
- parameter_columns=self._parameter_columns,
- fmt="html",
- max_rows=None if done else self._max_progress_rows,
- metric=self._metric,
- mode=self._mode,
- sort_by_metric=self._sort_by_metric,
- max_column_length=self._max_column_length,
- )
- trial_progress = trial_progress_data[0]
- trial_progress_messages = trial_progress_data[1:]
- trial_errors = _trial_errors_str(
- trials, fmt="html", max_rows=None if done else self._max_error_rows
- )
- if any([memory_message, trial_progress_messages, trial_errors]):
- msg = Template("tune_status_messages.html.j2").render(
- memory_message=memory_message,
- trial_progress_messages=trial_progress_messages,
- trial_errors=trial_errors,
- )
- else:
- msg = None
- return Template("tune_status.html.j2").render(
- status_table=status_table,
- sys_info_message=_generate_sys_info_str(*sys_info),
- trial_progress=trial_progress,
- messages=msg,
- )
- @PublicAPI
- class CLIReporter(TuneReporterBase):
- """Command-line reporter
- Args:
- metric_columns: Names of metrics to
- include in progress table. If this is a dict, the keys should
- be metric names and the values should be the displayed names.
- If this is a list, the metric name is used directly.
- parameter_columns: Names of parameters to
- include in progress table. If this is a dict, the keys should
- be parameter names and the values should be the displayed names.
- If this is a list, the parameter name is used directly. If empty,
- defaults to all available parameters.
- max_progress_rows: Maximum number of rows to print
- in the progress table. The progress table describes the
- progress of each trial. Defaults to 20.
- max_error_rows: Maximum number of rows to print in the
- error table. The error table lists the error file, if any,
- corresponding to each trial. Defaults to 20.
- max_column_length: Maximum column length (in characters). Column
- headers and values longer than this will be abbreviated.
- max_report_frequency: Maximum report frequency in seconds.
- Defaults to 5s.
- infer_limit: Maximum number of metrics to automatically infer
- from tune results.
- print_intermediate_tables: Print intermediate result
- tables. If None (default), will be set to True for verbosity
- levels above 3, otherwise False. If True, intermediate tables
- will be printed with experiment progress. If False, tables
- will only be printed at then end of the tuning run for verbosity
- levels greater than 2.
- metric: Metric used to determine best current trial.
- mode: One of [min, max]. Determines whether objective is
- minimizing or maximizing the metric attribute.
- sort_by_metric: Sort terminated trials by metric in the
- intermediate table. Defaults to False.
- """
- def __init__(
- self,
- *,
- metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
- parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
- total_samples: Optional[int] = None,
- max_progress_rows: int = 20,
- max_error_rows: int = 20,
- max_column_length: int = 20,
- max_report_frequency: int = 5,
- infer_limit: int = 3,
- print_intermediate_tables: Optional[bool] = None,
- metric: Optional[str] = None,
- mode: Optional[str] = None,
- sort_by_metric: bool = False,
- ):
- super(CLIReporter, self).__init__(
- metric_columns=metric_columns,
- parameter_columns=parameter_columns,
- total_samples=total_samples,
- max_progress_rows=max_progress_rows,
- max_error_rows=max_error_rows,
- max_column_length=max_column_length,
- max_report_frequency=max_report_frequency,
- infer_limit=infer_limit,
- print_intermediate_tables=print_intermediate_tables,
- metric=metric,
- mode=mode,
- sort_by_metric=sort_by_metric,
- )
- def _print(self, msg: str):
- safe_print(msg)
- def report(self, trials: List[Trial], done: bool, *sys_info: Dict):
- self._print(self._progress_str(trials, done, *sys_info))
- def _get_memory_usage() -> Tuple[float, float, Optional[str]]:
- """Get the current memory consumption.
- Returns:
- Memory used, memory available, and optionally a warning
- message to be shown to the user when memory consumption is higher
- than 90% or if `psutil` is not installed
- """
- try:
- import ray # noqa F401
- import psutil
- total_gb = psutil.virtual_memory().total / (1024**3)
- used_gb = total_gb - psutil.virtual_memory().available / (1024**3)
- if used_gb > total_gb * 0.9:
- message = (
- ": ***LOW MEMORY*** less than 10% of the memory on "
- "this node is available for use. This can cause "
- "unexpected crashes. Consider "
- "reducing the memory used by your application "
- "or reducing the Ray object store size by setting "
- "`object_store_memory` when calling `ray.init`."
- )
- else:
- message = None
- return round(used_gb, 1), round(total_gb, 1), message
- except ImportError:
- return (
- np.nan,
- np.nan,
- "Unknown memory usage. Please run `pip install psutil` to resolve",
- )
- def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]:
- """Get strings representing the current and elapsed time.
- Args:
- start_time: POSIX timestamp of the start of the tune run
- current_time: POSIX timestamp giving the current time
- Returns:
- Current time and elapsed time for the current run
- """
- current_time_dt = datetime.datetime.fromtimestamp(current_time)
- start_time_dt = datetime.datetime.fromtimestamp(start_time)
- delta: datetime.timedelta = current_time_dt - start_time_dt
- rest = delta.total_seconds()
- days = rest // (60 * 60 * 24)
- rest -= days * (60 * 60 * 24)
- hours = rest // (60 * 60)
- rest -= hours * (60 * 60)
- minutes = rest // 60
- seconds = rest - minutes * 60
- if days > 0:
- running_for_str = f"{days:.0f} days, "
- else:
- running_for_str = ""
- running_for_str += f"{hours:02.0f}:{minutes:02.0f}:{seconds:05.2f}"
- return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str
- def _time_passed_str(start_time: float, current_time: float) -> str:
- """Generate a message describing the current and elapsed time in the run.
- Args:
- start_time: POSIX timestamp of the start of the tune run
- current_time: POSIX timestamp giving the current time
- Returns:
- Message with the current and elapsed time for the current tune run,
- formatted to be displayed to the user
- """
- current_time_str, running_for_str = _get_time_str(start_time, current_time)
- return f"Current time: {current_time_str} " f"(running for {running_for_str})"
- def _get_trials_by_state(trials: List[Trial]):
- trials_by_state = collections.defaultdict(list)
- for t in trials:
- trials_by_state[t.status].append(t)
- return trials_by_state
- def _trial_progress_str(
- trials: List[Trial],
- metric_columns: Union[List[str], Dict[str, str]],
- parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
- total_samples: int = 0,
- force_table: bool = False,
- fmt: str = "psql",
- max_rows: Optional[int] = None,
- max_column_length: int = 20,
- done: bool = False,
- metric: Optional[str] = None,
- mode: Optional[str] = None,
- sort_by_metric: bool = False,
- ):
- """Returns a human readable message for printing to the console.
- This contains a table where each row represents a trial, its parameters
- and the current values of its metrics.
- Args:
- trials: List of trials to get progress string for.
- metric_columns: Names of metrics to include.
- If this is a dict, the keys are metric names and the values are
- the names to use in the message. If this is a list, the metric
- name is used in the message directly.
- parameter_columns: Names of parameters to
- include. If this is a dict, the keys are parameter names and the
- values are the names to use in the message. If this is a list,
- the parameter name is used in the message directly. If this is
- empty, all parameters are used in the message.
- total_samples: Total number of trials that will be generated.
- force_table: Force printing a table. If False, a table will
- be printed only at the end of the training for verbosity levels
- above `Verbosity.V2_TRIAL_NORM`.
- fmt: Output format (see tablefmt in tabulate API).
- max_rows: Maximum number of rows in the trial table. Defaults to
- unlimited.
- max_column_length: Maximum column length (in characters).
- done: True indicates that the tuning run finished.
- metric: Metric used to sort trials.
- mode: One of [min, max]. Determines whether objective is
- minimizing or maximizing the metric attribute.
- sort_by_metric: Sort terminated trials by metric in the
- intermediate table. Defaults to False.
- """
- messages = []
- delim = "<br>" if fmt == "html" else "\n"
- if len(trials) < 1:
- return delim.join(messages)
- num_trials = len(trials)
- trials_by_state = _get_trials_by_state(trials)
- for local_dir in sorted({t.local_experiment_path for t in trials}):
- messages.append("Result logdir: {}".format(local_dir))
- num_trials_strs = [
- "{} {}".format(len(trials_by_state[state]), state)
- for state in sorted(trials_by_state)
- ]
- if total_samples and total_samples >= sys.maxsize:
- total_samples = "infinite"
- messages.append(
- "Number of trials: {}{} ({})".format(
- num_trials,
- f"/{total_samples}" if total_samples else "",
- ", ".join(num_trials_strs),
- )
- )
- if force_table or (has_verbosity(Verbosity.V2_TRIAL_NORM) and done):
- messages += _trial_progress_table(
- trials=trials,
- metric_columns=metric_columns,
- parameter_columns=parameter_columns,
- fmt=fmt,
- max_rows=max_rows,
- metric=metric,
- mode=mode,
- sort_by_metric=sort_by_metric,
- max_column_length=max_column_length,
- )
- return delim.join(messages)
- def _max_len(
- value: Any, max_len: int = 20, add_addr: bool = False, wrap: bool = False
- ) -> Any:
- """Abbreviate a string representation of an object to `max_len` characters.
- For numbers, booleans and None, the original value will be returned for
- correct rendering in the table formatting tool.
- Args:
- value: Object to be represented as a string.
- max_len: Maximum return string length.
- add_addr: If True, will add part of the object address to the end of the
- string, e.g. to identify different instances of the same class. If
- False, three dots (``...``) will be used instead.
- """
- if value is None or isinstance(value, (int, float, numbers.Number, bool)):
- return value
- string = str(value)
- if len(string) <= max_len:
- return string
- if wrap:
- # Maximum two rows.
- # Todo: Make this configurable in the refactor
- if len(value) > max_len * 2:
- value = "..." + string[(3 - (max_len * 2)) :]
- wrapped = textwrap.wrap(value, width=max_len)
- return "\n".join(wrapped)
- if add_addr and not isinstance(value, (int, float, bool)):
- result = f"{string[: (max_len - 5)]}_{hex(id(value))[-4:]}"
- return result
- result = "..." + string[(3 - max_len) :]
- return result
- def _get_progress_table_data(
- trials: List[Trial],
- metric_columns: Union[List[str], Dict[str, str]],
- parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
- max_rows: Optional[int] = None,
- metric: Optional[str] = None,
- mode: Optional[str] = None,
- sort_by_metric: bool = False,
- max_column_length: int = 20,
- ) -> Tuple[List, List[str], Tuple[bool, str]]:
- """Generate a table showing the current progress of tuning trials.
- Args:
- trials: List of trials for which progress is to be shown.
- metric_columns: Metrics to be displayed in the table.
- parameter_columns: List of parameters to be included in the data
- max_rows: Maximum number of rows to show. If there's overflow, a
- message will be shown to the user indicating that some rows
- are not displayed
- metric: Metric which is being tuned
- mode: Sort the table in descending order if mode is "max";
- ascending otherwise
- sort_by_metric: If true, the table will be sorted by the metric
- max_column_length: Max number of characters in each column
- Returns:
- - Trial data
- - List of column names
- - Overflow tuple:
- - boolean indicating whether the table has rows which are hidden
- - string with info about the overflowing rows
- """
- num_trials = len(trials)
- trials_by_state = _get_trials_by_state(trials)
- # Sort terminated trials by metric and mode, descending if mode is "max"
- if sort_by_metric:
- trials_by_state[Trial.TERMINATED] = sorted(
- trials_by_state[Trial.TERMINATED],
- reverse=(mode == "max"),
- key=lambda t: unflattened_lookup(metric, t.last_result, default=None),
- )
- state_tbl_order = [
- Trial.RUNNING,
- Trial.PAUSED,
- Trial.PENDING,
- Trial.TERMINATED,
- Trial.ERROR,
- ]
- max_rows = max_rows or float("inf")
- if num_trials > max_rows:
- # TODO(ujvl): suggestion for users to view more rows.
- trials_by_state_trunc = _fair_filter_trials(
- trials_by_state, max_rows, sort_by_metric
- )
- trials = []
- overflow_strs = []
- for state in state_tbl_order:
- if state not in trials_by_state:
- continue
- trials += trials_by_state_trunc[state]
- num = len(trials_by_state[state]) - len(trials_by_state_trunc[state])
- if num > 0:
- overflow_strs.append("{} {}".format(num, state))
- # Build overflow string.
- overflow = num_trials - max_rows
- overflow_str = ", ".join(overflow_strs)
- else:
- overflow = False
- overflow_str = ""
- trials = []
- for state in state_tbl_order:
- if state not in trials_by_state:
- continue
- trials += trials_by_state[state]
- # Pre-process trials to figure out what columns to show.
- if isinstance(metric_columns, Mapping):
- metric_keys = list(metric_columns.keys())
- else:
- metric_keys = metric_columns
- metric_keys = [
- k
- for k in metric_keys
- if any(
- unflattened_lookup(k, t.last_result, default=None) is not None
- for t in trials
- )
- ]
- if not parameter_columns:
- parameter_keys = sorted(set().union(*[t.evaluated_params for t in trials]))
- elif isinstance(parameter_columns, Mapping):
- parameter_keys = list(parameter_columns.keys())
- else:
- parameter_keys = parameter_columns
- # Build trial rows.
- trial_table = [
- _get_trial_info(
- trial, parameter_keys, metric_keys, max_column_length=max_column_length
- )
- for trial in trials
- ]
- # Format column headings
- if isinstance(metric_columns, Mapping):
- formatted_metric_columns = [
- _max_len(
- metric_columns[k], max_len=max_column_length, add_addr=False, wrap=True
- )
- for k in metric_keys
- ]
- else:
- formatted_metric_columns = [
- _max_len(k, max_len=max_column_length, add_addr=False, wrap=True)
- for k in metric_keys
- ]
- if isinstance(parameter_columns, Mapping):
- formatted_parameter_columns = [
- _max_len(
- parameter_columns[k],
- max_len=max_column_length,
- add_addr=False,
- wrap=True,
- )
- for k in parameter_keys
- ]
- else:
- formatted_parameter_columns = [
- _max_len(k, max_len=max_column_length, add_addr=False, wrap=True)
- for k in parameter_keys
- ]
- columns = (
- ["Trial name", "status", "loc"]
- + formatted_parameter_columns
- + formatted_metric_columns
- )
- return trial_table, columns, (overflow, overflow_str)
- def _trial_progress_table(
- trials: List[Trial],
- metric_columns: Union[List[str], Dict[str, str]],
- parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
- fmt: str = "psql",
- max_rows: Optional[int] = None,
- metric: Optional[str] = None,
- mode: Optional[str] = None,
- sort_by_metric: bool = False,
- max_column_length: int = 20,
- ) -> List[str]:
- """Generate a list of trial progress table messages.
- Args:
- trials: List of trials for which progress is to be shown.
- metric_columns: Metrics to be displayed in the table.
- parameter_columns: List of parameters to be included in the data
- fmt: Format of the table; passed to tabulate as the fmtstr argument
- max_rows: Maximum number of rows to show. If there's overflow, a
- message will be shown to the user indicating that some rows
- are not displayed
- metric: Metric which is being tuned
- mode: Sort the table in descenting order if mode is "max";
- ascending otherwise
- sort_by_metric: If true, the table will be sorted by the metric
- max_column_length: Max number of characters in each column
- Returns:
- Messages to be shown to the user containing progress tables
- """
- data, columns, (overflow, overflow_str) = _get_progress_table_data(
- trials,
- metric_columns,
- parameter_columns,
- max_rows,
- metric,
- mode,
- sort_by_metric,
- max_column_length,
- )
- messages = [tabulate(data, headers=columns, tablefmt=fmt, showindex=False)]
- if overflow:
- messages.append(f"... {overflow} more trials not shown ({overflow_str})")
- return messages
- def _generate_sys_info_str(*sys_info) -> str:
- """Format system info into a string.
- *sys_info: System info strings to be included.
- Returns:
- Formatted string containing system information.
- """
- if sys_info:
- return "<br>".join(sys_info).replace("\n", "<br>")
- return ""
- def _trial_errors_str(
- trials: List[Trial], fmt: str = "psql", max_rows: Optional[int] = None
- ):
- """Returns a readable message regarding trial errors.
- Args:
- trials: List of trials to get progress string for.
- fmt: Output format (see tablefmt in tabulate API).
- max_rows: Maximum number of rows in the error table. Defaults to
- unlimited.
- """
- messages = []
- failed = [t for t in trials if t.error_file]
- num_failed = len(failed)
- if num_failed > 0:
- messages.append("Number of errored trials: {}".format(num_failed))
- if num_failed > (max_rows or float("inf")):
- messages.append(
- "Table truncated to {} rows ({} overflow)".format(
- max_rows, num_failed - max_rows
- )
- )
- fail_header = ["Trial name", "# failures", "error file"]
- fail_table_data = [
- [
- str(trial),
- str(trial.run_metadata.num_failures)
- + ("" if trial.status == Trial.ERROR else "*"),
- trial.error_file,
- ]
- for trial in failed[:max_rows]
- ]
- messages.append(
- tabulate(
- fail_table_data,
- headers=fail_header,
- tablefmt=fmt,
- showindex=False,
- colalign=("left", "right", "left"),
- )
- )
- if any(trial.status == Trial.TERMINATED for trial in failed[:max_rows]):
- messages.append("* The trial terminated successfully after retrying.")
- delim = "<br>" if fmt == "html" else "\n"
- return delim.join(messages)
- def _best_trial_str(
- trial: Trial,
- metric: str,
- parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
- ):
- """Returns a readable message stating the current best trial."""
- val = unflattened_lookup(metric, trial.last_result, default=None)
- config = trial.last_result.get("config", {})
- parameter_columns = parameter_columns or list(config.keys())
- if isinstance(parameter_columns, Mapping):
- parameter_columns = parameter_columns.keys()
- params = {p: unflattened_lookup(p, config) for p in parameter_columns}
- return (
- f"Current best trial: {trial.trial_id} with {metric}={val} and "
- f"parameters={params}"
- )
- def _fair_filter_trials(
- trials_by_state: Dict[str, List[Trial]],
- max_trials: int,
- sort_by_metric: bool = False,
- ):
- """Filters trials such that each state is represented fairly.
- The oldest trials are truncated if necessary.
- Args:
- trials_by_state: Maximum number of trials to return.
- Returns:
- Dict mapping state to List of fairly represented trials.
- """
- num_trials_by_state = collections.defaultdict(int)
- no_change = False
- # Determine number of trials to keep per state.
- while max_trials > 0 and not no_change:
- no_change = True
- for state in sorted(trials_by_state):
- if num_trials_by_state[state] < len(trials_by_state[state]):
- no_change = False
- max_trials -= 1
- num_trials_by_state[state] += 1
- # Sort by start time, descending if the trails is not sorted by metric.
- sorted_trials_by_state = dict()
- for state in sorted(trials_by_state):
- if state == Trial.TERMINATED and sort_by_metric:
- sorted_trials_by_state[state] = trials_by_state[state]
- else:
- sorted_trials_by_state[state] = sorted(
- trials_by_state[state], reverse=False, key=lambda t: t.trial_id
- )
- # Truncate oldest trials.
- filtered_trials = {
- state: sorted_trials_by_state[state][: num_trials_by_state[state]]
- for state in sorted(trials_by_state)
- }
- return filtered_trials
- def _get_trial_location(trial: Trial, result: dict) -> _Location:
- # we get the location from the result, as the one in trial will be
- # reset when trial terminates
- node_ip, pid = result.get(NODE_IP, None), result.get(PID, None)
- if node_ip and pid:
- location = _Location(node_ip, pid)
- else:
- # fallback to trial location if there hasn't been a report yet
- location = trial.temporary_state.location
- return location
- def _get_trial_info(
- trial: Trial, parameters: List[str], metrics: List[str], max_column_length: int = 20
- ):
- """Returns the following information about a trial:
- name | status | loc | params... | metrics...
- Args:
- trial: Trial to get information for.
- parameters: Names of trial parameters to include.
- metrics: Names of metrics to include.
- max_column_length: Maximum column length (in characters).
- """
- result = trial.last_result
- config = trial.config
- location = _get_trial_location(trial, result)
- trial_info = [str(trial), trial.status, str(location)]
- trial_info += [
- _max_len(
- unflattened_lookup(param, config, default=None),
- max_len=max_column_length,
- add_addr=True,
- )
- for param in parameters
- ]
- trial_info += [
- _max_len(
- unflattened_lookup(metric, result, default=None),
- max_len=max_column_length,
- add_addr=True,
- )
- for metric in metrics
- ]
- return trial_info
- @DeveloperAPI
- class TrialProgressCallback(Callback):
- """Reports (prints) intermediate trial progress.
- This callback is automatically added to the callback stack. When a
- result is obtained, this callback will print the results according to
- the specified verbosity level.
- For ``Verbosity.V3_TRIAL_DETAILS``, a full result list is printed.
- For ``Verbosity.V2_TRIAL_NORM``, only one line is printed per received
- result.
- All other verbosity levels do not print intermediate trial progress.
- Result printing is throttled on a per-trial basis. Per default, results are
- printed only once every 30 seconds. Results are always printed when a trial
- finished or errored.
- """
- def __init__(
- self, metric: Optional[str] = None, progress_metrics: Optional[List[str]] = None
- ):
- self._last_print = collections.defaultdict(float)
- self._last_print_iteration = collections.defaultdict(int)
- self._completed_trials = set()
- self._last_result_str = {}
- self._metric = metric
- self._progress_metrics = set(progress_metrics or [])
- # Only use progress metrics if at least two metrics are in there
- if self._metric and self._progress_metrics:
- self._progress_metrics.add(self._metric)
- self._last_result = {}
- self._display_handle = None
- def _print(self, msg: str):
- safe_print(msg)
- def on_trial_result(
- self,
- iteration: int,
- trials: List["Trial"],
- trial: "Trial",
- result: Dict,
- **info,
- ):
- self.log_result(trial, result, error=False)
- def on_trial_error(
- self, iteration: int, trials: List["Trial"], trial: "Trial", **info
- ):
- self.log_result(trial, trial.last_result, error=True)
- def on_trial_complete(
- self, iteration: int, trials: List["Trial"], trial: "Trial", **info
- ):
- # Only log when we never logged that a trial was completed
- if trial not in self._completed_trials:
- self._completed_trials.add(trial)
- print_result_str = self._print_result(trial.last_result)
- last_result_str = self._last_result_str.get(trial, "")
- # If this is a new result, print full result string
- if print_result_str != last_result_str:
- self.log_result(trial, trial.last_result, error=False)
- else:
- self._print(f"Trial {trial} completed. Last result: {print_result_str}")
- def log_result(self, trial: "Trial", result: Dict, error: bool = False):
- done = result.get("done", False) is True
- last_print = self._last_print[trial]
- should_print = done or error or time.time() - last_print > DEBUG_PRINT_INTERVAL
- if done and trial not in self._completed_trials:
- self._completed_trials.add(trial)
- if should_print:
- if IS_NOTEBOOK:
- self.display_result(trial, result, error, done)
- else:
- self.print_result(trial, result, error, done)
- self._last_print[trial] = time.time()
- if TRAINING_ITERATION in result:
- self._last_print_iteration[trial] = result[TRAINING_ITERATION]
- def print_result(self, trial: Trial, result: Dict, error: bool, done: bool):
- """Print the most recent results for the given trial to stdout.
- Args:
- trial: Trial for which results are to be printed
- result: Result to be printed
- error: True if an error has occurred, False otherwise
- done: True if the trial is finished, False otherwise
- """
- last_print_iteration = self._last_print_iteration[trial]
- if has_verbosity(Verbosity.V3_TRIAL_DETAILS):
- if result.get(TRAINING_ITERATION) != last_print_iteration:
- self._print(f"Result for {trial}:")
- self._print(" {}".format(pretty_print(result).replace("\n", "\n ")))
- if done:
- self._print(f"Trial {trial} completed.")
- elif has_verbosity(Verbosity.V2_TRIAL_NORM):
- metric_name = self._metric or "_metric"
- metric_value = result.get(metric_name, -99.0)
- error_file = Path(trial.local_path, EXPR_ERROR_FILE).as_posix()
- info = ""
- if done:
- info = " This trial completed."
- print_result_str = self._print_result(result)
- self._last_result_str[trial] = print_result_str
- if error:
- message = (
- f"The trial {trial} errored with "
- f"parameters={trial.config}. "
- f"Error file: {error_file}"
- )
- elif self._metric:
- message = (
- f"Trial {trial} reported "
- f"{metric_name}={metric_value:.2f} "
- f"with parameters={trial.config}.{info}"
- )
- else:
- message = (
- f"Trial {trial} reported "
- f"{print_result_str} "
- f"with parameters={trial.config}.{info}"
- )
- self._print(message)
- def generate_trial_table(
- self, trials: Dict[Trial, Dict], columns: List[str]
- ) -> str:
- """Generate an HTML table of trial progress info.
- Trials (rows) are sorted by name; progress stats (columns) are sorted
- as well.
- Args:
- trials: Trials and their associated latest results
- columns: Columns to show in the table; must be a list of valid
- keys for each Trial result
- Returns:
- HTML template containing a rendered table of progress info
- """
- data = []
- columns = sorted(columns)
- sorted_trials = collections.OrderedDict(
- sorted(self._last_result.items(), key=lambda item: str(item[0]))
- )
- for trial, result in sorted_trials.items():
- data.append([str(trial)] + [result.get(col, "") for col in columns])
- return Template("trial_progress.html.j2").render(
- table=tabulate(
- data, tablefmt="html", headers=["Trial name"] + columns, showindex=False
- )
- )
- def display_result(self, trial: Trial, result: Dict, error: bool, done: bool):
- """Display a formatted HTML table of trial progress results.
- Trial progress is only shown if verbosity is set to level 2 or 3.
- Args:
- trial: Trial for which results are to be printed
- result: Result to be printed
- error: True if an error has occurred, False otherwise
- done: True if the trial is finished, False otherwise
- """
- from IPython.display import HTML, display
- self._last_result[trial] = result
- if has_verbosity(Verbosity.V3_TRIAL_DETAILS):
- ignored_keys = {
- "config",
- "hist_stats",
- }
- elif has_verbosity(Verbosity.V2_TRIAL_NORM):
- ignored_keys = {
- "config",
- "hist_stats",
- "trial_id",
- "experiment_tag",
- "done",
- } | set(AUTO_RESULT_KEYS)
- else:
- return
- table = self.generate_trial_table(
- self._last_result, set(result.keys()) - ignored_keys
- )
- if not self._display_handle:
- self._display_handle = display(HTML(table), display_id=True)
- else:
- self._display_handle.update(HTML(table))
- def _print_result(self, result: Dict):
- if self._progress_metrics:
- # If progress metrics are given, only report these
- flat_result = flatten_dict(result)
- print_result = {}
- for metric in self._progress_metrics:
- print_result[metric] = flat_result.get(metric)
- else:
- # Else, skip auto populated results
- print_result = result.copy()
- for skip_result in SKIP_RESULTS_IN_REPORT:
- print_result.pop(skip_result, None)
- for auto_result in AUTO_RESULT_KEYS:
- print_result.pop(auto_result, None)
- print_result_str = ",".join(
- [f"{k}={v}" for k, v in print_result.items() if v is not None]
- )
- return print_result_str
- def _detect_reporter(_trainer_api: bool = False, **kwargs) -> TuneReporterBase:
- """Detect progress reporter class.
- Will return a :class:`JupyterNotebookReporter` if a IPython/Jupyter-like
- session was detected, and a :class:`CLIReporter` otherwise.
- Keyword arguments are passed on to the reporter class.
- """
- if IS_NOTEBOOK and not _trainer_api:
- kwargs.setdefault("overwrite", not has_verbosity(Verbosity.V2_TRIAL_NORM))
- progress_reporter = JupyterNotebookReporter(**kwargs)
- else:
- progress_reporter = CLIReporter(**kwargs)
- return progress_reporter
- def _detect_progress_metrics(
- trainable: Optional[Union["Trainable", Callable]]
- ) -> Optional[Collection[str]]:
- """Detect progress metrics to report."""
- if not trainable:
- return None
- return getattr(trainable, "_progress_metrics", None)
- def _prepare_progress_reporter_for_ray_client(
- progress_reporter: ProgressReporter,
- verbosity: Union[int, Verbosity],
- string_queue: Optional[Queue] = None,
- ) -> Tuple[ProgressReporter, Queue]:
- """Prepares progress reported for Ray Client by setting the string queue.
- The string queue will be created if it's None."""
- set_verbosity(verbosity)
- progress_reporter = progress_reporter or _detect_reporter()
- # JupyterNotebooks don't work with remote tune runs out of the box
- # (e.g. via Ray client) as they don't have access to the main
- # process stdout. So we introduce a queue here that accepts
- # strings, which will then be displayed on the driver side.
- if isinstance(progress_reporter, RemoteReporterMixin):
- if string_queue is None:
- string_queue = Queue(
- actor_options={"num_cpus": 0, **_force_on_current_node(None)}
- )
- progress_reporter.output_queue = string_queue
- return progress_reporter, string_queue
- def _stream_client_output(
- remote_future: ray.ObjectRef,
- progress_reporter: ProgressReporter,
- string_queue: Queue,
- ) -> Any:
- """
- Stream items from string queue to progress_reporter until remote_future resolves
- """
- if string_queue is None:
- return
- def get_next_queue_item():
- try:
- return string_queue.get(block=False)
- except Empty:
- return None
- def _handle_string_queue():
- string_item = get_next_queue_item()
- while string_item is not None:
- # This happens on the driver side
- progress_reporter.display(string_item)
- string_item = get_next_queue_item()
- # ray.wait(...)[1] returns futures that are not ready, yet
- while ray.wait([remote_future], timeout=0.2)[1]:
- # Check if we have items to execute
- _handle_string_queue()
- # Handle queue one last time
- _handle_string_queue()
|