| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- import logging
- import operator
- import os
- import shutil
- import subprocess
- from datetime import datetime
- from pathlib import Path
- from typing import List, Optional
- import click
- import pandas as pd
- from pandas.api.types import is_numeric_dtype, is_string_dtype
- from ray._private.thirdparty.tabulate.tabulate import tabulate
- from ray.air.constants import EXPR_RESULT_FILE
- from ray.tune import TuneError
- from ray.tune.analysis import ExperimentAnalysis
- from ray.tune.result import (
- CONFIG_PREFIX,
- DEFAULT_EXPERIMENT_INFO_KEYS,
- DEFAULT_RESULT_KEYS,
- )
- logger = logging.getLogger(__name__)
- EDITOR = os.getenv("EDITOR", "vim")
- TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
- DEFAULT_CLI_KEYS = DEFAULT_EXPERIMENT_INFO_KEYS + DEFAULT_RESULT_KEYS
- DEFAULT_PROJECT_INFO_KEYS = (
- "name",
- "total_trials",
- "last_updated",
- )
- TERM_WIDTH, TERM_HEIGHT = shutil.get_terminal_size(fallback=(100, 100))
- OPERATORS = {
- "<": operator.lt,
- "<=": operator.le,
- "==": operator.eq,
- "!=": operator.ne,
- ">=": operator.ge,
- ">": operator.gt,
- }
- def _check_tabulate():
- """Checks whether tabulate is installed."""
- if tabulate is None:
- raise ImportError("Tabulate not installed. Please run `pip install tabulate`.")
- def print_format_output(dataframe):
- """Prints output of given dataframe to fit into terminal.
- Returns:
- table: Final outputted dataframe.
- dropped_cols: Columns dropped due to terminal size.
- empty_cols: Empty columns (dropped on default).
- """
- print_df = pd.DataFrame()
- dropped_cols = []
- empty_cols = []
- # column display priority is based on the info_keys passed in
- for i, col in enumerate(dataframe):
- if dataframe[col].isnull().all():
- # Don't add col to print_df if is fully empty
- empty_cols += [col]
- continue
- print_df[col] = dataframe[col]
- test_table = tabulate(print_df, headers="keys", tablefmt="psql")
- if str(test_table).index("\n") > TERM_WIDTH:
- # Drop all columns beyond terminal width
- print_df.drop(col, axis=1, inplace=True)
- dropped_cols += list(dataframe.columns)[i:]
- break
- table = tabulate(print_df, headers="keys", tablefmt="psql", showindex="never")
- print(table)
- if dropped_cols:
- click.secho("Dropped columns: {}".format(dropped_cols), fg="yellow")
- click.secho("Please increase your terminal size to view remaining columns.")
- if empty_cols:
- click.secho("Empty columns: {}".format(empty_cols), fg="yellow")
- return table, dropped_cols, empty_cols
- def list_trials(
- experiment_path: str,
- sort: Optional[List[str]] = None,
- output: Optional[str] = None,
- filter_op: Optional[str] = None,
- info_keys: Optional[List[str]] = None,
- limit: int = None,
- desc: bool = False,
- ):
- """Lists trials in the directory subtree starting at the given path.
- Args:
- experiment_path: Directory where trials are located.
- Like Experiment.local_dir/Experiment.name/experiment*.json.
- sort: Keys to sort by.
- output: Name of file where output is saved.
- filter_op: Filter operation in the format
- "<column> <operator> <value>".
- info_keys: Keys that are displayed.
- limit: Number of rows to display.
- desc: Sort ascending vs. descending.
- """
- _check_tabulate()
- try:
- checkpoints_df = ExperimentAnalysis(experiment_path).dataframe() # last result
- except TuneError as e:
- raise click.ClickException("No trial data found!") from e
- config_prefix = CONFIG_PREFIX + "/"
- def key_filter(k):
- return k in DEFAULT_CLI_KEYS or k.startswith(config_prefix)
- col_keys = [k for k in checkpoints_df.columns if key_filter(k)]
- if info_keys:
- for k in info_keys:
- if k not in checkpoints_df.columns:
- raise click.ClickException(
- "Provided key invalid: {}. "
- "Available keys: {}.".format(k, checkpoints_df.columns)
- )
- col_keys = [k for k in checkpoints_df.columns if k in info_keys]
- if not col_keys:
- raise click.ClickException("No columns to output.")
- checkpoints_df = checkpoints_df[col_keys]
- if "last_update_time" in checkpoints_df:
- with pd.option_context("mode.use_inf_as_null", True):
- datetime_series = checkpoints_df["last_update_time"].dropna()
- datetime_series = datetime_series.apply(
- lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT)
- )
- checkpoints_df["last_update_time"] = datetime_series
- if "logdir" in checkpoints_df:
- # logdir often too long to view in table, so drop experiment_path
- checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
- experiment_path, ""
- )
- if filter_op:
- col, op, val = filter_op.split(" ")
- col_type = checkpoints_df[col].dtype
- if is_numeric_dtype(col_type):
- val = float(val)
- elif is_string_dtype(col_type):
- val = str(val)
- # TODO(Andrew): add support for datetime and boolean
- else:
- raise click.ClickException(
- "Unsupported dtype for {}: {}".format(val, col_type)
- )
- op = OPERATORS[op]
- filtered_index = op(checkpoints_df[col], val)
- checkpoints_df = checkpoints_df[filtered_index]
- if sort:
- for key in sort:
- if key not in checkpoints_df:
- raise click.ClickException(
- "{} not in: {}".format(key, list(checkpoints_df))
- )
- ascending = not desc
- checkpoints_df = checkpoints_df.sort_values(by=sort, ascending=ascending)
- if limit:
- checkpoints_df = checkpoints_df[:limit]
- print_format_output(checkpoints_df)
- if output:
- file_extension = os.path.splitext(output)[1].lower()
- if file_extension in (".p", ".pkl", ".pickle"):
- checkpoints_df.to_pickle(output)
- elif file_extension == ".csv":
- checkpoints_df.to_csv(output, index=False)
- else:
- raise click.ClickException("Unsupported filetype: {}".format(output))
- click.secho("Output saved at {}".format(output), fg="green")
- def list_experiments(
- project_path: str,
- sort: Optional[List[str]] = None,
- output: str = None,
- filter_op: str = None,
- info_keys: Optional[List[str]] = None,
- limit: int = None,
- desc: bool = False,
- ):
- """Lists experiments in the directory subtree.
- Args:
- project_path: Directory where experiments are located.
- Corresponds to Experiment.local_dir.
- sort: Keys to sort by.
- output: Name of file where output is saved.
- filter_op: Filter operation in the format
- "<column> <operator> <value>".
- info_keys: Keys that are displayed.
- limit: Number of rows to display.
- desc: Sort ascending vs. descending.
- """
- _check_tabulate()
- base, experiment_folders, _ = next(os.walk(project_path))
- experiment_data_collection = []
- for experiment_dir in experiment_folders:
- num_trials = sum(
- EXPR_RESULT_FILE in files
- for _, _, files in os.walk(os.path.join(base, experiment_dir))
- )
- experiment_data = {"name": experiment_dir, "total_trials": num_trials}
- experiment_data_collection.append(experiment_data)
- if not experiment_data_collection:
- raise click.ClickException("No experiments found!")
- info_df = pd.DataFrame(experiment_data_collection)
- if not info_keys:
- info_keys = DEFAULT_PROJECT_INFO_KEYS
- col_keys = [k for k in list(info_keys) if k in info_df]
- if not col_keys:
- raise click.ClickException(
- "None of keys {} in experiment data!".format(info_keys)
- )
- info_df = info_df[col_keys]
- if filter_op:
- col, op, val = filter_op.split(" ")
- col_type = info_df[col].dtype
- if is_numeric_dtype(col_type):
- val = float(val)
- elif is_string_dtype(col_type):
- val = str(val)
- # TODO(Andrew): add support for datetime and boolean
- else:
- raise click.ClickException(
- "Unsupported dtype for {}: {}".format(val, col_type)
- )
- op = OPERATORS[op]
- filtered_index = op(info_df[col], val)
- info_df = info_df[filtered_index]
- if sort:
- for key in sort:
- if key not in info_df:
- raise click.ClickException("{} not in: {}".format(key, list(info_df)))
- ascending = not desc
- info_df = info_df.sort_values(by=sort, ascending=ascending)
- if limit:
- info_df = info_df[:limit]
- print_format_output(info_df)
- if output:
- file_extension = os.path.splitext(output)[1].lower()
- if file_extension in (".p", ".pkl", ".pickle"):
- info_df.to_pickle(output)
- elif file_extension == ".csv":
- info_df.to_csv(output, index=False)
- else:
- raise click.ClickException("Unsupported filetype: {}".format(output))
- click.secho("Output saved at {}".format(output), fg="green")
- def add_note(path: str, filename: str = "note.txt"):
- """Opens a txt file at the given path where user can add and save notes.
- Args:
- path: Directory where note will be saved.
- filename: Name of note. Defaults to "note.txt"
- """
- path = Path(path).expanduser()
- assert path.is_dir(), "{} is not a valid directory.".format(path)
- filepath = path / filename
- try:
- subprocess.call([EDITOR, filepath.as_posix()])
- except Exception as exc:
- click.secho("Editing note failed: {}".format(str(exc)), fg="red")
- if filepath.exists():
- print("Note updated at:", filepath.as_posix())
- else:
- print("Note created at:", filepath.as_posix())
|