util.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import copy
  2. import logging
  3. import os
  4. import queue
  5. import threading
  6. from typing import Optional
  7. import numpy as np
  8. from ray.air.constants import _ERROR_REPORT_TIMEOUT
  9. logger = logging.getLogger(__name__)
  10. def is_nan(value):
  11. return np.isnan(value)
  12. def is_nan_or_inf(value):
  13. return is_nan(value) or np.isinf(value)
  14. class StartTraceback(Exception):
  15. """These exceptions (and their tracebacks) can be skipped with `skip_exceptions`"""
  16. pass
  17. class StartTracebackWithWorkerRank(StartTraceback):
  18. def __init__(self, worker_rank: int) -> None:
  19. super().__init__()
  20. self.worker_rank = worker_rank
  21. def __reduce__(self):
  22. return (self.__class__, (self.worker_rank,))
  23. def skip_exceptions(exc: Optional[Exception]) -> Exception:
  24. """Skip all contained `StartTracebacks` to reduce traceback output.
  25. Returns a shallow copy of the exception with all `StartTracebacks` removed.
  26. If the RAY_AIR_FULL_TRACEBACKS environment variable is set,
  27. the original exception (not a copy) is returned.
  28. """
  29. should_not_shorten = bool(int(os.environ.get("RAY_AIR_FULL_TRACEBACKS", "0")))
  30. if should_not_shorten:
  31. return exc
  32. if isinstance(exc, StartTraceback):
  33. # If this is a StartTraceback, skip
  34. return skip_exceptions(exc.__cause__)
  35. # Perform a shallow copy to prevent recursive __cause__/__context__.
  36. new_exc = copy.copy(exc).with_traceback(exc.__traceback__)
  37. # Make sure nested exceptions are properly skipped.
  38. cause = getattr(exc, "__cause__", None)
  39. if cause:
  40. new_exc.__cause__ = skip_exceptions(cause)
  41. return new_exc
  42. def exception_cause(exc: Optional[Exception]) -> Optional[Exception]:
  43. if not exc:
  44. return None
  45. return getattr(exc, "__cause__", None)
  46. class RunnerThread(threading.Thread):
  47. """Supervisor thread that runs your script."""
  48. def __init__(self, *args, error_queue, **kwargs):
  49. threading.Thread.__init__(self, *args, **kwargs)
  50. self._error_queue = error_queue
  51. self._ret = None
  52. def _propagate_exception(self, e: BaseException):
  53. try:
  54. # report the error but avoid indefinite blocking which would
  55. # prevent the exception from being propagated in the unlikely
  56. # case that something went terribly wrong
  57. self._error_queue.put(e, block=True, timeout=_ERROR_REPORT_TIMEOUT)
  58. except queue.Full:
  59. logger.critical(
  60. (
  61. "Runner Thread was unable to report error to main "
  62. "function runner thread. This means a previous error "
  63. "was not processed. This should never happen."
  64. )
  65. )
  66. def run(self):
  67. try:
  68. self._ret = self._target(*self._args, **self._kwargs)
  69. except StopIteration:
  70. logger.debug(
  71. (
  72. "Thread runner raised StopIteration. Interpreting it as a "
  73. "signal to terminate the thread without error."
  74. )
  75. )
  76. except SystemExit as e:
  77. # Do not propagate up for graceful termination.
  78. if e.code == 0:
  79. logger.debug(
  80. (
  81. "Thread runner raised SystemExit with error code 0. "
  82. "Interpreting it as a signal to terminate the thread "
  83. "without error."
  84. )
  85. )
  86. else:
  87. # If non-zero exit code, then raise exception to main thread.
  88. self._propagate_exception(e)
  89. except BaseException as e:
  90. # Propagate all other exceptions to the main thread.
  91. self._propagate_exception(e)
  92. def join(self, timeout=None):
  93. super(RunnerThread, self).join(timeout)
  94. return self._ret