wandb.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802
  1. import enum
  2. import os
  3. import pickle
  4. import urllib
  5. import warnings
  6. from numbers import Number
  7. from types import ModuleType
  8. from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
  9. import numpy as np
  10. import pyarrow.fs
  11. import ray
  12. from ray import logger
  13. from ray._common.utils import load_class
  14. from ray.air._internal import usage as air_usage
  15. from ray.air.constants import TRAINING_ITERATION
  16. from ray.air.util.node import _force_on_current_node
  17. from ray.train._internal.session import get_session
  18. from ray.train._internal.syncer import DEFAULT_SYNC_TIMEOUT
  19. from ray.tune.experiment import Trial
  20. from ray.tune.logger import LoggerCallback
  21. from ray.tune.utils import flatten_dict
  22. from ray.util import PublicAPI
  23. from ray.util.queue import Queue
  24. try:
  25. import wandb
  26. from wandb.sdk.data_types.base_types.wb_value import WBValue
  27. from wandb.sdk.data_types.image import Image
  28. from wandb.sdk.data_types.video import Video
  29. from wandb.sdk.lib.disabled import RunDisabled
  30. from wandb.util import json_dumps_safer
  31. from wandb.wandb_run import Run
  32. except ImportError:
  33. wandb = json_dumps_safer = Run = RunDisabled = WBValue = None
  34. WANDB_ENV_VAR = "WANDB_API_KEY"
  35. WANDB_PROJECT_ENV_VAR = "WANDB_PROJECT_NAME"
  36. WANDB_GROUP_ENV_VAR = "WANDB_GROUP_NAME"
  37. WANDB_MODE_ENV_VAR = "WANDB_MODE"
  38. # Hook that is invoked before wandb.init in the setup method of WandbLoggerCallback
  39. # to populate the API key if it isn't already set when initializing the callback.
  40. # It doesn't take in any arguments and returns the W&B API key.
  41. # Example: "your.module.wandb_setup_api_key_hook".
  42. WANDB_SETUP_API_KEY_HOOK = "WANDB_SETUP_API_KEY_HOOK"
  43. # Hook that is invoked before wandb.init in the setup method of WandbLoggerCallback
  44. # to populate environment variables to specify the location
  45. # (project and group) of the W&B run.
  46. # It doesn't take in any arguments and doesn't return anything, but it does populate
  47. # WANDB_PROJECT_NAME and WANDB_GROUP_NAME.
  48. # Example: "your.module.wandb_populate_run_location_hook".
  49. WANDB_POPULATE_RUN_LOCATION_HOOK = "WANDB_POPULATE_RUN_LOCATION_HOOK"
  50. # Hook that is invoked after running wandb.init in WandbLoggerCallback
  51. # to process information about the W&B run.
  52. # It takes in a W&B run object and doesn't return anything.
  53. # Example: "your.module.wandb_process_run_info_hook".
  54. WANDB_PROCESS_RUN_INFO_HOOK = "WANDB_PROCESS_RUN_INFO_HOOK"
  55. @PublicAPI(stability="alpha")
  56. def setup_wandb(
  57. config: Optional[Dict] = None,
  58. api_key: Optional[str] = None,
  59. api_key_file: Optional[str] = None,
  60. rank_zero_only: bool = True,
  61. **kwargs,
  62. ) -> Union[Run, RunDisabled]:
  63. """Set up a Weights & Biases session.
  64. This function can be used to initialize a Weights & Biases session in a
  65. (distributed) training or tuning run.
  66. By default, the run ID is the trial ID, the run name is the trial name, and
  67. the run group is the experiment name. These settings can be overwritten by
  68. passing the respective arguments as ``kwargs``, which will be passed to
  69. ``wandb.init()``.
  70. In distributed training with Ray Train, only the zero-rank worker will initialize
  71. wandb. All other workers will return a disabled run object, so that logging is not
  72. duplicated in a distributed run. This can be disabled by passing
  73. ``rank_zero_only=False``, which will then initialize wandb in every training
  74. worker.
  75. The ``config`` argument will be passed to Weights and Biases and will be logged
  76. as the run configuration.
  77. If no API key or key file are passed, wandb will try to authenticate
  78. using locally stored credentials, created for instance by running ``wandb login``.
  79. Keyword arguments passed to ``setup_wandb()`` will be passed to
  80. ``wandb.init()`` and take precedence over any potential default settings.
  81. Args:
  82. config: Configuration dict to be logged to Weights and Biases. Can contain
  83. arguments for ``wandb.init()`` as well as authentication information.
  84. api_key: API key to use for authentication with Weights and Biases.
  85. api_key_file: File pointing to API key for with Weights and Biases.
  86. rank_zero_only: If True, will return an initialized session only for the
  87. rank 0 worker in distributed training. If False, will initialize a
  88. session for all workers.
  89. kwargs: Passed to ``wandb.init()``.
  90. Example:
  91. .. code-block:: python
  92. from ray.air.integrations.wandb import setup_wandb
  93. def training_loop(config):
  94. wandb = setup_wandb(config)
  95. # ...
  96. wandb.log({"loss": 0.123})
  97. """
  98. if not wandb:
  99. raise RuntimeError(
  100. "Wandb was not found - please install with `pip install wandb`"
  101. )
  102. default_trial_id = None
  103. default_trial_name = None
  104. default_experiment_name = None
  105. # Do a try-catch here if we are not in a train session
  106. session = get_session()
  107. if rank_zero_only:
  108. # Check if we are in a train session and if we are not the rank 0 worker
  109. if session and session.world_rank is not None and session.world_rank != 0:
  110. return RunDisabled()
  111. if session:
  112. default_trial_id = session.trial_id
  113. default_trial_name = session.trial_name
  114. default_experiment_name = session.experiment_name
  115. # Default init kwargs
  116. wandb_init_kwargs = {
  117. "trial_id": kwargs.get("trial_id") or default_trial_id,
  118. "trial_name": kwargs.get("trial_name") or default_trial_name,
  119. "group": kwargs.get("group") or default_experiment_name,
  120. }
  121. # Passed kwargs take precedence over default kwargs
  122. wandb_init_kwargs.update(kwargs)
  123. return _setup_wandb(
  124. config=config, api_key=api_key, api_key_file=api_key_file, **wandb_init_kwargs
  125. )
  126. def _setup_wandb(
  127. trial_id: str,
  128. trial_name: str,
  129. config: Optional[Dict] = None,
  130. api_key: Optional[str] = None,
  131. api_key_file: Optional[str] = None,
  132. _wandb: Optional[ModuleType] = None,
  133. **kwargs,
  134. ) -> Union[Run, RunDisabled]:
  135. _config = config.copy() if config else {}
  136. # If key file is specified, set
  137. if api_key_file:
  138. api_key_file = os.path.expanduser(api_key_file)
  139. _set_api_key(api_key_file, api_key)
  140. project = _get_wandb_project(kwargs.pop("project", None))
  141. group = kwargs.pop("group", os.environ.get(WANDB_GROUP_ENV_VAR))
  142. # Remove unpickleable items.
  143. _config = _clean_log(_config)
  144. wandb_init_kwargs = dict(
  145. id=trial_id,
  146. name=trial_name,
  147. resume=True,
  148. reinit=True,
  149. allow_val_change=True,
  150. config=_config,
  151. project=project,
  152. group=group,
  153. )
  154. # Update config (e.g. set any other parameters in the call to wandb.init)
  155. wandb_init_kwargs.update(**kwargs)
  156. # On windows, we can't fork
  157. if os.name == "nt":
  158. os.environ["WANDB_START_METHOD"] = "thread"
  159. else:
  160. os.environ["WANDB_START_METHOD"] = "fork"
  161. _wandb = _wandb or wandb
  162. run = _wandb.init(**wandb_init_kwargs)
  163. _run_wandb_process_run_info_hook(run)
  164. # Record `setup_wandb` usage when everything has setup successfully.
  165. air_usage.tag_setup_wandb()
  166. return run
  167. def _is_allowed_type(obj):
  168. """Return True if type is allowed for logging to wandb"""
  169. if isinstance(obj, np.ndarray) and obj.size == 1:
  170. return isinstance(obj.item(), Number)
  171. if isinstance(obj, Sequence) and len(obj) > 0:
  172. return isinstance(obj[0], (Image, Video, WBValue))
  173. return isinstance(obj, (Number, WBValue))
  174. def _clean_log(
  175. obj: Any,
  176. *,
  177. video_kwargs: Optional[Dict[str, Any]] = None,
  178. image_kwargs: Optional[Dict[str, Any]] = None,
  179. ):
  180. # Fixes https://github.com/ray-project/ray/issues/10631
  181. if video_kwargs is None:
  182. video_kwargs = {}
  183. if image_kwargs is None:
  184. image_kwargs = {}
  185. if isinstance(obj, dict):
  186. return {
  187. k: _clean_log(v, video_kwargs=video_kwargs, image_kwargs=image_kwargs)
  188. for k, v in obj.items()
  189. }
  190. elif isinstance(obj, (list, set)):
  191. return [
  192. _clean_log(v, video_kwargs=video_kwargs, image_kwargs=image_kwargs)
  193. for v in obj
  194. ]
  195. elif isinstance(obj, tuple):
  196. return tuple(
  197. _clean_log(v, video_kwargs=video_kwargs, image_kwargs=image_kwargs)
  198. for v in obj
  199. )
  200. elif isinstance(obj, np.ndarray) and obj.ndim == 3:
  201. # Must be single image (H, W, C).
  202. return Image(obj, **image_kwargs)
  203. elif isinstance(obj, np.ndarray) and obj.ndim == 4:
  204. # Must be batch of images (N >= 1, H, W, C).
  205. return (
  206. _clean_log(
  207. [Image(v, **image_kwargs) for v in obj],
  208. video_kwargs=video_kwargs,
  209. image_kwargs=image_kwargs,
  210. )
  211. if obj.shape[0] > 1
  212. else Image(obj[0], **image_kwargs)
  213. )
  214. elif isinstance(obj, np.ndarray) and obj.ndim == 5:
  215. # Must be batch of videos (N >= 1, T, C, W, H).
  216. return (
  217. _clean_log(
  218. [Video(v, **video_kwargs) for v in obj],
  219. video_kwargs=video_kwargs,
  220. image_kwargs=image_kwargs,
  221. )
  222. if obj.shape[0] > 1
  223. else Video(obj[0], **video_kwargs)
  224. )
  225. elif _is_allowed_type(obj):
  226. return obj
  227. # Else
  228. try:
  229. # This is what wandb uses internally. If we cannot dump
  230. # an object using this method, wandb will raise an exception.
  231. json_dumps_safer(obj)
  232. # This is probably unnecessary, but left here to be extra sure.
  233. pickle.dumps(obj)
  234. return obj
  235. except Exception:
  236. # give up, similar to _SafeFallBackEncoder
  237. fallback = str(obj)
  238. # Try to convert to int
  239. try:
  240. fallback = int(fallback)
  241. return fallback
  242. except ValueError:
  243. pass
  244. # Try to convert to float
  245. try:
  246. fallback = float(fallback)
  247. return fallback
  248. except ValueError:
  249. pass
  250. # Else, return string
  251. return fallback
  252. def _get_wandb_project(project: Optional[str] = None) -> Optional[str]:
  253. """Get W&B project from environment variable or external hook if not passed
  254. as and argument."""
  255. if (
  256. not project
  257. and not os.environ.get(WANDB_PROJECT_ENV_VAR)
  258. and os.environ.get(WANDB_POPULATE_RUN_LOCATION_HOOK)
  259. ):
  260. # Try to populate WANDB_PROJECT_ENV_VAR and WANDB_GROUP_ENV_VAR
  261. # from external hook
  262. try:
  263. load_class(os.environ[WANDB_POPULATE_RUN_LOCATION_HOOK])()
  264. except Exception as e:
  265. logger.exception(
  266. f"Error executing {WANDB_POPULATE_RUN_LOCATION_HOOK} to "
  267. f"populate {WANDB_PROJECT_ENV_VAR} and {WANDB_GROUP_ENV_VAR}: {e}",
  268. exc_info=e,
  269. )
  270. if not project and os.environ.get(WANDB_PROJECT_ENV_VAR):
  271. # Try to get project and group from environment variables if not
  272. # passed through WandbLoggerCallback.
  273. project = os.environ.get(WANDB_PROJECT_ENV_VAR)
  274. return project
  275. def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = None):
  276. """Set WandB API key from `wandb_config`. Will pop the
  277. `api_key_file` and `api_key` keys from `wandb_config` parameter.
  278. The order of fetching the API key is:
  279. 1) From `api_key` or `api_key_file` arguments
  280. 2) From WANDB_API_KEY environment variables
  281. 3) User already logged in to W&B (wandb.api.api_key set)
  282. 4) From external hook WANDB_SETUP_API_KEY_HOOK
  283. """
  284. if os.environ.get(WANDB_MODE_ENV_VAR) in {"offline", "disabled"}:
  285. return
  286. if api_key_file:
  287. if api_key:
  288. raise ValueError("Both WandB `api_key_file` and `api_key` set.")
  289. with open(api_key_file, "rt") as fp:
  290. api_key = fp.readline().strip()
  291. if not api_key and not os.environ.get(WANDB_ENV_VAR):
  292. # Check if user is already logged into wandb.
  293. try:
  294. wandb.ensure_configured()
  295. if wandb.api.api_key:
  296. logger.info("Already logged into W&B.")
  297. return
  298. except AttributeError:
  299. pass
  300. # Try to get API key from external hook
  301. if WANDB_SETUP_API_KEY_HOOK in os.environ:
  302. try:
  303. api_key = load_class(os.environ[WANDB_SETUP_API_KEY_HOOK])()
  304. except Exception as e:
  305. logger.exception(
  306. f"Error executing {WANDB_SETUP_API_KEY_HOOK} to setup API key: {e}",
  307. exc_info=e,
  308. )
  309. if api_key:
  310. os.environ[WANDB_ENV_VAR] = api_key
  311. elif not os.environ.get(WANDB_ENV_VAR):
  312. raise ValueError(
  313. "No WandB API key found. Either set the {} environment "
  314. "variable, pass `api_key` or `api_key_file` to the"
  315. "`WandbLoggerCallback` class as arguments, "
  316. "or run `wandb login` from the command line".format(WANDB_ENV_VAR)
  317. )
  318. def _run_wandb_process_run_info_hook(run: Any) -> None:
  319. """Run external hook to process information about wandb run"""
  320. if WANDB_PROCESS_RUN_INFO_HOOK in os.environ:
  321. try:
  322. load_class(os.environ[WANDB_PROCESS_RUN_INFO_HOOK])(run)
  323. except Exception as e:
  324. logger.exception(
  325. f"Error calling {WANDB_PROCESS_RUN_INFO_HOOK}: {e}", exc_info=e
  326. )
  327. class _QueueItem(enum.Enum):
  328. END = enum.auto()
  329. RESULT = enum.auto()
  330. CHECKPOINT = enum.auto()
  331. class _WandbLoggingActor:
  332. """
  333. Wandb assumes that each trial's information should be logged from a
  334. separate process. We use Ray actors as forking multiprocessing
  335. processes is not supported by Ray and spawn processes run into pickling
  336. problems.
  337. We use a queue for the driver to communicate with the logging process.
  338. The queue accepts the following items:
  339. - If it's a dict, it is assumed to be a result and will be logged using
  340. ``wandb.log()``
  341. - If it's a checkpoint object, it will be saved using ``wandb.log_artifact()``.
  342. """
  343. def __init__(
  344. self,
  345. logdir: str,
  346. queue: Queue,
  347. exclude: List[str],
  348. to_config: List[str],
  349. *args,
  350. **kwargs,
  351. ):
  352. import wandb
  353. self._wandb = wandb
  354. os.chdir(logdir)
  355. self.queue = queue
  356. self._exclude = set(exclude)
  357. self._to_config = set(to_config)
  358. self.args = args
  359. self.kwargs = kwargs
  360. self._trial_name = self.kwargs.get("name", "unknown")
  361. self._logdir = logdir
  362. def run(self):
  363. # Since we're running in a separate process already, use threads.
  364. os.environ["WANDB_START_METHOD"] = "thread"
  365. run = self._wandb.init(*self.args, **self.kwargs)
  366. run.config.trial_log_path = self._logdir
  367. _run_wandb_process_run_info_hook(run)
  368. while True:
  369. item_type, item_content = self.queue.get()
  370. if item_type == _QueueItem.END:
  371. break
  372. if item_type == _QueueItem.CHECKPOINT:
  373. self._handle_checkpoint(item_content)
  374. continue
  375. assert item_type == _QueueItem.RESULT
  376. log, config_update = self._handle_result(item_content)
  377. try:
  378. self._wandb.config.update(config_update, allow_val_change=True)
  379. self._wandb.log(log, step=log.get(TRAINING_ITERATION))
  380. except urllib.error.HTTPError as e:
  381. # Ignore HTTPError. Missing a few data points is not a
  382. # big issue, as long as things eventually recover.
  383. logger.warning("Failed to log result to w&b: {}".format(str(e)))
  384. except FileNotFoundError as e:
  385. logger.error(
  386. "FileNotFoundError: Did not log result to Weights & Biases. "
  387. "Possible cause: relative file path used instead of absolute path. "
  388. "Error: %s",
  389. e,
  390. )
  391. self._wandb.finish()
  392. def _handle_checkpoint(self, checkpoint_path: str):
  393. artifact = self._wandb.Artifact(
  394. name=f"checkpoint_{self._trial_name}", type="model"
  395. )
  396. artifact.add_dir(checkpoint_path)
  397. self._wandb.log_artifact(artifact)
  398. def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
  399. config_update = result.get("config", {}).copy()
  400. log = {}
  401. flat_result = flatten_dict(result, delimiter="/")
  402. for k, v in flat_result.items():
  403. if any(k.startswith(item + "/") or k == item for item in self._exclude):
  404. continue
  405. elif any(k.startswith(item + "/") or k == item for item in self._to_config):
  406. config_update[k] = v
  407. elif not _is_allowed_type(v):
  408. continue
  409. else:
  410. log[k] = v
  411. config_update.pop("callbacks", None) # Remove callbacks
  412. return log, config_update
  413. @PublicAPI(stability="alpha")
  414. class WandbLoggerCallback(LoggerCallback):
  415. """WandbLoggerCallback
  416. Weights and biases (https://www.wandb.ai/) is a tool for experiment
  417. tracking, model optimization, and dataset versioning. This Ray Tune
  418. ``LoggerCallback`` sends metrics to Wandb for automatic tracking and
  419. visualization.
  420. Example:
  421. .. testcode::
  422. import random
  423. from ray import tune
  424. from ray.air.integrations.wandb import WandbLoggerCallback
  425. def train_func(config):
  426. offset = random.random() / 5
  427. for epoch in range(2, config["epochs"]):
  428. acc = 1 - (2 + config["lr"]) ** -epoch - random.random() / epoch - offset
  429. loss = (2 + config["lr"]) ** -epoch + random.random() / epoch + offset
  430. train.report({"acc": acc, "loss": loss})
  431. tuner = tune.Tuner(
  432. train_func,
  433. param_space={
  434. "lr": tune.grid_search([0.001, 0.01, 0.1, 1.0]),
  435. "epochs": 10,
  436. },
  437. run_config=tune.RunConfig(
  438. callbacks=[WandbLoggerCallback(project="Optimization_Project")]
  439. ),
  440. )
  441. results = tuner.fit()
  442. .. testoutput::
  443. :hide:
  444. ...
  445. Args:
  446. project: Name of the Wandb project. Mandatory.
  447. group: Name of the Wandb group. Defaults to the trainable
  448. name.
  449. api_key_file: Path to file containing the Wandb API KEY. This
  450. file only needs to be present on the node running the Tune script
  451. if using the WandbLogger.
  452. api_key: Wandb API Key. Alternative to setting ``api_key_file``.
  453. excludes: List of metrics and config that should be excluded from
  454. the log.
  455. log_config: Boolean indicating if the ``config`` parameter of
  456. the ``results`` dict should be logged. This makes sense if
  457. parameters will change during training, e.g. with
  458. PopulationBasedTraining. Defaults to False.
  459. upload_checkpoints: If ``True``, model checkpoints will be uploaded to
  460. Wandb as artifacts. Defaults to ``False``.
  461. video_kwargs: Dictionary of keyword arguments passed to wandb.Video()
  462. when logging videos. Videos have to be logged as 5D numpy arrays
  463. to be affected by this parameter. For valid keyword arguments, see
  464. https://docs.wandb.ai/ref/python/data-types/video/. Defaults to ``None``.
  465. image_kwargs: Dictionary of keyword arguments passed to wandb.Image()
  466. when logging images. Images have to be logged as 3D or 4D numpy arrays
  467. to be affected by this parameter. For valid keyword arguments, see
  468. https://docs.wandb.ai/ref/python/data-types/image/. Defaults to ``None``.
  469. **kwargs: The keyword arguments will be passed to ``wandb.init()``.
  470. Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
  471. by Tune, but can be overwritten by filling out the respective configuration
  472. values.
  473. Please see here for all other valid configuration settings:
  474. https://docs.wandb.ai/ref/python/init/
  475. """ # noqa: E501
  476. # Do not log these result keys
  477. _exclude_results = ["done", "should_checkpoint"]
  478. AUTO_CONFIG_KEYS = [
  479. "trial_id",
  480. "experiment_tag",
  481. "node_ip",
  482. "experiment_id",
  483. "hostname",
  484. "pid",
  485. "date",
  486. ]
  487. """Results that are saved with `wandb.config` instead of `wandb.log`."""
  488. _logger_actor_cls = _WandbLoggingActor
  489. def __init__(
  490. self,
  491. project: Optional[str] = None,
  492. group: Optional[str] = None,
  493. api_key_file: Optional[str] = None,
  494. api_key: Optional[str] = None,
  495. excludes: Optional[List[str]] = None,
  496. log_config: bool = False,
  497. upload_checkpoints: bool = False,
  498. save_checkpoints: bool = False,
  499. upload_timeout: int = DEFAULT_SYNC_TIMEOUT,
  500. video_kwargs: Optional[dict] = None,
  501. image_kwargs: Optional[dict] = None,
  502. **kwargs,
  503. ):
  504. if not wandb:
  505. raise RuntimeError(
  506. "Wandb was not found - please install with `pip install wandb`"
  507. )
  508. if save_checkpoints:
  509. warnings.warn(
  510. "`save_checkpoints` is deprecated. Use `upload_checkpoints` instead.",
  511. DeprecationWarning,
  512. )
  513. upload_checkpoints = save_checkpoints
  514. self.project = project
  515. self.group = group
  516. self.api_key_path = api_key_file
  517. self.api_key = api_key
  518. self.excludes = excludes or []
  519. self.log_config = log_config
  520. self.upload_checkpoints = upload_checkpoints
  521. self._upload_timeout = upload_timeout
  522. self.video_kwargs = video_kwargs or {}
  523. self.image_kwargs = image_kwargs or {}
  524. self.kwargs = kwargs
  525. self._remote_logger_class = None
  526. self._trial_logging_actors: Dict[
  527. "Trial", ray.actor.ActorHandle[_WandbLoggingActor]
  528. ] = {}
  529. self._trial_logging_futures: Dict["Trial", ray.ObjectRef] = {}
  530. self._logging_future_to_trial: Dict[ray.ObjectRef, "Trial"] = {}
  531. self._trial_queues: Dict["Trial", Queue] = {}
  532. def setup(self, *args, **kwargs):
  533. self.api_key_file = (
  534. os.path.expanduser(self.api_key_path) if self.api_key_path else None
  535. )
  536. _set_api_key(self.api_key_file, self.api_key)
  537. self.project = _get_wandb_project(self.project)
  538. if not self.project:
  539. raise ValueError(
  540. "Please pass the project name as argument or through "
  541. f"the {WANDB_PROJECT_ENV_VAR} environment variable."
  542. )
  543. if not self.group and os.environ.get(WANDB_GROUP_ENV_VAR):
  544. self.group = os.environ.get(WANDB_GROUP_ENV_VAR)
  545. def log_trial_start(self, trial: "Trial"):
  546. config = trial.config.copy()
  547. config.pop("callbacks", None) # Remove callbacks
  548. exclude_results = self._exclude_results.copy()
  549. # Additional excludes
  550. exclude_results += self.excludes
  551. # Log config keys on each result?
  552. if not self.log_config:
  553. exclude_results += ["config"]
  554. # Fill trial ID and name
  555. trial_id = trial.trial_id if trial else None
  556. trial_name = str(trial) if trial else None
  557. # Project name for Wandb
  558. wandb_project = self.project
  559. # Grouping
  560. wandb_group = self.group or trial.experiment_dir_name if trial else None
  561. # remove unpickleable items!
  562. config = _clean_log(config)
  563. config = {
  564. key: value for key, value in config.items() if key not in self.excludes
  565. }
  566. wandb_init_kwargs = dict(
  567. id=trial_id,
  568. name=trial_name,
  569. resume=False,
  570. reinit=True,
  571. allow_val_change=True,
  572. group=wandb_group,
  573. project=wandb_project,
  574. config=config,
  575. )
  576. wandb_init_kwargs.update(self.kwargs)
  577. self._start_logging_actor(trial, exclude_results, **wandb_init_kwargs)
  578. def _start_logging_actor(
  579. self, trial: "Trial", exclude_results: List[str], **wandb_init_kwargs
  580. ):
  581. # Reuse actor if one already exists.
  582. # This can happen if the trial is restarted.
  583. if trial in self._trial_logging_futures:
  584. return
  585. if not self._remote_logger_class:
  586. env_vars = {}
  587. # API key env variable is not set if authenticating through `wandb login`
  588. if WANDB_ENV_VAR in os.environ:
  589. env_vars[WANDB_ENV_VAR] = os.environ[WANDB_ENV_VAR]
  590. self._remote_logger_class = ray.remote(
  591. num_cpus=0,
  592. **_force_on_current_node(),
  593. runtime_env={"env_vars": env_vars},
  594. max_restarts=-1,
  595. max_task_retries=-1,
  596. )(self._logger_actor_cls)
  597. self._trial_queues[trial] = Queue(
  598. actor_options={
  599. "num_cpus": 0,
  600. **_force_on_current_node(),
  601. "max_restarts": -1,
  602. "max_task_retries": -1,
  603. }
  604. )
  605. self._trial_logging_actors[trial] = self._remote_logger_class.remote(
  606. logdir=trial.local_path,
  607. queue=self._trial_queues[trial],
  608. exclude=exclude_results,
  609. to_config=self.AUTO_CONFIG_KEYS,
  610. **wandb_init_kwargs,
  611. )
  612. logging_future = self._trial_logging_actors[trial].run.remote()
  613. self._trial_logging_futures[trial] = logging_future
  614. self._logging_future_to_trial[logging_future] = trial
  615. def _signal_logging_actor_stop(self, trial: "Trial"):
  616. self._trial_queues[trial].put((_QueueItem.END, None))
  617. def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
  618. if trial not in self._trial_logging_actors:
  619. self.log_trial_start(trial)
  620. result = _clean_log(
  621. result, video_kwargs=self.video_kwargs, image_kwargs=self.image_kwargs
  622. )
  623. self._trial_queues[trial].put((_QueueItem.RESULT, result))
  624. def log_trial_save(self, trial: "Trial"):
  625. if self.upload_checkpoints and trial.checkpoint:
  626. checkpoint_root = None
  627. if isinstance(trial.checkpoint.filesystem, pyarrow.fs.LocalFileSystem):
  628. checkpoint_root = trial.checkpoint.path
  629. if checkpoint_root:
  630. self._trial_queues[trial].put((_QueueItem.CHECKPOINT, checkpoint_root))
  631. def log_trial_end(self, trial: "Trial", failed: bool = False):
  632. self._signal_logging_actor_stop(trial=trial)
  633. self._cleanup_logging_actors()
  634. def _cleanup_logging_actor(self, trial: "Trial"):
  635. del self._trial_queues[trial]
  636. del self._trial_logging_futures[trial]
  637. ray.kill(self._trial_logging_actors[trial])
  638. del self._trial_logging_actors[trial]
  639. def _cleanup_logging_actors(self, timeout: int = 0, kill_on_timeout: bool = False):
  640. """Clean up logging actors that have finished uploading to wandb.
  641. Waits for `timeout` seconds to collect finished logging actors.
  642. Args:
  643. timeout: The number of seconds to wait. Defaults to 0 to clean up
  644. any immediate logging actors during the run.
  645. This is set to a timeout threshold to wait for pending uploads
  646. on experiment end.
  647. kill_on_timeout: Whether or not to kill and cleanup the logging actor if
  648. it hasn't finished within the timeout.
  649. """
  650. futures = list(self._trial_logging_futures.values())
  651. done, remaining = ray.wait(futures, num_returns=len(futures), timeout=timeout)
  652. for ready_future in done:
  653. finished_trial = self._logging_future_to_trial.pop(ready_future)
  654. self._cleanup_logging_actor(finished_trial)
  655. if kill_on_timeout:
  656. for remaining_future in remaining:
  657. trial = self._logging_future_to_trial.pop(remaining_future)
  658. self._cleanup_logging_actor(trial)
  659. def on_experiment_end(self, trials: List["Trial"], **info):
  660. """Wait for the actors to finish their call to `wandb.finish`.
  661. This includes uploading all logs + artifacts to wandb."""
  662. self._cleanup_logging_actors(timeout=self._upload_timeout, kill_on_timeout=True)
  663. def __del__(self):
  664. if ray.is_initialized():
  665. for trial in list(self._trial_logging_actors):
  666. self._signal_logging_actor_stop(trial=trial)
  667. self._cleanup_logging_actors(timeout=2, kill_on_timeout=True)
  668. self._trial_logging_actors = {}
  669. self._trial_logging_futures = {}
  670. self._logging_future_to_trial = {}
  671. self._trial_queues = {}