exit_hooks.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from __future__ import annotations
  2. import sys
  3. import traceback
  4. from types import TracebackType
  5. from typing import TYPE_CHECKING
  6. import wandb
  7. from wandb.errors import Error
  8. if TYPE_CHECKING:
  9. from typing import NoReturn
  10. class ExitHooks:
  11. exception: BaseException | None = None
  12. def __init__(self) -> None:
  13. self.exit_code = 0
  14. self.exception = None
  15. def hook(self) -> None:
  16. self._orig_exit = sys.exit
  17. sys.exit = self.exit
  18. self._orig_excepthook = (
  19. sys.excepthook
  20. if sys.excepthook
  21. != sys.__excepthook__ # respect hooks by other libraries like pdb
  22. else None
  23. )
  24. sys.excepthook = self.exc_handler # type: ignore
  25. def exit(self, code: object = 0) -> NoReturn:
  26. orig_code = code
  27. code = code if code is not None else 0
  28. code = code if isinstance(code, int) else 1
  29. self.exit_code = code
  30. self._orig_exit(orig_code) # type: ignore
  31. def was_ctrl_c(self) -> bool:
  32. return isinstance(self.exception, KeyboardInterrupt)
  33. def exc_handler(
  34. self, exc_type: type[BaseException], exc: BaseException, tb: TracebackType
  35. ) -> None:
  36. self.exit_code = 1
  37. self.exception = exc
  38. if issubclass(exc_type, Error):
  39. wandb.termerror(str(exc), repeat=False)
  40. if self.was_ctrl_c():
  41. self.exit_code = 255
  42. traceback.print_exception(exc_type, exc, tb)
  43. if self._orig_excepthook:
  44. self._orig_excepthook(exc_type, exc, tb)