_runtime_estimation.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import torch
  2. from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps
  3. from torch.fx.experimental.symbolic_shapes import (
  4. has_hint,
  5. size_hint,
  6. statically_known_true,
  7. )
  8. from torch.utils._ordered_set import OrderedSet
  9. from .flop_counter import flop_registry
  10. aten = torch.ops.aten
  11. _FLOAT_TYPES = OrderedSet(
  12. [
  13. torch.float16,
  14. torch.bfloat16,
  15. torch.float32,
  16. torch.float64,
  17. ]
  18. )
  19. # No fall-back kernel needed/exists for view ops
  20. _VIEW_OPS = OrderedSet(
  21. [
  22. aten.lift_fresh,
  23. aten.t,
  24. aten.transpose,
  25. aten.view,
  26. aten.detach,
  27. aten._unsafe_view,
  28. aten.split,
  29. aten.adjoint,
  30. aten.as_strided,
  31. aten.diagonal,
  32. aten.expand,
  33. aten.expand_as,
  34. aten.movedim,
  35. aten.permute,
  36. aten.select,
  37. aten.squeeze,
  38. aten.mT,
  39. aten.mH,
  40. aten.real,
  41. aten.imag,
  42. aten.view_as,
  43. aten.unflatten,
  44. aten.unfold,
  45. aten.unbind,
  46. aten.unsqueeze,
  47. aten.vsplit,
  48. aten.hsplit,
  49. aten.split_with_sizes,
  50. aten.swapaxes,
  51. aten.swapdims,
  52. aten.chunk,
  53. ]
  54. )
  55. # We can ignore benchmarking tensor create ops
  56. _CREATE_OPS = OrderedSet(
  57. [
  58. aten.randint,
  59. aten.randn,
  60. aten.rand,
  61. aten.randn_like,
  62. aten.rand_like,
  63. aten.randint_like,
  64. aten.arange,
  65. aten.ones_like,
  66. aten.zeros_like,
  67. ]
  68. )
  69. _IGNORE_OPS = _VIEW_OPS | _CREATE_OPS
  70. def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def]
  71. """
  72. Estimates the compute time of an aten operator.
  73. Args:
  74. func_packet: The operator overload packet.
  75. args: The arguments to the operator.
  76. kwargs: The keyword arguments to the operator.
  77. out: The output of the operator.
  78. out_dtypes: The output data types.
  79. Returns:
  80. float: The estimated compute time in nanoseconds.
  81. """
  82. if func_packet in flop_registry:
  83. if len(out_dtypes) != 1:
  84. raise AssertionError(
  85. f"Only support single out dtype got {out_dtypes} for {func_packet}"
  86. )
  87. dtype = out_dtypes.pop()
  88. # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s
  89. peak_gpu_flops = get_device_tflops(dtype) * 1e15
  90. # We can expect to achieve 75% of theoretical peak flops
  91. factor = 0.75
  92. peak_empirical_flops = factor * peak_gpu_flops
  93. flop_count_func = flop_registry[func_packet]
  94. # We divide by a factor of 2 to get the MACs (multiply and accumulate)
  95. flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2
  96. # We multiply by 1e9 to get the time in nano seconds
  97. compute_time = (flop_count / peak_empirical_flops) * 1e9
  98. return compute_time
  99. return 0.0
  100. def get_num_bytes(t: torch.Tensor) -> int:
  101. """
  102. Calculates the memory consumption of a tensor.
  103. Args:
  104. t (torch.Tensor): The input tensor.
  105. Returns:
  106. int: The memory consumption of the tensor in bytes.
  107. """
  108. real_numel = 1
  109. for size, stride in zip(t.shape, t.stride()):
  110. if not has_hint(size) or not has_hint(stride):
  111. return 0
  112. # For dims with stride=0 (expanded/broadcast), only 1 element accessed
  113. if not statically_known_true(stride == 0):
  114. real_numel *= size_hint(size)
  115. return real_numel * t.element_size()
  116. def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def]
  117. """
  118. Estimates the memory transfer time of input and output tensors.
  119. Args:
  120. flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments.
  121. flat_outs (List[torch.Tensor]): The flat list of outputs.
  122. Returns:
  123. float: The estimated memory transfer time in nanoseconds.
  124. """
  125. gpu_memory_bandwidth = get_gpu_dram_gbps()
  126. read_bytes = sum(
  127. get_num_bytes(t) for t in flat_args_kwargs if isinstance(t, torch.Tensor)
  128. )
  129. write_bytes = sum(
  130. get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor)
  131. )
  132. counted_bytes = read_bytes + write_bytes
  133. # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds
  134. transfer_time = counted_bytes / gpu_memory_bandwidth
  135. return transfer_time