network_logging.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # Copyright 2026 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import inspect
  16. import json
  17. import os
  18. import threading
  19. import time
  20. from collections import defaultdict
  21. from functools import wraps
  22. from pathlib import Path
  23. from typing import Any
  24. import httpx
  25. from .generic import strtobool
  26. class _NetworkRequestTrace:
  27. def __init__(self, request: httpx.Request):
  28. self.request = request
  29. self.started_at = time.perf_counter()
  30. self.phase_started_at = {}
  31. self.phases_ms = defaultdict(float)
  32. def trace(self, name: str, info: dict[str, Any]) -> None:
  33. parts = name.rsplit(".", 2)
  34. if len(parts) != 3:
  35. return
  36. _, phase, state = parts
  37. now = time.perf_counter()
  38. if state == "started":
  39. self.phase_started_at[phase] = now
  40. elif state in {"complete", "failed"}:
  41. phase_started_at = self.phase_started_at.pop(phase, None)
  42. if phase_started_at is not None:
  43. self.phases_ms[phase] += (now - phase_started_at) * 1000
  44. def build_record(
  45. self,
  46. *,
  47. response: httpx.Response | None = None,
  48. error: BaseException | None = None,
  49. stream: bool = False,
  50. ) -> dict[str, Any]:
  51. total_ms = (time.perf_counter() - self.started_at) * 1000
  52. url = self.request.url
  53. host = url.host or ""
  54. port = url.port
  55. default_port = {"http": 80, "https": 443}.get(url.scheme)
  56. host_display = host if port in (None, default_port) else f"{host}:{port}"
  57. http_version = None
  58. status_code = None
  59. bytes_downloaded = None
  60. response_complete = False
  61. if response is not None:
  62. status_code = response.status_code
  63. response_complete = response.is_closed
  64. raw_http_version = response.extensions.get("http_version")
  65. if isinstance(raw_http_version, bytes):
  66. http_version = raw_http_version.decode("ascii", errors="replace")
  67. elif raw_http_version is not None:
  68. http_version = str(raw_http_version)
  69. if response_complete:
  70. try:
  71. bytes_downloaded = len(response.content)
  72. except httpx.ResponseNotRead:
  73. pass
  74. return {
  75. "method": self.request.method,
  76. "scheme": url.scheme,
  77. "host": host,
  78. "host_display": host_display,
  79. "port": port,
  80. "path": url.path,
  81. "has_query": bool(url.query),
  82. "url": f"{url.scheme}://{host_display}{url.path}{'?...' if url.query else ''}",
  83. "request_id": self.request.headers.get("x-amzn-trace-id") or self.request.headers.get("x-request-id"),
  84. "status_code": status_code,
  85. "http_version": http_version,
  86. "bytes_downloaded": bytes_downloaded,
  87. "total_ms": total_ms,
  88. "stream": stream,
  89. "response_complete": response_complete,
  90. "phases_ms": dict(sorted(self.phases_ms.items())),
  91. "error": None if error is None else f"{type(error).__name__}: {error}",
  92. }
  93. class _NetworkDebugProfiler:
  94. def __init__(self):
  95. self._records = []
  96. self._lock = threading.Lock()
  97. self._enabled = False
  98. self._output_path = None
  99. self._original_client_send = None
  100. self._original_async_client_send = None
  101. self._shared_dir = None
  102. @property
  103. def enabled(self) -> bool:
  104. return self._enabled
  105. def clear(self) -> None:
  106. with self._lock:
  107. self._records = []
  108. def enable(self, output_path: str | os.PathLike | None = None) -> None:
  109. if self._enabled:
  110. self._output_path = None if output_path is None else os.fspath(output_path)
  111. self.clear()
  112. return
  113. self._output_path = None if output_path is None else os.fspath(output_path)
  114. self.clear()
  115. profiler = self
  116. self._original_client_send = httpx.Client.send
  117. self._original_async_client_send = httpx.AsyncClient.send
  118. @wraps(self._original_client_send)
  119. def patched_client_send(client, request, *args, **kwargs):
  120. return profiler._send_with_trace(profiler._original_client_send, client, request, *args, **kwargs)
  121. @wraps(self._original_async_client_send)
  122. async def patched_async_client_send(client, request, *args, **kwargs):
  123. return await profiler._async_send_with_trace(
  124. profiler._original_async_client_send, client, request, *args, **kwargs
  125. )
  126. httpx.Client.send = patched_client_send
  127. httpx.AsyncClient.send = patched_async_client_send
  128. self._enabled = True
  129. def setup_shared_dir(self) -> str | None:
  130. """Create a shared temp directory for xdist workers to dump records into."""
  131. if self._shared_dir is None:
  132. import tempfile
  133. self._shared_dir = tempfile.mkdtemp(prefix="network_debug_")
  134. return self._shared_dir
  135. def set_shared_dir(self, shared_dir: str) -> None:
  136. """Set the shared directory (called in xdist workers)."""
  137. self._shared_dir = shared_dir
  138. def dump_worker_records(self, worker_id: str | None = None) -> None:
  139. """Write this process's records to a file in the shared directory (called in workers)."""
  140. if not self._shared_dir or not self._records:
  141. return
  142. worker_id = worker_id or f"pid{os.getpid()}"
  143. dump_path = os.path.join(self._shared_dir, f"records_{worker_id}.json")
  144. with self._lock:
  145. records = [{**record, "phases_ms": dict(record["phases_ms"])} for record in self._records]
  146. Path(dump_path).write_text(json.dumps(records), encoding="utf-8")
  147. def load_worker_records(self) -> None:
  148. """Load all worker record files from the shared directory (called in controller)."""
  149. if not self._shared_dir or not os.path.isdir(self._shared_dir):
  150. return
  151. import glob as glob_module
  152. for record_file in glob_module.glob(os.path.join(self._shared_dir, "records_*.json")):
  153. try:
  154. records = json.loads(Path(record_file).read_text(encoding="utf-8"))
  155. with self._lock:
  156. for record in records:
  157. record["phases_ms"] = defaultdict(float, record.get("phases_ms", {}))
  158. self._records.append(record)
  159. except (OSError, json.JSONDecodeError):
  160. pass
  161. def cleanup_shared_dir(self) -> None:
  162. """Remove the shared temp directory."""
  163. if self._shared_dir and os.path.isdir(self._shared_dir):
  164. import shutil
  165. shutil.rmtree(self._shared_dir, ignore_errors=True)
  166. self._shared_dir = None
  167. def disable(self) -> None:
  168. if not self._enabled:
  169. return
  170. httpx.Client.send = self._original_client_send
  171. httpx.AsyncClient.send = self._original_async_client_send
  172. self._enabled = False
  173. self._original_client_send = None
  174. self._original_async_client_send = None
  175. self._output_path = None
  176. self.clear()
  177. def _append_record(self, record: dict[str, Any]) -> None:
  178. with self._lock:
  179. self._records.append(record)
  180. def _wrap_trace_callback(self, request: httpx.Request, trace: _NetworkRequestTrace):
  181. existing_trace = request.extensions.get("trace")
  182. def wrapped_trace(name: str, info: dict[str, Any]) -> Any:
  183. trace.trace(name, info)
  184. if existing_trace is not None:
  185. return existing_trace(name, info)
  186. return None
  187. return wrapped_trace
  188. async def _awrap_trace_callback(self, request: httpx.Request, trace: _NetworkRequestTrace):
  189. existing_trace = request.extensions.get("trace")
  190. async def wrapped_trace(name: str, info: dict[str, Any]) -> Any:
  191. trace.trace(name, info)
  192. if existing_trace is not None:
  193. result = existing_trace(name, info)
  194. if inspect.isawaitable(result):
  195. return await result
  196. return result
  197. return None
  198. return wrapped_trace
  199. def _send_with_trace(self, original_send, client, request: httpx.Request, *args, **kwargs):
  200. trace = _NetworkRequestTrace(request)
  201. request.extensions = dict(request.extensions)
  202. request.extensions["trace"] = self._wrap_trace_callback(request, trace)
  203. try:
  204. response = original_send(client, request, *args, **kwargs)
  205. except Exception as error:
  206. self._append_record(trace.build_record(error=error, stream=kwargs.get("stream", False)))
  207. raise
  208. self._append_record(trace.build_record(response=response, stream=kwargs.get("stream", False)))
  209. return response
  210. async def _async_send_with_trace(self, original_send, client, request: httpx.Request, *args, **kwargs):
  211. trace = _NetworkRequestTrace(request)
  212. request.extensions = dict(request.extensions)
  213. request.extensions["trace"] = await self._awrap_trace_callback(request, trace)
  214. try:
  215. response = await original_send(client, request, *args, **kwargs)
  216. except Exception as error:
  217. self._append_record(trace.build_record(error=error, stream=kwargs.get("stream", False)))
  218. raise
  219. self._append_record(trace.build_record(response=response, stream=kwargs.get("stream", False)))
  220. return response
  221. def build_report(self) -> dict[str, Any]:
  222. with self._lock:
  223. records = [
  224. {
  225. **record,
  226. "phases_ms": dict(record["phases_ms"]),
  227. }
  228. for record in self._records
  229. ]
  230. phase_totals_ms = defaultdict(float)
  231. route_totals = {}
  232. for record in records:
  233. for phase, duration_ms in record["phases_ms"].items():
  234. phase_totals_ms[phase] += duration_ms
  235. route_key = (record["method"], record["host_display"], record["path"])
  236. route_total = route_totals.setdefault(
  237. route_key,
  238. {
  239. "method": record["method"],
  240. "host_display": record["host_display"],
  241. "path": record["path"],
  242. "count": 0,
  243. "failures": 0,
  244. "total_ms": 0.0,
  245. "phase_totals_ms": defaultdict(float),
  246. },
  247. )
  248. route_total["count"] += 1
  249. route_total["total_ms"] += record["total_ms"]
  250. route_total["failures"] += int(record["error"] is not None)
  251. for phase, duration_ms in record["phases_ms"].items():
  252. route_total["phase_totals_ms"][phase] += duration_ms
  253. routes = []
  254. for route_total in route_totals.values():
  255. route_total["avg_ms"] = route_total["total_ms"] / route_total["count"]
  256. route_total["phase_totals_ms"] = dict(sorted(route_total["phase_totals_ms"].items()))
  257. routes.append(route_total)
  258. routes.sort(key=lambda route: route["total_ms"], reverse=True)
  259. total_time_ms = sum(record["total_ms"] for record in records)
  260. return {
  261. "enabled": self._enabled,
  262. "output_path": self._output_path,
  263. "total_requests": len(records),
  264. "failed_requests": sum(int(record["error"] is not None) for record in records),
  265. "total_time_ms": total_time_ms,
  266. "phase_totals_ms": dict(sorted(phase_totals_ms.items())),
  267. "requests": records,
  268. "routes": routes,
  269. }
  270. def maybe_write_report(self) -> str | None:
  271. if self._output_path is None:
  272. return None
  273. report_path = Path(self._output_path)
  274. report_path.parent.mkdir(parents=True, exist_ok=True)
  275. report_path.write_text(json.dumps(self.build_report(), indent=2, sort_keys=True), encoding="utf-8")
  276. return str(report_path)
  277. _NETWORK_DEBUG_PROFILER = _NetworkDebugProfiler()
  278. _DEFAULT_REPORT_PATH = "network_debug_report.json"
  279. def _parse_network_debug_env() -> tuple[bool, str]:
  280. enabled_raw = os.environ.get("NETWORK_DEBUG_REPORT", "").strip()
  281. try:
  282. enabled = bool(strtobool(enabled_raw)) if enabled_raw else False
  283. except ValueError:
  284. enabled = False
  285. output_path = os.environ.get("NETWORK_DEBUG_REPORT_PATH", "").strip() or _DEFAULT_REPORT_PATH
  286. return enabled, output_path
  287. def _enable_network_debug_report(output_path: str | os.PathLike | None = None) -> None:
  288. _NETWORK_DEBUG_PROFILER.enable(output_path=output_path)
  289. def _disable_network_debug_report() -> None:
  290. _NETWORK_DEBUG_PROFILER.disable()
  291. def _clear_network_debug_report() -> None:
  292. _NETWORK_DEBUG_PROFILER.clear()
  293. def _get_network_debug_report() -> dict[str, Any]:
  294. return _NETWORK_DEBUG_PROFILER.build_report()
  295. def _enable_network_debug_report_from_env() -> bool:
  296. enabled, output_path = _parse_network_debug_env()
  297. if not enabled:
  298. return False
  299. _enable_network_debug_report(output_path=output_path)
  300. return True
  301. def _format_network_debug_report(max_requests: int = 20, max_routes: int = 10) -> str:
  302. report = _get_network_debug_report()
  303. if report["total_requests"] == 0:
  304. return "Network debug report: no httpx requests captured."
  305. lines = [
  306. "Network debug report",
  307. f"Requests captured: {report['total_requests']}",
  308. f"Failed requests: {report['failed_requests']}",
  309. f"Cumulative request time: {report['total_time_ms']:.1f} ms",
  310. ]
  311. if report["phase_totals_ms"]:
  312. phase_summary = ", ".join(
  313. f"{phase}={duration_ms:.1f} ms"
  314. for phase, duration_ms in sorted(report["phase_totals_ms"].items(), key=lambda item: item[1], reverse=True)
  315. )
  316. lines.append(f"Phase totals: {phase_summary}")
  317. lines.append("")
  318. lines.append("Slowest requests:")
  319. for idx, record in enumerate(
  320. sorted(report["requests"], key=lambda request: request["total_ms"], reverse=True)[:max_requests],
  321. start=1,
  322. ):
  323. status = record["error"] or f"status={record['status_code']}"
  324. phase_bits = []
  325. for phase in ("connect_tcp", "start_tls", "receive_response_headers", "receive_response_body"):
  326. duration_ms = record["phases_ms"].get(phase)
  327. if duration_ms is not None:
  328. phase_bits.append(f"{phase}={duration_ms:.1f} ms")
  329. phase_suffix = f" ({', '.join(phase_bits)})" if phase_bits else ""
  330. incomplete_suffix = " incomplete" if record["stream"] and not record["response_complete"] else ""
  331. lines.append(
  332. f"{idx:>2}. {record['method']} {record['url']} {record['total_ms']:.1f} ms {status}{incomplete_suffix}{phase_suffix}"
  333. )
  334. lines.append("")
  335. lines.append("Slowest routes:")
  336. for idx, route in enumerate(report["routes"][:max_routes], start=1):
  337. lines.append(
  338. f"{idx:>2}. {route['method']} {route['host_display']}{route['path']} count={route['count']} "
  339. f"total={route['total_ms']:.1f} ms avg={route['avg_ms']:.1f} ms failures={route['failures']}"
  340. )
  341. return "\n".join(lines)
  342. class NetworkDebugPlugin:
  343. """Pytest plugin that handles all network debug orchestration including xdist coordination."""
  344. def pytest_configure(self, config):
  345. _enable_network_debug_report_from_env()
  346. if not _NETWORK_DEBUG_PROFILER.enabled:
  347. return
  348. # xdist controller: create shared dir for workers to dump network records
  349. if not hasattr(config, "workerinput"):
  350. shared_dir = _NETWORK_DEBUG_PROFILER.setup_shared_dir()
  351. if shared_dir:
  352. config._network_debug_shared_dir = shared_dir
  353. else:
  354. # xdist worker: receive shared dir from controller
  355. shared_dir = config.workerinput.get("network_debug_shared_dir")
  356. if shared_dir:
  357. _NETWORK_DEBUG_PROFILER.set_shared_dir(shared_dir)
  358. def pytest_configure_node(self, node):
  359. """xdist hook: called on the controller to configure each worker node."""
  360. shared_dir = getattr(node.config, "_network_debug_shared_dir", None)
  361. if shared_dir:
  362. node.workerinput["network_debug_shared_dir"] = shared_dir
  363. def pytest_sessionfinish(self, session, exitstatus):
  364. # xdist worker: dump network debug records for the controller to aggregate
  365. if hasattr(session.config, "workerinput"):
  366. worker_id = session.config.workerinput.get("workerid", f"pid{os.getpid()}")
  367. _NETWORK_DEBUG_PROFILER.dump_worker_records(worker_id=worker_id)
  368. def pytest_terminal_summary(self, terminalreporter):
  369. if not _NETWORK_DEBUG_PROFILER.enabled:
  370. return
  371. # Skip report generation in xdist worker processes; only the controller should aggregate and report.
  372. if hasattr(terminalreporter.config, "workerinput"):
  373. return
  374. # Aggregate worker records if running under xdist.
  375. _NETWORK_DEBUG_PROFILER.load_worker_records()
  376. report_path = None
  377. try:
  378. report_path = _NETWORK_DEBUG_PROFILER.maybe_write_report()
  379. except OSError as error:
  380. report_path = f"Failed to write JSON report: {error}"
  381. terminalreporter.section("Network debug", sep="=")
  382. for line in _format_network_debug_report().splitlines():
  383. terminalreporter.write_line(line)
  384. if report_path is not None:
  385. terminalreporter.write_line(f"JSON report: {report_path}")
  386. _NETWORK_DEBUG_PROFILER.cleanup_shared_dir()
  387. def register_network_debug_plugin(config) -> None:
  388. """Register the network debug pytest plugin. Single entry point for conftest.py."""
  389. config.pluginmanager.register(NetworkDebugPlugin(), "network_debug")
  390. __all__ = [
  391. "register_network_debug_plugin",
  392. ]