import asyncio import json import logging import os import socket import threading import time from abc import ABC, abstractmethod from collections.abc import Callable, Iterable from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from urllib.parse import parse_qs, urlparse from jinja2 import DictLoader, Environment from torch.distributed.debug._store import get_world_size, tcpstore_client logger: logging.Logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Base types # --------------------------------------------------------------------------- @dataclass(slots=True) class Response: status_code: int text: str def raise_for_status(self): if self.status_code != 200: raise RuntimeError(f"HTTP {self.status_code}: {self.text}") def json(self): return json.loads(self.text) @dataclass(slots=True) class NavLink: path: str label: str @dataclass(slots=True) class Route: path: str handler: Callable[["HTTPRequestHandler"], bytes] class DebugHandler(ABC): @abstractmethod def routes(self) -> list[Route]: ... @abstractmethod def nav_links(self) -> list[NavLink]: ... def templates(self) -> dict[str, str]: return {} def dump(self) -> str | None: return None def dump_filename(self) -> str: return type(self).__name__.lower() # --------------------------------------------------------------------------- # Network helpers # --------------------------------------------------------------------------- def fetch_thread_pool(urls: list[str]) -> Iterable[Response]: # late import for optional dependency import requests max_workers = 20 def get(url: str) -> Response: resp = requests.post(url) return Response(resp.status_code, resp.text) with ThreadPoolExecutor(max_workers=max_workers) as executor: resps = executor.map(get, urls) return resps def fetch_aiohttp(urls: list[str]) -> Iterable[Response]: # late import for optional dependency # pyrefly: ignore [missing-import] import aiohttp async def fetch(session: aiohttp.ClientSession, url: str) -> Response: async with session.post(url) as resp: text = await resp.text() return Response(resp.status, text) async def gather(urls: list[str]) -> Iterable[Response]: async with aiohttp.ClientSession() as session: return await asyncio.gather(*[fetch(session, url) for url in urls]) return asyncio.run(gather(urls)) def fetch_all(endpoint: str, args: str = "") -> tuple[list[str], Iterable[Response]]: store = tcpstore_client() keys = [f"rank{r}" for r in range(get_world_size())] addrs = store.multi_get(keys) addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs] try: resps = fetch_aiohttp(addrs) except ImportError: resps = fetch_thread_pool(addrs) return addrs, resps def format_json(blob: str): parsed = json.loads(blob) return json.dumps(parsed, indent=2) # --------------------------------------------------------------------------- # Template constants # --------------------------------------------------------------------------- BASE_TEMPLATE = """ {% block title %}{% endblock %} - PyTorch Distributed
{% block header %}{% endblock %} {% block content %}{% endblock %}
""" RAW_RESP_TEMPLATE = """ {% extends "base.html" %} {% block header %}

{% block title %}{{title}}{% endblock %}

{% endblock %} {% block content %} {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}

Rank {{ i }}: {{ addr }}

{% if resp.status_code != 200 %}

Failed to fetch: status={{ resp.status_code }}

{{ resp.text }}
{% else %}
{{ resp.text }}
{% endif %} {% endfor %} {% endblock %} """ JSON_RESP_TEMPLATE = """ {% extends "base.html" %} {% block header %}

{% block title %}{{ title }}{% endblock %}

{% endblock %} {% block content %} {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}

Rank {{ i }}: {{ addr }}

{% if resp.status_code != 200 %}

Failed to fetch: status={{ resp.status_code }}

