rpdb.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. # Some code in this file is from
  2. # https://github.com/ionelmc/python-remote-pdb/blob/07d563331c4ab9eb45731bb272b158816d98236e/src/remote_pdb.py
  3. # (BSD 2-Clause "Simplified" License)
  4. import errno
  5. import inspect
  6. import json
  7. import logging
  8. import os
  9. import re
  10. import select
  11. import socket
  12. import sys
  13. import time
  14. import traceback
  15. import uuid
  16. from pdb import Pdb
  17. from typing import Callable
  18. import ray
  19. from ray._common.network_utils import build_address, is_ipv6
  20. from ray._private import ray_constants
  21. from ray.experimental.internal_kv import _internal_kv_del, _internal_kv_put
  22. from ray.util.annotations import DeveloperAPI
  23. log = logging.getLogger(__name__)
  24. def _cry(message, stderr=sys.__stderr__):
  25. print(message, file=stderr)
  26. stderr.flush()
  27. class _LF2CRLF_FileWrapper(object):
  28. def __init__(self, connection):
  29. self.connection = connection
  30. self.stream = fh = connection.makefile("rw")
  31. self.read = fh.read
  32. self.readline = fh.readline
  33. self.readlines = fh.readlines
  34. self.close = fh.close
  35. self.flush = fh.flush
  36. self.fileno = fh.fileno
  37. if hasattr(fh, "encoding"):
  38. self._send = lambda data: connection.sendall(
  39. data.encode(fh.encoding, errors="replace")
  40. )
  41. else:
  42. self._send = connection.sendall
  43. @property
  44. def encoding(self):
  45. return self.stream.encoding
  46. def __iter__(self):
  47. return self.stream.__iter__()
  48. def write(self, data, nl_rex=re.compile("\r?\n")):
  49. data = nl_rex.sub("\r\n", data)
  50. self._send(data)
  51. def writelines(self, lines, nl_rex=re.compile("\r?\n")):
  52. for line in lines:
  53. self.write(line, nl_rex)
  54. class _PdbWrap(Pdb):
  55. """Wrap PDB to run a custom exit hook on continue."""
  56. def __init__(self, exit_hook: Callable[[], None]):
  57. self._exit_hook = exit_hook
  58. Pdb.__init__(self)
  59. def do_continue(self, arg):
  60. self._exit_hook()
  61. return Pdb.do_continue(self, arg)
  62. do_c = do_cont = do_continue
  63. class _RemotePdb(Pdb):
  64. """
  65. This will run pdb as a ephemeral telnet service. Once you connect no one
  66. else can connect. On construction this object will block execution till a
  67. client has connected.
  68. Based on https://github.com/tamentis/rpdb I think ...
  69. To use this::
  70. RemotePdb(host="0.0.0.0", port=4444).set_trace()
  71. Then run: telnet 127.0.0.1 4444
  72. """
  73. active_instance = None
  74. def __init__(
  75. self,
  76. breakpoint_uuid,
  77. host,
  78. port,
  79. ip_address,
  80. patch_stdstreams=False,
  81. quiet=False,
  82. ):
  83. self._breakpoint_uuid = breakpoint_uuid
  84. self._quiet = quiet
  85. self._patch_stdstreams = patch_stdstreams
  86. self._listen_socket = socket.socket(
  87. socket.AF_INET6 if is_ipv6(host) else socket.AF_INET, socket.SOCK_STREAM
  88. )
  89. self._listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
  90. self._listen_socket.bind((host, port))
  91. self._ip_address = ip_address
  92. def listen(self):
  93. if not self._quiet:
  94. _cry(
  95. "RemotePdb session open at %s, "
  96. "use 'ray debug' to connect..."
  97. % build_address(self._ip_address, self._listen_socket.getsockname()[1])
  98. )
  99. self._listen_socket.listen(1)
  100. connection, address = self._listen_socket.accept()
  101. if not self._quiet:
  102. _cry(f"RemotePdb accepted connection from {address}")
  103. self.handle = _LF2CRLF_FileWrapper(connection)
  104. Pdb.__init__(
  105. self,
  106. completekey="tab",
  107. stdin=self.handle,
  108. stdout=self.handle,
  109. skip=["ray.*"],
  110. )
  111. self.backup = []
  112. if self._patch_stdstreams:
  113. for name in (
  114. "stderr",
  115. "stdout",
  116. "__stderr__",
  117. "__stdout__",
  118. "stdin",
  119. "__stdin__",
  120. ):
  121. self.backup.append((name, getattr(sys, name)))
  122. setattr(sys, name, self.handle)
  123. _RemotePdb.active_instance = self
  124. def __restore(self):
  125. if self.backup and not self._quiet:
  126. _cry("Restoring streams: %s ..." % self.backup)
  127. for name, fh in self.backup:
  128. setattr(sys, name, fh)
  129. self.handle.close()
  130. _RemotePdb.active_instance = None
  131. def do_quit(self, arg):
  132. self.__restore()
  133. return Pdb.do_quit(self, arg)
  134. do_q = do_exit = do_quit
  135. def do_continue(self, arg):
  136. self.__restore()
  137. self.handle.connection.close()
  138. return Pdb.do_continue(self, arg)
  139. do_c = do_cont = do_continue
  140. def set_trace(self, frame=None):
  141. if frame is None:
  142. frame = sys._getframe().f_back
  143. try:
  144. Pdb.set_trace(self, frame)
  145. except IOError as exc:
  146. if exc.errno != errno.ECONNRESET:
  147. raise
  148. def post_mortem(self, traceback=None):
  149. # See https://github.com/python/cpython/blob/
  150. # 022bc7572f061e1d1132a4db9d085b29707701e7/Lib/pdb.py#L1617
  151. try:
  152. t = sys.exc_info()[2]
  153. self.reset()
  154. Pdb.interaction(self, None, t)
  155. except IOError as exc:
  156. if exc.errno != errno.ECONNRESET:
  157. raise
  158. def do_remote(self, arg):
  159. """remote
  160. Skip into the next remote call.
  161. """
  162. # Tell the next task to drop into the debugger.
  163. ray._private.worker.global_worker.debugger_breakpoint = self._breakpoint_uuid
  164. # Tell the debug loop to connect to the next task.
  165. data = json.dumps(
  166. {
  167. "job_id": ray.get_runtime_context().get_job_id(),
  168. }
  169. )
  170. _internal_kv_put(
  171. "RAY_PDB_CONTINUE_{}".format(self._breakpoint_uuid),
  172. data,
  173. namespace=ray_constants.KV_NAMESPACE_PDB,
  174. )
  175. self.__restore()
  176. self.handle.connection.close()
  177. return Pdb.do_continue(self, arg)
  178. def do_get(self, arg):
  179. """get
  180. Skip to where the current task returns to.
  181. """
  182. ray._private.worker.global_worker.debugger_get_breakpoint = (
  183. self._breakpoint_uuid
  184. )
  185. self.__restore()
  186. self.handle.connection.close()
  187. return Pdb.do_continue(self, arg)
  188. def _connect_ray_pdb(
  189. host=None,
  190. port=None,
  191. patch_stdstreams=False,
  192. quiet=None,
  193. breakpoint_uuid=None,
  194. debugger_external=False,
  195. ):
  196. """
  197. Opens a remote PDB on first available port.
  198. """
  199. if debugger_external:
  200. assert not host, "Cannot specify both host and debugger_external"
  201. host = "0.0.0.0"
  202. elif host is None:
  203. host = os.environ.get("REMOTE_PDB_HOST", "127.0.0.1")
  204. if port is None:
  205. port = int(os.environ.get("REMOTE_PDB_PORT", "0"))
  206. if quiet is None:
  207. quiet = bool(os.environ.get("REMOTE_PDB_QUIET", ""))
  208. if not breakpoint_uuid:
  209. breakpoint_uuid = uuid.uuid4().hex
  210. if debugger_external:
  211. ip_address = ray._private.worker.global_worker.node_ip_address
  212. else:
  213. ip_address = "localhost"
  214. rdb = _RemotePdb(
  215. breakpoint_uuid=breakpoint_uuid,
  216. host=host,
  217. port=port,
  218. ip_address=ip_address,
  219. patch_stdstreams=patch_stdstreams,
  220. quiet=quiet,
  221. )
  222. sockname = rdb._listen_socket.getsockname()
  223. pdb_address = build_address(ip_address, sockname[1])
  224. parentframeinfo = inspect.getouterframes(inspect.currentframe())[2]
  225. data = {
  226. "proctitle": ray._raylet.getproctitle(),
  227. "pdb_address": pdb_address,
  228. "filename": parentframeinfo.filename,
  229. "lineno": parentframeinfo.lineno,
  230. "traceback": "\n".join(traceback.format_exception(*sys.exc_info())),
  231. "timestamp": time.time(),
  232. "job_id": ray.get_runtime_context().get_job_id(),
  233. "node_id": ray.get_runtime_context().get_node_id(),
  234. "worker_id": ray.get_runtime_context().get_worker_id(),
  235. "actor_id": ray.get_runtime_context().get_actor_id(),
  236. "task_id": ray.get_runtime_context().get_task_id(),
  237. }
  238. _internal_kv_put(
  239. "RAY_PDB_{}".format(breakpoint_uuid),
  240. json.dumps(data),
  241. overwrite=True,
  242. namespace=ray_constants.KV_NAMESPACE_PDB,
  243. )
  244. rdb.listen()
  245. _internal_kv_del(
  246. "RAY_PDB_{}".format(breakpoint_uuid), namespace=ray_constants.KV_NAMESPACE_PDB
  247. )
  248. return rdb
  249. @DeveloperAPI
  250. def set_trace(breakpoint_uuid=None):
  251. """Interrupt the flow of the program and drop into the Ray debugger.
  252. Can be used within a Ray task or actor.
  253. """
  254. if os.environ.get("RAY_DEBUG", "1") == "1":
  255. return ray.util.ray_debugpy.set_trace(breakpoint_uuid)
  256. if os.environ.get("RAY_DEBUG", "1") == "legacy":
  257. # If there is an active debugger already, we do not want to
  258. # start another one, so "set_trace" is just a no-op in that case.
  259. if ray._private.worker.global_worker.debugger_breakpoint == b"":
  260. frame = sys._getframe().f_back
  261. rdb = _connect_ray_pdb(
  262. host=None,
  263. port=None,
  264. patch_stdstreams=False,
  265. quiet=None,
  266. breakpoint_uuid=breakpoint_uuid.decode() if breakpoint_uuid else None,
  267. debugger_external=ray._private.worker.global_worker.ray_debugger_external, # noqa: E501
  268. )
  269. rdb.set_trace(frame=frame)
  270. def _driver_set_trace():
  271. """The breakpoint hook to use for the driver.
  272. This disables Ray driver logs temporarily so that the PDB console is not
  273. spammed: https://github.com/ray-project/ray/issues/18172
  274. """
  275. if os.environ.get("RAY_DEBUG", "1") == "1":
  276. return ray.util.ray_debugpy.set_trace()
  277. if os.environ.get("RAY_DEBUG", "1") == "legacy":
  278. print("*** Temporarily disabling Ray worker logs ***")
  279. ray._private.worker._worker_logs_enabled = False
  280. def enable_logging():
  281. print("*** Re-enabling Ray worker logs ***")
  282. ray._private.worker._worker_logs_enabled = True
  283. pdb = _PdbWrap(enable_logging)
  284. frame = sys._getframe().f_back
  285. pdb.set_trace(frame)
  286. def _is_ray_debugger_post_mortem_enabled():
  287. return os.environ.get("RAY_DEBUG_POST_MORTEM", "0") == "1"
  288. def _post_mortem():
  289. if os.environ.get("RAY_DEBUG", "1") == "1":
  290. return ray.util.ray_debugpy._post_mortem()
  291. rdb = _connect_ray_pdb(
  292. host=None,
  293. port=None,
  294. patch_stdstreams=False,
  295. quiet=None,
  296. debugger_external=ray._private.worker.global_worker.ray_debugger_external,
  297. )
  298. rdb.post_mortem()
  299. def _connect_pdb_client(host, port):
  300. if sys.platform == "win32":
  301. import msvcrt
  302. s = socket.socket(
  303. socket.AF_INET6 if is_ipv6(host) else socket.AF_INET, socket.SOCK_STREAM
  304. )
  305. s.connect((host, port))
  306. while True:
  307. # Get the list of sockets which are readable.
  308. if sys.platform == "win32":
  309. ready_to_read = select.select([s], [], [], 1)[0]
  310. if msvcrt.kbhit():
  311. ready_to_read.append(sys.stdin)
  312. if not ready_to_read and not sys.stdin.isatty():
  313. # in tests, when using pexpect, the pipe makes
  314. # the msvcrt.kbhit() trick fail. Assume we are waiting
  315. # for stdin, since this will block waiting for input
  316. ready_to_read.append(sys.stdin)
  317. else:
  318. ready_to_read, write_sockets, error_sockets = select.select(
  319. [sys.stdin, s], [], []
  320. )
  321. for sock in ready_to_read:
  322. if sock == s:
  323. # Incoming message from remote debugger.
  324. data = sock.recv(4096)
  325. if not data:
  326. return
  327. else:
  328. sys.stdout.write(data.decode())
  329. sys.stdout.flush()
  330. else:
  331. # User entered a message.
  332. msg = sys.stdin.readline()
  333. s.send(msg.encode())