graph_utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from collections import deque
  2. from typing import Any, Optional
  3. import torch
  4. from torch.fx import Graph, map_arg, Node
  5. from torch.utils._ordered_set import OrderedSet
  6. from torch.utils._pytree import tree_flatten
  7. # flattens with support for slices
  8. # Note: a better way to do this would
  9. # be register/unregister slices as pytree nodes
  10. # but there is no unregister API in the pytorch
  11. # pytree impl
  12. def _get_flat_args(
  13. node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
  14. ) -> list[Node]:
  15. args = list[Any]()
  16. map_arg((node.args, node.kwargs), args.append)
  17. if node in node_to_additional_deps:
  18. args.extend(node_to_additional_deps[node])
  19. return args
  20. def _get_flat_args_unique(
  21. node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
  22. ) -> OrderedSet[Node]:
  23. args = OrderedSet[Node]()
  24. map_arg((node.args, node.kwargs), args.add)
  25. if node in node_to_additional_deps:
  26. args.update(node_to_additional_deps[node])
  27. return args
  28. def _detect_cycles(
  29. graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
  30. ) -> str:
  31. current_path: deque[Node] = deque()
  32. current_path_set: set[Node] = set()
  33. pending: deque[tuple[Node, Node]] = deque()
  34. def add_to_current_path(node: Node) -> None:
  35. current_path.append(node)
  36. current_path_set.add(node)
  37. def pop_current_path() -> None:
  38. node = current_path.pop()
  39. current_path_set.remove(node)
  40. def current_path_head() -> Node:
  41. return current_path[-1]
  42. for origin in graph.find_nodes(op="output"):
  43. current_path.clear()
  44. current_path_set.clear()
  45. add_to_current_path(origin)
  46. for child in _get_flat_args_unique(origin, node_to_additional_deps):
  47. pending.append((child, origin))
  48. while pending:
  49. cur_node, parent = pending.pop()
  50. # handle backtracking
  51. while current_path and current_path_head() != parent:
  52. pop_current_path()
  53. if not isinstance(cur_node, Node):
  54. continue
  55. if cur_node in current_path_set:
  56. current_path.append(cur_node)
  57. return f"cycle detected in path: {current_path}"
  58. add_to_current_path(cur_node)
  59. for child in _get_flat_args_unique(cur_node, node_to_additional_deps):
  60. pending.append((child, cur_node))
  61. return "no cycle detected"
  62. def _graph_device_type(graph: Optional[Graph]) -> str:
  63. if graph is None:
  64. return "cpu"
  65. def _device_type(x: Any) -> str:
  66. if isinstance(x, torch.device):
  67. return x.type
  68. if isinstance(x, torch.Tensor):
  69. return x.device.type
  70. return "cpu"
  71. def _flatten_meta(node: Node, key: str) -> list[Any]:
  72. if key not in node.meta:
  73. return []
  74. flat, _ = tree_flatten(node.meta[key])
  75. return flat
  76. for node in graph.nodes:
  77. for key in ("val", "example_value"):
  78. for obj in _flatten_meta(node, key):
  79. return _device_type(obj)
  80. # Check for device conversions
  81. if node.op == "call_method":
  82. for gpu in ["cuda", "xpu"]:
  83. if node.target == gpu:
  84. return gpu
  85. if node.target == "to" and gpu in node.args:
  86. return gpu
  87. # Check args/kwargs for non-CPU device specs
  88. flat_args, _ = tree_flatten((node.args, node.kwargs))
  89. for obj in flat_args:
  90. return _device_type(obj)
  91. return "cpu"