log_extract.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # mypy: allow-untyped-defs
  2. from contextlib import contextmanager
  3. from typing import Any, cast
  4. import random
  5. import torch
  6. import time
  7. from torch.utils.benchmark import Timer
  8. def extract_ir(filename: str) -> list[str]:
  9. BEGIN = "<GRAPH_EXPORT>"
  10. END = "</GRAPH_EXPORT>"
  11. pfx = None
  12. graphs = []
  13. with open(filename) as f:
  14. split_strs = f.read().split(BEGIN)
  15. for i, split_str in enumerate(split_strs):
  16. if i == 0:
  17. continue
  18. end_loc = split_str.find(END)
  19. if end_loc == -1:
  20. continue
  21. s = split_str[:end_loc]
  22. pfx = split_strs[i - 1].splitlines()[-1]
  23. lines = [x[len(pfx):] for x in s.splitlines(keepends=True)]
  24. graphs.append(''.join(lines))
  25. return graphs
  26. def make_tensor_from_type(inp_type: torch._C.TensorType):
  27. size = inp_type.sizes()
  28. stride = inp_type.strides()
  29. device = inp_type.device()
  30. dtype = inp_type.dtype()
  31. if size is None:
  32. raise AssertionError("make_tensor_from_type: 'size' is None (inp_type.sizes() returned None)")
  33. if stride is None:
  34. raise AssertionError("make_tensor_from_type: 'stride' is None (inp_type.strides() returned None)")
  35. if device is None:
  36. raise AssertionError("make_tensor_from_type: 'device' is None (inp_type.device() returned None)")
  37. if dtype is None:
  38. raise AssertionError("make_tensor_from_type: 'dtype' is None (inp_type.dtype() returned None)")
  39. return torch.empty_strided(size=size, stride=stride, device=device, dtype=dtype)
  40. def load_graph_and_inputs(ir: str) -> tuple[Any, list[Any]]:
  41. graph = torch._C.parse_ir(ir, parse_tensor_constants=True)
  42. graph.makeMultiOutputIntoTuple()
  43. inputs = []
  44. for inp in graph.inputs():
  45. if isinstance(inp.type(), torch._C.FloatType):
  46. inputs.append(random.uniform(.1, 100))
  47. elif isinstance(inp.type(), torch._C.IntType):
  48. inputs.append(random.randint(1, 100))
  49. elif isinstance(inp.type(), torch._C.TensorType):
  50. tensorType = cast(torch._C.TensorType, inp.type())
  51. inputs.append(make_tensor_from_type(tensorType))
  52. elif isinstance(inp.type(), torch._C.BoolType):
  53. inputs.append(random.randint(0, 1) == 1)
  54. else:
  55. raise NotImplementedError(f"A default value is not implemented for type {inp.type()}")
  56. func = torch._C._create_function_from_graph("forward", graph)
  57. torch._C._jit_pass_erase_shape_information(func.graph)
  58. return (func, inputs)
  59. def time_cuda(fn, inputs, test_runs):
  60. t = Timer(stmt="fn(*inputs)", globals={"fn": fn, "inputs" : inputs})
  61. times = t.blocked_autorange()
  62. return times.median * 1000 # time in ms
  63. def time_cpu(fn, inputs, test_runs):
  64. s = time.perf_counter()
  65. for _ in range(test_runs):
  66. fn(*inputs)
  67. e = time.perf_counter()
  68. return (e - s) / test_runs * 1000 # time in ms
  69. def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
  70. graph, _ = load_graph_and_inputs(ir)
  71. for _ in range(warmup_runs):
  72. graph(*inputs)
  73. is_cpu = None
  74. for input in inputs:
  75. if isinstance(input, torch.Tensor):
  76. is_cpu = input.device.type == "cpu"
  77. break
  78. if is_cpu is None:
  79. raise AssertionError("No tensor found in inputs")
  80. out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs)
  81. return out
  82. @contextmanager
  83. def no_fuser(*args, **kwargs):
  84. old_optimize = torch._C._get_graph_executor_optimize(False)
  85. try:
  86. yield
  87. finally:
  88. torch._C._get_graph_executor_optimize(old_optimize)
  89. def run_baseline_no_fusion(ir, inputs) -> float:
  90. with no_fuser():
  91. return run_test(ir, inputs)
  92. def run_nnc(ir, inputs, dynamic) -> float:
  93. try:
  94. strat = [("DYNAMIC", 10)] if dynamic else [("STATIC", 10)]
  95. old_strat = torch.jit.set_fusion_strategy(strat)
  96. with torch.jit.fuser("fuser1"):
  97. return run_test(ir, inputs)
  98. finally:
  99. torch.jit.set_fusion_strategy(old_strat)
  100. def run_nvfuser(ir, inputs) -> float:
  101. with torch.jit.fuser("fuser2"):
  102. return run_test(ir, inputs)