| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485 |
- # Copyright 2026 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import annotations
- import inspect
- import json
- import os
- import threading
- import time
- from collections import defaultdict
- from functools import wraps
- from pathlib import Path
- from typing import Any
- import httpx
- from .generic import strtobool
- class _NetworkRequestTrace:
- def __init__(self, request: httpx.Request):
- self.request = request
- self.started_at = time.perf_counter()
- self.phase_started_at = {}
- self.phases_ms = defaultdict(float)
- def trace(self, name: str, info: dict[str, Any]) -> None:
- parts = name.rsplit(".", 2)
- if len(parts) != 3:
- return
- _, phase, state = parts
- now = time.perf_counter()
- if state == "started":
- self.phase_started_at[phase] = now
- elif state in {"complete", "failed"}:
- phase_started_at = self.phase_started_at.pop(phase, None)
- if phase_started_at is not None:
- self.phases_ms[phase] += (now - phase_started_at) * 1000
- def build_record(
- self,
- *,
- response: httpx.Response | None = None,
- error: BaseException | None = None,
- stream: bool = False,
- ) -> dict[str, Any]:
- total_ms = (time.perf_counter() - self.started_at) * 1000
- url = self.request.url
- host = url.host or ""
- port = url.port
- default_port = {"http": 80, "https": 443}.get(url.scheme)
- host_display = host if port in (None, default_port) else f"{host}:{port}"
- http_version = None
- status_code = None
- bytes_downloaded = None
- response_complete = False
- if response is not None:
- status_code = response.status_code
- response_complete = response.is_closed
- raw_http_version = response.extensions.get("http_version")
- if isinstance(raw_http_version, bytes):
- http_version = raw_http_version.decode("ascii", errors="replace")
- elif raw_http_version is not None:
- http_version = str(raw_http_version)
- if response_complete:
- try:
- bytes_downloaded = len(response.content)
- except httpx.ResponseNotRead:
- pass
- return {
- "method": self.request.method,
- "scheme": url.scheme,
- "host": host,
- "host_display": host_display,
- "port": port,
- "path": url.path,
- "has_query": bool(url.query),
- "url": f"{url.scheme}://{host_display}{url.path}{'?...' if url.query else ''}",
- "request_id": self.request.headers.get("x-amzn-trace-id") or self.request.headers.get("x-request-id"),
- "status_code": status_code,
- "http_version": http_version,
- "bytes_downloaded": bytes_downloaded,
- "total_ms": total_ms,
- "stream": stream,
- "response_complete": response_complete,
- "phases_ms": dict(sorted(self.phases_ms.items())),
- "error": None if error is None else f"{type(error).__name__}: {error}",
- }
- class _NetworkDebugProfiler:
- def __init__(self):
- self._records = []
- self._lock = threading.Lock()
- self._enabled = False
- self._output_path = None
- self._original_client_send = None
- self._original_async_client_send = None
- self._shared_dir = None
- @property
- def enabled(self) -> bool:
- return self._enabled
- def clear(self) -> None:
- with self._lock:
- self._records = []
- def enable(self, output_path: str | os.PathLike | None = None) -> None:
- if self._enabled:
- self._output_path = None if output_path is None else os.fspath(output_path)
- self.clear()
- return
- self._output_path = None if output_path is None else os.fspath(output_path)
- self.clear()
- profiler = self
- self._original_client_send = httpx.Client.send
- self._original_async_client_send = httpx.AsyncClient.send
- @wraps(self._original_client_send)
- def patched_client_send(client, request, *args, **kwargs):
- return profiler._send_with_trace(profiler._original_client_send, client, request, *args, **kwargs)
- @wraps(self._original_async_client_send)
- async def patched_async_client_send(client, request, *args, **kwargs):
- return await profiler._async_send_with_trace(
- profiler._original_async_client_send, client, request, *args, **kwargs
- )
- httpx.Client.send = patched_client_send
- httpx.AsyncClient.send = patched_async_client_send
- self._enabled = True
- def setup_shared_dir(self) -> str | None:
- """Create a shared temp directory for xdist workers to dump records into."""
- if self._shared_dir is None:
- import tempfile
- self._shared_dir = tempfile.mkdtemp(prefix="network_debug_")
- return self._shared_dir
- def set_shared_dir(self, shared_dir: str) -> None:
- """Set the shared directory (called in xdist workers)."""
- self._shared_dir = shared_dir
- def dump_worker_records(self, worker_id: str | None = None) -> None:
- """Write this process's records to a file in the shared directory (called in workers)."""
- if not self._shared_dir or not self._records:
- return
- worker_id = worker_id or f"pid{os.getpid()}"
- dump_path = os.path.join(self._shared_dir, f"records_{worker_id}.json")
- with self._lock:
- records = [{**record, "phases_ms": dict(record["phases_ms"])} for record in self._records]
- Path(dump_path).write_text(json.dumps(records), encoding="utf-8")
- def load_worker_records(self) -> None:
- """Load all worker record files from the shared directory (called in controller)."""
- if not self._shared_dir or not os.path.isdir(self._shared_dir):
- return
- import glob as glob_module
- for record_file in glob_module.glob(os.path.join(self._shared_dir, "records_*.json")):
- try:
- records = json.loads(Path(record_file).read_text(encoding="utf-8"))
- with self._lock:
- for record in records:
- record["phases_ms"] = defaultdict(float, record.get("phases_ms", {}))
- self._records.append(record)
- except (OSError, json.JSONDecodeError):
- pass
- def cleanup_shared_dir(self) -> None:
- """Remove the shared temp directory."""
- if self._shared_dir and os.path.isdir(self._shared_dir):
- import shutil
- shutil.rmtree(self._shared_dir, ignore_errors=True)
- self._shared_dir = None
- def disable(self) -> None:
- if not self._enabled:
- return
- httpx.Client.send = self._original_client_send
- httpx.AsyncClient.send = self._original_async_client_send
- self._enabled = False
- self._original_client_send = None
- self._original_async_client_send = None
- self._output_path = None
- self.clear()
- def _append_record(self, record: dict[str, Any]) -> None:
- with self._lock:
- self._records.append(record)
- def _wrap_trace_callback(self, request: httpx.Request, trace: _NetworkRequestTrace):
- existing_trace = request.extensions.get("trace")
- def wrapped_trace(name: str, info: dict[str, Any]) -> Any:
- trace.trace(name, info)
- if existing_trace is not None:
- return existing_trace(name, info)
- return None
- return wrapped_trace
- async def _awrap_trace_callback(self, request: httpx.Request, trace: _NetworkRequestTrace):
- existing_trace = request.extensions.get("trace")
- async def wrapped_trace(name: str, info: dict[str, Any]) -> Any:
- trace.trace(name, info)
- if existing_trace is not None:
- result = existing_trace(name, info)
- if inspect.isawaitable(result):
- return await result
- return result
- return None
- return wrapped_trace
- def _send_with_trace(self, original_send, client, request: httpx.Request, *args, **kwargs):
- trace = _NetworkRequestTrace(request)
- request.extensions = dict(request.extensions)
- request.extensions["trace"] = self._wrap_trace_callback(request, trace)
- try:
- response = original_send(client, request, *args, **kwargs)
- except Exception as error:
- self._append_record(trace.build_record(error=error, stream=kwargs.get("stream", False)))
- raise
- self._append_record(trace.build_record(response=response, stream=kwargs.get("stream", False)))
- return response
- async def _async_send_with_trace(self, original_send, client, request: httpx.Request, *args, **kwargs):
- trace = _NetworkRequestTrace(request)
- request.extensions = dict(request.extensions)
- request.extensions["trace"] = await self._awrap_trace_callback(request, trace)
- try:
- response = await original_send(client, request, *args, **kwargs)
- except Exception as error:
- self._append_record(trace.build_record(error=error, stream=kwargs.get("stream", False)))
- raise
- self._append_record(trace.build_record(response=response, stream=kwargs.get("stream", False)))
- return response
- def build_report(self) -> dict[str, Any]:
- with self._lock:
- records = [
- {
- **record,
- "phases_ms": dict(record["phases_ms"]),
- }
- for record in self._records
- ]
- phase_totals_ms = defaultdict(float)
- route_totals = {}
- for record in records:
- for phase, duration_ms in record["phases_ms"].items():
- phase_totals_ms[phase] += duration_ms
- route_key = (record["method"], record["host_display"], record["path"])
- route_total = route_totals.setdefault(
- route_key,
- {
- "method": record["method"],
- "host_display": record["host_display"],
- "path": record["path"],
- "count": 0,
- "failures": 0,
- "total_ms": 0.0,
- "phase_totals_ms": defaultdict(float),
- },
- )
- route_total["count"] += 1
- route_total["total_ms"] += record["total_ms"]
- route_total["failures"] += int(record["error"] is not None)
- for phase, duration_ms in record["phases_ms"].items():
- route_total["phase_totals_ms"][phase] += duration_ms
- routes = []
- for route_total in route_totals.values():
- route_total["avg_ms"] = route_total["total_ms"] / route_total["count"]
- route_total["phase_totals_ms"] = dict(sorted(route_total["phase_totals_ms"].items()))
- routes.append(route_total)
- routes.sort(key=lambda route: route["total_ms"], reverse=True)
- total_time_ms = sum(record["total_ms"] for record in records)
- return {
- "enabled": self._enabled,
- "output_path": self._output_path,
- "total_requests": len(records),
- "failed_requests": sum(int(record["error"] is not None) for record in records),
- "total_time_ms": total_time_ms,
- "phase_totals_ms": dict(sorted(phase_totals_ms.items())),
- "requests": records,
- "routes": routes,
- }
- def maybe_write_report(self) -> str | None:
- if self._output_path is None:
- return None
- report_path = Path(self._output_path)
- report_path.parent.mkdir(parents=True, exist_ok=True)
- report_path.write_text(json.dumps(self.build_report(), indent=2, sort_keys=True), encoding="utf-8")
- return str(report_path)
- _NETWORK_DEBUG_PROFILER = _NetworkDebugProfiler()
- _DEFAULT_REPORT_PATH = "network_debug_report.json"
- def _parse_network_debug_env() -> tuple[bool, str]:
- enabled_raw = os.environ.get("NETWORK_DEBUG_REPORT", "").strip()
- try:
- enabled = bool(strtobool(enabled_raw)) if enabled_raw else False
- except ValueError:
- enabled = False
- output_path = os.environ.get("NETWORK_DEBUG_REPORT_PATH", "").strip() or _DEFAULT_REPORT_PATH
- return enabled, output_path
- def _enable_network_debug_report(output_path: str | os.PathLike | None = None) -> None:
- _NETWORK_DEBUG_PROFILER.enable(output_path=output_path)
- def _disable_network_debug_report() -> None:
- _NETWORK_DEBUG_PROFILER.disable()
- def _clear_network_debug_report() -> None:
- _NETWORK_DEBUG_PROFILER.clear()
- def _get_network_debug_report() -> dict[str, Any]:
- return _NETWORK_DEBUG_PROFILER.build_report()
- def _enable_network_debug_report_from_env() -> bool:
- enabled, output_path = _parse_network_debug_env()
- if not enabled:
- return False
- _enable_network_debug_report(output_path=output_path)
- return True
- def _format_network_debug_report(max_requests: int = 20, max_routes: int = 10) -> str:
- report = _get_network_debug_report()
- if report["total_requests"] == 0:
- return "Network debug report: no httpx requests captured."
- lines = [
- "Network debug report",
- f"Requests captured: {report['total_requests']}",
- f"Failed requests: {report['failed_requests']}",
- f"Cumulative request time: {report['total_time_ms']:.1f} ms",
- ]
- if report["phase_totals_ms"]:
- phase_summary = ", ".join(
- f"{phase}={duration_ms:.1f} ms"
- for phase, duration_ms in sorted(report["phase_totals_ms"].items(), key=lambda item: item[1], reverse=True)
- )
- lines.append(f"Phase totals: {phase_summary}")
- lines.append("")
- lines.append("Slowest requests:")
- for idx, record in enumerate(
- sorted(report["requests"], key=lambda request: request["total_ms"], reverse=True)[:max_requests],
- start=1,
- ):
- status = record["error"] or f"status={record['status_code']}"
- phase_bits = []
- for phase in ("connect_tcp", "start_tls", "receive_response_headers", "receive_response_body"):
- duration_ms = record["phases_ms"].get(phase)
- if duration_ms is not None:
- phase_bits.append(f"{phase}={duration_ms:.1f} ms")
- phase_suffix = f" ({', '.join(phase_bits)})" if phase_bits else ""
- incomplete_suffix = " incomplete" if record["stream"] and not record["response_complete"] else ""
- lines.append(
- f"{idx:>2}. {record['method']} {record['url']} {record['total_ms']:.1f} ms {status}{incomplete_suffix}{phase_suffix}"
- )
- lines.append("")
- lines.append("Slowest routes:")
- for idx, route in enumerate(report["routes"][:max_routes], start=1):
- lines.append(
- f"{idx:>2}. {route['method']} {route['host_display']}{route['path']} count={route['count']} "
- f"total={route['total_ms']:.1f} ms avg={route['avg_ms']:.1f} ms failures={route['failures']}"
- )
- return "\n".join(lines)
- class NetworkDebugPlugin:
- """Pytest plugin that handles all network debug orchestration including xdist coordination."""
- def pytest_configure(self, config):
- _enable_network_debug_report_from_env()
- if not _NETWORK_DEBUG_PROFILER.enabled:
- return
- # xdist controller: create shared dir for workers to dump network records
- if not hasattr(config, "workerinput"):
- shared_dir = _NETWORK_DEBUG_PROFILER.setup_shared_dir()
- if shared_dir:
- config._network_debug_shared_dir = shared_dir
- else:
- # xdist worker: receive shared dir from controller
- shared_dir = config.workerinput.get("network_debug_shared_dir")
- if shared_dir:
- _NETWORK_DEBUG_PROFILER.set_shared_dir(shared_dir)
- def pytest_configure_node(self, node):
- """xdist hook: called on the controller to configure each worker node."""
- shared_dir = getattr(node.config, "_network_debug_shared_dir", None)
- if shared_dir:
- node.workerinput["network_debug_shared_dir"] = shared_dir
- def pytest_sessionfinish(self, session, exitstatus):
- # xdist worker: dump network debug records for the controller to aggregate
- if hasattr(session.config, "workerinput"):
- worker_id = session.config.workerinput.get("workerid", f"pid{os.getpid()}")
- _NETWORK_DEBUG_PROFILER.dump_worker_records(worker_id=worker_id)
- def pytest_terminal_summary(self, terminalreporter):
- if not _NETWORK_DEBUG_PROFILER.enabled:
- return
- # Skip report generation in xdist worker processes; only the controller should aggregate and report.
- if hasattr(terminalreporter.config, "workerinput"):
- return
- # Aggregate worker records if running under xdist.
- _NETWORK_DEBUG_PROFILER.load_worker_records()
- report_path = None
- try:
- report_path = _NETWORK_DEBUG_PROFILER.maybe_write_report()
- except OSError as error:
- report_path = f"Failed to write JSON report: {error}"
- terminalreporter.section("Network debug", sep="=")
- for line in _format_network_debug_report().splitlines():
- terminalreporter.write_line(line)
- if report_path is not None:
- terminalreporter.write_line(f"JSON report: {report_path}")
- _NETWORK_DEBUG_PROFILER.cleanup_shared_dir()
- def register_network_debug_plugin(config) -> None:
- """Register the network debug pytest plugin. Single entry point for conftest.py."""
- config.pluginmanager.register(NetworkDebugPlugin(), "network_debug")
- __all__ = [
- "register_network_debug_plugin",
- ]
|