retry.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. from __future__ import annotations
  2. import abc
  3. import asyncio
  4. import datetime
  5. import functools
  6. import logging
  7. import os
  8. import random
  9. import threading
  10. import time
  11. from collections.abc import Awaitable
  12. from typing import Any, Callable, Generic, TypeVar
  13. import wandb
  14. import wandb.errors
  15. from wandb.util import CheckRetryFnType
  16. logger = logging.getLogger(__name__)
  17. # To let tests mock out the retry logic's now()/sleep() funcs, this file
  18. # should only use these variables, not call the stdlib funcs directly.
  19. NOW_FN = datetime.datetime.now
  20. SLEEP_FN = time.sleep
  21. SLEEP_ASYNC_FN = asyncio.sleep
  22. class RetryCancelledError(wandb.errors.Error):
  23. """A retry did not occur because it was cancelled."""
  24. class TransientError(Exception):
  25. """Exception type designated for errors that may only be temporary.
  26. Can have its own message and/or wrap another exception.
  27. """
  28. def __init__(
  29. self, msg: str | None = None, exc: BaseException | None = None
  30. ) -> None:
  31. super().__init__(msg)
  32. self.message = msg
  33. self.exception = exc
  34. _R = TypeVar("_R")
  35. class Retry(Generic[_R]):
  36. """Create a retryable version of a function.
  37. Calling this will call the passed function, retrying if any exceptions in
  38. retryable_exceptions are caught, with exponential backoff.
  39. """
  40. MAX_SLEEP_SECONDS = 5 * 60
  41. def __init__(
  42. self,
  43. call_fn: Callable[..., _R],
  44. retry_timedelta: datetime.timedelta | None = None,
  45. retry_cancel_event: threading.Event | None = None,
  46. num_retries: int | None = None,
  47. check_retry_fn: CheckRetryFnType = lambda e: True,
  48. retryable_exceptions: tuple[type[Exception], ...] | None = None,
  49. error_prefix: str = "Network error",
  50. retry_callback: Callable[[int, str], Any] | None = None,
  51. ) -> None:
  52. self._call_fn = call_fn
  53. self._check_retry_fn = check_retry_fn
  54. self._error_prefix = error_prefix
  55. self._last_print = datetime.datetime.now() - datetime.timedelta(minutes=1)
  56. self._retry_timedelta = retry_timedelta
  57. self._retry_cancel_event = retry_cancel_event
  58. self._num_retries = num_retries
  59. if retryable_exceptions is not None:
  60. self._retryable_exceptions = retryable_exceptions
  61. else:
  62. self._retryable_exceptions = (TransientError,)
  63. self.retry_callback = retry_callback
  64. self._num_iter = 0
  65. def _sleep_check_cancelled(
  66. self, wait_seconds: float, cancel_event: threading.Event | None
  67. ) -> bool:
  68. if not cancel_event:
  69. SLEEP_FN(wait_seconds)
  70. return False
  71. cancelled = cancel_event.wait(wait_seconds)
  72. return cancelled
  73. @property
  74. def num_iters(self) -> int:
  75. """The number of iterations the previous __call__ retried."""
  76. return self._num_iter
  77. def __call__(
  78. self,
  79. *args: Any,
  80. num_retries: int | None = None,
  81. retry_timedelta: datetime.timedelta | None = None,
  82. retry_sleep_base: float | None = None,
  83. retry_cancel_event: threading.Event | None = None,
  84. check_retry_fn: CheckRetryFnType | None = None,
  85. **kwargs: Any,
  86. ) -> _R:
  87. """Call the wrapped function, with retries.
  88. Args:
  89. num_retries: The number of retries after which to give up.
  90. retry_timedelta: An amount of time after which to give up.
  91. retry_sleep_base: Number of seconds to sleep for the first retry.
  92. This is used as the base for exponential backoff.
  93. retry_cancel_event: An event that causes this to raise
  94. a RetryCancelledException on the next attempted retry.
  95. check_retry_fn: A custom check for deciding whether an exception
  96. should be retried. Retrying is prevented if this returns a falsy
  97. value, even if more retries are left. This may also return a
  98. timedelta that represents a shorter timeout: retrying is
  99. prevented if the value is less than the amount of time that has
  100. passed since the last timedelta was returned.
  101. """
  102. if os.environ.get("WANDB_TEST"):
  103. max_retries = 0
  104. elif num_retries is not None:
  105. max_retries = num_retries
  106. elif self._num_retries is not None:
  107. max_retries = self._num_retries
  108. else:
  109. max_retries = 1000000
  110. if retry_timedelta is not None:
  111. timeout = retry_timedelta
  112. elif self._retry_timedelta is not None:
  113. timeout = self._retry_timedelta
  114. else:
  115. timeout = datetime.timedelta(days=365)
  116. if retry_sleep_base is not None:
  117. initial_sleep = retry_sleep_base
  118. else:
  119. initial_sleep = 1
  120. retry_loop = _RetryLoop(
  121. max_retries=max_retries,
  122. timeout=timeout,
  123. initial_sleep=initial_sleep,
  124. max_sleep=self.MAX_SLEEP_SECONDS,
  125. cancel_event=retry_cancel_event or self._retry_cancel_event,
  126. retry_check=check_retry_fn or self._check_retry_fn,
  127. )
  128. start_time = NOW_FN()
  129. self._num_iter = 0
  130. while True:
  131. try:
  132. result = self._call_fn(*args, **kwargs)
  133. except self._retryable_exceptions as e:
  134. if not retry_loop.should_retry(e):
  135. raise
  136. if self._num_iter == 2:
  137. logger.info("Retry attempt failed:", exc_info=e)
  138. self._print_entered_retry_loop(e)
  139. retry_loop.wait_before_retry()
  140. self._num_iter += 1
  141. else:
  142. if self._num_iter > 2:
  143. self._print_recovered(start_time)
  144. return result
  145. def _print_entered_retry_loop(self, exception: Exception) -> None:
  146. """Emit a message saying we've begun retrying.
  147. Either calls the retry callback or prints a warning to console.
  148. Args:
  149. exception: The most recent exception we will retry.
  150. """
  151. from requests import HTTPError
  152. if (
  153. isinstance(exception, HTTPError)
  154. and exception.response is not None
  155. and self.retry_callback is not None
  156. ):
  157. self.retry_callback(
  158. exception.response.status_code,
  159. exception.response.text,
  160. )
  161. else:
  162. wandb.termlog(
  163. f"{self._error_prefix}"
  164. + f" ({exception.__class__.__name__}), entering retry loop."
  165. )
  166. def _print_recovered(self, start_time: datetime.datetime) -> None:
  167. """Emit a message saying we've recovered after retrying.
  168. Args:
  169. start_time: When we started retrying.
  170. """
  171. if not self.retry_callback:
  172. return
  173. now = NOW_FN()
  174. if now - self._last_print < datetime.timedelta(minutes=1):
  175. return
  176. self._last_print = now
  177. time_to_recover = now - start_time
  178. self.retry_callback(
  179. 200,
  180. (
  181. f"{self._error_prefix} resolved after"
  182. f" {time_to_recover}, resuming normal operation."
  183. ),
  184. )
  185. class _RetryLoop:
  186. """An invocation of a Retry instance."""
  187. def __init__(
  188. self,
  189. *,
  190. max_retries: int,
  191. timeout: datetime.timedelta,
  192. initial_sleep: float,
  193. max_sleep: float,
  194. cancel_event: threading.Event | None,
  195. retry_check: CheckRetryFnType,
  196. ) -> None:
  197. """Start a new call of a Retry instance.
  198. Args:
  199. max_retries: The number of retries after which to give up.
  200. timeout: An amount of time after which to give up.
  201. initial_sleep: Number of seconds to sleep for the first retry.
  202. This is used as the base for exponential backoff.
  203. max_sleep: Maximum number of seconds to sleep between retries.
  204. cancel_event: An event that's set when the function is cancelled.
  205. retry_check: A custom check for deciding whether an exception should
  206. be retried. Retrying is prevented if this returns a falsy value,
  207. even if more retries are left. This may also return a timedelta
  208. that represents a shorter timeout: retrying is prevented if the
  209. value is less than the amount of time that has passed since the
  210. last timedelta was returned.
  211. """
  212. self._max_retries = max_retries
  213. self._total_retries = 0
  214. self._timeout = timeout
  215. self._start_time = NOW_FN()
  216. self._next_sleep_time = initial_sleep
  217. self._max_sleep = max_sleep
  218. self._cancel_event = cancel_event
  219. self._retry_check = retry_check
  220. self._last_custom_timeout: datetime.datetime | None = None
  221. def should_retry(self, exception: Exception) -> bool:
  222. """Returns whether an exception should be retried."""
  223. if self._total_retries >= self._max_retries:
  224. return False
  225. self._total_retries += 1
  226. now = NOW_FN()
  227. if now - self._start_time >= self._timeout:
  228. return False
  229. retry_check_result = self._retry_check(exception)
  230. if not retry_check_result:
  231. return False
  232. if isinstance(retry_check_result, datetime.timedelta):
  233. if not self._last_custom_timeout:
  234. self._last_custom_timeout = now
  235. if now - self._last_custom_timeout >= retry_check_result:
  236. return False
  237. return True
  238. def wait_before_retry(self) -> None:
  239. """Block until the next retry should happen.
  240. Raises:
  241. RetryCancelledError: If the operation is cancelled.
  242. """
  243. sleep_amount = self._next_sleep_time * (1 + random.random() * 0.25)
  244. if self._cancel_event:
  245. cancelled = self._cancel_event.wait(sleep_amount)
  246. if cancelled:
  247. raise RetryCancelledError("Cancelled while retrying.")
  248. else:
  249. SLEEP_FN(sleep_amount)
  250. self._next_sleep_time *= 2
  251. if self._next_sleep_time > self._max_sleep:
  252. self._next_sleep_time = self._max_sleep
  253. _F = TypeVar("_F", bound=Callable)
  254. def retriable(*args: Any, **kargs: Any) -> Callable[[_F], _F]:
  255. def decorator(fn: _F) -> _F:
  256. retrier: Retry[Any] = Retry(fn, *args, **kargs)
  257. @functools.wraps(fn)
  258. def wrapped_fn(*args: Any, **kargs: Any) -> Any:
  259. return retrier(*args, **kargs)
  260. return wrapped_fn # type: ignore
  261. return decorator
  262. class Backoff(abc.ABC):
  263. """A backoff strategy: decides whether to sleep or give up when an exception is raised."""
  264. @abc.abstractmethod
  265. def next_sleep_or_reraise(self, exc: Exception) -> datetime.timedelta:
  266. raise NotImplementedError # pragma: no cover
  267. class ExponentialBackoff(Backoff):
  268. """Jittered exponential backoff: sleep times increase ~exponentially up to some limit."""
  269. def __init__(
  270. self,
  271. initial_sleep: datetime.timedelta,
  272. max_sleep: datetime.timedelta,
  273. max_retries: int | None = None,
  274. timeout_at: datetime.datetime | None = None,
  275. ) -> None:
  276. self._next_sleep = min(max_sleep, initial_sleep)
  277. self._max_sleep = max_sleep
  278. self._remaining_retries = max_retries
  279. self._timeout_at = timeout_at
  280. def next_sleep_or_reraise(self, exc: Exception) -> datetime.timedelta:
  281. if self._remaining_retries is not None:
  282. if self._remaining_retries <= 0:
  283. raise exc
  284. self._remaining_retries -= 1
  285. if self._timeout_at is not None and NOW_FN() > self._timeout_at:
  286. raise exc
  287. result, self._next_sleep = (
  288. self._next_sleep,
  289. min(self._max_sleep, self._next_sleep * (1 + random.random())),
  290. )
  291. return result
  292. class FilteredBackoff(Backoff):
  293. """Re-raise any exceptions that fail a predicate; delegate others to another Backoff."""
  294. def __init__(self, filter: Callable[[Exception], bool], wrapped: Backoff) -> None:
  295. self._filter = filter
  296. self._wrapped = wrapped
  297. def next_sleep_or_reraise(self, exc: Exception) -> datetime.timedelta:
  298. if not self._filter(exc):
  299. raise exc
  300. return self._wrapped.next_sleep_or_reraise(exc)
  301. async def retry_async(
  302. backoff: Backoff,
  303. fn: Callable[..., Awaitable[_R]],
  304. *args: Any,
  305. on_exc: Callable[[Exception], None] | None = None,
  306. **kwargs: Any,
  307. ) -> _R:
  308. """Call `fn` repeatedly until either it succeeds, or `backoff` decides we should give up.
  309. Each time `fn` fails, `on_exc` is called with the exception.
  310. """
  311. while True:
  312. try:
  313. return await fn(*args, **kwargs)
  314. except Exception as e:
  315. if on_exc is not None:
  316. on_exc(e)
  317. await SLEEP_ASYNC_FN(backoff.next_sleep_or_reraise(e).total_seconds())