{{ resp.text }}
{% else %}
{{ format_json(resp.text) }}
{% endif %} {% endfor %} {% endblock %} """ # --------------------------------------------------------------------------- # PeriodicDumper # --------------------------------------------------------------------------- class PeriodicDumper: def __init__( self, handlers: list[DebugHandler], output_dir: str, interval_seconds: float = 60.0, ) -> None: self._handlers = handlers self._output_dir = output_dir self._interval_seconds = interval_seconds self._stop_event = threading.Event() self._thread: threading.Thread | None = None def start(self) -> None: os.makedirs(self._output_dir, exist_ok=True) self._thread = threading.Thread( target=self._run, daemon=True, name="distributed.debug.PeriodicDumper", ) self._thread.start() def stop(self) -> None: self._stop_event.set() if self._thread is not None: self._thread.join() def _run(self) -> None: while not self._stop_event.is_set(): for handler in self._handlers: try: content = handler.dump() except Exception: logger.exception("Failed to dump %s", handler.dump_filename()) continue if content is None: continue timestamp = time.strftime("%Y%m%d_%H%M%S") filename = f"{handler.dump_filename()}_{timestamp}.txt" path = os.path.join(self._output_dir, filename) try: with open(path, "w") as f: f.write(content) except Exception: logger.exception("Failed to write dump to %s", path) self._stop_event.wait(self._interval_seconds) # --------------------------------------------------------------------------- # HTTP server # --------------------------------------------------------------------------- class _IPv6HTTPServer(ThreadingHTTPServer): address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore request_queue_size: int = 1024 class HTTPRequestHandler(BaseHTTPRequestHandler): frontend: "FrontendServer" def log_message(self, format, *args): logger.info( "%s %s", self.client_address[0], format % args, ) def do_GET(self): self.frontend._handle_request(self) def get_path(self) -> str: return urlparse(self.path).path def get_query(self) -> dict[str, list[str]]: return parse_qs(self.get_raw_query()) def get_raw_query(self) -> str: return urlparse(self.path).query def get_query_arg( self, name: str, default: object = None, type: type = str ) -> object: query = self.get_query() if name not in query: return default return type(query[name][0]) class FrontendServer: def __init__( self, port: int, handlers: list[DebugHandler] | None = None, ): if handlers is None: from torch.distributed.debug._debug_handlers import default_handlers handlers = default_handlers() # Build nav HTML from handlers nav_html = "\n".join( f' {link.label} ' for handler in handlers for link in handler.nav_links() ) # Merge all handler templates + shared templates all_templates: dict[str, str] = { "base.html": BASE_TEMPLATE, "raw_resp.html": RAW_RESP_TEMPLATE, "json_resp.html": JSON_RESP_TEMPLATE, } for handler in handlers: all_templates.update(handler.templates()) loader = DictLoader(all_templates) self._jinja_env = Environment(loader=loader, enable_async=True) self._jinja_env.globals.update( zip=zip, format_json=format_json, enumerate=enumerate, nav_links=nav_html, ) # Build route table from handlers self._routes: dict[str, Callable[[HTTPRequestHandler], bytes]] = {} for handler in handlers: for route in handler.routes(): self._routes[route.path] = route.handler self._handlers = handlers # Create HTTP server RequestHandlerClass = type( "HTTPRequestHandler", (HTTPRequestHandler,), {"frontend": self}, ) server_address = ("", port) self._server = _IPv6HTTPServer(server_address, RequestHandlerClass) self._thread = threading.Thread( target=self._serve, args=(), daemon=True, name="distributed.debug.FrontendServer", ) self._thread.start() def _serve(self) -> None: try: self._server.serve_forever() except Exception: logger.exception("got exception in frontend server") def join(self) -> None: self._thread.join() def _handle_request(self, req: HTTPRequestHandler) -> None: path = req.get_path() if path not in self._routes: req.send_error(404, f"Handler not found: {path}") return handler = self._routes[path] try: resp = handler(req) # Catch SystemExit to not crash when FlightRecorder errors. except (Exception, SystemExit) as e: logger.exception( "Exception in frontend server when handling %s", path, ) req.send_error(500, f"Exception: {repr(e)}") return req.send_response(200) req.send_header("Content-type", "text/html") req.end_headers() req.wfile.write(resp) def render_template(self, template: str, **kwargs: object) -> bytes: return self._jinja_env.get_template(template).render(**kwargs).encode() def main( port: int, dump_dir: str | None, dump_interval: float, handlers: list[DebugHandler], enabled_dumps: set[str], ) -> None: logger.setLevel(logging.INFO) server = FrontendServer(port=port, handlers=handlers) logger.info("Frontend server started on port %d", server._server.server_port) dumper: PeriodicDumper | None = None if dump_dir is not None: dumper = PeriodicDumper( [ handler for handler in handlers if handler.dump_filename() in enabled_dumps ], dump_dir, dump_interval, ) dumper.start() logger.info( "Periodic dumper started, writing to %s every %.0fs", dump_dir, dump_interval, ) try: server.join() finally: if dumper is not None: dumper.stop()