debug.py 811 B

123456789101112131415161718192021222324252627282930313233
  1. from collections.abc import Sequence
  2. import torch.fx as fx
  3. __all__ = ["set_trace"]
  4. def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
  5. """
  6. Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
  7. `gm` gets run.
  8. Args:
  9. gm: graph module to insert breakpoint. It is then recompiled for it to
  10. take effect.
  11. Returns:
  12. the `gm` with breakpoint inserted.
  13. """
  14. def insert_pdb(body: Sequence[str]) -> list[str]:
  15. return ["import pdb; pdb.set_trace()\n", *body]
  16. with gm.graph.on_generate_code(
  17. make_transformer=lambda cur_transform: (
  18. # new code transformer to register
  19. lambda body: (insert_pdb(cur_transform(body) if cur_transform else body))
  20. )
  21. ):
  22. gm.recompile()
  23. return gm