_handlers.py 776 B

1234567891011121314151617181920212223
  1. import pathlib
  2. import tempfile
  3. import time
  4. from torch._C._distributed_c10d import _register_handler, _Request, _Response
  5. from torch.profiler import _ExperimentalConfig, profile
  6. def _torch_profile(req: _Request, resp: _Response) -> None:
  7. experimental_config = _ExperimentalConfig(
  8. profile_all_threads=True,
  9. )
  10. duration = float(req.get_param("duration"))
  11. with profile(record_shapes=True, experimental_config=experimental_config) as prof:
  12. time.sleep(duration)
  13. with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
  14. prof.export_chrome_trace(f.name)
  15. resp.set_content(pathlib.Path(f.name).read_bytes(), "application/json")
  16. resp.set_status(200)
  17. _register_handler("torch_profile", _torch_profile)