| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802 |
- import enum
- import os
- import pickle
- import urllib
- import warnings
- from numbers import Number
- from types import ModuleType
- from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
- import numpy as np
- import pyarrow.fs
- import ray
- from ray import logger
- from ray._common.utils import load_class
- from ray.air._internal import usage as air_usage
- from ray.air.constants import TRAINING_ITERATION
- from ray.air.util.node import _force_on_current_node
- from ray.train._internal.session import get_session
- from ray.train._internal.syncer import DEFAULT_SYNC_TIMEOUT
- from ray.tune.experiment import Trial
- from ray.tune.logger import LoggerCallback
- from ray.tune.utils import flatten_dict
- from ray.util import PublicAPI
- from ray.util.queue import Queue
- try:
- import wandb
- from wandb.sdk.data_types.base_types.wb_value import WBValue
- from wandb.sdk.data_types.image import Image
- from wandb.sdk.data_types.video import Video
- from wandb.sdk.lib.disabled import RunDisabled
- from wandb.util import json_dumps_safer
- from wandb.wandb_run import Run
- except ImportError:
- wandb = json_dumps_safer = Run = RunDisabled = WBValue = None
- WANDB_ENV_VAR = "WANDB_API_KEY"
- WANDB_PROJECT_ENV_VAR = "WANDB_PROJECT_NAME"
- WANDB_GROUP_ENV_VAR = "WANDB_GROUP_NAME"
- WANDB_MODE_ENV_VAR = "WANDB_MODE"
- # Hook that is invoked before wandb.init in the setup method of WandbLoggerCallback
- # to populate the API key if it isn't already set when initializing the callback.
- # It doesn't take in any arguments and returns the W&B API key.
- # Example: "your.module.wandb_setup_api_key_hook".
- WANDB_SETUP_API_KEY_HOOK = "WANDB_SETUP_API_KEY_HOOK"
- # Hook that is invoked before wandb.init in the setup method of WandbLoggerCallback
- # to populate environment variables to specify the location
- # (project and group) of the W&B run.
- # It doesn't take in any arguments and doesn't return anything, but it does populate
- # WANDB_PROJECT_NAME and WANDB_GROUP_NAME.
- # Example: "your.module.wandb_populate_run_location_hook".
- WANDB_POPULATE_RUN_LOCATION_HOOK = "WANDB_POPULATE_RUN_LOCATION_HOOK"
- # Hook that is invoked after running wandb.init in WandbLoggerCallback
- # to process information about the W&B run.
- # It takes in a W&B run object and doesn't return anything.
- # Example: "your.module.wandb_process_run_info_hook".
- WANDB_PROCESS_RUN_INFO_HOOK = "WANDB_PROCESS_RUN_INFO_HOOK"
- @PublicAPI(stability="alpha")
- def setup_wandb(
- config: Optional[Dict] = None,
- api_key: Optional[str] = None,
- api_key_file: Optional[str] = None,
- rank_zero_only: bool = True,
- **kwargs,
- ) -> Union[Run, RunDisabled]:
- """Set up a Weights & Biases session.
- This function can be used to initialize a Weights & Biases session in a
- (distributed) training or tuning run.
- By default, the run ID is the trial ID, the run name is the trial name, and
- the run group is the experiment name. These settings can be overwritten by
- passing the respective arguments as ``kwargs``, which will be passed to
- ``wandb.init()``.
- In distributed training with Ray Train, only the zero-rank worker will initialize
- wandb. All other workers will return a disabled run object, so that logging is not
- duplicated in a distributed run. This can be disabled by passing
- ``rank_zero_only=False``, which will then initialize wandb in every training
- worker.
- The ``config`` argument will be passed to Weights and Biases and will be logged
- as the run configuration.
- If no API key or key file are passed, wandb will try to authenticate
- using locally stored credentials, created for instance by running ``wandb login``.
- Keyword arguments passed to ``setup_wandb()`` will be passed to
- ``wandb.init()`` and take precedence over any potential default settings.
- Args:
- config: Configuration dict to be logged to Weights and Biases. Can contain
- arguments for ``wandb.init()`` as well as authentication information.
- api_key: API key to use for authentication with Weights and Biases.
- api_key_file: File pointing to API key for with Weights and Biases.
- rank_zero_only: If True, will return an initialized session only for the
- rank 0 worker in distributed training. If False, will initialize a
- session for all workers.
- kwargs: Passed to ``wandb.init()``.
- Example:
- .. code-block:: python
- from ray.air.integrations.wandb import setup_wandb
- def training_loop(config):
- wandb = setup_wandb(config)
- # ...
- wandb.log({"loss": 0.123})
- """
- if not wandb:
- raise RuntimeError(
- "Wandb was not found - please install with `pip install wandb`"
- )
- default_trial_id = None
- default_trial_name = None
- default_experiment_name = None
- # Do a try-catch here if we are not in a train session
- session = get_session()
- if rank_zero_only:
- # Check if we are in a train session and if we are not the rank 0 worker
- if session and session.world_rank is not None and session.world_rank != 0:
- return RunDisabled()
- if session:
- default_trial_id = session.trial_id
- default_trial_name = session.trial_name
- default_experiment_name = session.experiment_name
- # Default init kwargs
- wandb_init_kwargs = {
- "trial_id": kwargs.get("trial_id") or default_trial_id,
- "trial_name": kwargs.get("trial_name") or default_trial_name,
- "group": kwargs.get("group") or default_experiment_name,
- }
- # Passed kwargs take precedence over default kwargs
- wandb_init_kwargs.update(kwargs)
- return _setup_wandb(
- config=config, api_key=api_key, api_key_file=api_key_file, **wandb_init_kwargs
- )
- def _setup_wandb(
- trial_id: str,
- trial_name: str,
- config: Optional[Dict] = None,
- api_key: Optional[str] = None,
- api_key_file: Optional[str] = None,
- _wandb: Optional[ModuleType] = None,
- **kwargs,
- ) -> Union[Run, RunDisabled]:
- _config = config.copy() if config else {}
- # If key file is specified, set
- if api_key_file:
- api_key_file = os.path.expanduser(api_key_file)
- _set_api_key(api_key_file, api_key)
- project = _get_wandb_project(kwargs.pop("project", None))
- group = kwargs.pop("group", os.environ.get(WANDB_GROUP_ENV_VAR))
- # Remove unpickleable items.
- _config = _clean_log(_config)
- wandb_init_kwargs = dict(
- id=trial_id,
- name=trial_name,
- resume=True,
- reinit=True,
- allow_val_change=True,
- config=_config,
- project=project,
- group=group,
- )
- # Update config (e.g. set any other parameters in the call to wandb.init)
- wandb_init_kwargs.update(**kwargs)
- # On windows, we can't fork
- if os.name == "nt":
- os.environ["WANDB_START_METHOD"] = "thread"
- else:
- os.environ["WANDB_START_METHOD"] = "fork"
- _wandb = _wandb or wandb
- run = _wandb.init(**wandb_init_kwargs)
- _run_wandb_process_run_info_hook(run)
- # Record `setup_wandb` usage when everything has setup successfully.
- air_usage.tag_setup_wandb()
- return run
- def _is_allowed_type(obj):
- """Return True if type is allowed for logging to wandb"""
- if isinstance(obj, np.ndarray) and obj.size == 1:
- return isinstance(obj.item(), Number)
- if isinstance(obj, Sequence) and len(obj) > 0:
- return isinstance(obj[0], (Image, Video, WBValue))
- return isinstance(obj, (Number, WBValue))
- def _clean_log(
- obj: Any,
- *,
- video_kwargs: Optional[Dict[str, Any]] = None,
- image_kwargs: Optional[Dict[str, Any]] = None,
- ):
- # Fixes https://github.com/ray-project/ray/issues/10631
- if video_kwargs is None:
- video_kwargs = {}
- if image_kwargs is None:
- image_kwargs = {}
- if isinstance(obj, dict):
- return {
- k: _clean_log(v, video_kwargs=video_kwargs, image_kwargs=image_kwargs)
- for k, v in obj.items()
- }
- elif isinstance(obj, (list, set)):
- return [
- _clean_log(v, video_kwargs=video_kwargs, image_kwargs=image_kwargs)
- for v in obj
- ]
- elif isinstance(obj, tuple):
- return tuple(
- _clean_log(v, video_kwargs=video_kwargs, image_kwargs=image_kwargs)
- for v in obj
- )
- elif isinstance(obj, np.ndarray) and obj.ndim == 3:
- # Must be single image (H, W, C).
- return Image(obj, **image_kwargs)
- elif isinstance(obj, np.ndarray) and obj.ndim == 4:
- # Must be batch of images (N >= 1, H, W, C).
- return (
- _clean_log(
- [Image(v, **image_kwargs) for v in obj],
- video_kwargs=video_kwargs,
- image_kwargs=image_kwargs,
- )
- if obj.shape[0] > 1
- else Image(obj[0], **image_kwargs)
- )
- elif isinstance(obj, np.ndarray) and obj.ndim == 5:
- # Must be batch of videos (N >= 1, T, C, W, H).
- return (
- _clean_log(
- [Video(v, **video_kwargs) for v in obj],
- video_kwargs=video_kwargs,
- image_kwargs=image_kwargs,
- )
- if obj.shape[0] > 1
- else Video(obj[0], **video_kwargs)
- )
- elif _is_allowed_type(obj):
- return obj
- # Else
- try:
- # This is what wandb uses internally. If we cannot dump
- # an object using this method, wandb will raise an exception.
- json_dumps_safer(obj)
- # This is probably unnecessary, but left here to be extra sure.
- pickle.dumps(obj)
- return obj
- except Exception:
- # give up, similar to _SafeFallBackEncoder
- fallback = str(obj)
- # Try to convert to int
- try:
- fallback = int(fallback)
- return fallback
- except ValueError:
- pass
- # Try to convert to float
- try:
- fallback = float(fallback)
- return fallback
- except ValueError:
- pass
- # Else, return string
- return fallback
- def _get_wandb_project(project: Optional[str] = None) -> Optional[str]:
- """Get W&B project from environment variable or external hook if not passed
- as and argument."""
- if (
- not project
- and not os.environ.get(WANDB_PROJECT_ENV_VAR)
- and os.environ.get(WANDB_POPULATE_RUN_LOCATION_HOOK)
- ):
- # Try to populate WANDB_PROJECT_ENV_VAR and WANDB_GROUP_ENV_VAR
- # from external hook
- try:
- load_class(os.environ[WANDB_POPULATE_RUN_LOCATION_HOOK])()
- except Exception as e:
- logger.exception(
- f"Error executing {WANDB_POPULATE_RUN_LOCATION_HOOK} to "
- f"populate {WANDB_PROJECT_ENV_VAR} and {WANDB_GROUP_ENV_VAR}: {e}",
- exc_info=e,
- )
- if not project and os.environ.get(WANDB_PROJECT_ENV_VAR):
- # Try to get project and group from environment variables if not
- # passed through WandbLoggerCallback.
- project = os.environ.get(WANDB_PROJECT_ENV_VAR)
- return project
- def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = None):
- """Set WandB API key from `wandb_config`. Will pop the
- `api_key_file` and `api_key` keys from `wandb_config` parameter.
- The order of fetching the API key is:
- 1) From `api_key` or `api_key_file` arguments
- 2) From WANDB_API_KEY environment variables
- 3) User already logged in to W&B (wandb.api.api_key set)
- 4) From external hook WANDB_SETUP_API_KEY_HOOK
- """
- if os.environ.get(WANDB_MODE_ENV_VAR) in {"offline", "disabled"}:
- return
- if api_key_file:
- if api_key:
- raise ValueError("Both WandB `api_key_file` and `api_key` set.")
- with open(api_key_file, "rt") as fp:
- api_key = fp.readline().strip()
- if not api_key and not os.environ.get(WANDB_ENV_VAR):
- # Check if user is already logged into wandb.
- try:
- wandb.ensure_configured()
- if wandb.api.api_key:
- logger.info("Already logged into W&B.")
- return
- except AttributeError:
- pass
- # Try to get API key from external hook
- if WANDB_SETUP_API_KEY_HOOK in os.environ:
- try:
- api_key = load_class(os.environ[WANDB_SETUP_API_KEY_HOOK])()
- except Exception as e:
- logger.exception(
- f"Error executing {WANDB_SETUP_API_KEY_HOOK} to setup API key: {e}",
- exc_info=e,
- )
- if api_key:
- os.environ[WANDB_ENV_VAR] = api_key
- elif not os.environ.get(WANDB_ENV_VAR):
- raise ValueError(
- "No WandB API key found. Either set the {} environment "
- "variable, pass `api_key` or `api_key_file` to the"
- "`WandbLoggerCallback` class as arguments, "
- "or run `wandb login` from the command line".format(WANDB_ENV_VAR)
- )
- def _run_wandb_process_run_info_hook(run: Any) -> None:
- """Run external hook to process information about wandb run"""
- if WANDB_PROCESS_RUN_INFO_HOOK in os.environ:
- try:
- load_class(os.environ[WANDB_PROCESS_RUN_INFO_HOOK])(run)
- except Exception as e:
- logger.exception(
- f"Error calling {WANDB_PROCESS_RUN_INFO_HOOK}: {e}", exc_info=e
- )
- class _QueueItem(enum.Enum):
- END = enum.auto()
- RESULT = enum.auto()
- CHECKPOINT = enum.auto()
- class _WandbLoggingActor:
- """
- Wandb assumes that each trial's information should be logged from a
- separate process. We use Ray actors as forking multiprocessing
- processes is not supported by Ray and spawn processes run into pickling
- problems.
- We use a queue for the driver to communicate with the logging process.
- The queue accepts the following items:
- - If it's a dict, it is assumed to be a result and will be logged using
- ``wandb.log()``
- - If it's a checkpoint object, it will be saved using ``wandb.log_artifact()``.
- """
- def __init__(
- self,
- logdir: str,
- queue: Queue,
- exclude: List[str],
- to_config: List[str],
- *args,
- **kwargs,
- ):
- import wandb
- self._wandb = wandb
- os.chdir(logdir)
- self.queue = queue
- self._exclude = set(exclude)
- self._to_config = set(to_config)
- self.args = args
- self.kwargs = kwargs
- self._trial_name = self.kwargs.get("name", "unknown")
- self._logdir = logdir
- def run(self):
- # Since we're running in a separate process already, use threads.
- os.environ["WANDB_START_METHOD"] = "thread"
- run = self._wandb.init(*self.args, **self.kwargs)
- run.config.trial_log_path = self._logdir
- _run_wandb_process_run_info_hook(run)
- while True:
- item_type, item_content = self.queue.get()
- if item_type == _QueueItem.END:
- break
- if item_type == _QueueItem.CHECKPOINT:
- self._handle_checkpoint(item_content)
- continue
- assert item_type == _QueueItem.RESULT
- log, config_update = self._handle_result(item_content)
- try:
- self._wandb.config.update(config_update, allow_val_change=True)
- self._wandb.log(log, step=log.get(TRAINING_ITERATION))
- except urllib.error.HTTPError as e:
- # Ignore HTTPError. Missing a few data points is not a
- # big issue, as long as things eventually recover.
- logger.warning("Failed to log result to w&b: {}".format(str(e)))
- except FileNotFoundError as e:
- logger.error(
- "FileNotFoundError: Did not log result to Weights & Biases. "
- "Possible cause: relative file path used instead of absolute path. "
- "Error: %s",
- e,
- )
- self._wandb.finish()
- def _handle_checkpoint(self, checkpoint_path: str):
- artifact = self._wandb.Artifact(
- name=f"checkpoint_{self._trial_name}", type="model"
- )
- artifact.add_dir(checkpoint_path)
- self._wandb.log_artifact(artifact)
- def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
- config_update = result.get("config", {}).copy()
- log = {}
- flat_result = flatten_dict(result, delimiter="/")
- for k, v in flat_result.items():
- if any(k.startswith(item + "/") or k == item for item in self._exclude):
- continue
- elif any(k.startswith(item + "/") or k == item for item in self._to_config):
- config_update[k] = v
- elif not _is_allowed_type(v):
- continue
- else:
- log[k] = v
- config_update.pop("callbacks", None) # Remove callbacks
- return log, config_update
- @PublicAPI(stability="alpha")
- class WandbLoggerCallback(LoggerCallback):
- """WandbLoggerCallback
- Weights and biases (https://www.wandb.ai/) is a tool for experiment
- tracking, model optimization, and dataset versioning. This Ray Tune
- ``LoggerCallback`` sends metrics to Wandb for automatic tracking and
- visualization.
- Example:
- .. testcode::
- import random
- from ray import tune
- from ray.air.integrations.wandb import WandbLoggerCallback
- def train_func(config):
- offset = random.random() / 5
- for epoch in range(2, config["epochs"]):
- acc = 1 - (2 + config["lr"]) ** -epoch - random.random() / epoch - offset
- loss = (2 + config["lr"]) ** -epoch + random.random() / epoch + offset
- train.report({"acc": acc, "loss": loss})
- tuner = tune.Tuner(
- train_func,
- param_space={
- "lr": tune.grid_search([0.001, 0.01, 0.1, 1.0]),
- "epochs": 10,
- },
- run_config=tune.RunConfig(
- callbacks=[WandbLoggerCallback(project="Optimization_Project")]
- ),
- )
- results = tuner.fit()
- .. testoutput::
- :hide:
- ...
- Args:
- project: Name of the Wandb project. Mandatory.
- group: Name of the Wandb group. Defaults to the trainable
- name.
- api_key_file: Path to file containing the Wandb API KEY. This
- file only needs to be present on the node running the Tune script
- if using the WandbLogger.
- api_key: Wandb API Key. Alternative to setting ``api_key_file``.
- excludes: List of metrics and config that should be excluded from
- the log.
- log_config: Boolean indicating if the ``config`` parameter of
- the ``results`` dict should be logged. This makes sense if
- parameters will change during training, e.g. with
- PopulationBasedTraining. Defaults to False.
- upload_checkpoints: If ``True``, model checkpoints will be uploaded to
- Wandb as artifacts. Defaults to ``False``.
- video_kwargs: Dictionary of keyword arguments passed to wandb.Video()
- when logging videos. Videos have to be logged as 5D numpy arrays
- to be affected by this parameter. For valid keyword arguments, see
- https://docs.wandb.ai/ref/python/data-types/video/. Defaults to ``None``.
- image_kwargs: Dictionary of keyword arguments passed to wandb.Image()
- when logging images. Images have to be logged as 3D or 4D numpy arrays
- to be affected by this parameter. For valid keyword arguments, see
- https://docs.wandb.ai/ref/python/data-types/image/. Defaults to ``None``.
- **kwargs: The keyword arguments will be passed to ``wandb.init()``.
- Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
- by Tune, but can be overwritten by filling out the respective configuration
- values.
- Please see here for all other valid configuration settings:
- https://docs.wandb.ai/ref/python/init/
- """ # noqa: E501
- # Do not log these result keys
- _exclude_results = ["done", "should_checkpoint"]
- AUTO_CONFIG_KEYS = [
- "trial_id",
- "experiment_tag",
- "node_ip",
- "experiment_id",
- "hostname",
- "pid",
- "date",
- ]
- """Results that are saved with `wandb.config` instead of `wandb.log`."""
- _logger_actor_cls = _WandbLoggingActor
- def __init__(
- self,
- project: Optional[str] = None,
- group: Optional[str] = None,
- api_key_file: Optional[str] = None,
- api_key: Optional[str] = None,
- excludes: Optional[List[str]] = None,
- log_config: bool = False,
- upload_checkpoints: bool = False,
- save_checkpoints: bool = False,
- upload_timeout: int = DEFAULT_SYNC_TIMEOUT,
- video_kwargs: Optional[dict] = None,
- image_kwargs: Optional[dict] = None,
- **kwargs,
- ):
- if not wandb:
- raise RuntimeError(
- "Wandb was not found - please install with `pip install wandb`"
- )
- if save_checkpoints:
- warnings.warn(
- "`save_checkpoints` is deprecated. Use `upload_checkpoints` instead.",
- DeprecationWarning,
- )
- upload_checkpoints = save_checkpoints
- self.project = project
- self.group = group
- self.api_key_path = api_key_file
- self.api_key = api_key
- self.excludes = excludes or []
- self.log_config = log_config
- self.upload_checkpoints = upload_checkpoints
- self._upload_timeout = upload_timeout
- self.video_kwargs = video_kwargs or {}
- self.image_kwargs = image_kwargs or {}
- self.kwargs = kwargs
- self._remote_logger_class = None
- self._trial_logging_actors: Dict[
- "Trial", ray.actor.ActorHandle[_WandbLoggingActor]
- ] = {}
- self._trial_logging_futures: Dict["Trial", ray.ObjectRef] = {}
- self._logging_future_to_trial: Dict[ray.ObjectRef, "Trial"] = {}
- self._trial_queues: Dict["Trial", Queue] = {}
- def setup(self, *args, **kwargs):
- self.api_key_file = (
- os.path.expanduser(self.api_key_path) if self.api_key_path else None
- )
- _set_api_key(self.api_key_file, self.api_key)
- self.project = _get_wandb_project(self.project)
- if not self.project:
- raise ValueError(
- "Please pass the project name as argument or through "
- f"the {WANDB_PROJECT_ENV_VAR} environment variable."
- )
- if not self.group and os.environ.get(WANDB_GROUP_ENV_VAR):
- self.group = os.environ.get(WANDB_GROUP_ENV_VAR)
- def log_trial_start(self, trial: "Trial"):
- config = trial.config.copy()
- config.pop("callbacks", None) # Remove callbacks
- exclude_results = self._exclude_results.copy()
- # Additional excludes
- exclude_results += self.excludes
- # Log config keys on each result?
- if not self.log_config:
- exclude_results += ["config"]
- # Fill trial ID and name
- trial_id = trial.trial_id if trial else None
- trial_name = str(trial) if trial else None
- # Project name for Wandb
- wandb_project = self.project
- # Grouping
- wandb_group = self.group or trial.experiment_dir_name if trial else None
- # remove unpickleable items!
- config = _clean_log(config)
- config = {
- key: value for key, value in config.items() if key not in self.excludes
- }
- wandb_init_kwargs = dict(
- id=trial_id,
- name=trial_name,
- resume=False,
- reinit=True,
- allow_val_change=True,
- group=wandb_group,
- project=wandb_project,
- config=config,
- )
- wandb_init_kwargs.update(self.kwargs)
- self._start_logging_actor(trial, exclude_results, **wandb_init_kwargs)
- def _start_logging_actor(
- self, trial: "Trial", exclude_results: List[str], **wandb_init_kwargs
- ):
- # Reuse actor if one already exists.
- # This can happen if the trial is restarted.
- if trial in self._trial_logging_futures:
- return
- if not self._remote_logger_class:
- env_vars = {}
- # API key env variable is not set if authenticating through `wandb login`
- if WANDB_ENV_VAR in os.environ:
- env_vars[WANDB_ENV_VAR] = os.environ[WANDB_ENV_VAR]
- self._remote_logger_class = ray.remote(
- num_cpus=0,
- **_force_on_current_node(),
- runtime_env={"env_vars": env_vars},
- max_restarts=-1,
- max_task_retries=-1,
- )(self._logger_actor_cls)
- self._trial_queues[trial] = Queue(
- actor_options={
- "num_cpus": 0,
- **_force_on_current_node(),
- "max_restarts": -1,
- "max_task_retries": -1,
- }
- )
- self._trial_logging_actors[trial] = self._remote_logger_class.remote(
- logdir=trial.local_path,
- queue=self._trial_queues[trial],
- exclude=exclude_results,
- to_config=self.AUTO_CONFIG_KEYS,
- **wandb_init_kwargs,
- )
- logging_future = self._trial_logging_actors[trial].run.remote()
- self._trial_logging_futures[trial] = logging_future
- self._logging_future_to_trial[logging_future] = trial
- def _signal_logging_actor_stop(self, trial: "Trial"):
- self._trial_queues[trial].put((_QueueItem.END, None))
- def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
- if trial not in self._trial_logging_actors:
- self.log_trial_start(trial)
- result = _clean_log(
- result, video_kwargs=self.video_kwargs, image_kwargs=self.image_kwargs
- )
- self._trial_queues[trial].put((_QueueItem.RESULT, result))
- def log_trial_save(self, trial: "Trial"):
- if self.upload_checkpoints and trial.checkpoint:
- checkpoint_root = None
- if isinstance(trial.checkpoint.filesystem, pyarrow.fs.LocalFileSystem):
- checkpoint_root = trial.checkpoint.path
- if checkpoint_root:
- self._trial_queues[trial].put((_QueueItem.CHECKPOINT, checkpoint_root))
- def log_trial_end(self, trial: "Trial", failed: bool = False):
- self._signal_logging_actor_stop(trial=trial)
- self._cleanup_logging_actors()
- def _cleanup_logging_actor(self, trial: "Trial"):
- del self._trial_queues[trial]
- del self._trial_logging_futures[trial]
- ray.kill(self._trial_logging_actors[trial])
- del self._trial_logging_actors[trial]
- def _cleanup_logging_actors(self, timeout: int = 0, kill_on_timeout: bool = False):
- """Clean up logging actors that have finished uploading to wandb.
- Waits for `timeout` seconds to collect finished logging actors.
- Args:
- timeout: The number of seconds to wait. Defaults to 0 to clean up
- any immediate logging actors during the run.
- This is set to a timeout threshold to wait for pending uploads
- on experiment end.
- kill_on_timeout: Whether or not to kill and cleanup the logging actor if
- it hasn't finished within the timeout.
- """
- futures = list(self._trial_logging_futures.values())
- done, remaining = ray.wait(futures, num_returns=len(futures), timeout=timeout)
- for ready_future in done:
- finished_trial = self._logging_future_to_trial.pop(ready_future)
- self._cleanup_logging_actor(finished_trial)
- if kill_on_timeout:
- for remaining_future in remaining:
- trial = self._logging_future_to_trial.pop(remaining_future)
- self._cleanup_logging_actor(trial)
- def on_experiment_end(self, trials: List["Trial"], **info):
- """Wait for the actors to finish their call to `wandb.finish`.
- This includes uploading all logs + artifacts to wandb."""
- self._cleanup_logging_actors(timeout=self._upload_timeout, kill_on_timeout=True)
- def __del__(self):
- if ray.is_initialized():
- for trial in list(self._trial_logging_actors):
- self._signal_logging_actor_stop(trial=trial)
- self._cleanup_logging_actors(timeout=2, kill_on_timeout=True)
- self._trial_logging_actors = {}
- self._trial_logging_futures = {}
- self._logging_future_to_trial = {}
- self._trial_queues = {}
|