| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- import copy
- import logging
- import os
- import queue
- import threading
- from typing import Optional
- import numpy as np
- from ray.air.constants import _ERROR_REPORT_TIMEOUT
- logger = logging.getLogger(__name__)
- def is_nan(value):
- return np.isnan(value)
- def is_nan_or_inf(value):
- return is_nan(value) or np.isinf(value)
- class StartTraceback(Exception):
- """These exceptions (and their tracebacks) can be skipped with `skip_exceptions`"""
- pass
- class StartTracebackWithWorkerRank(StartTraceback):
- def __init__(self, worker_rank: int) -> None:
- super().__init__()
- self.worker_rank = worker_rank
- def __reduce__(self):
- return (self.__class__, (self.worker_rank,))
- def skip_exceptions(exc: Optional[Exception]) -> Exception:
- """Skip all contained `StartTracebacks` to reduce traceback output.
- Returns a shallow copy of the exception with all `StartTracebacks` removed.
- If the RAY_AIR_FULL_TRACEBACKS environment variable is set,
- the original exception (not a copy) is returned.
- """
- should_not_shorten = bool(int(os.environ.get("RAY_AIR_FULL_TRACEBACKS", "0")))
- if should_not_shorten:
- return exc
- if isinstance(exc, StartTraceback):
- # If this is a StartTraceback, skip
- return skip_exceptions(exc.__cause__)
- # Perform a shallow copy to prevent recursive __cause__/__context__.
- new_exc = copy.copy(exc).with_traceback(exc.__traceback__)
- # Make sure nested exceptions are properly skipped.
- cause = getattr(exc, "__cause__", None)
- if cause:
- new_exc.__cause__ = skip_exceptions(cause)
- return new_exc
- def exception_cause(exc: Optional[Exception]) -> Optional[Exception]:
- if not exc:
- return None
- return getattr(exc, "__cause__", None)
- class RunnerThread(threading.Thread):
- """Supervisor thread that runs your script."""
- def __init__(self, *args, error_queue, **kwargs):
- threading.Thread.__init__(self, *args, **kwargs)
- self._error_queue = error_queue
- self._ret = None
- def _propagate_exception(self, e: BaseException):
- try:
- # report the error but avoid indefinite blocking which would
- # prevent the exception from being propagated in the unlikely
- # case that something went terribly wrong
- self._error_queue.put(e, block=True, timeout=_ERROR_REPORT_TIMEOUT)
- except queue.Full:
- logger.critical(
- (
- "Runner Thread was unable to report error to main "
- "function runner thread. This means a previous error "
- "was not processed. This should never happen."
- )
- )
- def run(self):
- try:
- self._ret = self._target(*self._args, **self._kwargs)
- except StopIteration:
- logger.debug(
- (
- "Thread runner raised StopIteration. Interpreting it as a "
- "signal to terminate the thread without error."
- )
- )
- except SystemExit as e:
- # Do not propagate up for graceful termination.
- if e.code == 0:
- logger.debug(
- (
- "Thread runner raised SystemExit with error code 0. "
- "Interpreting it as a signal to terminate the thread "
- "without error."
- )
- )
- else:
- # If non-zero exit code, then raise exception to main thread.
- self._propagate_exception(e)
- except BaseException as e:
- # Propagate all other exceptions to the main thread.
- self._propagate_exception(e)
- def join(self, timeout=None):
- super(RunnerThread, self).join(timeout)
- return self._ret
|