partitioner.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import itertools
  4. import logging
  5. import operator
  6. from collections.abc import Iterable, Sequence
  7. from typing import Optional
  8. from torch.fx.graph_module import GraphModule
  9. from torch.fx.node import _get_qualified_name, Node
  10. from torch.fx.passes.operator_support import OperatorSupportBase
  11. from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
  12. logger = logging.getLogger(__name__)
  13. logger.setLevel(logging.WARNING)
  14. class Partition:
  15. def __init__(
  16. self,
  17. id: Optional[int] = None,
  18. nodes: Optional[Iterable[Node]] = None,
  19. node_orders: Optional[Iterable[int]] = None,
  20. ):
  21. self.id = id
  22. self.nodes: dict[Node, Optional[int]] = {}
  23. if nodes is not None:
  24. if node_orders is None:
  25. self.nodes = dict.fromkeys(nodes, None)
  26. else:
  27. nodes_list = list(nodes)
  28. node_orders_list = list(node_orders)
  29. if len(nodes_list) != len(node_orders_list):
  30. raise AssertionError(
  31. "nodes and node_orders must have the same length"
  32. )
  33. self.nodes = dict(zip(nodes_list, node_orders_list))
  34. def __repr__(self) -> str:
  35. return str(self.nodes)
  36. def add_node(self, node: Node, node_order: Optional[int] = None):
  37. self.nodes.update({node: node_order})
  38. def remove_node(self, node: Node):
  39. del self.nodes[node]
  40. def size(self):
  41. return len(self.nodes)
  42. class _DependencyViewer:
  43. def __init__(self, graph_module: GraphModule):
  44. self.downstreams = collections.defaultdict(set)
  45. for node in reversed(graph_module.graph.nodes):
  46. for output_node in node.users:
  47. # add output_node and output_node's downstream dependency
  48. self.downstreams[node].add(output_node)
  49. self.downstreams[node].update(self.downstreams[output_node])
  50. def downstreams_of(self, node: Node) -> set[Node]:
  51. return self.downstreams[node]
  52. class CapabilityBasedPartitioner:
  53. def __init__(
  54. self,
  55. graph_module: GraphModule,
  56. operator_support: OperatorSupportBase,
  57. allows_single_node_partition: bool = False,
  58. non_compute_ops: Optional[Sequence[str]] = None,
  59. allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
  60. ) -> None:
  61. self.graph_module = graph_module
  62. self.operator_support = operator_support
  63. self.allows_single_node_partition = allows_single_node_partition
  64. self.non_compute_ops = non_compute_ops if non_compute_ops is not None else []
  65. self.allowed_single_node_partition_ops = (
  66. allowed_single_node_partition_ops
  67. if allowed_single_node_partition_ops is not None
  68. else []
  69. )
  70. self.dependency_viewer = _DependencyViewer(graph_module)
  71. def _is_node_supported(self, node: Node) -> bool:
  72. return self.operator_support.is_node_supported(
  73. dict(self.graph_module.named_modules()), node
  74. )
  75. def propose_partitions(self) -> list[Partition]:
  76. # partition_map is a mapping from partition id to a set of partition id's.
  77. # The value set contains all the partition ids that can be reached by doing a
  78. # DFS starting from the partition id in the key.
  79. partition_map: dict[int, set] = collections.defaultdict(set)
  80. # assumptions: nodes in candidate list is sorted in topological order
  81. assignment: dict[Node, int] = {} # mapping from node to partition_id
  82. partitions_by_id: dict[
  83. int, Partition
  84. ] = {} # mapping from partition_id to partition
  85. nodes_order: dict[
  86. Node, int
  87. ] = {} # mapping from nodes to reversed topological order
  88. partitions_order: dict[
  89. int, int
  90. ] = {} # mapping from partition_id to minimum topo order of nodes in partition
  91. partition_users: dict[
  92. int, set
  93. ] = {} # mapping from partition_id to partition users
  94. new_partition_id = itertools.count()
  95. # try to merge partition other_id into partition self_id
  96. # merge only happens if the end graph doesn't contain cyclic dependency
  97. # returns `True` when merge happens, `False` otherwise.
  98. def maybe_merge_partition(self_id: int, other_id: int):
  99. # merged_nodes is the union of nodes in two partition to-be-merged
  100. self_nodes = partitions_by_id[self_id].nodes
  101. other_nodes = partitions_by_id[other_id].nodes
  102. def dfs_iter_find_cycle(all_user_nodes: set[Node]):
  103. for user_node in all_user_nodes:
  104. visited_partition_ids = set()
  105. for path_node in self.dependency_viewer.downstreams_of(user_node):
  106. # If any of the nodes in the dfs path of this node are in the merged_nodes
  107. # list then there is a cycle in the graph.
  108. if path_node in self_nodes or path_node in other_nodes:
  109. return True
  110. # If any of the nodes in the dfs path of this node are in the assignment
  111. # map then we have to make sure that the partitions that these nodes belong
  112. # to do not form a cycle with the current partitions being merged. This means
  113. # iterating through all the nodes in all the parititons that are traversed in
  114. # the dfs path and checking if they are in the merged_nodes list.
  115. if path_node in assignment:
  116. partition_id = assignment[path_node]
  117. # If the partition id has already been visited then we know that it doesn't
  118. # form a cycle with the current partitions being merged.
  119. if partition_id in visited_partition_ids:
  120. continue
  121. p_map = partition_map[partition_id]
  122. if self_id in p_map or other_id in p_map:
  123. return True
  124. visited_partition_ids.add(partition_id)
  125. return False
  126. # find new partition users if merge.
  127. all_user_nodes = partition_users[self_id] | partition_users[other_id]
  128. all_user_nodes.difference_update(other_nodes, self_nodes)
  129. # check if merge would create cyclic dependency.
  130. if dfs_iter_find_cycle(all_user_nodes):
  131. # return false indicating cyclic dependency found and
  132. # merge is aborted
  133. return self_id, False
  134. # merge the smaller partition into the larger.
  135. merge_id, removed_id = self_id, other_id
  136. if len(self_nodes) < len(other_nodes):
  137. merge_id, removed_id = removed_id, merge_id
  138. # no cyclic dependency found, move forward with the merge
  139. # updating partition nodes
  140. partitions_by_id[merge_id].nodes.update(partitions_by_id[removed_id].nodes)
  141. # updating assignment map
  142. for node in partitions_by_id[removed_id].nodes:
  143. assignment[node] = merge_id
  144. # delete other partition
  145. del partitions_by_id[removed_id]
  146. partitions_order[merge_id] = min(
  147. partitions_order[merge_id], partitions_order[removed_id]
  148. )
  149. del partitions_order[removed_id]
  150. partition_map[merge_id] = partition_map[merge_id].union(
  151. partition_map[removed_id]
  152. )
  153. del partition_map[removed_id]
  154. partition_users[merge_id] = all_user_nodes
  155. del partition_users[removed_id]
  156. return merge_id, True
  157. def merge_single_node(node: Node, node_order: Optional[int], id: Optional[int]):
  158. def _update_partition_map(node: Node, id: int):
  159. # Iterate through all the users of this node and update the partition map to indicate
  160. # that there is a path from the partition id of this node to the target partition id.
  161. for user_node in node.users:
  162. target_id = assignment.get(user_node)
  163. if target_id is not None:
  164. partition_map[id].add(target_id)
  165. partition_map[id].update(partition_map[target_id])
  166. if node in assignment:
  167. partitions_by_id[assignment[node]].remove_node(node)
  168. if id is None:
  169. assignment.pop(node)
  170. elif id not in partitions_by_id:
  171. assignment[node] = id
  172. if node_order is None:
  173. raise AssertionError("node_order is required for new partitions")
  174. partitions_by_id[id] = Partition(
  175. id=id, nodes=[node], node_orders=[node_order]
  176. )
  177. partition_users[id] = set(node.users)
  178. _update_partition_map(node, id)
  179. else:
  180. assignment[node] = id
  181. partitions_by_id[id].add_node(node, node_order)
  182. logger.debug("Proposing partitions...")
  183. for node_order, node in enumerate(reversed(self.graph_module.graph.nodes)):
  184. # use Dict as an ordered set to ensure deterministic partitioning result, don't care value
  185. merge_candidates: dict[int, None] = {}
  186. # Note a limited horizontal fusion is enabled:
  187. # when `node` is not supported, the code below attempts to fuse consumer of `node`.
  188. #
  189. # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
  190. # the fusion by adding an `else` block here to skip horizontal fusion.
  191. if self._is_node_supported(node) and node not in assignment:
  192. partition_id = next(new_partition_id)
  193. nodes_order[node] = partition_id
  194. partitions_order[partition_id] = partition_id
  195. merge_single_node(node, node_order, partition_id)
  196. merge_candidates[partition_id] = None
  197. # merge all possible partitions
  198. for partition_id, _ in sorted(
  199. partitions_order.items(), key=operator.itemgetter(1)
  200. ):
  201. merge_candidates[partition_id] = None
  202. merge_candidates_list = list(merge_candidates.keys())
  203. if len(merge_candidates_list) > 1:
  204. self_id = merge_candidates_list[0]
  205. for other_id in merge_candidates_list[1:]:
  206. # note: merge partitions if it doesn't create cyclic dependency
  207. # in the graph, otherwise, this is a no-op
  208. self_id, _ = maybe_merge_partition(self_id, other_id)
  209. # sort partition nodes based on descending node order
  210. for partition in partitions_by_id.values():
  211. partition.nodes = dict(
  212. sorted(
  213. partition.nodes.items(), key=operator.itemgetter(1), reverse=True
  214. )
  215. )
  216. # post processing to re-assign "getitem" nodes into upstream partition
  217. # Run iteratively until no more changes, to handle nested getitem chains
  218. # (e.g., getitem_619 = getitem_618[0] where getitem_618 = with_effects_167[1])
  219. logger.debug("Reassigning getitem nodes to its producer node's partition...")
  220. while True:
  221. nodes_reassignment: dict[Node, int] = {}
  222. for node in self.graph_module.graph.nodes:
  223. is_tuple_output = True
  224. for user in node.users:
  225. if (
  226. user.op != "call_function"
  227. or _get_qualified_name(user.target) != "_operator.getitem"
  228. ): # type: ignore[arg-type]
  229. is_tuple_output = False
  230. break
  231. # node has tuple outputs, re-assign all following getitem node into node's partition
  232. if is_tuple_output:
  233. id = assignment.get(node) # type: ignore[arg-type]
  234. for user in node.users:
  235. if assignment.get(user) != id: # type: ignore[arg-type]
  236. nodes_reassignment[user] = id # type: ignore[assignment]
  237. # no more re-assignments
  238. if not nodes_reassignment:
  239. break
  240. for node, id in nodes_reassignment.items():
  241. merge_single_node(node, None, id)
  242. # filter out single node partitions
  243. if not self.allows_single_node_partition:
  244. logger.debug("Filtering out single node partitions...")
  245. default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
  246. non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
  247. partitions_to_remove: list[int] = []
  248. for id, partition in partitions_by_id.items():
  249. compute_node_count = 0
  250. for node in partition.nodes:
  251. if node.op == "call_function":
  252. if not callable(node.target):
  253. raise AssertionError(
  254. f"Expected callable target, got {type(node.target)}"
  255. )
  256. if _get_qualified_name(node.target) not in non_compute_ops:
  257. compute_node_count += 1
  258. if (
  259. _get_qualified_name(node.target)
  260. in self.allowed_single_node_partition_ops
  261. ):
  262. compute_node_count += 1
  263. if compute_node_count <= 1:
  264. partitions_to_remove.append(id)
  265. for id in partitions_to_remove:
  266. del partitions_by_id[id]
  267. logger.debug("Partitions proposed:")
  268. for id, partition in partitions_by_id.items():
  269. logger.debug(
  270. "partition #%s: %s", id, [node.name for node in partition.nodes]
  271. )
  272. return [
  273. partition for partition in partitions_by_id.values() if partition.size() > 0
  274. ]
  275. def fuse_partitions(
  276. self, partitions: list[Partition], prefix: str = "fused_"
  277. ) -> GraphModule:
  278. logger.debug("Fusing partitions...")
  279. # fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ]
  280. return fuse_by_partitions(
  281. self.graph_module,
  282. [partition.nodes for partition in partitions],
  283. prefix=prefix,
  284. )
  285. # remove non-compute-ops that sits at the boundary of a partition.
  286. def remove_bookend_non_compute_ops(self, partitions: list[Partition]):
  287. non_compute_ops = set(self.non_compute_ops)
  288. def is_non_compute_node(node: Node):
  289. return (
  290. node.op == "call_function"
  291. and _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type]
  292. )
  293. # cache transparent nodes
  294. transparent_input_nodes: dict[Node, bool] = {}
  295. transparent_output_nodes: dict[Node, bool] = {}
  296. def is_transparent_input_node(
  297. node: Node, partition: set[Node], removed_nodes: set[Node]
  298. ):
  299. if (
  300. node.op == "placeholder"
  301. or (node not in partition)
  302. or (node in removed_nodes)
  303. ):
  304. return True
  305. if node in transparent_input_nodes:
  306. return transparent_input_nodes[node]
  307. if is_non_compute_node(node):
  308. for input_n in node.all_input_nodes:
  309. if not is_transparent_input_node(input_n, partition, removed_nodes):
  310. transparent_input_nodes[node] = False
  311. return False
  312. transparent_input_nodes[node] = True
  313. return True
  314. transparent_input_nodes[node] = False
  315. return False
  316. def is_transparent_output_node(
  317. node: Node, partition: set[Node], removed_nodes: set[Node]
  318. ):
  319. if (
  320. node.op == "placeholder"
  321. or (node not in partition)
  322. or (node in removed_nodes)
  323. ):
  324. return True
  325. if node in transparent_output_nodes:
  326. return transparent_output_nodes[node]
  327. if is_non_compute_node(node):
  328. for output_n in node.users:
  329. if not is_transparent_output_node(
  330. output_n, partition, removed_nodes
  331. ):
  332. transparent_output_nodes[node] = False
  333. return False
  334. transparent_output_nodes[node] = True
  335. return True
  336. transparent_output_nodes[node] = False
  337. return False
  338. for partition in partitions:
  339. # Note it's ok to use `set` here, since we are only query if a node
  340. # has been removed. We are NEVER going to iterate on nodes inside
  341. # the set.
  342. remove_node: set[Node] = set()
  343. for node in partition.nodes:
  344. if is_non_compute_node(node) and (
  345. is_transparent_input_node(node, set(partition.nodes), remove_node)
  346. or is_transparent_output_node(
  347. node, set(partition.nodes), remove_node
  348. )
  349. ):
  350. remove_node.add(node)
  351. if len(remove_node) != 0:
  352. for node in remove_node:
  353. partition.nodes.pop(node, None)
  354. def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule:
  355. partitions = self.propose_partitions()
  356. fused_gm = self.fuse_partitions(partitions, prefix=prefix)
  357. return fused_gm