trial.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073
  1. import copy
  2. import json
  3. import logging
  4. import os
  5. import platform
  6. import re
  7. import time
  8. import uuid
  9. from contextlib import contextmanager
  10. from functools import partial
  11. from numbers import Number
  12. from pathlib import Path
  13. from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
  14. import ray
  15. import ray.cloudpickle as cloudpickle
  16. from ray._common.utils import binary_to_hex, hex_to_binary
  17. from ray.air.constants import (
  18. EXPR_ERROR_FILE,
  19. EXPR_ERROR_PICKLE_FILE,
  20. TRAINING_ITERATION,
  21. )
  22. from ray.exceptions import RayActorError, RayTaskError
  23. from ray.train._internal.checkpoint_manager import _CheckpointManager
  24. from ray.train._internal.session import _FutureTrainingResult, _TrainingResult
  25. from ray.train._internal.storage import StorageContext, _exists_at_fs_path
  26. from ray.train.constants import (
  27. RAY_CHDIR_TO_TRIAL_DIR,
  28. RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE,
  29. )
  30. from ray.tune import Checkpoint, CheckpointConfig
  31. from ray.tune.error import TuneError
  32. from ray.tune.execution.placement_groups import (
  33. PlacementGroupFactory,
  34. resource_dict_to_pg_factory,
  35. )
  36. from ray.tune.logger import NoopLogger
  37. # NOTE(rkn): We import ray.tune.registry here instead of importing the names we
  38. # need because there are cyclic imports that may cause specific names to not
  39. # have been defined yet. See https://github.com/ray-project/ray/issues/1716.
  40. from ray.tune.registry import get_trainable_cls, validate_trainable
  41. from ray.tune.result import (
  42. DEBUG_METRICS,
  43. DONE,
  44. NODE_IP,
  45. PID,
  46. STDERR_FILE,
  47. STDOUT_FILE,
  48. TRIAL_ID,
  49. TRIAL_INFO,
  50. )
  51. from ray.tune.trainable.metadata import _TrainingRunMetadata
  52. from ray.tune.utils import date_str, flatten_dict
  53. from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder
  54. from ray.util import log_once
  55. from ray.util.annotations import Deprecated, DeveloperAPI
  56. DEBUG_PRINT_INTERVAL = 5
  57. _DEFAULT_WIN_MAX_PATH_LENGTH = 260
  58. TRIAL_STATE_FILENAME = "trial_metadata.json"
  59. logger = logging.getLogger(__name__)
  60. class _Location:
  61. """Describes the location at which Trial is placed to run."""
  62. def __init__(self, hostname=None, pid=None):
  63. self.hostname = hostname
  64. self.pid = pid
  65. def __str__(self):
  66. if not self.pid:
  67. return ""
  68. elif self.hostname == platform.node():
  69. return "pid={}".format(self.pid)
  70. else:
  71. return "{}:{}".format(self.hostname, self.pid)
  72. @DeveloperAPI
  73. class ExportFormat:
  74. """Describes the format to import/export the trial Trainable.
  75. This may correspond to different file formats based on the
  76. Trainable implementation.
  77. """
  78. CHECKPOINT = "checkpoint"
  79. MODEL = "model"
  80. ONNX = "onnx"
  81. H5 = "h5"
  82. @staticmethod
  83. def validate(formats):
  84. """Validates formats.
  85. Raises:
  86. ValueError: if the format is unknown.
  87. """
  88. for i in range(len(formats)):
  89. formats[i] = formats[i].strip().lower()
  90. if formats[i] not in [
  91. ExportFormat.CHECKPOINT,
  92. ExportFormat.MODEL,
  93. ExportFormat.ONNX,
  94. ExportFormat.H5,
  95. ]:
  96. raise TuneError("Unsupported import/export format: " + formats[i])
  97. class _TrialInfo:
  98. """Serializable struct for holding information for a Trial.
  99. Attributes:
  100. trial_name: String name of the current trial.
  101. trial_id: trial_id of the trial
  102. trial_resources: resources used by trial.
  103. """
  104. def __init__(self, trial: "Trial"):
  105. self._trial_name = str(trial)
  106. self._trial_id = trial.trial_id
  107. self._trial_resources = trial.placement_group_factory
  108. self._experiment_name = trial.experiment_dir_name
  109. @property
  110. def experiment_name(self):
  111. return self._experiment_name
  112. @property
  113. def trial_name(self):
  114. return self._trial_name
  115. @property
  116. def trial_id(self):
  117. return self._trial_id
  118. @property
  119. def trial_resources(self) -> PlacementGroupFactory:
  120. return self._trial_resources
  121. @trial_resources.setter
  122. def trial_resources(self, new_resources: PlacementGroupFactory):
  123. self._trial_resources = new_resources
  124. class _TemporaryTrialState:
  125. """Temporary trial state.
  126. Values saved here should not be restored on resume.
  127. """
  128. def __init__(self):
  129. self.location = _Location()
  130. self.ray_actor: Optional[ray.actor.ActorHandle] = None
  131. self.saving_to: Optional[_FutureTrainingResult] = None
  132. self.restoring_from: Optional[_TrainingResult] = None
  133. self.num_restore_failures: int = 0
  134. def __getstate__(self):
  135. return {}
  136. def _get_max_path_length() -> int:
  137. if hasattr(os, "pathconf"):
  138. return os.pathconf("/", "PC_PATH_MAX")
  139. # Windows
  140. return _DEFAULT_WIN_MAX_PATH_LENGTH
  141. def _create_unique_logdir_name(root: str, relative_logdir: str) -> str:
  142. candidate = Path(root).expanduser().joinpath(relative_logdir)
  143. if candidate.exists():
  144. relative_logdir_old = relative_logdir
  145. relative_logdir += "_" + uuid.uuid4().hex[:4]
  146. logger.info(
  147. f"Creating a new dirname {relative_logdir} because "
  148. f"trial dirname '{relative_logdir_old}' already exists."
  149. )
  150. return relative_logdir
  151. def _noop_logger_creator(config: Dict[str, Any], logdir: str):
  152. # Upon remote process setup, record the actor's original working dir before
  153. # changing to the Tune logdir
  154. os.environ.setdefault("TUNE_ORIG_WORKING_DIR", os.getcwd())
  155. os.makedirs(logdir, exist_ok=True)
  156. if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))):
  157. # Set the working dir to the trial directory in the remote process,
  158. # for user file writes
  159. if not ray._private.worker._mode() == ray._private.worker.LOCAL_MODE:
  160. os.chdir(logdir)
  161. return NoopLogger(config, logdir)
  162. def _get_trainable_kwargs(trial: "Trial") -> Dict[str, Any]:
  163. trial.init_local_path()
  164. logger_creator = partial(
  165. _noop_logger_creator, logdir=trial.storage.trial_working_directory
  166. )
  167. trial_config = copy.deepcopy(trial.config)
  168. trial_config[TRIAL_INFO] = _TrialInfo(trial)
  169. stdout_file, stderr_file = trial.log_to_file
  170. trial_config[STDOUT_FILE] = stdout_file
  171. trial_config[STDERR_FILE] = stderr_file
  172. assert trial.storage.trial_dir_name
  173. kwargs = {
  174. "config": trial_config,
  175. "logger_creator": logger_creator,
  176. "storage": trial.storage,
  177. }
  178. return kwargs
  179. @contextmanager
  180. def _change_working_directory(trial):
  181. """Context manager changing working directory to trial logdir.
  182. Used in local mode.
  183. For non-local mode it is no-op.
  184. """
  185. if ray._private.worker._mode() == ray._private.worker.LOCAL_MODE:
  186. old_dir = os.getcwd()
  187. try:
  188. os.chdir(trial.local_path)
  189. yield
  190. finally:
  191. os.chdir(old_dir)
  192. else:
  193. yield
  194. @DeveloperAPI
  195. class Trial:
  196. """A trial object holds the state for one model training run.
  197. Trials are themselves managed by the TrialRunner class, which implements
  198. the event loop for submitting trial runs to a Ray cluster.
  199. Trials start in the PENDING state, and transition to RUNNING once started.
  200. On error, it transitions to ERROR, otherwise TERMINATED on success.
  201. There are resources allocated to each trial. These should be specified
  202. using ``PlacementGroupFactory``.
  203. Attributes:
  204. trainable_name: Name of the trainable object to be executed.
  205. config: Provided configuration dictionary with evaluated params.
  206. trial_id: Unique identifier for the trial.
  207. path: Path where results for this trial are stored. Can be on
  208. the local node or on cloud storage.
  209. local_path: Path on the local disk where results are stored.
  210. remote_path: Path on cloud storage where results are stored,
  211. or None if not set.
  212. relative_logdir: Directory of the trial relative to its
  213. experiment directory.
  214. evaluated_params: Evaluated parameters by search algorithm,
  215. experiment_tag: Identifying trial name to show in the console
  216. status: One of PENDING, RUNNING, PAUSED, TERMINATED, ERROR/
  217. error_file: Path to the errors that this trial has raised.
  218. """
  219. _nonjson_fields = [
  220. "results",
  221. "extra_arg",
  222. "placement_group_factory",
  223. "_resources",
  224. "_default_placement_group_factory",
  225. ]
  226. PENDING = "PENDING"
  227. RUNNING = "RUNNING"
  228. PAUSED = "PAUSED"
  229. TERMINATED = "TERMINATED"
  230. ERROR = "ERROR"
  231. def __init__(
  232. self,
  233. trainable_name: str,
  234. *,
  235. config: Optional[Dict] = None,
  236. trial_id: Optional[str] = None,
  237. storage: Optional[StorageContext] = None,
  238. evaluated_params: Optional[Dict] = None,
  239. experiment_tag: str = "",
  240. placement_group_factory: Optional[PlacementGroupFactory] = None,
  241. stopping_criterion: Optional[Dict[str, float]] = None,
  242. checkpoint_config: Optional[CheckpointConfig] = None,
  243. export_formats: Optional[List[str]] = None,
  244. restore_path: Optional[str] = None,
  245. trial_name_creator: Optional[Callable[["Trial"], str]] = None,
  246. trial_dirname_creator: Optional[Callable[["Trial"], str]] = None,
  247. log_to_file: Union[Optional[str], Tuple[Optional[str], Optional[str]]] = None,
  248. max_failures: int = 0,
  249. stub: bool = False,
  250. _setup_default_resource: bool = True,
  251. ):
  252. """Initialize a new trial.
  253. The args here take the same meaning as the command line flags defined
  254. in ray.tune.experiment.config_parser.
  255. Args:
  256. _setup_default_resource: Whether to set up default resources.
  257. When initializing trials from checkpoints, this field is set to false,
  258. so that setting up default resources can be delayed till after
  259. ``trial.config`` is loaded from checkpoints.
  260. """
  261. # If this is set, trainables are not validated or looked up.
  262. # This can be used e.g. to initialize Trial objects from checkpoints
  263. # without loading the trainable first.
  264. self.stub = stub
  265. if not self.stub:
  266. validate_trainable(trainable_name)
  267. # Trial config
  268. self.trainable_name = trainable_name
  269. self.trial_id = Trial.generate_id() if trial_id is None else trial_id
  270. self.temporary_state = _TemporaryTrialState()
  271. self.run_metadata = _TrainingRunMetadata()
  272. # Create a copy, since `init_local_path` updates the context with the
  273. # generated trial dirname.
  274. self.storage = copy.copy(storage)
  275. self.config = config or {}
  276. # Save a copy of the original unresolved config so that we can swap
  277. # out and update any reference config values after restoration.
  278. self.__unresolved_config = self.config
  279. # Parameters that Tune varies across searches.
  280. self.evaluated_params = evaluated_params or {}
  281. self.experiment_tag = experiment_tag
  282. self.stopping_criterion = stopping_criterion or {}
  283. self._setup_default_resource = _setup_default_resource
  284. if placement_group_factory and not isinstance(
  285. placement_group_factory, PlacementGroupFactory
  286. ):
  287. placement_group_factory = resource_dict_to_pg_factory(
  288. placement_group_factory
  289. )
  290. self._default_placement_group_factory = placement_group_factory
  291. # Will be created in create_placement_group_factory().
  292. self.placement_group_factory = None
  293. self.log_to_file = log_to_file
  294. # Make sure `stdout_file, stderr_file = Trial.log_to_file` works
  295. if (
  296. not self.log_to_file
  297. or not isinstance(self.log_to_file, Sequence)
  298. or not len(self.log_to_file) == 2
  299. ):
  300. self.log_to_file = (None, None)
  301. self.max_failures = max_failures
  302. # Local trial state that is updated during the run
  303. self._default_result_or_future: Union[ray.ObjectRef, dict, None] = None
  304. self.export_formats = export_formats
  305. self.status = Trial.PENDING
  306. self.relative_logdir = None
  307. self.trial_name_creator = trial_name_creator
  308. self.trial_dirname_creator = trial_dirname_creator
  309. self.custom_trial_name = None
  310. self.custom_dirname = None
  311. # Checkpoint config
  312. checkpoint_config = checkpoint_config or CheckpointConfig()
  313. self.run_metadata.checkpoint_manager = _CheckpointManager(
  314. checkpoint_config=checkpoint_config
  315. )
  316. # Restoration fields
  317. self.restore_path = restore_path
  318. self._restore_checkpoint_result: Optional[_TrainingResult] = None
  319. if restore_path:
  320. # tune.run(restore) passes in a path without metrics.
  321. self._restore_checkpoint_result = _TrainingResult(
  322. checkpoint=Checkpoint.from_directory(restore_path), metrics={}
  323. )
  324. if trial_name_creator:
  325. self.custom_trial_name = trial_name_creator(self)
  326. if trial_dirname_creator:
  327. self.custom_dirname = trial_dirname_creator(self)
  328. if os.path.sep in self.custom_dirname:
  329. raise ValueError(
  330. f"Trial dirname must not contain '/'. Got {self.custom_dirname}"
  331. )
  332. self._state_json = None
  333. def create_placement_group_factory(self):
  334. """Compute placement group factory if needed.
  335. Note: this must be called after all the placeholders in
  336. self.config are resolved.
  337. """
  338. trainable_cls = self.get_trainable_cls()
  339. if not trainable_cls or not self._setup_default_resource:
  340. # Create placement group factory using default resources.
  341. self.placement_group_factory = (
  342. self._default_placement_group_factory or resource_dict_to_pg_factory()
  343. )
  344. return
  345. default_resources = trainable_cls.default_resource_request(self.config)
  346. # If Trainable returns resources, do not allow manual override via
  347. # `resources_per_trial` by the user.
  348. if default_resources and self._default_placement_group_factory:
  349. raise TuneError(
  350. "Resources for {} have been automatically set to {} "
  351. "by its `default_resource_request()` method. Please "
  352. "clear the `resources_per_trial` option.".format(
  353. trainable_cls, default_resources
  354. )
  355. )
  356. if default_resources and not isinstance(
  357. default_resources, PlacementGroupFactory
  358. ):
  359. default_resources = resource_dict_to_pg_factory(default_resources)
  360. self.placement_group_factory = (
  361. # default_resource_request
  362. default_resources
  363. # resources_per_trial
  364. or self._default_placement_group_factory
  365. # cpu=1
  366. or resource_dict_to_pg_factory()
  367. )
  368. def _get_default_result_or_future(self) -> Optional[dict]:
  369. """Calls ray.get on self._default_result_or_future and assigns back.
  370. Returns None in case of exceptions.
  371. Will also set the trial location if runner is set.
  372. """
  373. if self._default_result_or_future and isinstance(
  374. self._default_result_or_future, ray.ObjectRef
  375. ):
  376. try:
  377. self._default_result_or_future = ray.get(self._default_result_or_future)
  378. except RayActorError: # error during initialization
  379. self._default_result_or_future = None
  380. if self._default_result_or_future and self.temporary_state.ray_actor:
  381. self.set_location(
  382. _Location(
  383. self._default_result_or_future.get(NODE_IP),
  384. self._default_result_or_future.get(PID),
  385. )
  386. )
  387. return self._default_result_or_future
  388. def resolve_config_placeholders(self, placeholder_resolvers: Dict[Tuple, Any]):
  389. from ray.tune.impl.placeholder import resolve_placeholders
  390. # Make a copy of the unresolved config before resolve it.
  391. self.config = copy.deepcopy(self.__unresolved_config)
  392. resolve_placeholders(self.config, placeholder_resolvers)
  393. @property
  394. def last_result(self) -> dict:
  395. # The logic in here is as follows:
  396. # 1. If the trial has reported at least once, last_result would have
  397. # been set and therefore would not be empty. We can just return it.
  398. # 2. If the trial has not reported at least once but we have the
  399. # future for the default results dict, (obtained through
  400. # Trainable.get_auto_filled_metrics), we get that future
  401. # and return it.
  402. # 3. In the worst case where we have nothing, we just set the
  403. # trial_id and return that.
  404. result = self.run_metadata.last_result
  405. if not {k for k in result if k != TRIAL_ID}:
  406. self._get_default_result_or_future()
  407. result = self._default_result_or_future or result
  408. result.setdefault(TRIAL_ID, self.trial_id)
  409. return result
  410. @property
  411. def metric_analysis(self):
  412. return self.run_metadata.metric_analysis
  413. @property
  414. def metric_n_steps(self):
  415. return self.run_metadata.metric_n_steps
  416. def get_ray_actor_ip(self) -> Optional[str]:
  417. if self.temporary_state.location.hostname:
  418. return self.temporary_state.location.hostname
  419. if not self.temporary_state.ray_actor:
  420. return None
  421. hostname, pid = ray.get(
  422. self.temporary_state.ray_actor.get_current_ip_pid.remote()
  423. )
  424. self.temporary_state.location = _Location(hostname, pid)
  425. return self.temporary_state.location.hostname
  426. @property
  427. @Deprecated("Replaced by `local_experiment_path`")
  428. def local_dir(self):
  429. return self.local_experiment_path
  430. @property
  431. def experiment_dir_name(self):
  432. return self.storage.experiment_dir_name
  433. @property
  434. def remote_experiment_path(self) -> str:
  435. return self.storage.experiment_fs_path
  436. @property
  437. def local_experiment_path(self) -> str:
  438. return self.storage.experiment_driver_staging_path
  439. @property
  440. @Deprecated("Replaced by `local_path`")
  441. def logdir(self) -> Optional[str]:
  442. # TODO(justinvyu): [Deprecated] Remove in 2.11.
  443. raise DeprecationWarning("Use `local_path` instead of `logdir`.")
  444. @property
  445. def local_path(self) -> Optional[str]:
  446. return self.storage.trial_driver_staging_path
  447. @property
  448. def path(self) -> Optional[str]:
  449. return self.storage.trial_fs_path
  450. @property
  451. def has_reported_at_least_once(self) -> bool:
  452. return bool(self.run_metadata.last_result)
  453. @property
  454. def node_ip(self):
  455. return self.temporary_state.location.hostname
  456. @property
  457. def checkpoint_at_end(self):
  458. config = self.run_metadata.checkpoint_manager.checkpoint_config
  459. return config.checkpoint_at_end
  460. @property
  461. def checkpoint_freq(self):
  462. config = self.run_metadata.checkpoint_manager.checkpoint_config
  463. return config.checkpoint_frequency
  464. @property
  465. def latest_checkpoint_result(self) -> Optional[_TrainingResult]:
  466. # NOTE: Fallback to the checkpoint passed in from `tune.run(restore)`
  467. # if the trial hasn't saved any checkpoints itself yet.
  468. return (
  469. self.run_metadata.checkpoint_manager.latest_checkpoint_result
  470. or self._restore_checkpoint_result
  471. )
  472. @property
  473. def checkpoint(self) -> Optional[Checkpoint]:
  474. """Returns the most recent checkpoint if one has been saved."""
  475. return (
  476. self.latest_checkpoint_result.checkpoint
  477. if self.latest_checkpoint_result
  478. else None
  479. )
  480. @classmethod
  481. def generate_id(cls):
  482. return str(uuid.uuid4().hex)[:8]
  483. def reset(self) -> "Trial":
  484. # If there is `default_resource_request` associated with the trainable,
  485. # clear `resources` and `placement_group_factory`.
  486. # This is mainly relevant for RLlib tuning jobs, where we save users
  487. # of the trouble to specify the resources themselves by having some
  488. # default resources for popular RLlib algorithms.
  489. trainable_cls = self.get_trainable_cls()
  490. clear_resources = trainable_cls and trainable_cls.default_resource_request(
  491. self.config
  492. )
  493. placement_group_factory = (
  494. self.placement_group_factory if not clear_resources else None
  495. )
  496. checkpoint_config = self.run_metadata.checkpoint_manager.checkpoint_config
  497. return Trial(
  498. self.trainable_name,
  499. config=self.config,
  500. trial_id=None,
  501. evaluated_params=self.evaluated_params,
  502. experiment_tag=self.experiment_tag,
  503. placement_group_factory=placement_group_factory,
  504. stopping_criterion=self.stopping_criterion,
  505. checkpoint_config=checkpoint_config,
  506. export_formats=self.export_formats,
  507. restore_path=self.restore_path,
  508. trial_name_creator=self.trial_name_creator,
  509. trial_dirname_creator=self.trial_dirname_creator,
  510. log_to_file=self.log_to_file,
  511. max_failures=self.max_failures,
  512. storage=self.storage,
  513. )
  514. @Deprecated("Replaced by `init_local_path()`")
  515. def init_logdir(self):
  516. # TODO(justinvyu): [Deprecated] Remove in 2.11.
  517. raise DeprecationWarning("Use `init_local_path` instead of `init_logdir`.")
  518. def init_local_path(self):
  519. """Init logdir."""
  520. if not self.relative_logdir:
  521. self.relative_logdir = _create_unique_logdir_name(
  522. str(self.local_experiment_path), self._generate_dirname()
  523. )
  524. # Populate the storage context with the trial dir name we just generated.
  525. self.storage.trial_dir_name = self.relative_logdir
  526. assert self.local_path
  527. logdir_path = Path(self.local_path)
  528. max_path_length = _get_max_path_length()
  529. if len(str(logdir_path)) >= max_path_length:
  530. logger.warning(
  531. f"The path to the trial log directory is too long "
  532. f"(max length: {max_path_length}. "
  533. f"Consider using `trial_dirname_creator` to shorten the path. "
  534. f"Path: {logdir_path}"
  535. )
  536. logdir_path.mkdir(parents=True, exist_ok=True)
  537. self.invalidate_json_state()
  538. def update_resources(self, resources: Union[dict, PlacementGroupFactory]):
  539. """EXPERIMENTAL: Updates the resource requirements.
  540. Should only be called when the trial is not running.
  541. Raises:
  542. ValueError: if trial status is running.
  543. """
  544. if self.status is Trial.RUNNING:
  545. raise ValueError("Cannot update resources while Trial is running.")
  546. placement_group_factory = resources
  547. if isinstance(resources, dict):
  548. placement_group_factory = resource_dict_to_pg_factory(resources)
  549. self.placement_group_factory = placement_group_factory
  550. self.invalidate_json_state()
  551. def set_ray_actor(self, ray_actor):
  552. self.temporary_state.ray_actor = ray_actor
  553. if ray_actor:
  554. # Do not block here, the result will be gotten when last_result
  555. # property is accessed
  556. self._default_result_or_future = ray_actor.get_auto_filled_metrics.remote(
  557. debug_metrics_only=True
  558. )
  559. def set_location(self, location):
  560. """Sets the location of the trial."""
  561. self.temporary_state.location = location
  562. def set_status(self, status):
  563. """Sets the status of the trial."""
  564. self.status = status
  565. if status == Trial.RUNNING:
  566. if self.run_metadata.start_time is None:
  567. self.run_metadata.start_time = time.time()
  568. self.invalidate_json_state()
  569. def set_config(self, config):
  570. self.config = config
  571. self.invalidate_json_state()
  572. def set_experiment_tag(self, experiment_tag):
  573. self.experiment_tag = experiment_tag
  574. self.invalidate_json_state()
  575. def set_storage(self, new_storage: StorageContext):
  576. """Updates the storage context of the trial.
  577. If the `storage_path` or `experiment_dir_name` has changed, then this setter
  578. also updates the paths of all checkpoints tracked by the checkpoint manager.
  579. This enables restoration from a checkpoint if the user moves the directory.
  580. """
  581. original_storage = self.storage
  582. checkpoint_manager = self.run_metadata.checkpoint_manager
  583. for checkpoint_result in checkpoint_manager.best_checkpoint_results:
  584. checkpoint_result.checkpoint = Checkpoint(
  585. path=checkpoint_result.checkpoint.path.replace(
  586. original_storage.trial_fs_path, new_storage.trial_fs_path, 1
  587. ),
  588. filesystem=new_storage.storage_filesystem,
  589. )
  590. latest_checkpoint_result = checkpoint_manager.latest_checkpoint_result
  591. if latest_checkpoint_result:
  592. latest_checkpoint_result.checkpoint = Checkpoint(
  593. path=latest_checkpoint_result.checkpoint.path.replace(
  594. original_storage.trial_fs_path, new_storage.trial_fs_path, 1
  595. ),
  596. filesystem=new_storage.storage_filesystem,
  597. )
  598. self.storage = new_storage
  599. self.invalidate_json_state()
  600. @property
  601. def num_failures(self):
  602. return self.run_metadata.num_failures
  603. @property
  604. def num_failures_after_restore(self):
  605. return self.run_metadata.num_failures_after_restore
  606. @property
  607. def error_file(self):
  608. if not self.local_path or not self.run_metadata.error_filename:
  609. return None
  610. return Path(self.local_path, self.run_metadata.error_filename).as_posix()
  611. @property
  612. def pickled_error_file(self):
  613. if not self.local_path or not self.run_metadata.pickled_error_filename:
  614. return None
  615. return Path(
  616. self.local_path, self.run_metadata.pickled_error_filename
  617. ).as_posix()
  618. def get_pickled_error(self) -> Optional[Exception]:
  619. """Returns the pickled error object if it exists in storage.
  620. This is a pickled version of the latest error that the trial encountered.
  621. """
  622. error_filename = self.run_metadata.pickled_error_filename
  623. if error_filename is None:
  624. return None
  625. fs = self.storage.storage_filesystem
  626. pickled_error_fs_path = Path(
  627. self.storage.trial_fs_path, error_filename
  628. ).as_posix()
  629. if _exists_at_fs_path(fs=fs, fs_path=pickled_error_fs_path):
  630. with fs.open_input_stream(pickled_error_fs_path) as f:
  631. return cloudpickle.loads(f.readall())
  632. return None
  633. def get_error(self) -> Optional[TuneError]:
  634. """Returns the error text file trace as a TuneError object
  635. if it exists in storage.
  636. This is a text trace of the latest error that the trial encountered,
  637. which is used in the case that the error is not picklable.
  638. """
  639. error_filename = self.run_metadata.error_filename
  640. if error_filename is None:
  641. return None
  642. fs = self.storage.storage_filesystem
  643. txt_error_fs_path = Path(self.storage.trial_fs_path, error_filename).as_posix()
  644. if _exists_at_fs_path(fs=fs, fs_path=txt_error_fs_path):
  645. with fs.open_input_stream(txt_error_fs_path) as f:
  646. return f.readall().decode()
  647. return None
  648. def _handle_restore_error(self, exc: Exception):
  649. # For Restoration errors, we only increment the restore failure count
  650. # if the number of failures exceeds the restore retry limit.
  651. if self.temporary_state.num_restore_failures >= int(
  652. os.environ.get("TUNE_RESTORE_RETRY_NUM", 0)
  653. ):
  654. self.run_metadata.num_failures += 1
  655. else:
  656. self.temporary_state.num_restore_failures += 1
  657. def _handle_ray_actor_error(self, exc: RayActorError):
  658. count_preemption_errors = bool(
  659. int(os.environ.get(RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE, "0"))
  660. )
  661. if not exc.preempted or count_preemption_errors:
  662. # Only count non-preempted actor errors as failures.
  663. self.run_metadata.num_failures += 1
  664. def _handle_ray_task_error(self, exc: RayTaskError):
  665. cause = exc.as_instanceof_cause()
  666. if isinstance(cause, RayActorError):
  667. # Handle the RayActorError directly (ex: Ray Train worker actor errors)
  668. return self._handle_ray_actor_error(cause)
  669. # Increment failures for all user errors (which get raised as RayTaskError)
  670. self.run_metadata.num_failures += 1
  671. def handle_error(
  672. self, exc: Optional[Union[TuneError, RayTaskError, RayActorError]] = None
  673. ):
  674. if self.is_restoring:
  675. self._handle_restore_error(exc)
  676. elif isinstance(exc, RayActorError):
  677. self._handle_ray_actor_error(exc)
  678. elif isinstance(exc, RayTaskError):
  679. self._handle_ray_task_error(exc)
  680. else:
  681. self.run_metadata.num_failures += 1
  682. if self.local_path:
  683. self.run_metadata.error_filename = EXPR_ERROR_FILE
  684. if isinstance(exc, (RayTaskError, RayActorError)):
  685. # Piping through the actual error to result grid.
  686. self.run_metadata.pickled_error_filename = EXPR_ERROR_PICKLE_FILE
  687. with open(self.pickled_error_file, "wb") as f:
  688. cloudpickle.dump(exc, f)
  689. with open(self.error_file, "a+") as f:
  690. f.write(
  691. "Failure # {} (occurred at {})\n".format(
  692. self.run_metadata.num_failures, date_str()
  693. )
  694. )
  695. f.write(str(exc) + "\n")
  696. self.run_metadata.invalidate_cache()
  697. def should_stop(self, result):
  698. """Whether the given result meets this trial's stopping criteria."""
  699. if result.get(DONE):
  700. return True
  701. for criterion, stop_value in self.stopping_criterion.items():
  702. if isinstance(criterion, dict):
  703. raise ValueError(
  704. "Stopping criteria is now flattened by default. "
  705. "Use forward slashes to nest values `key1/key2/key3`."
  706. )
  707. elif criterion not in result:
  708. if log_once("tune_trial_stop_criterion_not_found"):
  709. logger.warning(
  710. f"Stopping criterion '{criterion}' not found in result dict! "
  711. f"Available keys are {list(result.keys())}. If '{criterion}' is"
  712. " never reported, the run will continue until training is "
  713. "finished."
  714. )
  715. elif result[criterion] >= stop_value:
  716. return True
  717. return False
  718. def should_checkpoint(self):
  719. """Whether this trial is due for checkpointing."""
  720. result = self.last_result or {}
  721. if result.get(DONE) and self.checkpoint_at_end:
  722. return True
  723. return (
  724. self.checkpoint_freq
  725. and result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0
  726. )
  727. def has_checkpoint(self) -> bool:
  728. return self.checkpoint is not None
  729. def on_checkpoint(self, checkpoint_result: _TrainingResult):
  730. """Hook for handling checkpoints taken by the Trainable.
  731. Args:
  732. checkpoint: Checkpoint taken.
  733. """
  734. self.run_metadata.checkpoint_manager.register_checkpoint(checkpoint_result)
  735. # Update the checkpoint index to keep the checkpoint index in sync.
  736. # This index will get restored when the trial is restored and will
  737. # be passed to the Trainable as the starting checkpoint index.
  738. self.storage._update_checkpoint_index(checkpoint_result.metrics)
  739. self.invalidate_json_state()
  740. self.run_metadata.invalidate_cache()
  741. def on_restore(self):
  742. """Handles restoration completion."""
  743. assert self.is_restoring
  744. self.run_metadata.last_result = self.temporary_state.restoring_from.metrics
  745. self.run_metadata.last_result.setdefault("config", self.config)
  746. self.temporary_state.restoring_from = None
  747. self.temporary_state.num_restore_failures = 0
  748. def should_recover(self):
  749. """Returns whether the trial qualifies for retrying.
  750. `num_failures` should represent the number of times the trial has
  751. failed *up to the moment this method is called.* If we've failed
  752. 5 times and `max_failures=5`, then we should recover, since
  753. we only pass the limit on the 6th failure.
  754. Note this may return true even when there is no checkpoint, either because
  755. `self.checkpoint_freq` is `0` or because the trial failed before
  756. a checkpoint has been made.
  757. """
  758. return (
  759. self.run_metadata.num_failures <= self.max_failures or self.max_failures < 0
  760. )
  761. def update_last_result(self, result):
  762. if self.experiment_tag:
  763. result.update(experiment_tag=self.experiment_tag)
  764. self.set_location(_Location(result.get(NODE_IP), result.get(PID)))
  765. self.run_metadata.last_result = result
  766. self.run_metadata.last_result_time = time.time()
  767. metric_result = self.last_result.copy()
  768. for remove_metric in DEBUG_METRICS:
  769. metric_result.pop(remove_metric, None)
  770. for metric, value in flatten_dict(metric_result).items():
  771. if isinstance(value, Number):
  772. self.run_metadata.update_metric(
  773. metric, value, step=result.get("training_iteration")
  774. )
  775. def get_trainable_cls(self):
  776. if self.stub:
  777. return None
  778. return get_trainable_cls(self.trainable_name)
  779. def is_finished(self):
  780. return self.status in [Trial.ERROR, Trial.TERMINATED]
  781. @property
  782. def is_restoring(self):
  783. return self.temporary_state.restoring_from is not None
  784. @property
  785. def is_saving(self):
  786. return self.temporary_state.saving_to is not None
  787. def __repr__(self):
  788. return self._trainable_name(include_trial_id=True)
  789. def __str__(self):
  790. return self._trainable_name(include_trial_id=True)
  791. def _trainable_name(self, include_trial_id=False):
  792. """Combines ``env`` with ``trainable_name`` and ``trial_id``.
  793. Can be overridden with a custom string creator.
  794. """
  795. if self.custom_trial_name:
  796. return self.custom_trial_name
  797. if "env" in self.config:
  798. env = self.config["env"]
  799. if isinstance(env, type):
  800. env = env.__name__
  801. identifier = "{}_{}".format(self.trainable_name, env)
  802. else:
  803. identifier = self.trainable_name
  804. if include_trial_id:
  805. identifier += "_" + self.trial_id
  806. return identifier.replace("/", "_")
  807. def _generate_dirname(self):
  808. if self.custom_dirname:
  809. generated_dirname = self.custom_dirname
  810. else:
  811. MAX_LEN_IDENTIFIER = int(os.environ.get("TUNE_MAX_LEN_IDENTIFIER", "130"))
  812. generated_dirname = f"{str(self)}_{self.experiment_tag}"
  813. generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER]
  814. generated_dirname += f"_{date_str()}"
  815. # This is the file path used by rsync. ['/', '(', ')'] are not allowed.
  816. return re.sub("[/()]", "_", generated_dirname)
  817. def invalidate_json_state(self):
  818. self._state_json = None
  819. def get_json_state(self) -> Tuple[str, str]:
  820. if self._state_json is None:
  821. state = self.__getstate__()
  822. state.pop("run_metadata", None)
  823. self._state_json = json.dumps(state, indent=2, cls=TuneFunctionEncoder)
  824. runtime_metadata_json = self.run_metadata.get_json_state()
  825. return self._state_json, runtime_metadata_json
  826. @classmethod
  827. def from_json_state(cls, json_state: str, stub: bool = False) -> "Trial":
  828. state = json.loads(json_state, cls=TuneFunctionDecoder)
  829. new_trial = Trial(
  830. state["trainable_name"],
  831. stub=stub,
  832. _setup_default_resource=False,
  833. )
  834. new_trial.__setstate__(state)
  835. return new_trial
  836. def restore_run_metadata(self, run_metadata: str):
  837. self.run_metadata = _TrainingRunMetadata.from_json_state(run_metadata)
  838. @classmethod
  839. def from_directory(
  840. cls, path: Union[str, os.PathLike], stub: bool = False
  841. ) -> "Trial":
  842. metadata_path = Path(path, TRIAL_STATE_FILENAME)
  843. if not metadata_path.exists():
  844. raise FileNotFoundError(
  845. f"Can't restore trial from path: File `{metadata_path}` not found."
  846. )
  847. json_state = metadata_path.read_text()
  848. return cls.from_json_state(json_state, stub=stub)
  849. def __getstate__(self):
  850. """Memento generator for Trial.
  851. Sets RUNNING trials to PENDING.
  852. Note this can only occur if the trial holds a PERSISTENT checkpoint.
  853. """
  854. state = self.__dict__.copy()
  855. for key in self._nonjson_fields:
  856. state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))
  857. state.pop("temporary_state", None)
  858. state["_state_json"] = None
  859. state["_default_result_or_future"] = None
  860. return state
  861. def __setstate__(self, state):
  862. if state["status"] == Trial.RUNNING:
  863. state["status"] = Trial.PENDING
  864. for key in self._nonjson_fields:
  865. if key in state:
  866. state[key] = cloudpickle.loads(hex_to_binary(state[key]))
  867. # Ensure that stub doesn't get overriden
  868. stub = state.pop("stub", True)
  869. self.__dict__.update(state)
  870. self.stub = stub or getattr(self, "stub", False)
  871. if not self.stub:
  872. validate_trainable(self.trainable_name)
  873. self.temporary_state = _TemporaryTrialState()
  874. assert self.placement_group_factory