_frontend.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. import asyncio
  2. import json
  3. import logging
  4. import os
  5. import socket
  6. import threading
  7. import time
  8. from abc import ABC, abstractmethod
  9. from collections.abc import Callable, Iterable
  10. from concurrent.futures import ThreadPoolExecutor
  11. from dataclasses import dataclass
  12. from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
  13. from urllib.parse import parse_qs, urlparse
  14. from jinja2 import DictLoader, Environment
  15. from torch.distributed.debug._store import get_world_size, tcpstore_client
  16. logger: logging.Logger = logging.getLogger(__name__)
  17. # ---------------------------------------------------------------------------
  18. # Base types
  19. # ---------------------------------------------------------------------------
  20. @dataclass(slots=True)
  21. class Response:
  22. status_code: int
  23. text: str
  24. def raise_for_status(self):
  25. if self.status_code != 200:
  26. raise RuntimeError(f"HTTP {self.status_code}: {self.text}")
  27. def json(self):
  28. return json.loads(self.text)
  29. @dataclass(slots=True)
  30. class NavLink:
  31. path: str
  32. label: str
  33. @dataclass(slots=True)
  34. class Route:
  35. path: str
  36. handler: Callable[["HTTPRequestHandler"], bytes]
  37. class DebugHandler(ABC):
  38. @abstractmethod
  39. def routes(self) -> list[Route]: ...
  40. @abstractmethod
  41. def nav_links(self) -> list[NavLink]: ...
  42. def templates(self) -> dict[str, str]:
  43. return {}
  44. def dump(self) -> str | None:
  45. return None
  46. def dump_filename(self) -> str:
  47. return type(self).__name__.lower()
  48. # ---------------------------------------------------------------------------
  49. # Network helpers
  50. # ---------------------------------------------------------------------------
  51. def fetch_thread_pool(urls: list[str]) -> Iterable[Response]:
  52. # late import for optional dependency
  53. import requests
  54. max_workers = 20
  55. def get(url: str) -> Response:
  56. resp = requests.post(url)
  57. return Response(resp.status_code, resp.text)
  58. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  59. resps = executor.map(get, urls)
  60. return resps
  61. def fetch_aiohttp(urls: list[str]) -> Iterable[Response]:
  62. # late import for optional dependency
  63. # pyrefly: ignore [missing-import]
  64. import aiohttp
  65. async def fetch(session: aiohttp.ClientSession, url: str) -> Response:
  66. async with session.post(url) as resp:
  67. text = await resp.text()
  68. return Response(resp.status, text)
  69. async def gather(urls: list[str]) -> Iterable[Response]:
  70. async with aiohttp.ClientSession() as session:
  71. return await asyncio.gather(*[fetch(session, url) for url in urls])
  72. return asyncio.run(gather(urls))
  73. def fetch_all(endpoint: str, args: str = "") -> tuple[list[str], Iterable[Response]]:
  74. store = tcpstore_client()
  75. keys = [f"rank{r}" for r in range(get_world_size())]
  76. addrs = store.multi_get(keys)
  77. addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs]
  78. try:
  79. resps = fetch_aiohttp(addrs)
  80. except ImportError:
  81. resps = fetch_thread_pool(addrs)
  82. return addrs, resps
  83. def format_json(blob: str):
  84. parsed = json.loads(blob)
  85. return json.dumps(parsed, indent=2)
  86. # ---------------------------------------------------------------------------
  87. # Template constants
  88. # ---------------------------------------------------------------------------
  89. BASE_TEMPLATE = """
  90. <!doctype html>
  91. <head>
  92. <title>{% block title %}{% endblock %} - PyTorch Distributed</title>
  93. <link rel="shortcut icon" type="image/x-icon" href="https://pytorch.org/favicon.ico?">
  94. <style>
  95. body {
  96. margin: 0;
  97. font-family:
  98. -apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,
  99. "Helvetica Neue",Arial,"Noto Sans",sans-serif,"Apple Color Emoji",
  100. "Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";
  101. font-size: 1rem;
  102. font-weight: 400;
  103. line-height: 1.5;
  104. color: #212529;
  105. text-align: left;
  106. background-color: #fff;
  107. }
  108. h1, h2, h2, h4, h5, h6, .h1, .h2, .h2, .h4, .h5, .h6 {
  109. margin-bottom: .5rem;
  110. font-weight: 500;
  111. line-height: 1.2;
  112. }
  113. nav {
  114. background-color: rgba(0, 0, 0, 0.17);
  115. padding: 10px;
  116. display: flex;
  117. align-items: center;
  118. padding: 16px;
  119. justify-content: flex-start;
  120. }
  121. nav h1 {
  122. display: inline-block;
  123. margin: 0;
  124. }
  125. nav a {
  126. margin: 0 8px;
  127. }
  128. section {
  129. max-width: 1280px;
  130. padding: 16px;
  131. margin: 0 auto;
  132. }
  133. pre {
  134. white-space: pre-wrap;
  135. max-width: 100%;
  136. }
  137. </style>
  138. </head>
  139. <nav>
  140. <h1>Torch Distributed Debug Server</h1>
  141. {{ nav_links | safe }}
  142. </nav>
  143. <section class="content">
  144. {% block header %}{% endblock %}
  145. {% block content %}{% endblock %}
  146. </section>
  147. """
  148. RAW_RESP_TEMPLATE = """
  149. {% extends "base.html" %}
  150. {% block header %}
  151. <h1>{% block title %}{{title}}{% endblock %}</h1>
  152. {% endblock %}
  153. {% block content %}
  154. {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
  155. <h2>Rank {{ i }}: {{ addr }}</h2>
  156. {% if resp.status_code != 200 %}
  157. <p>Failed to fetch: status={{ resp.status_code }}</p>
  158. <pre>{{ resp.text }}</pre>
  159. {% else %}
  160. <pre>{{ resp.text }}</pre>
  161. {% endif %}
  162. {% endfor %}
  163. {% endblock %}
  164. """
  165. JSON_RESP_TEMPLATE = """
  166. {% extends "base.html" %}
  167. {% block header %}
  168. <h1>{% block title %}{{ title }}{% endblock %}</h1>
  169. {% endblock %}
  170. {% block content %}
  171. {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
  172. <h2>Rank {{ i }}: {{ addr }}</h2>
  173. {% if resp.status_code != 200 %}
  174. <p>Failed to fetch: status={{ resp.status_code }}</p>
  175. <pre>{{ resp.text }}</pre>
  176. {% else %}
  177. <pre>{{ format_json(resp.text) }}</pre>
  178. {% endif %}
  179. {% endfor %}
  180. {% endblock %}
  181. """
  182. # ---------------------------------------------------------------------------
  183. # PeriodicDumper
  184. # ---------------------------------------------------------------------------
  185. class PeriodicDumper:
  186. def __init__(
  187. self,
  188. handlers: list[DebugHandler],
  189. output_dir: str,
  190. interval_seconds: float = 60.0,
  191. ) -> None:
  192. self._handlers = handlers
  193. self._output_dir = output_dir
  194. self._interval_seconds = interval_seconds
  195. self._stop_event = threading.Event()
  196. self._thread: threading.Thread | None = None
  197. def start(self) -> None:
  198. os.makedirs(self._output_dir, exist_ok=True)
  199. self._thread = threading.Thread(
  200. target=self._run,
  201. daemon=True,
  202. name="distributed.debug.PeriodicDumper",
  203. )
  204. self._thread.start()
  205. def stop(self) -> None:
  206. self._stop_event.set()
  207. if self._thread is not None:
  208. self._thread.join()
  209. def _run(self) -> None:
  210. while not self._stop_event.is_set():
  211. for handler in self._handlers:
  212. try:
  213. content = handler.dump()
  214. except Exception:
  215. logger.exception("Failed to dump %s", handler.dump_filename())
  216. continue
  217. if content is None:
  218. continue
  219. timestamp = time.strftime("%Y%m%d_%H%M%S")
  220. filename = f"{handler.dump_filename()}_{timestamp}.txt"
  221. path = os.path.join(self._output_dir, filename)
  222. try:
  223. with open(path, "w") as f:
  224. f.write(content)
  225. except Exception:
  226. logger.exception("Failed to write dump to %s", path)
  227. self._stop_event.wait(self._interval_seconds)
  228. # ---------------------------------------------------------------------------
  229. # HTTP server
  230. # ---------------------------------------------------------------------------
  231. class _IPv6HTTPServer(ThreadingHTTPServer):
  232. address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore
  233. request_queue_size: int = 1024
  234. class HTTPRequestHandler(BaseHTTPRequestHandler):
  235. frontend: "FrontendServer"
  236. def log_message(self, format, *args):
  237. logger.info(
  238. "%s %s",
  239. self.client_address[0],
  240. format % args,
  241. )
  242. def do_GET(self):
  243. self.frontend._handle_request(self)
  244. def get_path(self) -> str:
  245. return urlparse(self.path).path
  246. def get_query(self) -> dict[str, list[str]]:
  247. return parse_qs(self.get_raw_query())
  248. def get_raw_query(self) -> str:
  249. return urlparse(self.path).query
  250. def get_query_arg(
  251. self, name: str, default: object = None, type: type = str
  252. ) -> object:
  253. query = self.get_query()
  254. if name not in query:
  255. return default
  256. return type(query[name][0])
  257. class FrontendServer:
  258. def __init__(
  259. self,
  260. port: int,
  261. handlers: list[DebugHandler] | None = None,
  262. ):
  263. if handlers is None:
  264. from torch.distributed.debug._debug_handlers import default_handlers
  265. handlers = default_handlers()
  266. # Build nav HTML from handlers
  267. nav_html = "\n".join(
  268. f' <a href="{link.path}">{link.label}</a> <!--@lint-ignore-->'
  269. for handler in handlers
  270. for link in handler.nav_links()
  271. )
  272. # Merge all handler templates + shared templates
  273. all_templates: dict[str, str] = {
  274. "base.html": BASE_TEMPLATE,
  275. "raw_resp.html": RAW_RESP_TEMPLATE,
  276. "json_resp.html": JSON_RESP_TEMPLATE,
  277. }
  278. for handler in handlers:
  279. all_templates.update(handler.templates())
  280. loader = DictLoader(all_templates)
  281. self._jinja_env = Environment(loader=loader, enable_async=True)
  282. self._jinja_env.globals.update(
  283. zip=zip,
  284. format_json=format_json,
  285. enumerate=enumerate,
  286. nav_links=nav_html,
  287. )
  288. # Build route table from handlers
  289. self._routes: dict[str, Callable[[HTTPRequestHandler], bytes]] = {}
  290. for handler in handlers:
  291. for route in handler.routes():
  292. self._routes[route.path] = route.handler
  293. self._handlers = handlers
  294. # Create HTTP server
  295. RequestHandlerClass = type(
  296. "HTTPRequestHandler",
  297. (HTTPRequestHandler,),
  298. {"frontend": self},
  299. )
  300. server_address = ("", port)
  301. self._server = _IPv6HTTPServer(server_address, RequestHandlerClass)
  302. self._thread = threading.Thread(
  303. target=self._serve,
  304. args=(),
  305. daemon=True,
  306. name="distributed.debug.FrontendServer",
  307. )
  308. self._thread.start()
  309. def _serve(self) -> None:
  310. try:
  311. self._server.serve_forever()
  312. except Exception:
  313. logger.exception("got exception in frontend server")
  314. def join(self) -> None:
  315. self._thread.join()
  316. def _handle_request(self, req: HTTPRequestHandler) -> None:
  317. path = req.get_path()
  318. if path not in self._routes:
  319. req.send_error(404, f"Handler not found: {path}")
  320. return
  321. handler = self._routes[path]
  322. try:
  323. resp = handler(req)
  324. # Catch SystemExit to not crash when FlightRecorder errors.
  325. except (Exception, SystemExit) as e:
  326. logger.exception(
  327. "Exception in frontend server when handling %s",
  328. path,
  329. )
  330. req.send_error(500, f"Exception: {repr(e)}")
  331. return
  332. req.send_response(200)
  333. req.send_header("Content-type", "text/html")
  334. req.end_headers()
  335. req.wfile.write(resp)
  336. def render_template(self, template: str, **kwargs: object) -> bytes:
  337. return self._jinja_env.get_template(template).render(**kwargs).encode()
  338. def main(
  339. port: int,
  340. dump_dir: str | None,
  341. dump_interval: float,
  342. handlers: list[DebugHandler],
  343. enabled_dumps: set[str],
  344. ) -> None:
  345. logger.setLevel(logging.INFO)
  346. server = FrontendServer(port=port, handlers=handlers)
  347. logger.info("Frontend server started on port %d", server._server.server_port)
  348. dumper: PeriodicDumper | None = None
  349. if dump_dir is not None:
  350. dumper = PeriodicDumper(
  351. [
  352. handler
  353. for handler in handlers
  354. if handler.dump_filename() in enabled_dumps
  355. ],
  356. dump_dir,
  357. dump_interval,
  358. )
  359. dumper.start()
  360. logger.info(
  361. "Periodic dumper started, writing to %s every %.0fs",
  362. dump_dir,
  363. dump_interval,
  364. )
  365. try:
  366. server.join()
  367. finally:
  368. if dumper is not None:
  369. dumper.stop()