trainable.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990
  1. import copy
  2. import logging
  3. import os
  4. import platform
  5. import shutil
  6. import sys
  7. import tempfile
  8. import time
  9. from contextlib import redirect_stderr, redirect_stdout
  10. from datetime import datetime
  11. from pathlib import Path
  12. from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
  13. import ray
  14. import ray.cloudpickle as ray_pickle
  15. from ray._common.utils import try_to_create_directory
  16. from ray.air._internal.util import exception_cause, skip_exceptions
  17. from ray.air.constants import TIME_THIS_ITER_S, TIMESTAMP, TRAINING_ITERATION
  18. from ray.train._internal.checkpoint_manager import _TrainingResult
  19. from ray.train._internal.storage import StorageContext, _exists_at_fs_path
  20. from ray.train.constants import DEFAULT_STORAGE_PATH
  21. from ray.tune.execution.placement_groups import PlacementGroupFactory
  22. from ray.tune.result import (
  23. DEBUG_METRICS,
  24. DONE,
  25. EPISODES_THIS_ITER,
  26. EPISODES_TOTAL,
  27. HOSTNAME,
  28. NODE_IP,
  29. PID,
  30. RESULT_DUPLICATE,
  31. SHOULD_CHECKPOINT,
  32. STDERR_FILE,
  33. STDOUT_FILE,
  34. TIME_TOTAL_S,
  35. TIMESTEPS_THIS_ITER,
  36. TIMESTEPS_TOTAL,
  37. TRIAL_ID,
  38. TRIAL_INFO,
  39. )
  40. from ray.tune.utils import UtilMonitor
  41. from ray.tune.utils.log import disable_ipython
  42. from ray.tune.utils.util import Tee
  43. from ray.util.annotations import DeveloperAPI, PublicAPI
  44. if TYPE_CHECKING:
  45. from ray.tune.logger import Logger
  46. logger = logging.getLogger(__name__)
  47. SETUP_TIME_THRESHOLD = 10
  48. # File containing dict data returned by user from `Trainable.save_checkpoint`
  49. _DICT_CHECKPOINT_FILE_NAME = "_dict_checkpoint.pkl"
  50. @PublicAPI
  51. class Trainable:
  52. """Abstract class for trainable models, functions, etc.
  53. A call to ``train()`` on a trainable will execute one logical iteration of
  54. training. As a rule of thumb, the execution time of one train call should
  55. be large enough to avoid overheads (i.e. more than a few seconds), but
  56. short enough to report progress periodically (i.e. at most a few minutes).
  57. Calling ``save()`` should save the training state of a trainable to disk,
  58. and ``restore(path)`` should restore a trainable to the given state.
  59. Generally you only need to implement ``setup``, ``step``,
  60. ``save_checkpoint``, and ``load_checkpoint`` when subclassing Trainable.
  61. Other implementation methods that may be helpful to override are
  62. ``log_result``, ``reset_config``, ``cleanup``, and ``_export_model``.
  63. Tune will convert this class into a Ray actor, which runs on a separate process.
  64. By default, Tune will also change the current working directory of this process to
  65. its corresponding trial-level log directory ``self.logdir``.
  66. This is designed so that different trials that run on the same physical node won't
  67. accidentally write to the same location and overstep each other.
  68. The behavior of changing the working directory can be disabled by setting the
  69. `RAY_CHDIR_TO_TRIAL_DIR=0` environment variable. This allows access to files
  70. in the original working directory, but relative paths should be used for read only
  71. purposes, and you must make sure that the directory is synced on all nodes if
  72. running on multiple machines.
  73. The `TUNE_ORIG_WORKING_DIR` environment variable was the original workaround for
  74. accessing paths relative to the original working directory. This environment
  75. variable is deprecated, and the `RAY_CHDIR_TO_TRIAL_DIR` environment variable
  76. described above should be used instead.
  77. This class supports checkpointing to and restoring from remote storage.
  78. """
  79. def __init__(
  80. self,
  81. config: Dict[str, Any] = None,
  82. logger_creator: Callable[[Dict[str, Any]], "Logger"] = None, # Deprecated (2.7)
  83. storage: Optional[StorageContext] = None,
  84. ):
  85. """Initialize a Trainable.
  86. Sets up logging and points ``self.logdir`` to a directory in which
  87. training outputs should be placed.
  88. Subclasses should prefer defining ``setup()`` instead of overriding
  89. ``__init__()`` directly.
  90. Args:
  91. config: Trainable-specific configuration data. By default
  92. will be saved as ``self.config``.
  93. logger_creator: (Deprecated) Function that creates a ray.tune.Logger
  94. object. If unspecified, a default logger is created.
  95. storage: StorageContext object that contains persistent storage paths
  96. """
  97. self.config = config or {}
  98. trial_info = self.config.pop(TRIAL_INFO, None)
  99. if self.is_actor():
  100. disable_ipython()
  101. # TODO(ml-team): Remove `logger_creator` in 2.7.
  102. # TODO(justinvyu): Rename/remove logdir.
  103. self._result_logger = self._logdir = None
  104. self._create_logger(self.config, logger_creator)
  105. self._stdout_context = self._stdout_fp = self._stdout_stream = None
  106. self._stderr_context = self._stderr_fp = self._stderr_stream = None
  107. self._stderr_logging_handler = None
  108. stdout_file = self.config.pop(STDOUT_FILE, None)
  109. stderr_file = self.config.pop(STDERR_FILE, None)
  110. self._iteration = 0
  111. self._time_total = 0.0
  112. self._timesteps_total = None
  113. self._episodes_total = None
  114. self._time_since_restore = 0.0
  115. self._timesteps_since_restore = 0
  116. self._iterations_since_restore = 0
  117. self._last_result = None
  118. self._restored = False
  119. self._trial_info = trial_info
  120. self._stdout_file = stdout_file
  121. self._stderr_file = stderr_file
  122. self._start_time = time.time()
  123. self._local_ip = ray.util.get_node_ip_address()
  124. self._storage = storage
  125. if storage:
  126. assert storage.trial_fs_path
  127. logger.debug(f"StorageContext on the TRAINABLE:\n{storage}")
  128. self._open_logfiles(stdout_file, stderr_file)
  129. self.setup(copy.deepcopy(self.config))
  130. setup_time = time.time() - self._start_time
  131. if setup_time > SETUP_TIME_THRESHOLD:
  132. logger.info(
  133. "Trainable.setup took {:.3f} seconds. If your "
  134. "trainable is slow to initialize, consider setting "
  135. "reuse_actors=True to reduce actor creation "
  136. "overheads.".format(setup_time)
  137. )
  138. log_sys_usage = self.config.get("log_sys_usage", False)
  139. self._monitor = UtilMonitor(start=log_sys_usage)
  140. @classmethod
  141. def default_resource_request(
  142. cls, config: Dict[str, Any]
  143. ) -> Optional[PlacementGroupFactory]:
  144. """Provides a static resource requirement for the given configuration.
  145. This can be overridden by sub-classes to set the correct trial resource
  146. allocation, so the user does not need to.
  147. .. testcode::
  148. @classmethod
  149. def default_resource_request(cls, config):
  150. return PlacementGroupFactory([{"CPU": 1}, {"CPU": 1}])
  151. Args:
  152. config[Dict[str, Any]]: The Trainable's config dict.
  153. Returns:
  154. PlacementGroupFactory: A PlacementGroupFactory consumed by Tune
  155. for queueing.
  156. """
  157. return None
  158. @classmethod
  159. def resource_help(cls, config: Dict):
  160. """Returns a help string for configuring this trainable's resources.
  161. Args:
  162. config: The Trainer's config dict.
  163. """
  164. return ""
  165. def get_current_ip_pid(self):
  166. return self._local_ip, os.getpid()
  167. def get_auto_filled_metrics(
  168. self,
  169. now: Optional[datetime] = None,
  170. time_this_iter: Optional[float] = None,
  171. timestamp: Optional[int] = None,
  172. debug_metrics_only: bool = False,
  173. ) -> dict:
  174. """Return a dict with metrics auto-filled by the trainable.
  175. If ``debug_metrics_only`` is True, only metrics that don't
  176. require at least one iteration will be returned
  177. (``ray.tune.result.DEBUG_METRICS``).
  178. """
  179. if now is None:
  180. now = datetime.today()
  181. autofilled = {
  182. TRIAL_ID: self.trial_id,
  183. "date": now.strftime("%Y-%m-%d_%H-%M-%S"),
  184. "timestamp": timestamp if timestamp else int(time.mktime(now.timetuple())),
  185. TIME_THIS_ITER_S: time_this_iter,
  186. TIME_TOTAL_S: self._time_total,
  187. PID: os.getpid(),
  188. HOSTNAME: platform.node(),
  189. NODE_IP: self._local_ip,
  190. "config": self.config,
  191. "time_since_restore": self._time_since_restore,
  192. "iterations_since_restore": self._iterations_since_restore,
  193. }
  194. if self._timesteps_since_restore:
  195. autofilled["timesteps_since_restore"] = self._timesteps_since_restore
  196. if debug_metrics_only:
  197. autofilled = {k: v for k, v in autofilled.items() if k in DEBUG_METRICS}
  198. return autofilled
  199. def is_actor(self):
  200. try:
  201. actor_id = ray._private.worker.global_worker.actor_id
  202. return actor_id != actor_id.nil()
  203. except Exception:
  204. # If global_worker is not instantiated, we're not in an actor
  205. return False
  206. def train_buffered(self, buffer_time_s: float, max_buffer_length: int = 1000):
  207. """Runs multiple iterations of training.
  208. Calls ``train()`` internally. Collects and combines multiple results.
  209. This function will run ``self.train()`` repeatedly until one of
  210. the following conditions is met: 1) the maximum buffer length is
  211. reached, 2) the maximum buffer time is reached, or 3) a checkpoint
  212. was created. Even if the maximum time is reached, it will always
  213. block until at least one result is received.
  214. Args:
  215. buffer_time_s: Maximum time to buffer. The next result
  216. received after this amount of time has passed will return
  217. the whole buffer.
  218. max_buffer_length: Maximum number of results to buffer.
  219. """
  220. results = []
  221. now = time.time()
  222. send_buffer_at = now + buffer_time_s
  223. while now < send_buffer_at or not results: # At least one result
  224. result = self.train()
  225. results.append(result)
  226. if result.get(DONE, False):
  227. # If the trial is done, return
  228. break
  229. elif result.get(SHOULD_CHECKPOINT, False):
  230. # If a checkpoint was created, return
  231. break
  232. elif result.get(RESULT_DUPLICATE):
  233. # If the function API trainable completed, return
  234. break
  235. elif len(results) >= max_buffer_length:
  236. # If the buffer is full, return
  237. break
  238. now = time.time()
  239. return results
  240. def train(self):
  241. """Runs one logical iteration of training.
  242. Calls ``step()`` internally. Subclasses should override ``step()``
  243. instead to return results.
  244. This method automatically fills the following fields in the result:
  245. `done` (bool): training is terminated. Filled only if not provided.
  246. `time_this_iter_s` (float): Time in seconds this iteration
  247. took to run. This may be overridden in order to override the
  248. system-computed time difference.
  249. `time_total_s` (float): Accumulated time in seconds for this
  250. entire experiment.
  251. `training_iteration` (int): The index of this
  252. training iteration, e.g. call to train(). This is incremented
  253. after `step()` is called.
  254. `pid` (str): The pid of the training process.
  255. `date` (str): A formatted date of when the result was processed.
  256. `timestamp` (str): A UNIX timestamp of when the result
  257. was processed. This may be overridden.
  258. `hostname` (str): Hostname of the machine hosting the training
  259. process.
  260. `node_ip` (str): Node ip of the machine hosting the training
  261. process.
  262. Returns:
  263. A dict that describes training progress.
  264. """
  265. start = time.time()
  266. try:
  267. result = self.step()
  268. except Exception as e:
  269. skipped = skip_exceptions(e)
  270. raise skipped from exception_cause(skipped)
  271. assert isinstance(result, dict), "step() needs to return a dict."
  272. # We do not modify internal state nor update this result if duplicate.
  273. if RESULT_DUPLICATE in result:
  274. return result
  275. result = result.copy()
  276. self._iteration += 1
  277. self._iterations_since_restore += 1
  278. if result.get(TIME_THIS_ITER_S) is not None:
  279. time_this_iter = result[TIME_THIS_ITER_S]
  280. else:
  281. time_this_iter = time.time() - start
  282. self._time_total += time_this_iter
  283. self._time_since_restore += time_this_iter
  284. result_timestamp = result.get(TIMESTAMP, None)
  285. result.setdefault(DONE, False)
  286. # self._timesteps_total should only be tracked if increments are provided
  287. if result.get(TIMESTEPS_THIS_ITER) is not None:
  288. if self._timesteps_total is None:
  289. self._timesteps_total = 0
  290. self._timesteps_total += result[TIMESTEPS_THIS_ITER]
  291. self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER]
  292. # self._episodes_total should only be tracked if increments provided
  293. if result.get(EPISODES_THIS_ITER) is not None:
  294. if self._episodes_total is None:
  295. self._episodes_total = 0
  296. self._episodes_total += result[EPISODES_THIS_ITER]
  297. # self._timesteps_total should not override user-provided total
  298. if self._timesteps_total is not None:
  299. result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total)
  300. if self._episodes_total is not None:
  301. result.setdefault(EPISODES_TOTAL, self._episodes_total)
  302. result.setdefault(TRAINING_ITERATION, self._iteration)
  303. now = datetime.today()
  304. result.update(
  305. self.get_auto_filled_metrics(
  306. now=now, time_this_iter=time_this_iter, timestamp=result_timestamp
  307. )
  308. )
  309. monitor_data = self._monitor.get_data()
  310. if monitor_data:
  311. result.update(monitor_data)
  312. self.log_result(result)
  313. if self._stdout_context:
  314. self._stdout_stream.flush()
  315. if self._stderr_context:
  316. self._stderr_stream.flush()
  317. self._last_result = result
  318. if self._storage:
  319. # Launch background tasks to sync artifacts at some specified frequency.
  320. self._storage.persist_artifacts()
  321. return result
  322. def get_state(self):
  323. return {
  324. "iteration": self._iteration,
  325. "timesteps_total": self._timesteps_total,
  326. "time_total": self._time_total,
  327. "episodes_total": self._episodes_total,
  328. "last_result": self._last_result,
  329. "ray_version": ray.__version__,
  330. }
  331. def _report_class_trainable_checkpoint(
  332. self, checkpoint_dir: str, checkpoint_dict_or_path: Union[str, Dict]
  333. ) -> _TrainingResult:
  334. """Report a checkpoint saved via Trainable.save_checkpoint.
  335. Need to handle both dict or path checkpoint returned by the user's
  336. `save_checkpoint` method.
  337. This is to get class trainables to work with storage backend used by
  338. function trainables.
  339. This basically re-implements `tune.report` for class trainables,
  340. making sure to persist the checkpoint to storage.
  341. """
  342. if isinstance(checkpoint_dict_or_path, dict):
  343. with Path(checkpoint_dir, _DICT_CHECKPOINT_FILE_NAME).open("wb") as f:
  344. ray_pickle.dump(checkpoint_dict_or_path, f)
  345. elif isinstance(checkpoint_dict_or_path, str):
  346. if checkpoint_dict_or_path != checkpoint_dir:
  347. raise ValueError(
  348. "The returned checkpoint path from `save_checkpoint` "
  349. "must be None or the same as the provided path argument."
  350. f"Got {checkpoint_dict_or_path} != {checkpoint_dir}"
  351. )
  352. local_checkpoint = ray.tune.Checkpoint.from_directory(checkpoint_dir)
  353. metrics = self._last_result.copy() if self._last_result else {}
  354. if self._storage:
  355. # The checkpoint index is updated with the current result.
  356. # NOTE: This is no longer using "iteration" as the folder indexing
  357. # to be consistent with fn trainables.
  358. self._storage._update_checkpoint_index(metrics)
  359. persisted_checkpoint = self._storage.persist_current_checkpoint(
  360. local_checkpoint
  361. )
  362. checkpoint_result = _TrainingResult(
  363. checkpoint=persisted_checkpoint, metrics=metrics
  364. )
  365. # Persist trial artifacts to storage.
  366. self._storage.persist_artifacts(
  367. force=self._storage.sync_config.sync_artifacts_on_checkpoint
  368. )
  369. else:
  370. # `storage=None` only happens when initializing the
  371. # Trainable manually, outside of Tune/Train.
  372. # In this case, no storage is set, so the default behavior
  373. # is to just not upload anything and report a local checkpoint.
  374. # This is fine for the main use case of local debugging.
  375. checkpoint_result = _TrainingResult(
  376. checkpoint=local_checkpoint, metrics=metrics
  377. )
  378. return checkpoint_result
  379. @DeveloperAPI
  380. def save(self, checkpoint_dir: Optional[str] = None) -> _TrainingResult:
  381. """Saves the current model state to a checkpoint.
  382. Subclasses should override ``save_checkpoint()`` instead to save state.
  383. Args:
  384. checkpoint_dir: Optional dir to place the checkpoint.
  385. Returns:
  386. The given or created checkpoint directory.
  387. Note the return value matches up with what is expected of `restore()`.
  388. """
  389. if not isinstance(self, ray.tune.trainable.FunctionTrainable):
  390. # Use a temporary directory if no checkpoint_dir is provided.
  391. use_temp_dir = not checkpoint_dir
  392. checkpoint_dir = checkpoint_dir or tempfile.mkdtemp()
  393. os.makedirs(checkpoint_dir, exist_ok=True)
  394. checkpoint_dict_or_path = self.save_checkpoint(checkpoint_dir)
  395. checkpoint_result = self._report_class_trainable_checkpoint(
  396. checkpoint_dir, checkpoint_dict_or_path
  397. )
  398. # Clean up the temporary directory, since it's already been
  399. # reported + persisted to storage. If no storage is set, the user is
  400. # running the Trainable locally and is responsible for cleaning
  401. # up the checkpoint directory themselves.
  402. if use_temp_dir and self._storage:
  403. shutil.rmtree(checkpoint_dir, ignore_errors=True)
  404. else:
  405. checkpoint_result: _TrainingResult = self.save_checkpoint(None)
  406. assert isinstance(checkpoint_result, _TrainingResult)
  407. assert self._last_result
  408. # Update the checkpoint result to include auto-filled metrics.
  409. checkpoint_result.metrics.update(self._last_result)
  410. return checkpoint_result
  411. @DeveloperAPI
  412. def restore(
  413. self, checkpoint_path: Union[str, "ray.tune.Checkpoint", _TrainingResult]
  414. ):
  415. """Restores training state from a given model checkpoint.
  416. These checkpoints are returned from calls to save().
  417. Subclasses should override ``load_checkpoint()`` instead to
  418. restore state.
  419. This method restores additional metadata saved with the checkpoint.
  420. `checkpoint_path` should match with the return from ``save()``.
  421. Args:
  422. checkpoint_path: training result that was returned by a
  423. previous call to `save()`.
  424. """
  425. # TODO(justinvyu): This also supports restoring from a Checkpoint object
  426. # or a path, which are legacy APIs that RLlib depends on.
  427. # RLlib should remove this dependency since `restore` is a DeveloperAPI.
  428. if isinstance(checkpoint_path, str):
  429. checkpoint_path = ray.tune.Checkpoint.from_directory(checkpoint_path)
  430. if isinstance(checkpoint_path, ray.tune.Checkpoint):
  431. checkpoint_result = _TrainingResult(checkpoint=checkpoint_path, metrics={})
  432. else:
  433. checkpoint_result: _TrainingResult = checkpoint_path
  434. assert isinstance(checkpoint_result, _TrainingResult), type(checkpoint_result)
  435. checkpoint = checkpoint_result.checkpoint
  436. checkpoint_metrics = checkpoint_result.metrics
  437. self._iteration = checkpoint_metrics.get(TRAINING_ITERATION, 0)
  438. self._time_total = checkpoint_metrics.get(TIME_TOTAL_S, 0)
  439. self._time_since_restore = 0.0
  440. self._iterations_since_restore = 0
  441. # TODO(justinvyu): This stuff should be moved to rllib.
  442. self._timesteps_total = checkpoint_metrics.get(TIMESTEPS_TOTAL)
  443. self._timesteps_since_restore = 0
  444. self._episodes_total = checkpoint_metrics.get(EPISODES_TOTAL)
  445. if not _exists_at_fs_path(checkpoint.filesystem, checkpoint.path):
  446. raise ValueError(
  447. f"Could not recover from checkpoint as it does not exist on "
  448. f"storage anymore. "
  449. f"Got storage fs type `{checkpoint.filesystem.type_name}` and "
  450. f"path: {checkpoint.path}"
  451. )
  452. # TODO(justinvyu): [cls_trainable_support]
  453. # This is to conform to the public class Trainable `load_checkpoint` API.
  454. if not isinstance(self, ray.tune.trainable.FunctionTrainable):
  455. # Need to convert Checkpoint -> local path or dict
  456. # (depending on what the output of save_checkpoint was)
  457. with checkpoint.as_directory() as checkpoint_dir:
  458. checkpoint_path = Path(checkpoint_dir)
  459. dict_checkpoint_file = checkpoint_path / _DICT_CHECKPOINT_FILE_NAME
  460. if dict_checkpoint_file.exists():
  461. # If this was a dict checkpoint, load it as a dict
  462. with open(dict_checkpoint_file, "rb") as f:
  463. checkpoint_dict = ray_pickle.load(f)
  464. self.load_checkpoint(checkpoint_dict)
  465. else:
  466. self.load_checkpoint(checkpoint_dir)
  467. else:
  468. # TODO(justinvyu): The Function Trainable case doesn't conform
  469. # to the load_checkpoint API at the moment.
  470. self.load_checkpoint(checkpoint_result)
  471. self._restored = True
  472. logger.info(f"Restored on {self._local_ip} from checkpoint: {checkpoint}")
  473. def export_model(
  474. self, export_formats: Union[List[str], str], export_dir: Optional[str] = None
  475. ):
  476. """Exports model based on export_formats.
  477. Subclasses should override _export_model() to actually
  478. export model to local directory.
  479. Args:
  480. export_formats: Format or list of (str) formats
  481. that should be exported.
  482. export_dir: Optional dir to place the exported model.
  483. Defaults to self.logdir.
  484. Returns:
  485. A dict that maps ExportFormats to successfully exported models.
  486. """
  487. if isinstance(export_formats, str):
  488. export_formats = [export_formats]
  489. export_dir = export_dir or self.logdir
  490. return self._export_model(export_formats, export_dir)
  491. def reset(self, new_config, logger_creator=None, storage=None):
  492. """Resets trial for use with new config.
  493. Subclasses should override reset_config() to actually
  494. reset actor behavior for the new config."""
  495. self.config = new_config
  496. self._storage = storage
  497. trial_info = new_config.pop(TRIAL_INFO, None)
  498. if trial_info:
  499. self._trial_info = trial_info
  500. self._result_logger.flush()
  501. self._result_logger.close()
  502. if logger_creator:
  503. logger.debug("Logger reset.")
  504. self._create_logger(new_config.copy(), logger_creator)
  505. else:
  506. logger.debug(
  507. "Did not reset logger. Got: "
  508. f"trainable.reset(logger_creator={logger_creator})."
  509. )
  510. stdout_file = new_config.pop(STDOUT_FILE, None)
  511. stderr_file = new_config.pop(STDERR_FILE, None)
  512. self._close_logfiles()
  513. self._open_logfiles(stdout_file, stderr_file)
  514. success = self.reset_config(new_config)
  515. if not success:
  516. return False
  517. # Reset attributes. Will be overwritten by `restore` if a checkpoint
  518. # is provided.
  519. self._iteration = 0
  520. self._time_total = 0.0
  521. self._timesteps_total = None
  522. self._episodes_total = None
  523. self._time_since_restore = 0.0
  524. self._timesteps_since_restore = 0
  525. self._iterations_since_restore = 0
  526. self._restored = False
  527. return True
  528. def reset_config(self, new_config: Dict) -> bool:
  529. """Resets configuration without restarting the trial.
  530. This method is optional, but can be implemented to speed up algorithms
  531. such as PBT, and to allow performance optimizations such as running
  532. experiments with reuse_actors=True.
  533. Args:
  534. new_config: Updated hyperparameter configuration
  535. for the trainable.
  536. Returns:
  537. True if reset was successful else False.
  538. """
  539. return False
  540. def _create_logger(
  541. self,
  542. config: Dict[str, Any],
  543. logger_creator: Callable[[Dict[str, Any]], "Logger"] = None,
  544. ):
  545. """Create logger from logger creator.
  546. Sets _logdir and _result_logger.
  547. `_logdir` is the **per trial** directory for the Trainable.
  548. """
  549. if logger_creator:
  550. self._result_logger = logger_creator(config)
  551. self._logdir = self._result_logger.logdir
  552. else:
  553. from ray.tune.logger import UnifiedLogger
  554. logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
  555. try_to_create_directory(DEFAULT_STORAGE_PATH)
  556. self._logdir = tempfile.mkdtemp(
  557. prefix=logdir_prefix, dir=DEFAULT_STORAGE_PATH
  558. )
  559. self._result_logger = UnifiedLogger(config, self._logdir, loggers=None)
  560. def _open_logfiles(self, stdout_file, stderr_file):
  561. """Create loggers. Open stdout and stderr logfiles."""
  562. if stdout_file:
  563. stdout_path = (Path(self._logdir) / stdout_file).expanduser().as_posix()
  564. self._stdout_fp = open(stdout_path, "a+")
  565. self._stdout_stream = Tee(sys.stdout, self._stdout_fp)
  566. self._stdout_context = redirect_stdout(self._stdout_stream)
  567. self._stdout_context.__enter__()
  568. if stderr_file:
  569. stderr_path = (Path(self._logdir) / stderr_file).expanduser().as_posix()
  570. self._stderr_fp = open(stderr_path, "a+")
  571. self._stderr_stream = Tee(sys.stderr, self._stderr_fp)
  572. self._stderr_context = redirect_stderr(self._stderr_stream)
  573. self._stderr_context.__enter__()
  574. # Add logging handler to root ray logger
  575. formatter = logging.Formatter(
  576. "[%(levelname)s %(asctime)s] "
  577. "%(filename)s: %(lineno)d "
  578. "%(message)s"
  579. )
  580. self._stderr_logging_handler = logging.StreamHandler(self._stderr_fp)
  581. self._stderr_logging_handler.setFormatter(formatter)
  582. ray.logger.addHandler(self._stderr_logging_handler)
  583. def _close_logfiles(self):
  584. """Close stdout and stderr logfiles."""
  585. if self._stderr_logging_handler:
  586. ray.logger.removeHandler(self._stderr_logging_handler)
  587. if self._stdout_context:
  588. self._stdout_stream.flush()
  589. self._stdout_context.__exit__(None, None, None)
  590. self._stdout_fp.close()
  591. self._stdout_context = None
  592. if self._stderr_context:
  593. self._stderr_stream.flush()
  594. self._stderr_context.__exit__(None, None, None)
  595. self._stderr_fp.close()
  596. self._stderr_context = None
  597. def stop(self):
  598. """Releases all resources used by this trainable.
  599. Calls ``Trainable.cleanup`` internally. Subclasses should override
  600. ``Trainable.cleanup`` for custom cleanup procedures.
  601. """
  602. self._result_logger.flush()
  603. self._result_logger.close()
  604. if self._monitor.is_alive():
  605. self._monitor.stop()
  606. self._monitor.join()
  607. self.cleanup()
  608. self._close_logfiles()
  609. @property
  610. def logdir(self):
  611. """Directory of the results and checkpoints for this Trainable.
  612. Note that the current working directory will also be changed to this.
  613. """
  614. return self._logdir
  615. @property
  616. def trial_name(self):
  617. """Trial name for the corresponding trial of this Trainable.
  618. This is not set if not using Tune.
  619. .. testcode::
  620. from ray.tune import Trainable
  621. name = Trainable().trial_name
  622. """
  623. if self._trial_info:
  624. return self._trial_info.trial_name
  625. else:
  626. return "default"
  627. @property
  628. def trial_id(self):
  629. """Trial ID for the corresponding trial of this Trainable.
  630. This is not set if not using Tune.
  631. .. testcode::
  632. from ray.tune import Trainable
  633. trial_id = Trainable().trial_id
  634. """
  635. if self._trial_info:
  636. return self._trial_info.trial_id
  637. else:
  638. return "default"
  639. @property
  640. def trial_resources(self) -> Optional[PlacementGroupFactory]:
  641. """Resources currently assigned to the trial of this Trainable.
  642. This is not set if not using Tune.
  643. .. testcode::
  644. from ray.tune import Trainable
  645. trial_resources = Trainable().trial_resources
  646. """
  647. if self._trial_info:
  648. return self._trial_info.trial_resources
  649. else:
  650. return None
  651. @property
  652. def iteration(self):
  653. """Current training iteration.
  654. This value is automatically incremented every time `train()` is called
  655. and is automatically inserted into the training result dict.
  656. """
  657. return self._iteration
  658. @property
  659. def training_iteration(self):
  660. """Current training iteration (same as `self.iteration`).
  661. This value is automatically incremented every time `train()` is called
  662. and is automatically inserted into the training result dict.
  663. """
  664. return self._iteration
  665. def get_config(self):
  666. """Returns configuration passed in by Tune."""
  667. return self.config
  668. def step(self):
  669. """Subclasses should override this to implement train().
  670. The return value will be automatically passed to the loggers. Users
  671. can also return `tune.result.DONE` or `tune.result.SHOULD_CHECKPOINT`
  672. as a key to manually trigger termination or checkpointing of this
  673. trial. Note that manual checkpointing only works when subclassing
  674. Trainables.
  675. .. versionadded:: 0.8.7
  676. Returns:
  677. A dict that describes training progress.
  678. """
  679. raise NotImplementedError
  680. def save_checkpoint(self, checkpoint_dir: str) -> Optional[Dict]:
  681. """Subclasses should override this to implement ``save()``.
  682. Warning:
  683. Do not rely on absolute paths in the implementation of
  684. ``Trainable.save_checkpoint`` and ``Trainable.load_checkpoint``.
  685. Use ``validate_save_restore`` to catch ``Trainable.save_checkpoint``/
  686. ``Trainable.load_checkpoint`` errors before execution.
  687. >>> from ray.tune.utils import validate_save_restore
  688. >>> MyTrainableClass = ... # doctest: +SKIP
  689. >>> validate_save_restore(MyTrainableClass) # doctest: +SKIP
  690. .. versionadded:: 0.8.7
  691. Args:
  692. checkpoint_dir: The directory where the checkpoint
  693. file must be stored. In a Tune run, if the trial is paused,
  694. the provided path may be temporary and moved.
  695. Returns:
  696. A dict or None. If dict, the return value will
  697. be automatically serialized by Tune. In that case,
  698. ``Trainable.load_checkpoint()`` will receive the dict upon restore.
  699. Example:
  700. >>> trainable, trainable1, trainable2 = ... # doctest: +SKIP
  701. >>> print(trainable1.save_checkpoint("/tmp/checkpoint_1")) # doctest: +SKIP
  702. "/tmp/checkpoint_1"
  703. >>> print(trainable2.save_checkpoint("/tmp/checkpoint_2")) # doctest: +SKIP
  704. {"some": "data"}
  705. >>> trainable.save_checkpoint("/tmp/bad_example") # doctest: +SKIP
  706. "/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
  707. """
  708. raise NotImplementedError
  709. def load_checkpoint(self, checkpoint: Optional[Dict]):
  710. """Subclasses should override this to implement restore().
  711. Warning:
  712. In this method, do not rely on absolute paths. The absolute
  713. path of the checkpoint_dir used in ``Trainable.save_checkpoint``
  714. may be changed.
  715. If ``Trainable.save_checkpoint`` returned a prefixed string, the
  716. prefix of the checkpoint string returned by
  717. ``Trainable.save_checkpoint`` may be changed.
  718. This is because trial pausing depends on temporary directories.
  719. The directory structure under the checkpoint_dir provided to
  720. ``Trainable.save_checkpoint`` is preserved.
  721. See the examples below.
  722. Example:
  723. >>> import os
  724. >>> from ray.tune.trainable import Trainable
  725. >>> class Example(Trainable):
  726. ... def save_checkpoint(self, checkpoint_path):
  727. ... my_checkpoint_path = os.path.join(checkpoint_path, "my/path")
  728. ... return my_checkpoint_path
  729. ... def load_checkpoint(self, my_checkpoint_path):
  730. ... print(my_checkpoint_path)
  731. >>> trainer = Example()
  732. >>> # This is used when PAUSED.
  733. >>> checkpoint_result = trainer.save() # doctest: +SKIP
  734. >>> trainer.restore(checkpoint_result) # doctest: +SKIP
  735. If `Trainable.save_checkpoint` returned a dict, then Tune will directly pass
  736. the dict data as the argument to this method.
  737. Example:
  738. >>> from ray.tune.trainable import Trainable
  739. >>> class Example(Trainable):
  740. ... def save_checkpoint(self, checkpoint_path):
  741. ... return {"my_data": 1}
  742. ... def load_checkpoint(self, checkpoint_dict):
  743. ... print(checkpoint_dict["my_data"])
  744. .. versionadded:: 0.8.7
  745. Args:
  746. checkpoint: If dict, the return value is as
  747. returned by ``save_checkpoint``. Otherwise, the directory
  748. the checkpoint was stored in.
  749. """
  750. raise NotImplementedError
  751. def setup(self, config: Dict):
  752. """Subclasses should override this for custom initialization.
  753. .. versionadded:: 0.8.7
  754. Args:
  755. config: Hyperparameters and other configs given.
  756. Copy of `self.config`.
  757. """
  758. pass
  759. def log_result(self, result: Dict):
  760. """Subclasses can optionally override this to customize logging.
  761. The logging here is done on the worker process rather than
  762. the driver.
  763. .. versionadded:: 0.8.7
  764. Args:
  765. result: Training result returned by step().
  766. """
  767. self._result_logger.on_result(result)
  768. def cleanup(self):
  769. """Subclasses should override this for any cleanup on stop.
  770. If any Ray actors are launched in the Trainable (i.e., with a RLlib
  771. trainer), be sure to kill the Ray actor process here.
  772. This process should be lightweight. Per default,
  773. You can kill a Ray actor by calling `ray.kill(actor)`
  774. on the actor or removing all references to it and waiting for garbage
  775. collection
  776. .. versionadded:: 0.8.7
  777. """
  778. pass
  779. def _export_model(self, export_formats: List[str], export_dir: str):
  780. """Subclasses should override this to export model.
  781. Args:
  782. export_formats: List of formats that should be exported.
  783. export_dir: Directory to place exported models.
  784. Return:
  785. A dict that maps ExportFormats to successfully exported models.
  786. """
  787. return {}
  788. def _implements_method(self, key):
  789. return hasattr(self, key) and callable(getattr(self, key))