| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- from collections import deque
- from typing import Any, Optional
- import torch
- from torch.fx import Graph, map_arg, Node
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._pytree import tree_flatten
- # flattens with support for slices
- # Note: a better way to do this would
- # be register/unregister slices as pytree nodes
- # but there is no unregister API in the pytorch
- # pytree impl
- def _get_flat_args(
- node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
- ) -> list[Node]:
- args = list[Any]()
- map_arg((node.args, node.kwargs), args.append)
- if node in node_to_additional_deps:
- args.extend(node_to_additional_deps[node])
- return args
- def _get_flat_args_unique(
- node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
- ) -> OrderedSet[Node]:
- args = OrderedSet[Node]()
- map_arg((node.args, node.kwargs), args.add)
- if node in node_to_additional_deps:
- args.update(node_to_additional_deps[node])
- return args
- def _detect_cycles(
- graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
- ) -> str:
- current_path: deque[Node] = deque()
- current_path_set: set[Node] = set()
- pending: deque[tuple[Node, Node]] = deque()
- def add_to_current_path(node: Node) -> None:
- current_path.append(node)
- current_path_set.add(node)
- def pop_current_path() -> None:
- node = current_path.pop()
- current_path_set.remove(node)
- def current_path_head() -> Node:
- return current_path[-1]
- for origin in graph.find_nodes(op="output"):
- current_path.clear()
- current_path_set.clear()
- add_to_current_path(origin)
- for child in _get_flat_args_unique(origin, node_to_additional_deps):
- pending.append((child, origin))
- while pending:
- cur_node, parent = pending.pop()
- # handle backtracking
- while current_path and current_path_head() != parent:
- pop_current_path()
- if not isinstance(cur_node, Node):
- continue
- if cur_node in current_path_set:
- current_path.append(cur_node)
- return f"cycle detected in path: {current_path}"
- add_to_current_path(cur_node)
- for child in _get_flat_args_unique(cur_node, node_to_additional_deps):
- pending.append((child, cur_node))
- return "no cycle detected"
- def _graph_device_type(graph: Optional[Graph]) -> str:
- if graph is None:
- return "cpu"
- def _device_type(x: Any) -> str:
- if isinstance(x, torch.device):
- return x.type
- if isinstance(x, torch.Tensor):
- return x.device.type
- return "cpu"
- def _flatten_meta(node: Node, key: str) -> list[Any]:
- if key not in node.meta:
- return []
- flat, _ = tree_flatten(node.meta[key])
- return flat
- for node in graph.nodes:
- for key in ("val", "example_value"):
- for obj in _flatten_meta(node, key):
- return _device_type(obj)
- # Check for device conversions
- if node.op == "call_method":
- for gpu in ["cuda", "xpu"]:
- if node.target == gpu:
- return gpu
- if node.target == "to" and gpu in node.args:
- return gpu
- # Check args/kwargs for non-CPU device specs
- flat_args, _ = tree_flatten((node.args, node.kwargs))
- for obj in flat_args:
- return _device_type(obj)
- return "cpu"
|