| 123456789101112131415161718192021222324252627282930313233 |
- from collections.abc import Sequence
- import torch.fx as fx
- __all__ = ["set_trace"]
- def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
- """
- Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
- `gm` gets run.
- Args:
- gm: graph module to insert breakpoint. It is then recompiled for it to
- take effect.
- Returns:
- the `gm` with breakpoint inserted.
- """
- def insert_pdb(body: Sequence[str]) -> list[str]:
- return ["import pdb; pdb.set_trace()\n", *body]
- with gm.graph.on_generate_code(
- make_transformer=lambda cur_transform: (
- # new code transformer to register
- lambda body: (insert_pdb(cur_transform(body) if cur_transform else body))
- )
- ):
- gm.recompile()
- return gm
|