common.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # mypy: allow-untyped-defs
  2. from torch.fx._compatibility import compatibility
  3. from torch.fx.graph import Graph
  4. from torch.fx.graph_module import GraphModule
  5. from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
  6. from torch.nn import Module
  7. __all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"]
  8. @compatibility(is_backward_compatible=False)
  9. class HolderModule(Module):
  10. """
  11. HolderModule is used to copy all the attributes from original module to submodules
  12. that uses the attributes
  13. """
  14. def __init__(self, d):
  15. super().__init__()
  16. for k, v in d.items():
  17. self.add_module(k, v)
  18. @compatibility(is_backward_compatible=False)
  19. def lift_subgraph_as_module(
  20. gm: GraphModule,
  21. subgraph: Graph,
  22. comp_name: str = "",
  23. class_name: str = "GraphModule",
  24. ) -> tuple[GraphModule, dict[str, str]]:
  25. """
  26. Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
  27. Args:
  28. gm (GraphModule): parent graph module
  29. subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph
  30. comp_name (str): name for the new component
  31. class_name (str): name for the submodule
  32. """
  33. # Loop through all module calls (call_module) and param fetches (get_attr)
  34. # in this component, creating HolderModules as necessary to match the path.
  35. # e.g. if in the original module there's a get_attr node fetches "conv.weight".
  36. # We create a HolderModule as root -> add a HolderModule named "conv" ->
  37. # make "weight" a attribute of "conv" HolderModule and point to conv.weight in
  38. # the original module.
  39. submodule = HolderModule({})
  40. orig_to_split_fqn_mapping: dict[str, str] = {}
  41. for n in subgraph.nodes:
  42. if n.op not in ("call_module", "get_attr"):
  43. continue
  44. target = n.target
  45. if not isinstance(target, str):
  46. raise AssertionError(f"Expected str target, got {type(target)}")
  47. target_name_parts = target.split(".")
  48. curr = submodule
  49. orig_gm = gm
  50. for name in target_name_parts[:-1]:
  51. if not hasattr(curr, name):
  52. curr.add_module(name, HolderModule({}))
  53. curr = getattr(curr, name)
  54. orig_gm = getattr(orig_gm, name)
  55. leaf_node_name = target_name_parts[-1]
  56. leaf_node = getattr(orig_gm, leaf_node_name)
  57. orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}"
  58. # Relies on custom __setattr__ magic.
  59. setattr(curr, leaf_node_name, leaf_node)
  60. return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping
  61. @compatibility(is_backward_compatible=False)
  62. def compare_graphs(left: Graph, right: Graph) -> bool:
  63. """
  64. Return True if two graphs are identical, i.e they
  65. - have the same number of outputs in the same order
  66. - have the same number of inputs in the same order
  67. - have the same set of nodes, and identical connectivity
  68. """
  69. matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True)
  70. matches = matcher.match(right)
  71. return len(matches) > 0