_unflatten.py 952 B

123456789101112131415161718192021222324252627282930
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from collections import defaultdict
  3. import torch
  4. from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry
  5. def _outline_submodules(orig_graph: torch.fx.Graph) -> torch.fx.GraphModule:
  6. # Create an empty GraphModule to hold the outlined modules
  7. new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
  8. seen_nodes: dict[str, torch.fx.Node] = {}
  9. seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list)
  10. seen_attrs: dict[str, set[str]] = defaultdict(set)
  11. created_modules: dict[str, torch.nn.Module] = {}
  12. _ModuleFrame(
  13. orig_graph,
  14. tuple(orig_graph.nodes),
  15. seen_nodes,
  16. seen_modules,
  17. seen_attrs,
  18. created_modules,
  19. None,
  20. [("", None, 0)],
  21. "",
  22. {},
  23. module=new_module,
  24. ).run_outer()
  25. new_module.graph.lint()
  26. new_module.recompile()
  27. return new_module