function_trainable.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import inspect
  2. import logging
  3. import os
  4. import queue
  5. from functools import partial
  6. from numbers import Number
  7. from typing import Any, Callable, Dict, Optional, Type
  8. from ray.air._internal.util import RunnerThread, StartTraceback
  9. from ray.air.constants import _ERROR_FETCH_TIMEOUT
  10. from ray.train._internal.checkpoint_manager import _TrainingResult
  11. from ray.train._internal.session import (
  12. TrialInfo,
  13. _TrainSession,
  14. get_session,
  15. init_session,
  16. shutdown_session,
  17. )
  18. from ray.tune.execution.placement_groups import PlacementGroupFactory
  19. from ray.tune.result import DEFAULT_METRIC, RESULT_DUPLICATE, SHOULD_CHECKPOINT
  20. from ray.tune.trainable.trainable import Trainable
  21. from ray.tune.utils import _detect_config_single
  22. from ray.util.annotations import DeveloperAPI
  23. logger = logging.getLogger(__name__)
  24. # Time between FunctionTrainable checks when fetching
  25. # new results after signaling the reporter to continue
  26. NULL_MARKER = ".null_marker"
  27. TEMP_MARKER = ".temp_marker"
  28. @DeveloperAPI
  29. class FunctionTrainable(Trainable):
  30. """Trainable that runs a user function reporting results.
  31. This mode of execution does not support checkpoint/restore."""
  32. _name = "func"
  33. def setup(self, config):
  34. init_session(
  35. training_func=lambda: self._trainable_func(self.config),
  36. trial_info=TrialInfo(
  37. name=self.trial_name,
  38. id=self.trial_id,
  39. resources=self.trial_resources,
  40. logdir=self._storage.trial_driver_staging_path,
  41. driver_ip=None,
  42. driver_node_id=None,
  43. experiment_name=self._storage.experiment_dir_name,
  44. ),
  45. storage=self._storage,
  46. synchronous_result_reporting=True,
  47. # Set all Train-specific properties to None.
  48. world_rank=None,
  49. local_rank=None,
  50. node_rank=None,
  51. local_world_size=None,
  52. world_size=None,
  53. dataset_shard=None,
  54. checkpoint=None,
  55. )
  56. self._last_training_result: Optional[_TrainingResult] = None
  57. def _trainable_func(self, config: Dict[str, Any]):
  58. """Subclasses can override this to set the trainable func."""
  59. raise NotImplementedError
  60. def _start(self):
  61. def entrypoint():
  62. try:
  63. return self._trainable_func(self.config)
  64. except Exception as e:
  65. raise StartTraceback from e
  66. # the runner thread is not started until the first call to _train
  67. self._runner = RunnerThread(
  68. target=entrypoint, error_queue=self._error_queue, daemon=True
  69. )
  70. # if not alive, try to start
  71. self._status_reporter._start()
  72. try:
  73. self._runner.start()
  74. except RuntimeError:
  75. # If this is reached, it means the thread was started and is
  76. # now done or has raised an exception.
  77. pass
  78. def step(self):
  79. """Implements train() for a Function API.
  80. If the RunnerThread finishes without reporting "done",
  81. Tune will automatically provide a magic keyword __duplicate__
  82. along with a result with "done=True". The TrialRunner will handle the
  83. result accordingly (see tune/tune_controller.py).
  84. """
  85. session: _TrainSession = get_session()
  86. if not session.training_started:
  87. session.start()
  88. training_result: Optional[_TrainingResult] = session.get_next()
  89. if not training_result:
  90. # The `RESULT_DUPLICATE` result should have been the last
  91. # result reported by the session, which triggers cleanup.
  92. raise RuntimeError(
  93. "Should not have reached here. The TuneController should not "
  94. "have scheduled another `train` remote call."
  95. "It should have scheduled a `stop` instead "
  96. "after the training function exits."
  97. )
  98. metrics = training_result.metrics
  99. # This keyword appears if the train_func using the Function API
  100. # finishes without "done=True". This duplicates the last result, but
  101. # the TuneController will not log this result again.
  102. # TuneController will also inject done=True to the result,
  103. # and proceed to queue up a STOP decision for the trial.
  104. if RESULT_DUPLICATE in metrics:
  105. metrics[SHOULD_CHECKPOINT] = False
  106. self._last_training_result = training_result
  107. if training_result.checkpoint is not None:
  108. # TODO(justinvyu): Result/checkpoint reporting can be combined.
  109. # For now, since result/checkpoint reporting is separate, this
  110. # special key will tell Tune to pull the checkpoint from
  111. # the `last_training_result`.
  112. metrics[SHOULD_CHECKPOINT] = True
  113. return metrics
  114. def execute(self, fn):
  115. return fn(self)
  116. def save_checkpoint(self, checkpoint_dir: str = ""):
  117. if checkpoint_dir:
  118. raise ValueError("Checkpoint dir should not be used with function API.")
  119. # TODO(justinvyu): This currently breaks the `save_checkpoint` interface.
  120. # TRAIN -> SAVE remote calls get processed sequentially,
  121. # so `_last_training_result.checkpoint` holds onto the latest ckpt.
  122. return self._last_training_result
  123. def load_checkpoint(self, checkpoint_result: _TrainingResult):
  124. # TODO(justinvyu): This currently breaks the `load_checkpoint` interface.
  125. session = get_session()
  126. session.loaded_checkpoint = checkpoint_result.checkpoint
  127. def cleanup(self):
  128. session = get_session()
  129. try:
  130. # session.finish raises any Exceptions from training.
  131. # Do not wait for thread termination here (timeout=0).
  132. session.finish(timeout=0)
  133. finally:
  134. # Check for any errors that might have been missed.
  135. session._report_thread_runner_error()
  136. # Shutdown session even if session.finish() raises an Exception.
  137. shutdown_session()
  138. def reset_config(self, new_config):
  139. session = get_session()
  140. # Wait for thread termination so it is save to re-use the same actor.
  141. thread_timeout = int(os.environ.get("TUNE_FUNCTION_THREAD_TIMEOUT_S", 2))
  142. session.finish(timeout=thread_timeout)
  143. if session.training_thread.is_alive():
  144. # Did not finish within timeout, reset unsuccessful.
  145. return False
  146. session.reset(
  147. training_func=lambda: self._trainable_func(self.config),
  148. trial_info=TrialInfo(
  149. name=self.trial_name,
  150. id=self.trial_id,
  151. resources=self.trial_resources,
  152. logdir=self._storage.trial_working_directory,
  153. driver_ip=None,
  154. driver_node_id=None,
  155. experiment_name=self._storage.experiment_dir_name,
  156. ),
  157. storage=self._storage,
  158. )
  159. self._last_result = {}
  160. return True
  161. def _report_thread_runner_error(self, block=False):
  162. try:
  163. e = self._error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT)
  164. raise StartTraceback from e
  165. except queue.Empty:
  166. pass
  167. @DeveloperAPI
  168. def wrap_function(
  169. train_func: Callable[[Any], Any], name: Optional[str] = None
  170. ) -> Type["FunctionTrainable"]:
  171. inherit_from = (FunctionTrainable,)
  172. if hasattr(train_func, "__mixins__"):
  173. inherit_from = train_func.__mixins__ + inherit_from
  174. func_args = inspect.getfullargspec(train_func).args
  175. use_config_single = _detect_config_single(train_func)
  176. if not use_config_single:
  177. raise ValueError(
  178. "Unknown argument found in the Trainable function. "
  179. "The function args must include a single 'config' positional parameter.\n"
  180. "Found: {}".format(func_args)
  181. )
  182. resources = getattr(train_func, "_resources", None)
  183. class ImplicitFunc(*inherit_from):
  184. _name = name or (
  185. train_func.__name__ if hasattr(train_func, "__name__") else "func"
  186. )
  187. def __repr__(self):
  188. return self._name
  189. def _trainable_func(self, config):
  190. fn = partial(train_func, config)
  191. def handle_output(output):
  192. if not output:
  193. return
  194. elif isinstance(output, dict):
  195. get_session().report(output)
  196. elif isinstance(output, Number):
  197. get_session().report({DEFAULT_METRIC: output})
  198. else:
  199. raise ValueError(
  200. "Invalid return or yield value. Either return/yield "
  201. "a single number or a dictionary object in your "
  202. "trainable function."
  203. )
  204. output = None
  205. if inspect.isgeneratorfunction(train_func):
  206. for output in fn():
  207. handle_output(output)
  208. else:
  209. output = fn()
  210. handle_output(output)
  211. # If train_func returns, we need to notify the main event loop
  212. # of the last result while avoiding double logging. This is done
  213. # with the keyword RESULT_DUPLICATE -- see tune/tune_controller.py.
  214. get_session().report({RESULT_DUPLICATE: True})
  215. return output
  216. @classmethod
  217. def default_resource_request(
  218. cls, config: Dict[str, Any]
  219. ) -> Optional[PlacementGroupFactory]:
  220. if not isinstance(resources, PlacementGroupFactory) and callable(resources):
  221. return resources(config)
  222. return resources
  223. return ImplicitFunc