graph_manipulation.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # mypy: allow-untyped-defs
  2. from typing import Any, NamedTuple, Optional
  3. import torch
  4. from torch.fx._compatibility import compatibility
  5. from torch.fx.graph import Graph
  6. from torch.fx.graph_module import GraphModule
  7. from torch.fx.node import map_arg, Node, Target
  8. from torch.fx.passes.shape_prop import ShapeProp
  9. __all__ = [
  10. "replace_target_nodes_with",
  11. "size_bytes",
  12. "get_size_of_all_nodes",
  13. "get_tensor_meta",
  14. "get_size_of_node",
  15. ]
  16. @compatibility(is_backward_compatible=False)
  17. def replace_target_nodes_with(
  18. fx_module: GraphModule,
  19. old_op: str,
  20. old_target: Target,
  21. new_op: str,
  22. new_target: Target,
  23. ):
  24. """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
  25. and updates them to match the new op code and target"""
  26. new_graph = Graph()
  27. val_map: dict[Node, Node] = {}
  28. for node in fx_module.graph.nodes:
  29. if node.op == old_op and node.target == old_target:
  30. args = map_arg(node.args, lambda n: val_map[n])
  31. kwargs = map_arg(node.kwargs, lambda n: val_map[n])
  32. if not isinstance(args, tuple):
  33. raise AssertionError(f"Expected tuple, got {type(args)}")
  34. if not isinstance(kwargs, dict):
  35. raise AssertionError(f"Expected dict, got {type(kwargs)}")
  36. val_map[node] = new_graph.create_node(
  37. new_op, new_target, args, kwargs, node.name
  38. )
  39. else:
  40. val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
  41. fx_module.graph = new_graph
  42. @compatibility(is_backward_compatible=False)
  43. class size_bytes(NamedTuple):
  44. output_size: int
  45. total_size: int
  46. @compatibility(is_backward_compatible=False)
  47. def get_size_of_all_nodes(
  48. fx_module: GraphModule, args: Optional[list[torch.Tensor]] = None
  49. ) -> None:
  50. """Given a fx graph module, update each node with its total size (weights + bias + output)
  51. and its output_size(output). For a non-module node, the total size is the output size.
  52. return total size"""
  53. if args is not None:
  54. # Mark shape and dtype for each node (node.shape and node.dtype)
  55. ShapeProp(fx_module).propagate(*args)
  56. # Calculate the total size of the whole fx graph
  57. for node in fx_module.graph.nodes:
  58. if node.op == "output":
  59. break
  60. node.size_bytes = get_size_of_node(fx_module, node)
  61. return
  62. @compatibility(is_backward_compatible=False)
  63. def get_tensor_meta(node: Node) -> Any:
  64. tensor_meta = node.meta.get("tensor_meta")
  65. if not tensor_meta:
  66. raise RuntimeError(
  67. f"Node {node} has no tensor metadata associated with it! "
  68. f"Check that shape propagation has run."
  69. )
  70. return tensor_meta
  71. @compatibility(is_backward_compatible=False)
  72. def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
  73. """Given a node with node.dtype and node.shape, return its total size and its output size.
  74. total_size = weights + bias + output_size
  75. """
  76. # Total num of elements
  77. total_num_of_elems = 0
  78. # For a module, consider all parameters
  79. if node.op == "call_module":
  80. submodule_dict = dict(fx_module.named_modules())
  81. submodule = submodule_dict[node.target]
  82. parameters = submodule.named_parameters()
  83. # Parameters are named tuples
  84. for _name, p in parameters:
  85. total_num_of_elems += p.numel()
  86. # Don't forget the output size
  87. # node.shape is the shape of this node's output
  88. tensor_meta = get_tensor_meta(node)
  89. output_elem = tensor_meta.shape.numel()
  90. total_num_of_elems += output_elem
  91. # Assume for now if it's quantized then it's qint8 or quint8
  92. if tensor_meta.is_quantized:
  93. size_per_elem_bytes = torch._empty_affine_quantized(
  94. [], dtype=tensor_meta.dtype
  95. ).element_size()
  96. else:
  97. size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size()
  98. total_size = size_per_elem_bytes * total_num_of_elems
  99. output_size = size_per_elem_bytes * output_elem
  100. return size_bytes(output_size, total_size)