experiment.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. import copy
  2. import datetime
  3. import logging
  4. import pprint as pp
  5. import traceback
  6. from functools import partial
  7. from pathlib import Path
  8. from pickle import PicklingError
  9. from typing import (
  10. TYPE_CHECKING,
  11. Any,
  12. Callable,
  13. Dict,
  14. List,
  15. Mapping,
  16. Optional,
  17. Sequence,
  18. Type,
  19. Union,
  20. )
  21. import ray
  22. from ray.exceptions import RpcError
  23. from ray.train._internal.storage import StorageContext
  24. from ray.train.constants import DEFAULT_STORAGE_PATH
  25. from ray.tune import CheckpointConfig, SyncConfig
  26. from ray.tune.error import TuneError
  27. from ray.tune.registry import is_function_trainable, register_trainable
  28. from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, TimeoutStopper
  29. from ray.util.annotations import Deprecated, DeveloperAPI
  30. if TYPE_CHECKING:
  31. import pyarrow.fs
  32. from ray.tune import PlacementGroupFactory
  33. from ray.tune.experiment import Trial
  34. logger = logging.getLogger(__name__)
  35. def _validate_log_to_file(log_to_file):
  36. """Validate ``tune.RunConfig``'s ``log_to_file`` parameter. Return
  37. validated relative stdout and stderr filenames."""
  38. if not log_to_file:
  39. stdout_file = stderr_file = None
  40. elif isinstance(log_to_file, bool) and log_to_file:
  41. stdout_file = "stdout"
  42. stderr_file = "stderr"
  43. elif isinstance(log_to_file, str):
  44. stdout_file = stderr_file = log_to_file
  45. elif isinstance(log_to_file, Sequence):
  46. if len(log_to_file) != 2:
  47. raise ValueError(
  48. "If you pass a Sequence to `log_to_file` it has to have "
  49. "a length of 2 (for stdout and stderr, respectively). The "
  50. "Sequence you passed has length {}.".format(len(log_to_file))
  51. )
  52. stdout_file, stderr_file = log_to_file
  53. else:
  54. raise ValueError(
  55. "You can pass a boolean, a string, or a Sequence of length 2 to "
  56. "`log_to_file`, but you passed something else ({}).".format(
  57. type(log_to_file)
  58. )
  59. )
  60. return stdout_file, stderr_file
  61. @DeveloperAPI
  62. class Experiment:
  63. """Tracks experiment specifications.
  64. Implicitly registers the Trainable if needed. The args here take
  65. the same meaning as the arguments defined `tune.py:run`.
  66. .. code-block:: python
  67. experiment_spec = Experiment(
  68. "my_experiment_name",
  69. my_func,
  70. stop={"mean_accuracy": 100},
  71. config={
  72. "alpha": tune.grid_search([0.2, 0.4, 0.6]),
  73. "beta": tune.grid_search([1, 2]),
  74. },
  75. resources_per_trial={
  76. "cpu": 1,
  77. "gpu": 0
  78. },
  79. num_samples=10,
  80. local_dir="~/ray_results",
  81. checkpoint_freq=10,
  82. max_failures=2)
  83. """
  84. # Keys that will be present in `public_spec` dict.
  85. PUBLIC_KEYS = {"stop", "num_samples", "time_budget_s"}
  86. _storage_context_cls = StorageContext
  87. def __init__(
  88. self,
  89. name: str,
  90. run: Union[str, Callable, Type],
  91. *,
  92. stop: Optional[Union[Mapping, Stopper, Callable[[str, Mapping], bool]]] = None,
  93. time_budget_s: Optional[Union[int, float, datetime.timedelta]] = None,
  94. config: Optional[Dict[str, Any]] = None,
  95. resources_per_trial: Union[
  96. None, Mapping[str, Union[float, int, Mapping]], "PlacementGroupFactory"
  97. ] = None,
  98. num_samples: int = 1,
  99. storage_path: Optional[str] = None,
  100. storage_filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  101. sync_config: Optional[Union[SyncConfig, dict]] = None,
  102. checkpoint_config: Optional[Union[CheckpointConfig, dict]] = None,
  103. trial_name_creator: Optional[Callable[["Trial"], str]] = None,
  104. trial_dirname_creator: Optional[Callable[["Trial"], str]] = None,
  105. log_to_file: bool = False,
  106. export_formats: Optional[Sequence] = None,
  107. max_failures: int = 0,
  108. restore: Optional[str] = None,
  109. # Deprecated
  110. local_dir: Optional[str] = None,
  111. ):
  112. if isinstance(checkpoint_config, dict):
  113. checkpoint_config = CheckpointConfig(**checkpoint_config)
  114. else:
  115. checkpoint_config = checkpoint_config or CheckpointConfig()
  116. if is_function_trainable(run):
  117. if checkpoint_config.checkpoint_at_end:
  118. raise ValueError(
  119. "'checkpoint_at_end' cannot be used with a function trainable. "
  120. "You should include one last call to "
  121. "`ray.tune.report(metrics=..., checkpoint=...)` "
  122. "at the end of your training loop to get this behavior."
  123. )
  124. if checkpoint_config.checkpoint_frequency:
  125. raise ValueError(
  126. "'checkpoint_frequency' cannot be set for a function trainable. "
  127. "You will need to report a checkpoint every "
  128. "`checkpoint_frequency` iterations within your training loop using "
  129. "`ray.tune.report(metrics=..., checkpoint=...)` "
  130. "to get this behavior."
  131. )
  132. try:
  133. self._run_identifier = Experiment.register_if_needed(run)
  134. except RpcError as e:
  135. if e.rpc_code == ray._raylet.GRPC_STATUS_CODE_RESOURCE_EXHAUSTED:
  136. raise TuneError(
  137. f"The Trainable/training function is too large for grpc resource "
  138. f"limit. Check that its definition is not implicitly capturing a "
  139. f"large array or other object in scope. "
  140. f"Tip: use tune.with_parameters() to put large objects "
  141. f"in the Ray object store. \n"
  142. f"Original exception: {traceback.format_exc()}"
  143. )
  144. else:
  145. raise e
  146. if not name:
  147. name = StorageContext.get_experiment_dir_name(run)
  148. storage_path = storage_path or DEFAULT_STORAGE_PATH
  149. self.storage = self._storage_context_cls(
  150. storage_path=storage_path,
  151. storage_filesystem=storage_filesystem,
  152. sync_config=sync_config,
  153. experiment_dir_name=name,
  154. )
  155. logger.debug(f"StorageContext on the DRIVER:\n{self.storage}")
  156. config = config or {}
  157. if not isinstance(config, dict):
  158. raise ValueError(
  159. f"`Experiment(config)` must be a dict, got: {type(config)}. "
  160. "Please convert your search space to a dict before passing it in."
  161. )
  162. self._stopper = None
  163. stopping_criteria = {}
  164. if not stop:
  165. pass
  166. elif isinstance(stop, list):
  167. bad_stoppers = [s for s in stop if not isinstance(s, Stopper)]
  168. if bad_stoppers:
  169. stopper_types = [type(s) for s in stop]
  170. raise ValueError(
  171. "If you pass a list as the `stop` argument to "
  172. "`tune.RunConfig()`, each element must be an instance of "
  173. f"`tune.stopper.Stopper`. Got {stopper_types}."
  174. )
  175. self._stopper = CombinedStopper(*stop)
  176. elif isinstance(stop, dict):
  177. stopping_criteria = stop
  178. elif callable(stop):
  179. if FunctionStopper.is_valid_function(stop):
  180. self._stopper = FunctionStopper(stop)
  181. elif isinstance(stop, Stopper):
  182. self._stopper = stop
  183. else:
  184. raise ValueError(
  185. "Provided stop object must be either a dict, "
  186. "a function, or a subclass of "
  187. f"`ray.tune.Stopper`. Got {type(stop)}."
  188. )
  189. else:
  190. raise ValueError(
  191. f"Invalid stop criteria: {stop}. Must be a "
  192. f"callable or dict. Got {type(stop)}."
  193. )
  194. if time_budget_s:
  195. if self._stopper:
  196. self._stopper = CombinedStopper(
  197. self._stopper, TimeoutStopper(time_budget_s)
  198. )
  199. else:
  200. self._stopper = TimeoutStopper(time_budget_s)
  201. stdout_file, stderr_file = _validate_log_to_file(log_to_file)
  202. spec = {
  203. "run": self._run_identifier,
  204. "stop": stopping_criteria,
  205. "time_budget_s": time_budget_s,
  206. "config": config,
  207. "resources_per_trial": resources_per_trial,
  208. "num_samples": num_samples,
  209. "checkpoint_config": checkpoint_config,
  210. "trial_name_creator": trial_name_creator,
  211. "trial_dirname_creator": trial_dirname_creator,
  212. "log_to_file": (stdout_file, stderr_file),
  213. "export_formats": export_formats or [],
  214. "max_failures": max_failures,
  215. "restore": (
  216. Path(restore).expanduser().absolute().as_posix() if restore else None
  217. ),
  218. "storage": self.storage,
  219. }
  220. self.spec = spec
  221. @classmethod
  222. def from_json(cls, name: str, spec: dict):
  223. """Generates an Experiment object from JSON.
  224. Args:
  225. name: Name of Experiment.
  226. spec: JSON configuration of experiment.
  227. """
  228. if "run" not in spec:
  229. raise TuneError("No trainable specified!")
  230. # Special case the `env` param for RLlib by automatically
  231. # moving it into the `config` section.
  232. if "env" in spec:
  233. spec["config"] = spec.get("config", {})
  234. spec["config"]["env"] = spec["env"]
  235. del spec["env"]
  236. if "sync_config" in spec and isinstance(spec["sync_config"], dict):
  237. spec["sync_config"] = SyncConfig(**spec["sync_config"])
  238. if "checkpoint_config" in spec and isinstance(spec["checkpoint_config"], dict):
  239. spec["checkpoint_config"] = CheckpointConfig(**spec["checkpoint_config"])
  240. spec = copy.deepcopy(spec)
  241. run_value = spec.pop("run")
  242. try:
  243. exp = cls(name, run_value, **spec)
  244. except TypeError as e:
  245. raise TuneError(
  246. f"Failed to load the following Tune experiment "
  247. f"specification:\n\n {pp.pformat(spec)}.\n\n"
  248. f"Please check that the arguments are valid. "
  249. f"Experiment creation failed with the following "
  250. f"error:\n {e}"
  251. )
  252. return exp
  253. @classmethod
  254. def get_trainable_name(cls, run_object: Union[str, Callable, Type]):
  255. """Get Trainable name.
  256. Args:
  257. run_object: Trainable to run. If string,
  258. assumes it is an ID and does not modify it. Otherwise,
  259. returns a string corresponding to the run_object name.
  260. Returns:
  261. A string representing the trainable identifier.
  262. Raises:
  263. TuneError: if ``run_object`` passed in is invalid.
  264. """
  265. from ray.tune.search.sample import Domain
  266. if isinstance(run_object, str) or isinstance(run_object, Domain):
  267. return run_object
  268. elif isinstance(run_object, type) or callable(run_object):
  269. name = "DEFAULT"
  270. if hasattr(run_object, "_name"):
  271. name = run_object._name
  272. elif hasattr(run_object, "__name__"):
  273. fn_name = run_object.__name__
  274. if fn_name == "<lambda>":
  275. name = "lambda"
  276. elif fn_name.startswith("<"):
  277. name = "DEFAULT"
  278. else:
  279. name = fn_name
  280. elif (
  281. isinstance(run_object, partial)
  282. and hasattr(run_object, "func")
  283. and hasattr(run_object.func, "__name__")
  284. ):
  285. name = run_object.func.__name__
  286. else:
  287. logger.warning("No name detected on trainable. Using {}.".format(name))
  288. return name
  289. else:
  290. raise TuneError("Improper 'run' - not string nor trainable.")
  291. @classmethod
  292. def register_if_needed(cls, run_object: Union[str, Callable, Type]):
  293. """Registers Trainable or Function at runtime.
  294. Assumes already registered if run_object is a string.
  295. Also, does not inspect interface of given run_object.
  296. Args:
  297. run_object: Trainable to run. If string,
  298. assumes it is an ID and does not modify it. Otherwise,
  299. returns a string corresponding to the run_object name.
  300. Returns:
  301. A string representing the trainable identifier.
  302. """
  303. from ray.tune.search.sample import Domain
  304. if isinstance(run_object, str):
  305. return run_object
  306. elif isinstance(run_object, Domain):
  307. logger.warning("Not registering trainable. Resolving as variant.")
  308. return run_object
  309. name = cls.get_trainable_name(run_object)
  310. try:
  311. register_trainable(name, run_object)
  312. except (TypeError, PicklingError) as e:
  313. extra_msg = (
  314. "Other options: "
  315. "\n-Try reproducing the issue by calling "
  316. "`pickle.dumps(trainable)`. "
  317. "\n-If the error is typing-related, try removing "
  318. "the type annotations and try again."
  319. )
  320. raise type(e)(str(e) + " " + extra_msg) from None
  321. return name
  322. @property
  323. def stopper(self):
  324. return self._stopper
  325. @property
  326. def local_path(self) -> Optional[str]:
  327. return self.storage.experiment_driver_staging_path
  328. @property
  329. @Deprecated("Replaced by `local_path`")
  330. def local_dir(self):
  331. # TODO(justinvyu): [Deprecated] Remove in 2.11.
  332. raise DeprecationWarning("Use `local_path` instead of `local_dir`.")
  333. @property
  334. def remote_path(self) -> Optional[str]:
  335. return self.storage.experiment_fs_path
  336. @property
  337. def path(self) -> Optional[str]:
  338. return self.remote_path or self.local_path
  339. @property
  340. def checkpoint_config(self):
  341. return self.spec.get("checkpoint_config")
  342. @property
  343. @Deprecated("Replaced by `local_path`")
  344. def checkpoint_dir(self):
  345. # TODO(justinvyu): [Deprecated] Remove in 2.11.
  346. raise DeprecationWarning("Use `local_path` instead of `checkpoint_dir`.")
  347. @property
  348. def run_identifier(self):
  349. """Returns a string representing the trainable identifier."""
  350. return self._run_identifier
  351. @property
  352. def public_spec(self) -> Dict[str, Any]:
  353. """Returns the spec dict with only the public-facing keys.
  354. Intended to be used for passing information to callbacks,
  355. Searchers and Schedulers.
  356. """
  357. return {k: v for k, v in self.spec.items() if k in self.PUBLIC_KEYS}
  358. def _convert_to_experiment_list(experiments: Union[Experiment, List[Experiment], Dict]):
  359. """Produces a list of Experiment objects.
  360. Converts input from dict, single experiment, or list of
  361. experiments to list of experiments. If input is None,
  362. will return an empty list.
  363. Arguments:
  364. experiments: Experiments to run.
  365. Returns:
  366. List of experiments.
  367. """
  368. exp_list = experiments
  369. # Transform list if necessary
  370. if experiments is None:
  371. exp_list = []
  372. elif isinstance(experiments, Experiment):
  373. exp_list = [experiments]
  374. elif type(experiments) is dict:
  375. exp_list = [
  376. Experiment.from_json(name, spec) for name, spec in experiments.items()
  377. ]
  378. # Validate exp_list
  379. if type(exp_list) is list and all(isinstance(exp, Experiment) for exp in exp_list):
  380. if len(exp_list) > 1:
  381. logger.info(
  382. "Running with multiple concurrent experiments. "
  383. "All experiments will be using the same SearchAlgorithm."
  384. )
  385. else:
  386. raise TuneError("Invalid argument: {}".format(experiments))
  387. return exp_list