| 1234567891011121314151617181920212223 |
- import pathlib
- import tempfile
- import time
- from torch._C._distributed_c10d import _register_handler, _Request, _Response
- from torch.profiler import _ExperimentalConfig, profile
- def _torch_profile(req: _Request, resp: _Response) -> None:
- experimental_config = _ExperimentalConfig(
- profile_all_threads=True,
- )
- duration = float(req.get_param("duration"))
- with profile(record_shapes=True, experimental_config=experimental_config) as prof:
- time.sleep(duration)
- with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
- prof.export_chrome_trace(f.name)
- resp.set_content(pathlib.Path(f.name).read_bytes(), "application/json")
- resp.set_status(200)
- _register_handler("torch_profile", _torch_profile)
|