commands.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import logging
  2. import operator
  3. import os
  4. import shutil
  5. import subprocess
  6. from datetime import datetime
  7. from pathlib import Path
  8. from typing import List, Optional
  9. import click
  10. import pandas as pd
  11. from pandas.api.types import is_numeric_dtype, is_string_dtype
  12. from ray._private.thirdparty.tabulate.tabulate import tabulate
  13. from ray.air.constants import EXPR_RESULT_FILE
  14. from ray.tune import TuneError
  15. from ray.tune.analysis import ExperimentAnalysis
  16. from ray.tune.result import (
  17. CONFIG_PREFIX,
  18. DEFAULT_EXPERIMENT_INFO_KEYS,
  19. DEFAULT_RESULT_KEYS,
  20. )
  21. logger = logging.getLogger(__name__)
  22. EDITOR = os.getenv("EDITOR", "vim")
  23. TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
  24. DEFAULT_CLI_KEYS = DEFAULT_EXPERIMENT_INFO_KEYS + DEFAULT_RESULT_KEYS
  25. DEFAULT_PROJECT_INFO_KEYS = (
  26. "name",
  27. "total_trials",
  28. "last_updated",
  29. )
  30. TERM_WIDTH, TERM_HEIGHT = shutil.get_terminal_size(fallback=(100, 100))
  31. OPERATORS = {
  32. "<": operator.lt,
  33. "<=": operator.le,
  34. "==": operator.eq,
  35. "!=": operator.ne,
  36. ">=": operator.ge,
  37. ">": operator.gt,
  38. }
  39. def _check_tabulate():
  40. """Checks whether tabulate is installed."""
  41. if tabulate is None:
  42. raise ImportError("Tabulate not installed. Please run `pip install tabulate`.")
  43. def print_format_output(dataframe):
  44. """Prints output of given dataframe to fit into terminal.
  45. Returns:
  46. table: Final outputted dataframe.
  47. dropped_cols: Columns dropped due to terminal size.
  48. empty_cols: Empty columns (dropped on default).
  49. """
  50. print_df = pd.DataFrame()
  51. dropped_cols = []
  52. empty_cols = []
  53. # column display priority is based on the info_keys passed in
  54. for i, col in enumerate(dataframe):
  55. if dataframe[col].isnull().all():
  56. # Don't add col to print_df if is fully empty
  57. empty_cols += [col]
  58. continue
  59. print_df[col] = dataframe[col]
  60. test_table = tabulate(print_df, headers="keys", tablefmt="psql")
  61. if str(test_table).index("\n") > TERM_WIDTH:
  62. # Drop all columns beyond terminal width
  63. print_df.drop(col, axis=1, inplace=True)
  64. dropped_cols += list(dataframe.columns)[i:]
  65. break
  66. table = tabulate(print_df, headers="keys", tablefmt="psql", showindex="never")
  67. print(table)
  68. if dropped_cols:
  69. click.secho("Dropped columns: {}".format(dropped_cols), fg="yellow")
  70. click.secho("Please increase your terminal size to view remaining columns.")
  71. if empty_cols:
  72. click.secho("Empty columns: {}".format(empty_cols), fg="yellow")
  73. return table, dropped_cols, empty_cols
  74. def list_trials(
  75. experiment_path: str,
  76. sort: Optional[List[str]] = None,
  77. output: Optional[str] = None,
  78. filter_op: Optional[str] = None,
  79. info_keys: Optional[List[str]] = None,
  80. limit: int = None,
  81. desc: bool = False,
  82. ):
  83. """Lists trials in the directory subtree starting at the given path.
  84. Args:
  85. experiment_path: Directory where trials are located.
  86. Like Experiment.local_dir/Experiment.name/experiment*.json.
  87. sort: Keys to sort by.
  88. output: Name of file where output is saved.
  89. filter_op: Filter operation in the format
  90. "<column> <operator> <value>".
  91. info_keys: Keys that are displayed.
  92. limit: Number of rows to display.
  93. desc: Sort ascending vs. descending.
  94. """
  95. _check_tabulate()
  96. try:
  97. checkpoints_df = ExperimentAnalysis(experiment_path).dataframe() # last result
  98. except TuneError as e:
  99. raise click.ClickException("No trial data found!") from e
  100. config_prefix = CONFIG_PREFIX + "/"
  101. def key_filter(k):
  102. return k in DEFAULT_CLI_KEYS or k.startswith(config_prefix)
  103. col_keys = [k for k in checkpoints_df.columns if key_filter(k)]
  104. if info_keys:
  105. for k in info_keys:
  106. if k not in checkpoints_df.columns:
  107. raise click.ClickException(
  108. "Provided key invalid: {}. "
  109. "Available keys: {}.".format(k, checkpoints_df.columns)
  110. )
  111. col_keys = [k for k in checkpoints_df.columns if k in info_keys]
  112. if not col_keys:
  113. raise click.ClickException("No columns to output.")
  114. checkpoints_df = checkpoints_df[col_keys]
  115. if "last_update_time" in checkpoints_df:
  116. with pd.option_context("mode.use_inf_as_null", True):
  117. datetime_series = checkpoints_df["last_update_time"].dropna()
  118. datetime_series = datetime_series.apply(
  119. lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT)
  120. )
  121. checkpoints_df["last_update_time"] = datetime_series
  122. if "logdir" in checkpoints_df:
  123. # logdir often too long to view in table, so drop experiment_path
  124. checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
  125. experiment_path, ""
  126. )
  127. if filter_op:
  128. col, op, val = filter_op.split(" ")
  129. col_type = checkpoints_df[col].dtype
  130. if is_numeric_dtype(col_type):
  131. val = float(val)
  132. elif is_string_dtype(col_type):
  133. val = str(val)
  134. # TODO(Andrew): add support for datetime and boolean
  135. else:
  136. raise click.ClickException(
  137. "Unsupported dtype for {}: {}".format(val, col_type)
  138. )
  139. op = OPERATORS[op]
  140. filtered_index = op(checkpoints_df[col], val)
  141. checkpoints_df = checkpoints_df[filtered_index]
  142. if sort:
  143. for key in sort:
  144. if key not in checkpoints_df:
  145. raise click.ClickException(
  146. "{} not in: {}".format(key, list(checkpoints_df))
  147. )
  148. ascending = not desc
  149. checkpoints_df = checkpoints_df.sort_values(by=sort, ascending=ascending)
  150. if limit:
  151. checkpoints_df = checkpoints_df[:limit]
  152. print_format_output(checkpoints_df)
  153. if output:
  154. file_extension = os.path.splitext(output)[1].lower()
  155. if file_extension in (".p", ".pkl", ".pickle"):
  156. checkpoints_df.to_pickle(output)
  157. elif file_extension == ".csv":
  158. checkpoints_df.to_csv(output, index=False)
  159. else:
  160. raise click.ClickException("Unsupported filetype: {}".format(output))
  161. click.secho("Output saved at {}".format(output), fg="green")
  162. def list_experiments(
  163. project_path: str,
  164. sort: Optional[List[str]] = None,
  165. output: str = None,
  166. filter_op: str = None,
  167. info_keys: Optional[List[str]] = None,
  168. limit: int = None,
  169. desc: bool = False,
  170. ):
  171. """Lists experiments in the directory subtree.
  172. Args:
  173. project_path: Directory where experiments are located.
  174. Corresponds to Experiment.local_dir.
  175. sort: Keys to sort by.
  176. output: Name of file where output is saved.
  177. filter_op: Filter operation in the format
  178. "<column> <operator> <value>".
  179. info_keys: Keys that are displayed.
  180. limit: Number of rows to display.
  181. desc: Sort ascending vs. descending.
  182. """
  183. _check_tabulate()
  184. base, experiment_folders, _ = next(os.walk(project_path))
  185. experiment_data_collection = []
  186. for experiment_dir in experiment_folders:
  187. num_trials = sum(
  188. EXPR_RESULT_FILE in files
  189. for _, _, files in os.walk(os.path.join(base, experiment_dir))
  190. )
  191. experiment_data = {"name": experiment_dir, "total_trials": num_trials}
  192. experiment_data_collection.append(experiment_data)
  193. if not experiment_data_collection:
  194. raise click.ClickException("No experiments found!")
  195. info_df = pd.DataFrame(experiment_data_collection)
  196. if not info_keys:
  197. info_keys = DEFAULT_PROJECT_INFO_KEYS
  198. col_keys = [k for k in list(info_keys) if k in info_df]
  199. if not col_keys:
  200. raise click.ClickException(
  201. "None of keys {} in experiment data!".format(info_keys)
  202. )
  203. info_df = info_df[col_keys]
  204. if filter_op:
  205. col, op, val = filter_op.split(" ")
  206. col_type = info_df[col].dtype
  207. if is_numeric_dtype(col_type):
  208. val = float(val)
  209. elif is_string_dtype(col_type):
  210. val = str(val)
  211. # TODO(Andrew): add support for datetime and boolean
  212. else:
  213. raise click.ClickException(
  214. "Unsupported dtype for {}: {}".format(val, col_type)
  215. )
  216. op = OPERATORS[op]
  217. filtered_index = op(info_df[col], val)
  218. info_df = info_df[filtered_index]
  219. if sort:
  220. for key in sort:
  221. if key not in info_df:
  222. raise click.ClickException("{} not in: {}".format(key, list(info_df)))
  223. ascending = not desc
  224. info_df = info_df.sort_values(by=sort, ascending=ascending)
  225. if limit:
  226. info_df = info_df[:limit]
  227. print_format_output(info_df)
  228. if output:
  229. file_extension = os.path.splitext(output)[1].lower()
  230. if file_extension in (".p", ".pkl", ".pickle"):
  231. info_df.to_pickle(output)
  232. elif file_extension == ".csv":
  233. info_df.to_csv(output, index=False)
  234. else:
  235. raise click.ClickException("Unsupported filetype: {}".format(output))
  236. click.secho("Output saved at {}".format(output), fg="green")
  237. def add_note(path: str, filename: str = "note.txt"):
  238. """Opens a txt file at the given path where user can add and save notes.
  239. Args:
  240. path: Directory where note will be saved.
  241. filename: Name of note. Defaults to "note.txt"
  242. """
  243. path = Path(path).expanduser()
  244. assert path.is_dir(), "{} is not a valid directory.".format(path)
  245. filepath = path / filename
  246. try:
  247. subprocess.call([EDITOR, filepath.as_posix()])
  248. except Exception as exc:
  249. click.secho("Editing note failed: {}".format(str(exc)), fg="red")
  250. if filepath.exists():
  251. print("Note updated at:", filepath.as_posix())
  252. else:
  253. print("Note created at:", filepath.as_posix())