checkpoints.py 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075
  1. import abc
  2. import inspect
  3. import json
  4. import logging
  5. import os
  6. import pathlib
  7. import re
  8. import tempfile
  9. from types import MappingProxyType
  10. from typing import Any, Collection, Dict, List, Optional, Tuple, Union
  11. import pyarrow.fs
  12. from packaging import version
  13. import ray
  14. import ray.cloudpickle as pickle
  15. from ray.rllib.core import (
  16. COMPONENT_LEARNER,
  17. COMPONENT_LEARNER_GROUP,
  18. COMPONENT_RL_MODULE,
  19. )
  20. from ray.rllib.utils import force_list
  21. from ray.rllib.utils.actor_manager import FaultTolerantActorManager
  22. from ray.rllib.utils.annotations import (
  23. OldAPIStack,
  24. OverrideToImplementCustomLogic_CallToSuperRecommended,
  25. )
  26. from ray.rllib.utils.serialization import NOT_SERIALIZABLE, serialize_type
  27. from ray.rllib.utils.typing import StateDict
  28. from ray.train import Checkpoint as Checkpoint_train
  29. from ray.tune import Checkpoint as Checkpoint_tune
  30. from ray.tune.utils.file_transfer import sync_dir_between_nodes
  31. from ray.util import log_once
  32. from ray.util.annotations import PublicAPI
  33. logger = logging.getLogger(__name__)
  34. # The current checkpoint version used by RLlib for Algorithm and Policy checkpoints.
  35. # History:
  36. # 0.1: Ray 2.0.0
  37. # A single `checkpoint-[iter num]` file for Algorithm checkpoints
  38. # within the checkpoint directory. Policy checkpoints not supported across all
  39. # DL frameworks.
  40. # 1.0: Ray >=2.1.0
  41. # An algorithm_state.pkl file for the state of the Algorithm (excluding
  42. # individual policy states).
  43. # One sub-dir inside the "policies" sub-dir for each policy with a
  44. # dedicated policy_state.pkl in it for the policy state.
  45. # 1.1: Same as 1.0, but has a new "format" field in the rllib_checkpoint.json file
  46. # indicating, whether the checkpoint is `cloudpickle` (default) or `msgpack`.
  47. # 1.2: Introduces the checkpoint for the new Learner API if the Learner API is enabled.
  48. # 2.0: Introduces the Checkpointable API for all components on the new API stack
  49. # (if the Learner-, RLModule, EnvRunner, and ConnectorV2 APIs are enabled).
  50. CHECKPOINT_VERSION = version.Version("1.1")
  51. CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER = version.Version("2.1")
  52. @PublicAPI(stability="alpha")
  53. class Checkpointable(abc.ABC):
  54. """Abstract base class for a component of RLlib that can be checkpointed to disk.
  55. Subclasses must implement the following APIs:
  56. - save_to_path()
  57. - restore_from_path()
  58. - from_checkpoint()
  59. - get_state()
  60. - set_state()
  61. - get_ctor_args_and_kwargs()
  62. - get_metadata()
  63. - get_checkpointable_components()
  64. """
  65. # The state file for the implementing class.
  66. # This file contains any state information that does NOT belong to any subcomponent
  67. # of the implementing class (which are `Checkpointable` themselves and thus should
  68. # have their own state- and metadata files).
  69. # After a `save_to_path([path])` this file can be found directly in: `path/`.
  70. STATE_FILE_NAME = "state"
  71. # The filename of the pickle file that contains the class information of the
  72. # Checkpointable as well as all constructor args to be passed to such a class in
  73. # order to construct a new instance.
  74. CLASS_AND_CTOR_ARGS_FILE_NAME = "class_and_ctor_args.pkl"
  75. # Subclasses may set this to their own metadata filename.
  76. # The dict returned by self.get_metadata() is stored in this JSON file.
  77. METADATA_FILE_NAME = "metadata.json"
  78. def save_to_path(
  79. self,
  80. path: Optional[Union[str, pathlib.Path]] = None,
  81. *,
  82. state: Optional[StateDict] = None,
  83. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  84. use_msgpack: bool = False,
  85. ) -> str:
  86. """Saves the state of the implementing class (or `state`) to `path`.
  87. The state of the implementing class is always saved in the following format:
  88. .. testcode::
  89. :skipif: True
  90. path/
  91. [component1]/
  92. [component1 subcomponentA]/
  93. ...
  94. [component1 subcomponentB]/
  95. ...
  96. [component2]/
  97. ...
  98. [cls.METADATA_FILE_NAME] (json)
  99. [cls.STATE_FILE_NAME] (pkl|msgpack)
  100. The main logic is to loop through all subcomponents of this Checkpointable
  101. and call their respective `save_to_path` methods. Then save the remaining
  102. (non subcomponent) state to this Checkpointable's STATE_FILE_NAME.
  103. In the exception that a component is a FaultTolerantActorManager instance,
  104. instead of calling `save_to_path` directly on that manager, the first healthy
  105. actor is interpreted as the component and its `save_to_path` method is called.
  106. Even if that actor is located on another node, the created file is automatically
  107. synced to the local node.
  108. Args:
  109. path: The path to the directory to save the state of the implementing class
  110. to. If `path` doesn't exist or is None, then a new directory will be
  111. created (and returned).
  112. state: An optional state dict to be used instead of getting a new state of
  113. the implementing class through `self.get_state()`.
  114. filesystem: PyArrow FileSystem to use to access data at the `path`.
  115. If not specified, this is inferred from the URI scheme of `path`.
  116. use_msgpack: Whether the state file should be written using msgpack and
  117. msgpack_numpy (file extension is `.msgpack`), rather than pickle (file
  118. extension is `.pkl`).
  119. Returns:
  120. The path (str) where the state has been saved.
  121. """
  122. # If no path is given create a local temporary directory.
  123. if path is None:
  124. import uuid
  125. # Get the location of the temporary directory on the OS.
  126. tmp_dir = pathlib.Path(tempfile.gettempdir())
  127. # Create a random directory name.
  128. random_dir_name = str(uuid.uuid4())
  129. # Create the path, but do not craet the directory on the
  130. # filesystem, yet. This is done by `PyArrow`.
  131. path = path or tmp_dir / random_dir_name
  132. # We need a string path for `pyarrow.fs.FileSystem.from_uri`.
  133. path = path if isinstance(path, str) else path.as_posix()
  134. # If we have no filesystem, figure it out.
  135. if path and not filesystem:
  136. # Note the path needs to be a path that is relative to the
  137. # filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
  138. filesystem, path = pyarrow.fs.FileSystem.from_uri(path)
  139. # Make sure, path exists.
  140. filesystem.create_dir(path, recursive=True)
  141. # Convert to `pathlib.Path` for easy handling.
  142. path = pathlib.Path(path)
  143. # Write metadata file to disk.
  144. metadata = self.get_metadata()
  145. if "checkpoint_version" not in metadata:
  146. metadata["checkpoint_version"] = str(
  147. CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER
  148. )
  149. with filesystem.open_output_stream(
  150. (path / self.METADATA_FILE_NAME).as_posix()
  151. ) as f:
  152. f.write(json.dumps(metadata).encode("utf-8"))
  153. # Write the class and constructor args information to disk. Always use pickle
  154. # for this, because this information contains classes and maybe other
  155. # non-serializable data.
  156. with filesystem.open_output_stream(
  157. (path / self.CLASS_AND_CTOR_ARGS_FILE_NAME).as_posix()
  158. ) as f:
  159. pickle.dump(
  160. {
  161. "class": type(self),
  162. "ctor_args_and_kwargs": self.get_ctor_args_and_kwargs(),
  163. },
  164. f,
  165. )
  166. # Get the entire state of this Checkpointable, or use provided `state`.
  167. _state_provided = state is not None
  168. # Get only the non-checkpointable components of the state. Checkpointable
  169. # components are saved to path by their own `save_to_path` in the loop below.
  170. state = state or self.get_state(
  171. not_components=[c[0] for c in self.get_checkpointable_components()]
  172. )
  173. # Write components of `self` that themselves are `Checkpointable`.
  174. for comp_name, comp in self.get_checkpointable_components():
  175. # If subcomponent's name is not in `state`, ignore it and don't write this
  176. # subcomponent's state to disk.
  177. if _state_provided and comp_name not in state:
  178. continue
  179. comp_path = path / comp_name
  180. # If component is an ActorManager, save the manager's first healthy
  181. # actor's state to disk (even if it's on another node, in which case, we'll
  182. # sync the generated file(s) back to this node).
  183. if isinstance(comp, FaultTolerantActorManager):
  184. actor_to_use = comp.healthy_actor_ids()[0]
  185. def _get_ip(_=None):
  186. import ray
  187. return ray.util.get_node_ip_address()
  188. _result = next(
  189. iter(
  190. comp.foreach_actor(
  191. _get_ip,
  192. remote_actor_ids=[actor_to_use],
  193. )
  194. )
  195. )
  196. if not _result.ok:
  197. raise _result.get()
  198. worker_ip_addr = _result.get()
  199. self_ip_addr = _get_ip()
  200. # Save the state to a temporary location on the `actor_to_use`'s
  201. # node.
  202. comp_state_ref = None
  203. if _state_provided:
  204. comp_state_ref = ray.put(state.pop(comp_name))
  205. # If worker_addr == self_addr, save directly to the path
  206. # provided by the user, make sure to use filesystem.
  207. if worker_ip_addr == self_ip_addr:
  208. comp.foreach_actor(
  209. lambda w, _path=comp_path, _filesystem=filesystem, _state=comp_state_ref, _use_msgpack=use_msgpack: ( # noqa
  210. w.save_to_path(
  211. path=_path,
  212. filesystem=_filesystem,
  213. state=(
  214. ray.get(_state)
  215. if _state is not None
  216. else w.get_state()
  217. ),
  218. use_msgpack=_use_msgpack,
  219. )
  220. ),
  221. remote_actor_ids=[actor_to_use],
  222. )
  223. # Transfer state files from the worker node to the head node
  224. else:
  225. # Save the checkpoint to the temporary directory on the worker.
  226. def _save(w, _state=comp_state_ref, _use_msgpack=use_msgpack):
  227. import tempfile
  228. # Create a temporary directory on the worker.
  229. tmpdir = tempfile.mkdtemp()
  230. w.save_to_path(
  231. path=tmpdir,
  232. state=(
  233. ray.get(_state) if _state is not None else w.get_state()
  234. ),
  235. use_msgpack=_use_msgpack,
  236. )
  237. return tmpdir
  238. _result = next(
  239. iter(comp.foreach_actor(_save, remote_actor_ids=[actor_to_use]))
  240. )
  241. if not _result.ok:
  242. raise _result.get()
  243. worker_temp_dir = _result.get()
  244. # Sync the temporary directory from the worker to this node.
  245. sync_dir_between_nodes(
  246. worker_ip_addr,
  247. worker_temp_dir,
  248. self_ip_addr,
  249. str(comp_path),
  250. )
  251. # Remove the temporary directory on the worker.
  252. def _rmdir(_, _dir=worker_temp_dir):
  253. import shutil
  254. shutil.rmtree(_dir)
  255. comp.foreach_actor(_rmdir, remote_actor_ids=[actor_to_use])
  256. # Local component (instance stored in a property of `self`).
  257. else:
  258. if _state_provided:
  259. comp_state = state.pop(comp_name)
  260. else:
  261. comp_state = self.get_state(components=comp_name)[comp_name]
  262. # By providing the `state` arg, we make sure that the component does not
  263. # have to call its own `get_state()` anymore, but uses what's provided
  264. # here.
  265. comp.save_to_path(
  266. path=comp_path,
  267. filesystem=filesystem,
  268. state=comp_state,
  269. use_msgpack=use_msgpack,
  270. )
  271. # Write all the remaining state to disk.
  272. filename = path / (
  273. self.STATE_FILE_NAME + (".msgpack" if use_msgpack else ".pkl")
  274. )
  275. with filesystem.open_output_stream(filename.as_posix()) as f:
  276. if use_msgpack:
  277. msgpack = try_import_msgpack(error=True)
  278. msgpack.dump(state, f)
  279. else:
  280. pickle.dump(state, f)
  281. return str(path)
  282. def restore_from_path(
  283. self,
  284. path: Union[str, pathlib.Path],
  285. *,
  286. component: Optional[str] = None,
  287. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  288. **kwargs,
  289. ) -> None:
  290. """Restores the state of the implementing class from the given path.
  291. If the `component` arg is provided, `path` refers to a checkpoint of a
  292. subcomponent of `self`, thus allowing the user to load only the subcomponent's
  293. state into `self` without affecting any of the other state information (for
  294. example, loading only the NN state into a Checkpointable, which contains such
  295. an NN, but also has other state information that should NOT be changed by
  296. calling this method).
  297. The given `path` should have the following structure and contain the following
  298. files:
  299. .. testcode::
  300. :skipif: True
  301. path/
  302. [component1]/
  303. [component1 subcomponentA]/
  304. ...
  305. [component1 subcomponentB]/
  306. ...
  307. [component2]/
  308. ...
  309. [cls.METADATA_FILE_NAME] (json)
  310. [cls.STATE_FILE_NAME] (pkl|msgpack)
  311. Note that the self.METADATA_FILE_NAME file is not required to restore the state.
  312. Args:
  313. path: The path to load the implementing class' state from or to load the
  314. state of only one subcomponent's state of the implementing class (if
  315. `component` is provided).
  316. component: If provided, `path` is interpreted as the checkpoint path of only
  317. the subcomponent and thus, only that subcomponent's state is
  318. restored/loaded. All other state of `self` remains unchanged in this
  319. case.
  320. filesystem: PyArrow FileSystem to use to access data at the `path`. If not
  321. specified, this is inferred from the URI scheme of `path`.
  322. **kwargs: Forward compatibility kwargs.
  323. """
  324. path = path if isinstance(path, str) else path.as_posix()
  325. if path and not filesystem:
  326. # Note the path needs to be a path that is relative to the
  327. # filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
  328. filesystem, path = pyarrow.fs.FileSystem.from_uri(path)
  329. # Only here convert to a `Path` instance b/c otherwise
  330. # cloud path gets broken (i.e. 'gs://' -> 'gs:/').
  331. path = pathlib.Path(path)
  332. if not _exists_at_fs_path(filesystem, path.as_posix()):
  333. raise FileNotFoundError(f"`path` ({path}) not found!")
  334. # Restore components of `self` that themselves are `Checkpointable`.
  335. orig_comp_names = {c[0] for c in self.get_checkpointable_components()}
  336. self._restore_all_subcomponents_from_path(
  337. path=path,
  338. filesystem=filesystem,
  339. component=component,
  340. **kwargs,
  341. )
  342. # Restore the "base" state (not individual subcomponents).
  343. if component is None:
  344. filename = path / self.STATE_FILE_NAME
  345. if filename.with_suffix(".msgpack").is_file():
  346. msgpack = try_import_msgpack(error=True)
  347. with filesystem.open_input_stream(
  348. filename.with_suffix(".msgpack").as_posix()
  349. ) as f:
  350. state = msgpack.load(f, strict_map_key=False)
  351. else:
  352. with filesystem.open_input_stream(
  353. filename.with_suffix(".pkl").as_posix()
  354. ) as f:
  355. state = pickle.load(f)
  356. self.set_state(state)
  357. new_comp_names = {c[0] for c in self.get_checkpointable_components()}
  358. diff_comp_names = new_comp_names - orig_comp_names
  359. if diff_comp_names:
  360. self._restore_all_subcomponents_from_path(
  361. path=path,
  362. filesystem=filesystem,
  363. only_comp_names=diff_comp_names,
  364. **kwargs,
  365. )
  366. @classmethod
  367. def from_checkpoint(
  368. cls,
  369. path: Union[str, pathlib.Path],
  370. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  371. **kwargs,
  372. ) -> "Checkpointable":
  373. """Creates a new Checkpointable instance from the given location and returns it.
  374. Args:
  375. path: The checkpoint path to load (a) the information on how to construct
  376. a new instance of the implementing class and (b) the state to restore
  377. the created instance to.
  378. filesystem: PyArrow FileSystem to use to access data at the `path`. If not
  379. specified, this is inferred from the URI scheme of `path`.
  380. kwargs: Forward compatibility kwargs. Note that these kwargs are sent to
  381. each subcomponent's `from_checkpoint()` call.
  382. Returns:
  383. A new instance of the implementing class, already set to the state stored
  384. under `path`.
  385. """
  386. # We need a string path for the `PyArrow` filesystem.
  387. path = path if isinstance(path, str) else path.as_posix()
  388. # If no filesystem is passed in create one.
  389. if path and not filesystem:
  390. # Note the path needs to be a path that is relative to the
  391. # filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
  392. filesystem, path = pyarrow.fs.FileSystem.from_uri(path)
  393. # Only here convert to a `Path` instance b/c otherwise
  394. # cloud path gets broken (i.e. 'gs://' -> 'gs:/').
  395. path = pathlib.Path(path)
  396. # Get the class constructor to call and its args/kwargs.
  397. # Try reading the pickle file first.
  398. try:
  399. with filesystem.open_input_stream(
  400. (path / cls.CLASS_AND_CTOR_ARGS_FILE_NAME).as_posix()
  401. ) as f:
  402. ctor_info = pickle.load(f)
  403. ctor = ctor_info["class"]
  404. ctor_args = force_list(ctor_info["ctor_args_and_kwargs"][0])
  405. ctor_kwargs = ctor_info["ctor_args_and_kwargs"][1]
  406. # Inspect the ctor to see, which arguments in ctor_info should be replaced
  407. # with the user provided **kwargs.
  408. for i, (param_name, param) in enumerate(
  409. inspect.signature(ctor).parameters.items()
  410. ):
  411. if param_name in kwargs:
  412. val = kwargs.pop(param_name)
  413. if (
  414. param.kind == inspect._ParameterKind.POSITIONAL_OR_KEYWORD
  415. and len(ctor_args) > i
  416. ):
  417. ctor_args[i] = val
  418. else:
  419. ctor_kwargs[param_name] = val
  420. # If the pickle file is from another python version, use provided
  421. # args instead.
  422. except Exception:
  423. # Use class that this method was called on.
  424. ctor = cls
  425. # Use only user provided **kwargs.
  426. ctor_args = []
  427. ctor_kwargs = kwargs
  428. # Check, whether the constructor actually goes together with `cls`.
  429. if not issubclass(ctor, cls):
  430. raise ValueError(
  431. f"The class ({ctor}) stored in checkpoint ({path}) does not seem to be "
  432. f"a subclass of `cls` ({cls})!"
  433. )
  434. elif not issubclass(ctor, Checkpointable):
  435. raise ValueError(
  436. f"The class ({ctor}) stored in checkpoint ({path}) does not seem to be "
  437. "an implementer of the `Checkpointable` API!"
  438. )
  439. # Construct the initial object (without any particular state).
  440. obj = ctor(*ctor_args, **ctor_kwargs)
  441. # Restore the state of the constructed object.
  442. obj.restore_from_path(path, filesystem=filesystem, **kwargs)
  443. # Return the new object.
  444. return obj
  445. @abc.abstractmethod
  446. def get_state(
  447. self,
  448. components: Optional[Union[str, Collection[str]]] = None,
  449. *,
  450. not_components: Optional[Union[str, Collection[str]]] = None,
  451. **kwargs,
  452. ) -> StateDict:
  453. """Returns the implementing class's current state as a dict.
  454. The returned dict must only contain msgpack-serializable data if you want to
  455. use the `AlgorithmConfig._msgpack_checkpoints` option. Consider returning your
  456. non msgpack-serializable data from the `Checkpointable.get_ctor_args_and_kwargs`
  457. method, instead.
  458. Args:
  459. components: An optional collection of string keys to be included in the
  460. returned state. This might be useful, if getting certain components
  461. of the state is expensive (e.g. reading/compiling the weights of a large
  462. NN) and at the same time, these components are not required by the
  463. caller.
  464. not_components: An optional list of string keys to be excluded in the
  465. returned state, even if the same string is part of `components`.
  466. This is useful to get the complete state of the class, except
  467. one or a few components.
  468. kwargs: Forward-compatibility kwargs.
  469. Returns:
  470. The current state of the implementing class (or only the `components`
  471. specified, w/o those in `not_components`).
  472. """
  473. @abc.abstractmethod
  474. def set_state(self, state: StateDict) -> None:
  475. """Sets the implementing class' state to the given state dict.
  476. If component keys are missing in `state`, these components of the implementing
  477. class will not be updated/set.
  478. Args:
  479. state: The state dict to restore the state from. Maps component keys
  480. to the corresponding subcomponent's own state.
  481. """
  482. @abc.abstractmethod
  483. def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
  484. """Returns the args/kwargs used to create `self` from its constructor.
  485. Returns:
  486. A tuple of the args (as a tuple) and kwargs (as a Dict[str, Any]) used to
  487. construct `self` from its class constructor.
  488. """
  489. @OverrideToImplementCustomLogic_CallToSuperRecommended
  490. def get_metadata(self) -> Dict:
  491. """Returns JSON writable metadata further describing the implementing class.
  492. Note that this metadata is NOT part of any state and is thus NOT needed to
  493. restore the state of a Checkpointable instance from a directory. Rather, the
  494. metadata will be written into `self.METADATA_FILE_NAME` when calling
  495. `self.save_to_path()` for the user's convenience.
  496. Returns:
  497. A JSON-encodable dict of metadata information.
  498. """
  499. return {
  500. "class_and_ctor_args_file": self.CLASS_AND_CTOR_ARGS_FILE_NAME,
  501. "state_file": self.STATE_FILE_NAME,
  502. "ray_version": ray.__version__,
  503. "ray_commit": ray.__commit__,
  504. }
  505. def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]:
  506. """Returns the implementing class's own Checkpointable subcomponents.
  507. Returns:
  508. A list of 2-tuples (name, subcomponent) describing the implementing class'
  509. subcomponents, all of which have to be `Checkpointable` themselves and
  510. whose state is therefore written into subdirectories (rather than the main
  511. state file (self.STATE_FILE_NAME) when calling `self.save_to_path()`).
  512. """
  513. return []
  514. def _check_component(self, name, components, not_components) -> bool:
  515. """Returns True if a component should be checkpointed.
  516. Args:
  517. name: The checkpoint name.
  518. components: A list of components that should be checkpointed.
  519. non_components: A list of components that should not be checkpointed.
  520. Returns:
  521. True, if the component should be checkpointed and otherwise False.
  522. """
  523. comp_list = force_list(components)
  524. not_comp_list = force_list(not_components)
  525. if (
  526. components is None
  527. or any(c.startswith(name + "/") for c in comp_list)
  528. or name in comp_list
  529. ) and (not_components is None or name not in not_comp_list):
  530. return True
  531. return False
  532. def _get_subcomponents(self, name, components):
  533. if components is None:
  534. return None
  535. components = force_list(components)
  536. subcomponents = []
  537. for comp in components:
  538. if comp.startswith(name + "/"):
  539. subcomponents.append(comp[len(name) + 1 :])
  540. return None if not subcomponents else subcomponents
  541. def _restore_all_subcomponents_from_path(
  542. self, path, filesystem, only_comp_names=None, component=None, **kwargs
  543. ):
  544. for comp_name, comp in self.get_checkpointable_components():
  545. if only_comp_names is not None and comp_name not in only_comp_names:
  546. continue
  547. # The value of the `component` argument for the upcoming
  548. # `[subcomponent].restore_from_path(.., component=..)` call.
  549. comp_arg = None
  550. if component is None:
  551. comp_dir = path / comp_name
  552. # If subcomponent's dir is not in path, ignore it and don't restore this
  553. # subcomponent's state from disk.
  554. if not _exists_at_fs_path(filesystem, comp_dir.as_posix()):
  555. continue
  556. else:
  557. comp_dir = path
  558. # `component` is a path that starts with `comp` -> Remove the name of
  559. # `comp` from the `component` arg in the upcoming call to `restore_..`.
  560. if component.startswith(comp_name + "/"):
  561. comp_arg = component[len(comp_name) + 1 :]
  562. # `component` has nothing to do with `comp` -> Skip.
  563. elif component != comp_name:
  564. continue
  565. # If component is an ActorManager, restore all the manager's healthy
  566. # actors' states from disk (even if they are on another node, in which case,
  567. # we'll sync checkpoint file(s) to the respective node).
  568. if isinstance(comp, FaultTolerantActorManager):
  569. head_node_ip = ray.util.get_node_ip_address()
  570. all_healthy_actors = comp.healthy_actor_ids()
  571. def _restore(
  572. w,
  573. _kwargs=MappingProxyType(kwargs),
  574. _path=comp_dir,
  575. _head_ip=head_node_ip,
  576. _comp_arg=comp_arg,
  577. ):
  578. import tempfile
  579. import ray
  580. worker_node_ip = ray.util.get_node_ip_address()
  581. # If the worker is on the same node as the head, load the checkpoint
  582. # directly from the path otherwise sync the checkpoint from the head
  583. # to the worker and load it from there.
  584. if worker_node_ip == _head_ip:
  585. w.restore_from_path(
  586. path=_path,
  587. filesystem=filesystem,
  588. component=_comp_arg,
  589. **_kwargs,
  590. )
  591. else:
  592. with tempfile.TemporaryDirectory() as temp_dir:
  593. sync_dir_between_nodes(
  594. _head_ip, _path, worker_node_ip, temp_dir
  595. )
  596. w.restore_from_path(
  597. temp_dir, component=_comp_arg, **_kwargs
  598. )
  599. comp.foreach_actor(_restore, remote_actor_ids=all_healthy_actors)
  600. # Call `restore_from_path()` on local subcomponent, thereby passing in the
  601. # **kwargs.
  602. else:
  603. comp.restore_from_path(
  604. comp_dir, filesystem=filesystem, component=comp_arg, **kwargs
  605. )
  606. def _exists_at_fs_path(fs: pyarrow.fs.FileSystem, path: str) -> bool:
  607. """Returns `True` if the path can be found in the filesystem."""
  608. valid = fs.get_file_info(path)
  609. return valid.type != pyarrow.fs.FileType.NotFound
  610. def _is_dir(file_info: pyarrow.fs.FileInfo) -> bool:
  611. """Returns `True`, if the file info is from a directory."""
  612. return file_info.type == pyarrow.fs.FileType.Directory
  613. @OldAPIStack
  614. def get_checkpoint_info(
  615. checkpoint: Union[str, Checkpoint_train, Checkpoint_tune],
  616. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  617. ) -> Dict[str, Any]:
  618. """Returns a dict with information about an Algorithm/Policy checkpoint.
  619. If the given checkpoint is a >=v1.0 checkpoint directory, try reading all
  620. information from the contained `rllib_checkpoint.json` file.
  621. Args:
  622. checkpoint: The checkpoint directory (str) or a Checkpoint object.
  623. filesystem: PyArrow FileSystem to use to access data at the `checkpoint`. If not
  624. specified, this is inferred from the URI scheme provided by `checkpoint`.
  625. Returns:
  626. A dict containing the keys:
  627. "type": One of "Policy" or "Algorithm".
  628. "checkpoint_version": A version tuple, e.g. v1.0, indicating the checkpoint
  629. version. This will help RLlib to remain backward compatible wrt. future
  630. Ray and checkpoint versions.
  631. "checkpoint_dir": The directory with all the checkpoint files in it. This might
  632. be the same as the incoming `checkpoint` arg.
  633. "state_file": The main file with the Algorithm/Policy's state information in it.
  634. This is usually a pickle-encoded file.
  635. "policy_ids": An optional set of PolicyIDs in case we are dealing with an
  636. Algorithm checkpoint. None if `checkpoint` is a Policy checkpoint.
  637. """
  638. # Default checkpoint info.
  639. info = {
  640. "type": "Algorithm",
  641. "format": "cloudpickle",
  642. "checkpoint_version": CHECKPOINT_VERSION,
  643. "checkpoint_dir": None,
  644. "state_file": None,
  645. "policy_ids": None,
  646. "module_ids": None,
  647. }
  648. # `checkpoint` is a Checkpoint instance: Translate to directory and continue.
  649. if isinstance(checkpoint, (Checkpoint_train, Checkpoint_tune)):
  650. checkpoint = checkpoint.to_directory()
  651. if checkpoint and not filesystem:
  652. # Note the path needs to be a path that is relative to the
  653. # filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
  654. filesystem, checkpoint = pyarrow.fs.FileSystem.from_uri(checkpoint)
  655. # Only here convert to a `Path` instance b/c otherwise
  656. # cloud path gets broken (i.e. 'gs://' -> 'gs:/').
  657. checkpoint = pathlib.Path(checkpoint)
  658. # Checkpoint is dir.
  659. if _exists_at_fs_path(filesystem, checkpoint.as_posix()) and _is_dir(
  660. filesystem.get_file_info(checkpoint.as_posix())
  661. ):
  662. info.update({"checkpoint_dir": str(checkpoint)})
  663. # Figure out whether this is an older checkpoint format
  664. # (with a `checkpoint-\d+` file in it).
  665. file_info_list = filesystem.get_file_info(
  666. pyarrow.fs.FileSelector(checkpoint.as_posix(), recursive=False)
  667. )
  668. for file_info in file_info_list:
  669. if file_info.is_file:
  670. if re.match("checkpoint-\\d+", file_info.base_name):
  671. info.update(
  672. {
  673. "checkpoint_version": version.Version("0.1"),
  674. "state_file": str(file_info.base_name),
  675. }
  676. )
  677. return info
  678. # No old checkpoint file found.
  679. # If rllib_checkpoint.json file present, read available information from it
  680. # and then continue with the checkpoint analysis (possibly overriding further
  681. # information).
  682. if _exists_at_fs_path(
  683. filesystem, (checkpoint / "rllib_checkpoint.json").as_posix()
  684. ):
  685. # if (checkpoint / "rllib_checkpoint.json").is_file():
  686. with filesystem.open_input_stream(
  687. (checkpoint / "rllib_checkpoint.json").as_posix()
  688. ) as f:
  689. # with open(checkpoint / "rllib_checkpoint.json") as f:
  690. rllib_checkpoint_info = json.load(fp=f)
  691. if "checkpoint_version" in rllib_checkpoint_info:
  692. rllib_checkpoint_info["checkpoint_version"] = version.Version(
  693. rllib_checkpoint_info["checkpoint_version"]
  694. )
  695. info.update(rllib_checkpoint_info)
  696. else:
  697. # No rllib_checkpoint.json file present: Warn and continue trying to figure
  698. # out checkpoint info ourselves.
  699. if log_once("no_rllib_checkpoint_json_file"):
  700. logger.warning(
  701. "No `rllib_checkpoint.json` file found in checkpoint directory "
  702. f"{checkpoint}! Trying to extract checkpoint info from other files "
  703. f"found in that dir."
  704. )
  705. # Policy checkpoint file found.
  706. for extension in ["pkl", "msgpck"]:
  707. if _exists_at_fs_path(
  708. filesystem, (checkpoint / ("policy_state." + extension)).as_posix()
  709. ):
  710. # if (checkpoint / ("policy_state." + extension)).is_file():
  711. info.update(
  712. {
  713. "type": "Policy",
  714. "format": "cloudpickle" if extension == "pkl" else "msgpack",
  715. "checkpoint_version": CHECKPOINT_VERSION,
  716. "state_file": str(checkpoint / f"policy_state.{extension}"),
  717. }
  718. )
  719. return info
  720. # Valid Algorithm checkpoint >v0 file found?
  721. format = None
  722. for extension in ["pkl", "msgpck", "msgpack"]:
  723. state_file = checkpoint / f"algorithm_state.{extension}"
  724. if (
  725. _exists_at_fs_path(filesystem, state_file.as_posix())
  726. and filesystem.get_file_info(state_file.as_posix()).is_file
  727. ):
  728. format = "cloudpickle" if extension == "pkl" else "msgpack"
  729. break
  730. if format is None:
  731. raise ValueError(
  732. "Given checkpoint does not seem to be valid! No file with the name "
  733. "`algorithm_state.[pkl|msgpack|msgpck]` (or `checkpoint-[0-9]+`) found."
  734. )
  735. info.update(
  736. {
  737. "format": format,
  738. "state_file": str(state_file),
  739. }
  740. )
  741. # Collect all policy IDs in the sub-dir "policies/".
  742. policies_dir = checkpoint / "policies"
  743. if _exists_at_fs_path(filesystem, policies_dir.as_posix()) and _is_dir(
  744. filesystem.get_file_info(policies_dir.as_posix())
  745. ):
  746. policy_ids = set()
  747. file_info_list = filesystem.get_file_info(
  748. pyarrow.fs.FileSelector(policies_dir.as_posix(), recursive=False)
  749. )
  750. for file_info in file_info_list:
  751. policy_ids.add(file_info.base_name)
  752. info.update({"policy_ids": policy_ids})
  753. # Collect all module IDs in the sub-dir "learner/module_state/".
  754. modules_dir = (
  755. checkpoint
  756. / COMPONENT_LEARNER_GROUP
  757. / COMPONENT_LEARNER
  758. / COMPONENT_RL_MODULE
  759. )
  760. if _exists_at_fs_path(filesystem, checkpoint.as_posix()) and _is_dir(
  761. filesystem.get_file_info(modules_dir.as_posix())
  762. ):
  763. module_ids = set()
  764. file_info_list = filesystem.get_file_info(
  765. pyarrow.fs.FileSelector(modules_dir.as_posix(), recursive=False)
  766. )
  767. for file_info in file_info_list:
  768. # Only add subdirs (those are the ones where the RLModule data
  769. # is stored, not files (could be json metadata files).
  770. module_dir = modules_dir / file_info.base_name
  771. if _is_dir(filesystem.get_file_info(module_dir.as_posix())):
  772. module_ids.add(file_info.base_name)
  773. info.update({"module_ids": module_ids})
  774. # Checkpoint is a file: Use as-is (interpreting it as old Algorithm checkpoint
  775. # version).
  776. elif (
  777. _exists_at_fs_path(filesystem, checkpoint.as_posix())
  778. and filesystem.get_file_info(checkpoint.as_posix()).is_file
  779. ):
  780. info.update(
  781. {
  782. "checkpoint_version": version.Version("0.1"),
  783. "checkpoint_dir": str(checkpoint.parent),
  784. "state_file": str(checkpoint),
  785. }
  786. )
  787. else:
  788. raise ValueError(
  789. f"Given checkpoint ({str(checkpoint)}) not found! Must be a "
  790. "checkpoint directory (or a file for older checkpoint versions)."
  791. )
  792. return info
  793. @OldAPIStack
  794. def convert_to_msgpack_checkpoint(
  795. checkpoint: Union[str, Checkpoint_train, Checkpoint_tune],
  796. msgpack_checkpoint_dir: str,
  797. ) -> str:
  798. """Converts an Algorithm checkpoint (pickle based) to a msgpack based one.
  799. Msgpack has the advantage of being python version independent.
  800. Args:
  801. checkpoint: The directory, in which to find the Algorithm checkpoint (pickle
  802. based).
  803. msgpack_checkpoint_dir: The directory, in which to create the new msgpack
  804. based checkpoint.
  805. Returns:
  806. The directory in which the msgpack checkpoint has been created. Note that
  807. this is the same as `msgpack_checkpoint_dir`.
  808. """
  809. from ray.rllib.algorithms import Algorithm
  810. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  811. from ray.rllib.core.rl_module import validate_module_id
  812. # Try to import msgpack and msgpack_numpy.
  813. msgpack = try_import_msgpack(error=True)
  814. # Restore the Algorithm using the python version dependent checkpoint.
  815. algo = Algorithm.from_checkpoint(checkpoint)
  816. state = algo.__getstate__()
  817. # Convert all code in state into serializable data.
  818. # Serialize the algorithm class.
  819. state["algorithm_class"] = serialize_type(state["algorithm_class"])
  820. # Serialize the algorithm's config object.
  821. if not isinstance(state["config"], dict):
  822. state["config"] = state["config"].serialize()
  823. else:
  824. state["config"] = AlgorithmConfig._serialize_dict(state["config"])
  825. # Extract policy states from worker state (Policies get their own
  826. # checkpoint sub-dirs).
  827. policy_states = {}
  828. if "worker" in state and "policy_states" in state["worker"]:
  829. policy_states = state["worker"].pop("policy_states", {})
  830. # Policy mapping fn.
  831. state["worker"]["policy_mapping_fn"] = NOT_SERIALIZABLE
  832. # Is Policy to train function.
  833. state["worker"]["is_policy_to_train"] = NOT_SERIALIZABLE
  834. # Add RLlib checkpoint version (as string).
  835. state["checkpoint_version"] = str(CHECKPOINT_VERSION)
  836. # Write state (w/o policies) to disk.
  837. state_file = os.path.join(msgpack_checkpoint_dir, "algorithm_state.msgpck")
  838. with open(state_file, "wb") as f:
  839. msgpack.dump(state, f)
  840. # Write rllib_checkpoint.json.
  841. with open(os.path.join(msgpack_checkpoint_dir, "rllib_checkpoint.json"), "w") as f:
  842. json.dump(
  843. {
  844. "type": "Algorithm",
  845. "checkpoint_version": state["checkpoint_version"],
  846. "format": "msgpack",
  847. "state_file": state_file,
  848. "policy_ids": list(policy_states.keys()),
  849. "ray_version": ray.__version__,
  850. "ray_commit": ray.__commit__,
  851. },
  852. f,
  853. )
  854. # Write individual policies to disk, each in their own subdirectory.
  855. for pid, policy_state in policy_states.items():
  856. # From here on, disallow policyIDs that would not work as directory names.
  857. validate_module_id(pid, error=True)
  858. policy_dir = os.path.join(msgpack_checkpoint_dir, "policies", pid)
  859. os.makedirs(policy_dir, exist_ok=True)
  860. policy = algo.get_policy(pid)
  861. policy.export_checkpoint(
  862. policy_dir,
  863. policy_state=policy_state,
  864. checkpoint_format="msgpack",
  865. )
  866. # Release all resources used by the Algorithm.
  867. algo.stop()
  868. return msgpack_checkpoint_dir
  869. @OldAPIStack
  870. def convert_to_msgpack_policy_checkpoint(
  871. policy_checkpoint: Union[str, Checkpoint_train, Checkpoint_tune],
  872. msgpack_checkpoint_dir: str,
  873. ) -> str:
  874. """Converts a Policy checkpoint (pickle based) to a msgpack based one.
  875. Msgpack has the advantage of being python version independent.
  876. Args:
  877. policy_checkpoint: The directory, in which to find the Policy checkpoint (pickle
  878. based).
  879. msgpack_checkpoint_dir: The directory, in which to create the new msgpack
  880. based checkpoint.
  881. Returns:
  882. The directory in which the msgpack checkpoint has been created. Note that
  883. this is the same as `msgpack_checkpoint_dir`.
  884. """
  885. from ray.rllib.policy.policy import Policy
  886. policy = Policy.from_checkpoint(policy_checkpoint)
  887. os.makedirs(msgpack_checkpoint_dir, exist_ok=True)
  888. policy.export_checkpoint(
  889. msgpack_checkpoint_dir,
  890. policy_state=policy.get_state(),
  891. checkpoint_format="msgpack",
  892. )
  893. # Release all resources used by the Policy.
  894. del policy
  895. return msgpack_checkpoint_dir
  896. @PublicAPI
  897. def try_import_msgpack(error: bool = False):
  898. """Tries importing msgpack and msgpack_numpy and returns the patched msgpack module.
  899. Returns None if error is False and msgpack or msgpack_numpy is not installed.
  900. Raises an error, if error is True and the modules could not be imported.
  901. Args:
  902. error: Whether to raise an error if msgpack/msgpack_numpy cannot be imported.
  903. Returns:
  904. The `msgpack` module, with the msgpack_numpy module already patched in. This
  905. means you can already encde and decode numpy arrays with the returned module.
  906. Raises:
  907. ImportError: If error=True and msgpack/msgpack_numpy is not installed.
  908. """
  909. try:
  910. import msgpack
  911. import msgpack_numpy
  912. # Make msgpack_numpy look like msgpack.
  913. msgpack_numpy.patch()
  914. return msgpack
  915. except Exception:
  916. if error:
  917. raise ImportError(
  918. "Could not import or setup msgpack and msgpack_numpy! "
  919. "Try running `pip install msgpack msgpack_numpy` first."
  920. )