| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682 |
- # mypy: allow-untyped-defs
- import inspect
- import logging
- from collections import OrderedDict
- from collections.abc import Callable
- from typing import Any, Optional
- import torch
- from torch.fx._compatibility import compatibility
- from torch.fx._utils import lazy_format_graph_code
- from torch.fx.graph_module import GraphModule
- from torch.fx.node import Node
- __all__ = ["Partition", "split_module"]
- log = _LOGGER = logging.getLogger(__name__)
- @compatibility(is_backward_compatible=True)
- class Partition:
- def __init__(self, name: str):
- self.name: str = name
- self.submod_name = f"submod_{name}"
- self.node_names: list[str] = []
- self.inputs: dict[str, None] = {}
- self.outputs: dict[str, None] = {}
- self.dependencies: dict[str, None] = {}
- self.dependents: dict[str, None] = {}
- self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
- self.environment: dict[Node, Node] = {}
- self.targets: dict[str, Any] = {}
- def __repr__(self) -> str:
- return (
- f"name: {self.name},\n"
- f" nodes: {self.node_names},\n"
- f" inputs: {self.inputs},\n"
- f" outputs: {self.outputs},\n"
- f" partitions depended on: {self.dependencies},\n"
- f" partition dependents: {self.dependents}"
- )
- def _get_attr_from_qualname(mod: torch.nn.Module, qualname: str) -> Any:
- attr_val = mod
- for atom in qualname.split("."): # type: ignore[union-attr]
- if not hasattr(attr_val, atom):
- raise AttributeError(f"Node target {qualname} not found!")
- attr_val = getattr(attr_val, atom)
- return attr_val
- # Creates subgraphs out of main graph
- @compatibility(is_backward_compatible=True)
- def split_module(
- m: GraphModule,
- root_m: torch.nn.Module,
- split_callback: Callable[[Node], int],
- qualname_map: Optional[dict[str, str]] = None,
- keep_original_order: Optional[bool] = False,
- keep_original_node_name: Optional[bool] = False,
- keep_original_input_name: bool = True,
- *,
- partition_affix: Optional[str] = None,
- ):
- """
- Creates subgraphs out of main graph
- Args:
- m (GraphModule): Graph module to split
- root_m (torch.nn.Module): root nn module. Not currently used. Included
- because the root nn module is usually transformed via
- torch.fx._symbolic_trace.symbolic_trace (see example below)
- split_callback (Callable[[Node], int]): Callable function
- that maps a given Node instance to a numeric partition identifier.
- split_module will use this function as the policy for which operations
- appear in which partitions in the output Module.
- qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a
- mapping from new target names in the module after split to old target
- names in the original module.
- keep_original_order: Optional[bool]: keep the original order of the GraphModule
- or use the Topological order of the new constructed GraphModule
- keep_original_node_name: Optional[bool]: If the partitioned graphs should
- have the same node names as the original graph.
- keep_original_input_name: bool: If the partitioned graphs should
- have the same input names as the original graph.
- partition_affix: Optional[str]: If specified, the submodules' names will contain
- the affix, e.g. "submod_<affix>_<idx>".
- Returns:
- GraphModule: the module after split.
- Example:
- This is a sample setup:
- import torch
- from torch.fx._symbolic_trace import symbolic_trace
- from torch.fx.graph_module import GraphModule
- from torch.fx.node import Node
- from torch.fx.passes.split_module import split_module
- class MyModule(torch.nn.Module):
- def __init__(self) -> None:
- super().__init__()
- self.param = torch.nn.Parameter(torch.rand(3, 4))
- self.linear = torch.nn.Linear(4, 5)
- def forward(self, x, y):
- z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
- w = self.linear(y).clamp(min=0.0, max=1.0)
- return z + w
- # symbolically trace model
- my_module = MyModule()
- my_module_traced = symbolic_trace(my_module)
- # random mod partitioning
- partition_counter = 0
- NPARTITIONS = 3
- def mod_partition(node: Node):
- global partition_counter
- partition = partition_counter % NPARTITIONS
- partition_counter = (partition_counter + 1) % NPARTITIONS
- return partition
- # split module in module with submodules
- module_with_submodules = split_module(
- my_module_traced, my_module, mod_partition
- )
- Output looks like this. Original graph is broken into partitions
- > print(module_with_submodules)
- GraphModule(
- (submod_0): GraphModule(
- (linear): Linear(in_features=4, out_features=5, bias=True)
- )
- (submod_1): GraphModule(
- (linear): Linear(in_features=4, out_features=5, bias=True)
- )
- (submod_2): GraphModule()
- )
- def forward(self, x, y):
- param = self.param
- submod_0 = self.submod_0(x, param, y); x = param = y = None
- getitem = submod_0[0]
- getitem_1 = submod_0[1]; submod_0 = None
- submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None
- getitem_2 = submod_1[0]
- getitem_3 = submod_1[1]; submod_1 = None
- submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
- return submod_2
- Output of split module is the same as output of input traced module.
- This is an example within a test setting:
- > orig_out = my_module_traced(x, y)
- > submodules_out = module_with_submodules(x, y)
- > self.assertEqual(orig_out, submodules_out)
- True
- """
- log.debug(
- "%s",
- lazy_format_graph_code("pre split_module", m, colored=True),
- )
- def construct_graph(
- node: Node,
- base_mod_env: dict[str, Node],
- base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule],
- ):
- if node.op == "placeholder":
- default_value = (
- node.args[0] if len(node.args) > 0 else inspect.Signature.empty
- )
- if keep_original_node_name:
- args = (
- () if default_value is inspect.Signature.empty else (default_value,)
- )
- base_mod_env[node.name] = base_mod_graph.create_node(
- "placeholder",
- node.name,
- args=args, # type: ignore[arg-type]
- type_expr=node.type,
- )
- else:
- base_mod_env[node.name] = base_mod_graph.placeholder(
- node.target, # type: ignore[arg-type]
- type_expr=node.type,
- default_value=default_value,
- )
- base_mod_env[node.name].meta = node.meta.copy()
- elif node.op == "get_attr":
- base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type]
- base_mod_env[node.name].meta = node.meta.copy()
- if not isinstance(node.target, str):
- raise AssertionError(f"Expected str target, got {type(node.target)}")
- attr_val = _get_attr_from_qualname(m, node.target)
- base_mod_attrs[node.target] = attr_val # type: ignore[index]
- return base_mod_env, base_mod_attrs
- import sympy
- partitions: dict[str, Partition] = {}
- orig_nodes: dict[str, Node] = {}
- symbol_to_node: dict[sympy.Symbol, Node] = {}
- def record_cross_partition_use(def_node: Node, use_node: Optional[Node]):
- from torch.fx.experimental.symbolic_shapes import free_symbols
- defined = getattr(def_node, "_fx_partition", None)
- used = getattr(use_node, "_fx_partition", None)
- log.debug(
- "record_cross_partition_use %s (%s) %s (%s)",
- def_node.name,
- defined,
- use_node.name if use_node is not None else "-",
- used,
- )
- if defined != used:
- if defined is not None:
- def_partition = partitions[defined]
- def_partition.outputs.setdefault(def_node.name)
- if used is not None:
- def_partition.dependents.setdefault(used)
- if used is not None:
- use_partition = partitions[used]
- use_partition.inputs.setdefault(def_node.name)
- # We have made def_node an input to the use_partition. If
- # this input has symbolic symbols in its size, those also must
- # be made as inputs to the partition
- if (def_val := def_node.meta.get("example_value")) is not None:
- for s in sorted(free_symbols(def_val), key=str):
- s_node = symbol_to_node[s]
- use_partition.inputs.setdefault(s_node.name)
- if symbol_to_node[s].op != "placeholder":
- # If the node that defines the symbol is not a
- # placeholder, we must make it an output of the
- # partition. Note that this may be in a different
- # partition than defined! Although, this doesn't
- # really make a difference for correctness, since
- # defined is guaranteed to have the symbol in
- # scope and can return it; you just get less
- # optimal codegen in this case.
- s_defined = getattr(s_node, "_fx_partition", None)
- if s_defined is not None:
- s_def_partition = partitions[s_defined]
- s_def_partition.outputs.setdefault(s_node.name)
- s_def_partition.dependents.setdefault(used)
- use_partition.dependencies.setdefault(s_defined)
- if defined is not None:
- use_partition.dependencies.setdefault(defined)
- def instantiate_node_partition_mapping(node):
- partition_idx = split_callback(node)
- partition_name = str(partition_idx)
- if partition_affix is not None:
- # For example, if user specifies partition_affix = "pp", then the
- # partition name will be "pp_0", "pp_1", etc
- partition_name = "_".join([partition_affix, partition_name])
- log.debug(
- "instantiate_node_partition_mapping %s (%s)", node.name, partition_name
- )
- # add node to partitions
- partition = partitions.get(partition_name)
- if partition is None:
- partitions[partition_name] = partition = Partition(partition_name)
- partition.node_names.append(node.name)
- node._fx_partition = partition_name
- # Global State Nodes are nodes which by their global state effects,
- # "taint" all downstream nodes while they are active.
- GLOBAL_STATE_NODES = [
- torch.amp._enter_autocast,
- torch.amp._exit_autocast,
- torch._C._set_grad_enabled,
- ]
- # For grad regions:
- # ------------------------
- # 1. first region: we do nothing
- # 2. subsequent regions: we insert the set_grad at the beginning
- grad_regions: OrderedDict[Node, set[int]] = OrderedDict()
- # For autocast regions:
- # ------------------------
- # 1. first region: we will only insert the _exit at the end
- # 2. intermediate regions: we will insert both the
- # _enter at the beginning and _exit at the end
- # 3. last region: we will only insert _enter at the beginning
- # We will do so in the order in which the autocasts were instantiated.
- autocast_regions: OrderedDict[Node, set[int]] = OrderedDict()
- autocast_exits: dict[Node, Optional[Node]] = {}
- active_grad = None
- active_autocasts = set()
- for node in m.graph.nodes:
- # This will prefer placeholder bindings, because those come first.
- # This is a little dangerous though: it is possible that an unbacked
- # symbol is used without any binding site for it, in which case we
- # will get a KeyError not able to find it. I'd like to fix this by
- # having passes.runtime_assert establish some invariants that I can
- # rely on later, but this needs some extra work. Quick fix first.
- # See https://github.com/pytorch/pytorch/issues/130534
- if (
- (val := node.meta.get("example_value")) is not None
- and isinstance(val, (torch.SymInt, torch.SymFloat))
- and isinstance(s0 := val.node.expr, sympy.Symbol)
- and s0 not in symbol_to_node
- ):
- symbol_to_node[val.node.expr] = node
- if node.op in ["placeholder", "get_attr", "output"]:
- continue
- instantiate_node_partition_mapping(node)
- if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
- if node.target is torch._C._set_grad_enabled:
- if len(node.args) != 1:
- raise AssertionError(
- f"Expected 1 arg for _set_grad_enabled, got {len(node.args)}"
- )
- if not isinstance(node.args[0], bool):
- raise AssertionError(f"Expected bool arg, got {type(node.args[0])}")
- active_grad = node
- grad_regions[active_grad] = set({split_callback(node)})
- elif node.target is torch.amp._enter_autocast:
- # Should all be python constants
- if not all(not isinstance(arg, Node) for arg in node.args):
- raise AssertionError(
- "Expected all args to be python constants, not Nodes"
- )
- active_autocasts.add(node)
- autocast_regions[node] = set({split_callback(node)})
- autocast_exits[node] = None
- elif node.target is torch.amp._exit_autocast:
- if len(node.args) != 1:
- raise AssertionError(
- f"Expected 1 arg for _exit_autocast, got {len(node.args)}"
- )
- autocast_regions[node.args[0]].add(split_callback(node))
- active_autocasts.remove(node.args[0])
- autocast_exits[node.args[0]] = node
- if active_grad is not None:
- grad_regions[active_grad].add(split_callback(node))
- for a in active_autocasts:
- autocast_regions[a].add(split_callback(node))
- if not all(v is not None for v in autocast_exits.values()):
- raise AssertionError("autocast must exit")
- # pyrefly: ignore [bad-assignment]
- autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
- # pyrefly: ignore [bad-assignment]
- grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
- if _LOGGER.isEnabledFor(logging.DEBUG):
- _LOGGER.debug("autocast_regions: %s", autocast_regions)
- _LOGGER.debug("grad_regions: %s", grad_regions)
- assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)
- # split nodes into partitions
- highest_partition = -1
- for node in m.graph.nodes:
- orig_nodes[node.name] = node
- # TODO currently placeholders/parameters aren't put into random partitions,
- # rather they're added to the graphs where they are used down below
- if node.op in ["placeholder", "get_attr"]:
- continue
- if node.op == "output":
- torch.fx.graph.map_arg(
- node.args[0], lambda n: record_cross_partition_use(n, None)
- )
- continue
- if assert_monotonically_increasing:
- pid = split_callback(node)
- if highest_partition > pid:
- raise AssertionError(
- "autocast or set_grad_enabled require monotonically increasing "
- f"partitions: highest: {highest_partition}, this node's: {pid}"
- )
- highest_partition = pid
- # do not capture cross-partition dependencies for global state nodes as they will be
- # self-contained - their setup and unwind will be isolated to each partition submodule.
- if node.target not in GLOBAL_STATE_NODES:
- torch.fx.graph.map_arg(
- node.args, lambda def_node: record_cross_partition_use(def_node, node)
- )
- torch.fx.graph.map_arg(
- node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
- ) # noqa: B950
- original_partition_order = list(partitions.keys())
- # find partitions with no dependencies
- root_partitions: list[str] = []
- for partition_name, partition in partitions.items():
- if not len(partition.dependencies):
- root_partitions.append(partition_name)
- # check partitions for circular dependencies and create topological partition ordering
- sorted_partitions: list[str] = []
- while root_partitions:
- root_partition = root_partitions.pop()
- sorted_partitions.append(root_partition)
- for dependent in partitions[root_partition].dependents:
- partitions[dependent].dependencies.pop(root_partition) # noqa: B909
- if not partitions[dependent].dependencies:
- root_partitions.append(dependent)
- if len(sorted_partitions) != len(partitions):
- raise RuntimeError("cycle exists between partitions!")
- # Enter prelude
- for regions_mapping in [autocast_regions, grad_regions]:
- for node, regions in regions_mapping.items():
- if len(regions) == 0:
- raise AssertionError("Expected at least one region for node")
- # pyrefly: ignore [bad-index]
- partitions[str(regions[0])].environment[node] = node
- # pyrefly: ignore [bad-index, index-error]
- # pyrefly: ignore [bad-index, index-error]
- for r in regions[1:]:
- partition = partitions[str(r)]
- new_node = partition.graph.create_node(
- op=node.op,
- target=node.target,
- args=tuple(arg for arg in node.args),
- kwargs={},
- type_expr=node.type,
- )
- new_node.meta = (
- node.meta.copy()
- ) # is it really a good idea to copy this?
- partition.environment[node] = new_node
- # add placeholders to partition inputs
- for partition_name in sorted_partitions:
- partition = partitions[partition_name]
- new_inputs: dict[str, None] = {}
- counter = 0
- for inp in partition.inputs:
- orig_node = orig_nodes[inp]
- # We don't pass in get_attr nodes as inputs to the partition, but
- # instead set them as targets and use getattr within the module
- def add_placeholder():
- if keep_original_input_name:
- name = inp
- else:
- nonlocal counter
- name = f"arg_{counter}"
- counter += 1
- placeholder = partition.graph.placeholder(
- name,
- type_expr=orig_nodes[inp].type,
- )
- new_inputs[inp] = None
- return placeholder
- if orig_node.op == "get_attr":
- if not isinstance(orig_node.target, str):
- raise AssertionError(
- f"Expected str target, got {type(orig_node.target)}"
- )
- orig_attr = _get_attr_from_qualname(m, orig_node.target)
- if isinstance(orig_attr, torch.nn.Module):
- placeholder = partition.graph.get_attr(orig_node.target)
- partition.targets[orig_node.target] = orig_attr
- else:
- placeholder = add_placeholder()
- else:
- placeholder = add_placeholder()
- placeholder.meta = orig_nodes[inp].meta.copy()
- partition.environment[orig_nodes[inp]] = placeholder
- partition.inputs = new_inputs
- # Transform nodes and collect targets for partition's submodule
- for node in m.graph.nodes:
- if hasattr(node, "_fx_partition"):
- partition = partitions[node._fx_partition]
- # swap out old graph nodes in kw/args with references to new nodes in this submodule
- environment = partition.environment
- gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
- gathered_kwargs = torch.fx.graph.map_arg(
- node.kwargs, lambda n: environment[n]
- )
- if node.op not in ["call_module", "get_attr"]:
- target = node.target
- else:
- target_attr = _get_attr_from_qualname(m, node.target)
- target = node.target.replace(".", "_")
- partition.targets[target] = target_attr
- # Fill in the passed-in mapping from new qualname to old qualname
- if qualname_map is not None:
- # When creating the split module later, the submodules will have
- # path prefix matching the corresponding partition's submod_name
- qualname = f"{partition.submod_name}.{target}"
- qualname_map[qualname] = node.target
- if not isinstance(gathered_args, tuple):
- raise AssertionError(
- f"Expected tuple for gathered_args, got {type(gathered_args)}"
- )
- if not isinstance(gathered_kwargs, dict):
- raise AssertionError(
- f"Expected dict for gathered_kwargs, got {type(gathered_kwargs)}"
- )
- name = node.name if keep_original_node_name else None
- new_node = partition.graph.create_node(
- op=node.op,
- target=target,
- args=gathered_args,
- kwargs=gathered_kwargs,
- type_expr=node.type,
- name=name,
- )
- new_node.meta = node.meta.copy()
- partition.environment[node] = new_node
- # Exit epilogue
- for regions_mapping in [autocast_regions]:
- for node in reversed(regions_mapping):
- regions = regions_mapping[node]
- if len(regions) == 0:
- raise AssertionError("Expected at least one region")
- # pyrefly: ignore [bad-index, index-error]
- for r in regions[:-1]:
- partition = partitions[str(r)]
- exit_node = autocast_exits[node]
- if exit_node is None:
- raise AssertionError("Missing exit node")
- new_node = partition.graph.create_node(
- op=exit_node.op,
- target=exit_node.target,
- args=(partition.environment[node],),
- kwargs={},
- type_expr=exit_node.type,
- )
- new_node.meta = (
- exit_node.meta.copy()
- ) # is it really a good idea to copy this?
- # original module environment dict mapping node names to nodes
- orig_mod_env: dict[str, Node] = {}
- # Set up values to construct base module
- base_mod_env: dict[str, Node] = {}
- base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
- base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule] = {}
- if not keep_original_order:
- for node in m.graph.nodes:
- base_mod_env, base_mod_attrs = construct_graph(
- node, base_mod_env, base_mod_attrs
- )
- else:
- # Go through the graph to construct the mapping dict
- for node in m.graph.nodes:
- orig_mod_env[node.name] = node
- # Do some things iterating over the partitions in topological order again:
- # 1) Finish off submodule Graphs by setting corresponding outputs
- # 2) Construct GraphModules for each submodule
- # 3) Construct the base graph by emitting calls to those submodules in
- # topological order or original order specified by keep_original_order
- construct_order_partitions = (
- sorted_partitions if not keep_original_order else original_partition_order
- )
- already_constructed_attr_nodes = set()
- # We actually need to insert the placeholder nodes in the original order
- # otherwise graph signature will be wrong.
- original_order = [node for node in m.graph.nodes if node.op == "placeholder"]
- for partition_name in construct_order_partitions:
- partition = partitions[partition_name]
- # Set correct output values
- output_vals = tuple(
- partition.environment[orig_nodes[name]] for name in partition.outputs
- )
- # skip output node generation if there are no output values
- num_output_vals = len(output_vals)
- if num_output_vals == 1:
- partition.graph.output(output_vals[0])
- elif num_output_vals > 1:
- partition.graph.output(output_vals)
- else:
- # Invariant - Graph should always have an output node.
- partition.graph.output(())
- if keep_original_order:
- # first get the attr nodes required by this partition
- orig_mod_attr_nodes: list[Node] = [
- orig_mod_env[key]
- for key in partition.inputs
- if key not in original_order
- ]
- for node in original_order:
- if node in already_constructed_attr_nodes:
- continue # already added this attr to the base graph
- base_mod_env, _based_mod_attrs = construct_graph(
- node, base_mod_env, base_mod_attrs
- )
- already_constructed_attr_nodes.add(node)
- # Construct GraphModule for this partition
- for node in orig_mod_attr_nodes: # type: ignore[attr-defined]
- if node in already_constructed_attr_nodes:
- continue
- base_mod_env, base_mod_attrs = construct_graph(
- node, base_mod_env, base_mod_attrs
- )
- already_constructed_attr_nodes.add(node)
- base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
- partition.targets, partition.graph
- ) # noqa: B950
- # Emit call in base graph to this submodule
- output_val = base_mod_graph.call_module(
- partition.submod_name,
- tuple(base_mod_env[name] for name in partition.inputs),
- )
- num_outputs = len(partition.outputs)
- if num_outputs > 1:
- # Unpack multiple return values from submodule
- output_val_proxy = torch.fx.proxy.Proxy(output_val)
- for i, output_name in enumerate(partition.outputs):
- base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
- elif num_outputs == 1:
- base_mod_env[next(iter(partition.outputs))] = output_val
- # When keep_original_order=True and if the graph doesn't have any
- # `call_function` node then `base_mod_graph`, `base_mod_env` and `base_mod_attrs`
- # are never populated.
- # For this case, we call `construct_graph` here which takes care of updating them.
- if keep_original_order and not base_mod_env:
- for node in m.graph.nodes:
- base_mod_env, base_mod_attrs = construct_graph(
- node, base_mod_env, base_mod_attrs
- )
- # Add output node to `base_mod_graph` (i.e. the split graph) which will be returned.
- for node in m.graph.nodes:
- if node.op == "output":
- base_mod_graph.output(
- torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
- ) # noqa: B950
- ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
- log.debug(
- "%s",
- lazy_format_graph_code("post split_module", ret, colored=True),
- )
- return ret
|