cudagraphs.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # mypy: allow-untyped-defs
  2. import operator
  3. import torch
  4. from torch.fx.passes.fake_tensor_prop import FakeTensorProp
  5. from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
  6. from torch.fx.passes.operator_support import OperatorSupport
  7. from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
  8. from torch.utils import _pytree as pytree
  9. class CudaGraphsSupport(OperatorSupport):
  10. # TODO: why is submodules passed here
  11. def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
  12. if node.op not in CALLABLE_NODE_OPS:
  13. return False
  14. if node.target is torch.ops.aten.embedding_dense_backward.default:
  15. return False
  16. if node.target is operator.getitem:
  17. return True
  18. found_not_cuda = False
  19. def meta_fk(meta):
  20. return meta["val"] if "val" in meta else meta["fake_result"]
  21. def find_not_cuda(t):
  22. nonlocal found_not_cuda
  23. if isinstance(t, torch.Tensor) and t.device.type != "cuda":
  24. found_not_cuda = True
  25. for n in node.all_input_nodes:
  26. pytree.tree_map_(find_not_cuda, meta_fk(n.meta))
  27. pytree.tree_map_(find_not_cuda, meta_fk(node.meta))
  28. # NB: factory function is accounted for because the result would be
  29. # cpu or cuda
  30. return not found_not_cuda
  31. def partition_cudagraphs(gm, inputs):
  32. """
  33. Partition an FX graph into sub-GraphModules that can be validly run under
  34. CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations
  35. must involve CUDA tensors only/
  36. """
  37. FakeTensorProp(gm).propagate(*inputs)
  38. supported_ops = CudaGraphsSupport()
  39. # TODO: single node partition may be wrong due to the pessimization
  40. # from copying in and out the data. Check in benchmarks, perhaps
  41. partitioner = CapabilityBasedPartitioner(
  42. gm, supported_ops, allows_single_node_partition=True
  43. )
  44. partitions = partitioner.propose_partitions()
  45. fused_graph = partitioner.fuse_partitions(partitions)
  46. return fused_graph