from __future__ import annotations import abc import asyncio import datetime import functools import logging import os import random import threading import time from collections.abc import Awaitable from typing import Any, Callable, Generic, TypeVar import wandb import wandb.errors from wandb.util import CheckRetryFnType logger = logging.getLogger(__name__) # To let tests mock out the retry logic's now()/sleep() funcs, this file # should only use these variables, not call the stdlib funcs directly. NOW_FN = datetime.datetime.now SLEEP_FN = time.sleep SLEEP_ASYNC_FN = asyncio.sleep class RetryCancelledError(wandb.errors.Error): """A retry did not occur because it was cancelled.""" class TransientError(Exception): """Exception type designated for errors that may only be temporary. Can have its own message and/or wrap another exception. """ def __init__( self, msg: str | None = None, exc: BaseException | None = None ) -> None: super().__init__(msg) self.message = msg self.exception = exc _R = TypeVar("_R") class Retry(Generic[_R]): """Create a retryable version of a function. Calling this will call the passed function, retrying if any exceptions in retryable_exceptions are caught, with exponential backoff. """ MAX_SLEEP_SECONDS = 5 * 60 def __init__( self, call_fn: Callable[..., _R], retry_timedelta: datetime.timedelta | None = None, retry_cancel_event: threading.Event | None = None, num_retries: int | None = None, check_retry_fn: CheckRetryFnType = lambda e: True, retryable_exceptions: tuple[type[Exception], ...] | None = None, error_prefix: str = "Network error", retry_callback: Callable[[int, str], Any] | None = None, ) -> None: self._call_fn = call_fn self._check_retry_fn = check_retry_fn self._error_prefix = error_prefix self._last_print = datetime.datetime.now() - datetime.timedelta(minutes=1) self._retry_timedelta = retry_timedelta self._retry_cancel_event = retry_cancel_event self._num_retries = num_retries if retryable_exceptions is not None: self._retryable_exceptions = retryable_exceptions else: self._retryable_exceptions = (TransientError,) self.retry_callback = retry_callback self._num_iter = 0 def _sleep_check_cancelled( self, wait_seconds: float, cancel_event: threading.Event | None ) -> bool: if not cancel_event: SLEEP_FN(wait_seconds) return False cancelled = cancel_event.wait(wait_seconds) return cancelled @property def num_iters(self) -> int: """The number of iterations the previous __call__ retried.""" return self._num_iter def __call__( self, *args: Any, num_retries: int | None = None, retry_timedelta: datetime.timedelta | None = None, retry_sleep_base: float | None = None, retry_cancel_event: threading.Event | None = None, check_retry_fn: CheckRetryFnType | None = None, **kwargs: Any, ) -> _R: """Call the wrapped function, with retries. Args: num_retries: The number of retries after which to give up. retry_timedelta: An amount of time after which to give up. retry_sleep_base: Number of seconds to sleep for the first retry. This is used as the base for exponential backoff. retry_cancel_event: An event that causes this to raise a RetryCancelledException on the next attempted retry. check_retry_fn: A custom check for deciding whether an exception should be retried. Retrying is prevented if this returns a falsy value, even if more retries are left. This may also return a timedelta that represents a shorter timeout: retrying is prevented if the value is less than the amount of time that has passed since the last timedelta was returned. """ if os.environ.get("WANDB_TEST"): max_retries = 0 elif num_retries is not None: max_retries = num_retries elif self._num_retries is not None: max_retries = self._num_retries else: max_retries = 1000000 if retry_timedelta is not None: timeout = retry_timedelta elif self._retry_timedelta is not None: timeout = self._retry_timedelta else: timeout = datetime.timedelta(days=365) if retry_sleep_base is not None: initial_sleep = retry_sleep_base else: initial_sleep = 1 retry_loop = _RetryLoop( max_retries=max_retries, timeout=timeout, initial_sleep=initial_sleep, max_sleep=self.MAX_SLEEP_SECONDS, cancel_event=retry_cancel_event or self._retry_cancel_event, retry_check=check_retry_fn or self._check_retry_fn, ) start_time = NOW_FN() self._num_iter = 0 while True: try: result = self._call_fn(*args, **kwargs) except self._retryable_exceptions as e: if not retry_loop.should_retry(e): raise if self._num_iter == 2: logger.info("Retry attempt failed:", exc_info=e) self._print_entered_retry_loop(e) retry_loop.wait_before_retry() self._num_iter += 1 else: if self._num_iter > 2: self._print_recovered(start_time) return result def _print_entered_retry_loop(self, exception: Exception) -> None: """Emit a message saying we've begun retrying. Either calls the retry callback or prints a warning to console. Args: exception: The most recent exception we will retry. """ from requests import HTTPError if ( isinstance(exception, HTTPError) and exception.response is not None and self.retry_callback is not None ): self.retry_callback( exception.response.status_code, exception.response.text, ) else: wandb.termlog( f"{self._error_prefix}" + f" ({exception.__class__.__name__}), entering retry loop." ) def _print_recovered(self, start_time: datetime.datetime) -> None: """Emit a message saying we've recovered after retrying. Args: start_time: When we started retrying. """ if not self.retry_callback: return now = NOW_FN() if now - self._last_print < datetime.timedelta(minutes=1): return self._last_print = now time_to_recover = now - start_time self.retry_callback( 200, ( f"{self._error_prefix} resolved after" f" {time_to_recover}, resuming normal operation." ), ) class _RetryLoop: """An invocation of a Retry instance.""" def __init__( self, *, max_retries: int, timeout: datetime.timedelta, initial_sleep: float, max_sleep: float, cancel_event: threading.Event | None, retry_check: CheckRetryFnType, ) -> None: """Start a new call of a Retry instance. Args: max_retries: The number of retries after which to give up. timeout: An amount of time after which to give up. initial_sleep: Number of seconds to sleep for the first retry. This is used as the base for exponential backoff. max_sleep: Maximum number of seconds to sleep between retries. cancel_event: An event that's set when the function is cancelled. retry_check: A custom check for deciding whether an exception should be retried. Retrying is prevented if this returns a falsy value, even if more retries are left. This may also return a timedelta that represents a shorter timeout: retrying is prevented if the value is less than the amount of time that has passed since the last timedelta was returned. """ self._max_retries = max_retries self._total_retries = 0 self._timeout = timeout self._start_time = NOW_FN() self._next_sleep_time = initial_sleep self._max_sleep = max_sleep self._cancel_event = cancel_event self._retry_check = retry_check self._last_custom_timeout: datetime.datetime | None = None def should_retry(self, exception: Exception) -> bool: """Returns whether an exception should be retried.""" if self._total_retries >= self._max_retries: return False self._total_retries += 1 now = NOW_FN() if now - self._start_time >= self._timeout: return False retry_check_result = self._retry_check(exception) if not retry_check_result: return False if isinstance(retry_check_result, datetime.timedelta): if not self._last_custom_timeout: self._last_custom_timeout = now if now - self._last_custom_timeout >= retry_check_result: return False return True def wait_before_retry(self) -> None: """Block until the next retry should happen. Raises: RetryCancelledError: If the operation is cancelled. """ sleep_amount = self._next_sleep_time * (1 + random.random() * 0.25) if self._cancel_event: cancelled = self._cancel_event.wait(sleep_amount) if cancelled: raise RetryCancelledError("Cancelled while retrying.") else: SLEEP_FN(sleep_amount) self._next_sleep_time *= 2 if self._next_sleep_time > self._max_sleep: self._next_sleep_time = self._max_sleep _F = TypeVar("_F", bound=Callable) def retriable(*args: Any, **kargs: Any) -> Callable[[_F], _F]: def decorator(fn: _F) -> _F: retrier: Retry[Any] = Retry(fn, *args, **kargs) @functools.wraps(fn) def wrapped_fn(*args: Any, **kargs: Any) -> Any: return retrier(*args, **kargs) return wrapped_fn # type: ignore return decorator class Backoff(abc.ABC): """A backoff strategy: decides whether to sleep or give up when an exception is raised.""" @abc.abstractmethod def next_sleep_or_reraise(self, exc: Exception) -> datetime.timedelta: raise NotImplementedError # pragma: no cover class ExponentialBackoff(Backoff): """Jittered exponential backoff: sleep times increase ~exponentially up to some limit.""" def __init__( self, initial_sleep: datetime.timedelta, max_sleep: datetime.timedelta, max_retries: int | None = None, timeout_at: datetime.datetime | None = None, ) -> None: self._next_sleep = min(max_sleep, initial_sleep) self._max_sleep = max_sleep self._remaining_retries = max_retries self._timeout_at = timeout_at def next_sleep_or_reraise(self, exc: Exception) -> datetime.timedelta: if self._remaining_retries is not None: if self._remaining_retries <= 0: raise exc self._remaining_retries -= 1 if self._timeout_at is not None and NOW_FN() > self._timeout_at: raise exc result, self._next_sleep = ( self._next_sleep, min(self._max_sleep, self._next_sleep * (1 + random.random())), ) return result class FilteredBackoff(Backoff): """Re-raise any exceptions that fail a predicate; delegate others to another Backoff.""" def __init__(self, filter: Callable[[Exception], bool], wrapped: Backoff) -> None: self._filter = filter self._wrapped = wrapped def next_sleep_or_reraise(self, exc: Exception) -> datetime.timedelta: if not self._filter(exc): raise exc return self._wrapped.next_sleep_or_reraise(exc) async def retry_async( backoff: Backoff, fn: Callable[..., Awaitable[_R]], *args: Any, on_exc: Callable[[Exception], None] | None = None, **kwargs: Any, ) -> _R: """Call `fn` repeatedly until either it succeeds, or `backoff` decides we should give up. Each time `fn` fails, `on_exc` is called with the exception. """ while True: try: return await fn(*args, **kwargs) except Exception as e: if on_exc is not None: on_exc(e) await SLEEP_ASYNC_FN(backoff.next_sleep_or_reraise(e).total_seconds())