partitioner_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # mypy: allow-untyped-defs
  2. from enum import Enum
  3. from typing import NamedTuple
  4. from torch.fx.node import map_arg, Node
  5. class Partition:
  6. """Partition class contains all the information about an individual partition.
  7. It also provides necessary methods for manipulation the partition.
  8. """
  9. def __init__(self, partition_id: int) -> None:
  10. self.nodes: set[Node] = set()
  11. self.partition_id = partition_id
  12. self.parents: set[Partition] = set()
  13. self.children: set[Partition] = set()
  14. self.bfs_level: int = -1
  15. self.used_mem_bytes: int = 0
  16. self.logical_device_ids: list[int] = []
  17. def __str__(self):
  18. return str(self.partition_id)
  19. def recalculate_mem_size(self):
  20. self.used_mem_bytes = 0
  21. for node in self.nodes:
  22. self.used_mem_bytes += get_extra_size_of(node, self.nodes)
  23. def add_node(self, node):
  24. input_nodes: dict[Node, None] = {}
  25. map_arg(node.args, input_nodes.setdefault)
  26. map_arg(node.kwargs, input_nodes.setdefault)
  27. # Add current node's input nodes if they are placeholder or constants
  28. for n in input_nodes:
  29. if n.op in {"placeholder", "get_attr"}:
  30. self.nodes.add(n)
  31. self.nodes.add(node)
  32. self.recalculate_mem_size()
  33. def remove_node(self, node):
  34. # Remove a node only if the node is in the partition
  35. if node in self.nodes:
  36. self.nodes.remove(node)
  37. # Collect the node's input nodes
  38. input_nodes: dict[Node, None] = {}
  39. map_arg(node.args, input_nodes.setdefault)
  40. map_arg(node.kwargs, input_nodes.setdefault)
  41. # Check if an input node is a placeholder or get_attr,
  42. # and this input node is not used by some other nodes in this partition,
  43. # the remove this input node
  44. for input_node in input_nodes:
  45. if all(
  46. n not in self.nodes for n in input_node.users
  47. ) and input_node.op in {"placeholder", "get_attr"}:
  48. self.nodes.remove(input_node)
  49. self.recalculate_mem_size()
  50. class Device(NamedTuple):
  51. name: str
  52. available_mem_bytes: int
  53. logical_id: int
  54. class NodeLatency(NamedTuple):
  55. # Latency due to the memory bandwidth
  56. mem_latency_sec: float
  57. # Latency due to the computation
  58. computer_latency_sec: float
  59. class PartitionLatency(NamedTuple):
  60. # Sum of all nodes' memory latency on the critical path
  61. mem_latency_sec: float
  62. # Sum of all nodes' compute latency on the critical path
  63. computer_latency_sec: float
  64. # Latency of the critical path
  65. overall_latency_sec: float
  66. class PartitionMode(Enum):
  67. size_based = 0
  68. sparse_nn = 1
  69. cost_aware = 2
  70. kl_based = 3
  71. aot_based = 4
  72. class PartitionerConfig(NamedTuple):
  73. devices: list[Device]
  74. mode: PartitionMode = PartitionMode.size_based
  75. transfer_rate_bytes_per_sec: float = 0.0
  76. node_to_latency_mapping: dict[Node, NodeLatency] = {}
  77. node_to_partition_mapping: dict[Node, int] = {}
  78. partition_to_logical_device_mapping: dict[int, list[int]] = {}
  79. # Saturate host by replicating partitions to the remaining idle devices.
  80. saturate_host: bool = False
  81. def get_extra_size_of(node: Node, nodes: set[Node]) -> int:
  82. """Given a node and a set of nodes,
  83. this function return the extra size that needed
  84. if this node is included in this set.
  85. """
  86. # Find all its input nodes
  87. input_nodes: dict[Node, None] = {}
  88. map_arg(node.args, input_nodes.setdefault)
  89. map_arg(node.kwargs, input_nodes.setdefault)
  90. # Calculate total size of related nodes
  91. total_size_of_input_nodes = 0
  92. for n in input_nodes:
  93. # Make sure this node hasn't been in this set yet
  94. if n not in nodes:
  95. size_bytes = getattr(n, "size_bytes", None)
  96. if size_bytes:
  97. total_size_of_input_nodes += size_bytes.output_size
  98. else:
  99. raise RuntimeError("node has no size_bytes attr")
  100. # Don't forget the op node itself
  101. size_bytes = getattr(node, "size_bytes", None)
  102. if size_bytes:
  103. total_size_of_input_nodes += size_bytes.total_size
  104. else:
  105. raise RuntimeError("node has no size_bytes attr")
  106. return total_size_of_input_nodes
  107. def get_latency_of_one_partition(
  108. partition: Partition, node_to_latency_mapping: dict[Node, NodeLatency]
  109. ) -> PartitionLatency:
  110. """Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
  111. def get_top_nodes(partition: Partition) -> list[Node]:
  112. """Given a partition, return a list of nodes on the top bfs level"""
  113. top_nodes: list[Node] = []
  114. for node in partition.nodes:
  115. # Skip placeholder and get_attr nodes
  116. if node.op in {"placeholder", "get_attr"}:
  117. continue
  118. input_nodes: dict[Node, None] = {}
  119. map_arg(node.args, input_nodes.setdefault)
  120. map_arg(node.kwargs, input_nodes.setdefault)
  121. # If a node has no input nodes in this partition,
  122. # or its input nodes in this partition are placeholders and get_attrs
  123. # this node is on the top bfs level in this partition
  124. if not any(
  125. n in partition.nodes and n.op not in {"placeholder", "get_attr"}
  126. for n in input_nodes
  127. ):
  128. top_nodes.append(node)
  129. return top_nodes
  130. def dfs_helper(node: Node, partition_latency) -> PartitionLatency:
  131. """Given a top node of a partition, this function returns
  132. the latency of the critical path in the partition
  133. """
  134. node_latency = node_to_latency_mapping[node]
  135. # Calculate the current overall latency of the partition
  136. overall_latency_sec = partition_latency.overall_latency_sec + max(
  137. node_latency.computer_latency_sec, node_latency.mem_latency_sec
  138. )
  139. # Update the mem latency of this path
  140. mem_latency_sec = (
  141. partition_latency.mem_latency_sec + node_latency.mem_latency_sec
  142. )
  143. # Update the compute latency of this path
  144. computer_latency_sec = (
  145. partition_latency.computer_latency_sec + node_latency.computer_latency_sec
  146. )
  147. # Get all users of this node that are in this partition
  148. users = set(node.users).intersection(partition.nodes)
  149. if users:
  150. max_latency = PartitionLatency(
  151. mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
  152. )
  153. for n in users:
  154. # Get new partition latency recursively
  155. new_partition_latency = dfs_helper(
  156. n,
  157. PartitionLatency(
  158. mem_latency_sec, computer_latency_sec, overall_latency_sec
  159. ),
  160. )
  161. if (
  162. new_partition_latency.overall_latency_sec
  163. > max_latency.overall_latency_sec
  164. ):
  165. max_latency = new_partition_latency
  166. return max_latency
  167. # If there is no user, the node is at bottom of the partition
  168. return PartitionLatency(
  169. mem_latency_sec, computer_latency_sec, overall_latency_sec
  170. )
  171. # Main part starts
  172. # Get all top level nodes of this partition
  173. top_nodes = get_top_nodes(partition)
  174. critical_path_latency = PartitionLatency(
  175. mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
  176. )
  177. # Go through all top nodes and find the largest latency (critical pass latency)
  178. for node in top_nodes:
  179. partition_latency = dfs_helper(
  180. node,
  181. PartitionLatency(
  182. mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
  183. ),
  184. )
  185. if (
  186. partition_latency.overall_latency_sec
  187. > critical_path_latency.overall_latency_sec
  188. ):
  189. critical_path_latency = partition_latency
  190. return critical_path_latency
  191. def get_partition_to_latency_mapping(
  192. partitions: list[Partition], node_to_latency_mapping: dict[Node, NodeLatency]
  193. ) -> dict[Partition, PartitionLatency]:
  194. """Given all the partitions and node_to_latency_mapping dictionary,
  195. return a mapping dictionary of each partition to its overall latency
  196. """
  197. partition_to_latency_mapping: dict[Partition, PartitionLatency] = {}
  198. # Go through each partition and get its latency
  199. for partition in partitions:
  200. partition_latency = get_latency_of_one_partition(
  201. partition, node_to_latency_mapping
  202. )
  203. partition_to_latency_mapping[partition] = partition_latency
  204. return partition_to_latency_mapping
  205. def get_comm_latency_between(
  206. parent_partition: Partition,
  207. child_partition: Partition,
  208. transfer_rate_bytes_per_sec: float,
  209. ):
  210. """Given two partitions (parent and child),
  211. calculate the communication latency between the two.
  212. """
  213. # If two partitions are on the same device, the comm latency is 0.
  214. if (
  215. parent_partition.logical_device_ids != []
  216. and child_partition.logical_device_ids != []
  217. and parent_partition.logical_device_ids == child_partition.logical_device_ids
  218. ):
  219. return 0.0
  220. # Keep tracking the communication size between parent and child
  221. comm_size = 0
  222. # Keep tracking all the counted node
  223. visited_nodes = set()
  224. # Go through all nodes in the child partition
  225. # If a node has input nodes from the parent partition,
  226. # the output size of those input nodes will be counted
  227. # and added to comm_size
  228. for node in child_partition.nodes:
  229. input_nodes: dict[Node, None] = {}
  230. map_arg(node.args, input_nodes.setdefault)
  231. map_arg(node.kwargs, input_nodes.setdefault)
  232. for n in input_nodes:
  233. if n in parent_partition.nodes and n not in visited_nodes:
  234. size_bytes = getattr(n, "size_bytes", None)
  235. if size_bytes is not None:
  236. comm_size += size_bytes.output_size
  237. visited_nodes.add(n)
  238. return comm_size / transfer_rate_bytes_per_sec
  239. def get_latency_of_partitioned_graph(
  240. partitions: list[Partition],
  241. partition_to_latency_mapping: dict[Partition, PartitionLatency],
  242. transfer_rate_bytes_per_sec: float,
  243. ):
  244. """Given all partitions in a graph, find the critical path among all partitions
  245. and return its latency as the latency of the whole graph
  246. """
  247. def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
  248. """This function helps to recursively get the latency of a path of partitions"""
  249. # Update latency by adding current partition's latency
  250. latency_so_far_sec += partition_to_latency_mapping[
  251. partition
  252. ].overall_latency_sec
  253. if partition.children:
  254. max_latency_sec = 0.0
  255. for child in partition.children:
  256. # Calculate latency between
  257. comm_latency_sec = get_comm_latency_between(
  258. partition, child, transfer_rate_bytes_per_sec
  259. )
  260. new_latency_sec = dfs_helper(
  261. child, latency_so_far_sec + comm_latency_sec
  262. )
  263. if new_latency_sec > max_latency_sec:
  264. max_latency_sec = new_latency_sec
  265. return max_latency_sec
  266. return latency_so_far_sec
  267. def get_top_partitions(partitions: list[Partition]) -> list[Partition]:
  268. """This function is to return all the partitions without parents
  269. as the starting points of all the paths
  270. """
  271. # If a partition has no parents, then it is a top partition
  272. top_partitions = [
  273. partition for partition in partitions if len(partition.parents) == 0
  274. ]
  275. return top_partitions
  276. top_partitions = get_top_partitions(partitions)
  277. critical_path_latency_sec = 0.0
  278. for partition in top_partitions:
  279. latency_sec = dfs_helper(partition, 0.0)
  280. if latency_sec > critical_path_latency_sec:
  281. critical_path_latency_sec = latency_sec
  282. return critical_path_latency_sec