debugpy.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import importlib
  2. import logging
  3. import os
  4. import sys
  5. import threading
  6. import ray
  7. from ray._common.network_utils import build_address
  8. from ray.util.annotations import DeveloperAPI
  9. log = logging.getLogger(__name__)
  10. POST_MORTEM_ERROR_UUID = "post_mortem_error_uuid"
  11. def _try_import_debugpy():
  12. try:
  13. debugpy = importlib.import_module("debugpy")
  14. if not hasattr(debugpy, "__version__") or debugpy.__version__ < "1.8.0":
  15. raise ImportError()
  16. return debugpy
  17. except (ModuleNotFoundError, ImportError):
  18. log.error(
  19. "Module 'debugpy>=1.8.0' cannot be loaded. "
  20. "Ray Debugpy Debugger will not work without 'debugpy>=1.8.0' installed. "
  21. "Install this module using 'pip install debugpy==1.8.0' "
  22. )
  23. return None
  24. # A lock to ensure that only one thread can open the debugger port.
  25. debugger_port_lock = threading.Lock()
  26. def _override_breakpoint_hooks():
  27. """
  28. This method overrides the breakpoint() function to set_trace()
  29. so that other threads can reuse the same setup logic.
  30. This is based on: https://github.com/microsoft/debugpy/blob/ef9a67fe150179ee4df9997f9273723c26687fab/src/debugpy/_vendored/pydevd/pydev_sitecustomize/sitecustomize.py#L87 # noqa: E501
  31. """
  32. sys.__breakpointhook__ = set_trace
  33. sys.breakpointhook = set_trace
  34. import builtins as __builtin__
  35. __builtin__.breakpoint = set_trace
  36. def _ensure_debugger_port_open_thread_safe():
  37. """
  38. This is a thread safe method that ensure that the debugger port
  39. is open, and if not, open it.
  40. """
  41. # The lock is acquired before checking the debugger port so only
  42. # one thread can open the debugger port.
  43. with debugger_port_lock:
  44. debugpy = _try_import_debugpy()
  45. if not debugpy:
  46. return
  47. debugger_port = ray._private.worker.global_worker.debugger_port
  48. if not debugger_port:
  49. (host, port) = debugpy.listen(
  50. (ray._private.worker.global_worker.node_ip_address, 0)
  51. )
  52. ray._private.worker.global_worker.set_debugger_port(port)
  53. log.info(f"Ray debugger is listening on {build_address(host, port)}")
  54. else:
  55. log.info(f"Ray debugger is already open on {debugger_port}")
  56. @DeveloperAPI
  57. def set_trace(breakpoint_uuid=None):
  58. """Interrupt the flow of the program and drop into the Ray debugger.
  59. Can be used within a Ray task or actor.
  60. """
  61. debugpy = _try_import_debugpy()
  62. if not debugpy:
  63. return
  64. _ensure_debugger_port_open_thread_safe()
  65. # debugpy overrides the breakpoint() function, so we need to set it back
  66. # so other threads can reuse it.
  67. _override_breakpoint_hooks()
  68. with ray._private.worker.global_worker.worker_paused_by_debugger():
  69. msg = (
  70. "Waiting for debugger to attach (see "
  71. "https://docs.ray.io/en/latest/ray-observability/"
  72. "ray-distributed-debugger.html)..."
  73. )
  74. log.info(msg)
  75. debugpy.wait_for_client()
  76. log.info("Debugger client is connected")
  77. if breakpoint_uuid == POST_MORTEM_ERROR_UUID:
  78. _debugpy_excepthook()
  79. else:
  80. _debugpy_breakpoint()
  81. def _debugpy_breakpoint():
  82. """
  83. Drop the user into the debugger on a breakpoint.
  84. """
  85. import pydevd
  86. pydevd.settrace(stop_at_frame=sys._getframe().f_back)
  87. def _debugpy_excepthook():
  88. """
  89. Drop the user into the debugger on an unhandled exception.
  90. """
  91. import threading
  92. import pydevd
  93. py_db = pydevd.get_global_debugger()
  94. thread = threading.current_thread()
  95. additional_info = py_db.set_additional_thread_info(thread)
  96. additional_info.is_tracing += 1
  97. try:
  98. error = sys.exc_info()
  99. py_db.stop_on_unhandled_exception(py_db, thread, additional_info, error)
  100. sys.excepthook(error[0], error[1], error[2])
  101. finally:
  102. additional_info.is_tracing -= 1
  103. def _is_ray_debugger_post_mortem_enabled():
  104. return os.environ.get("RAY_DEBUG_POST_MORTEM", "0") == "1"
  105. def _post_mortem():
  106. return set_trace(POST_MORTEM_ERROR_UUID)