flop_counter.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921
  1. # mypy: allow-untyped-defs
  2. from types import NoneType
  3. import logging
  4. import torch
  5. from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
  6. from .module_tracker import ModuleTracker
  7. from typing import Any, TypeVar
  8. from collections.abc import Callable
  9. from collections.abc import Iterator
  10. from typing_extensions import ParamSpec
  11. from collections import defaultdict
  12. from torch.utils._python_dispatch import TorchDispatchMode
  13. from math import prod
  14. from functools import wraps
  15. import warnings
  16. __all__ = ["FlopCounterMode", "register_flop_formula"]
  17. _T = TypeVar("_T")
  18. _P = ParamSpec("_P")
  19. log = logging.getLogger(__name__)
  20. try:
  21. from triton.runtime.jit import JITFunction as _JITFunction
  22. except ImportError:
  23. if any(getattr(torch.version, attr, None) is not None for attr in ["cuda", "hip", "xpu"]):
  24. log.warning("triton not found; flop counting will not work for triton kernels")
  25. _JITFunction = NoneType
  26. aten = torch.ops.aten
  27. def get_shape(i):
  28. if isinstance(i, torch.Tensor):
  29. return i.shape
  30. return i
  31. flop_registry: dict[Any, Any] = {}
  32. def shape_wrapper(f):
  33. @wraps(f)
  34. def nf(*args, out_val=None, **kwargs):
  35. args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val))
  36. return f(*args, out_shape=out_shape, **kwargs)
  37. return nf
  38. def register_flop_formula(targets, get_raw=False) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  39. def register_fun(flop_formula: Callable[_P, _T]) -> Callable[_P, _T]:
  40. if not get_raw:
  41. flop_formula = shape_wrapper(flop_formula)
  42. def register(target) -> None:
  43. if not (isinstance(target, (torch._ops.OpOverloadPacket, _JITFunction))):
  44. raise ValueError(
  45. f"register_flop_formula(targets): expected each target to be "
  46. f"OpOverloadPacket (i.e. torch.ops.mylib.foo), or JitFunction"
  47. f", got {target} which is of type {type(target)}")
  48. if target in flop_registry:
  49. raise RuntimeError(f"duplicate registrations for {target}")
  50. flop_registry[target] = flop_formula
  51. # To handle allowing multiple aten_ops at once
  52. torch.utils._pytree.tree_map_(register, targets)
  53. return flop_formula
  54. return register_fun
  55. @register_flop_formula(aten.mm)
  56. def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
  57. """Count flops for matmul."""
  58. # Inputs should be a list of length 2.
  59. # Inputs contains the shapes of two matrices.
  60. m, k = a_shape
  61. k2, n = b_shape
  62. if k != k2:
  63. raise AssertionError(f"matmul: inner dimensions must match (k == k2), got {k} and {k2}")
  64. # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
  65. return m * n * 2 * k
  66. @register_flop_formula(aten.addmm)
  67. def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
  68. """Count flops for addmm."""
  69. return mm_flop(a_shape, b_shape)
  70. @register_flop_formula(aten.bmm)
  71. def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
  72. """Count flops for the bmm operation."""
  73. # Inputs should be a list of length 2.
  74. # Inputs contains the shapes of two tensor.
  75. b, m, k = a_shape
  76. b2, k2, n = b_shape
  77. if b != b2:
  78. raise AssertionError(f"bmm: batch dimensions must match (b == b2), got {b} and {b2}")
  79. if k != k2:
  80. raise AssertionError(f"bmm: inner dimensions must match (k == k2), got {k} and {k2}")
  81. # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
  82. flop = b * m * n * 2 * k
  83. return flop
  84. @register_flop_formula(aten.baddbmm)
  85. def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
  86. """Count flops for the baddbmm operation."""
  87. # Inputs should be a list of length 3.
  88. # Inputs contains the shapes of three tensors.
  89. return bmm_flop(a_shape, b_shape)
  90. @register_flop_formula(aten._scaled_mm)
  91. def _scaled_mm_flop(
  92. a_shape,
  93. b_shape,
  94. scale_a_shape,
  95. scale_b_shape,
  96. bias_shape=None,
  97. scale_result_shape=None,
  98. out_dtype=None,
  99. use_fast_accum=False,
  100. out_shape=None,
  101. **kwargs,
  102. ) -> int:
  103. """Count flops for _scaled_mm."""
  104. return mm_flop(a_shape, b_shape)
  105. def conv_flop_count(
  106. x_shape: list[int],
  107. w_shape: list[int],
  108. out_shape: list[int],
  109. transposed: bool = False,
  110. ) -> int:
  111. """Count flops for convolution.
  112. Note only multiplication is
  113. counted. Computation for bias are ignored.
  114. Flops for a transposed convolution are calculated as
  115. flops = (x_shape[2:] * prod(w_shape) * batch_size).
  116. Args:
  117. x_shape (list(int)): The input shape before convolution.
  118. w_shape (list(int)): The filter shape.
  119. out_shape (list(int)): The output shape after convolution.
  120. transposed (bool): is the convolution transposed
  121. Returns:
  122. int: the number of flops
  123. """
  124. batch_size = x_shape[0]
  125. conv_shape = (x_shape if transposed else out_shape)[2:]
  126. c_out, c_in, *filter_size = w_shape
  127. """
  128. General idea here is that for a regular conv, for each point in the output
  129. spatial dimension we convolve the filter with something (hence
  130. `prod(conv_shape) * prod(filter_size)` ops). Then, this gets multiplied by
  131. 1. batch_size, 2. the cross product of input and weight channels.
  132. For the transpose, it's not each point in the *output* spatial dimension but
  133. each point in the *input* spatial dimension.
  134. """
  135. # NB(chilli): I don't think this properly accounts for padding :think:
  136. # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
  137. flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2
  138. return flop
  139. @register_flop_formula([aten.convolution,
  140. aten._convolution,
  141. aten.cudnn_convolution,
  142. aten._slow_conv2d_forward,
  143. aten.convolution_overrideable])
  144. def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
  145. """Count flops for convolution."""
  146. # pyrefly: ignore [bad-argument-type]
  147. return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
  148. @register_flop_formula(aten.convolution_backward)
  149. def conv_backward_flop(
  150. grad_out_shape,
  151. x_shape,
  152. w_shape,
  153. _bias,
  154. _stride,
  155. _padding,
  156. _dilation,
  157. transposed,
  158. _output_padding,
  159. _groups,
  160. output_mask,
  161. out_shape) -> int:
  162. def t(shape):
  163. return [shape[1], shape[0]] + list(shape[2:])
  164. flop_count = 0
  165. """
  166. Let's say we have a regular 1D conv
  167. {A, B, C} [inp]
  168. {i, j} [weight]
  169. => (conv)
  170. {Ai + Bj, Bi + Cj} [out]
  171. And as a reminder, the transposed conv of the above is
  172. => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
  173. For the backwards of conv, we now have
  174. {D, E} [grad_out]
  175. {A, B, C} [inp]
  176. {i, j} [weight]
  177. # grad_inp as conv_transpose(grad_out, weight)
  178. Let's first compute grad_inp. To do so, we can simply look at all the
  179. multiplications that each element of inp is involved in. For example, A is
  180. only involved in the first element of the output (and thus only depends upon
  181. D in grad_out), and C is only involved in the last element of the output
  182. (and thus only depends upon E in grad_out)
  183. {Di, Dj + Ei, Ej} [grad_inp]
  184. Note that this corresponds to the below conv_transpose. This gives us the
  185. output_mask[0] branch, which is grad_inp.
  186. {D, E} [inp (grad_out)]
  187. {i, j} [weight]
  188. => (conv_transpose)
  189. {Di, Dj + Ei, Ej} [out (grad_inp)]
  190. I leave the fact that grad_inp for a transposed conv is just conv(grad_out,
  191. weight) as an exercise for the reader.
  192. # grad_weight as conv(inp, grad_out)
  193. To compute grad_weight, we again look at the terms in the output, which as
  194. a reminder is:
  195. => {Ai + Bj, Bi + Cj} [out]
  196. => {D, E} [grad_out]
  197. If we manually compute the gradient for the weights, we see it's
  198. {AD + BE, BD + CE} [grad_weight]
  199. This corresponds to the below conv
  200. {A, B, C} [inp]
  201. {D, E} [weight (grad_out)]
  202. => (conv)
  203. {AD + BE, BD + CE} [out (grad_weight)]
  204. # grad_weight of transposed conv as conv(grad_out, inp)
  205. As a reminder, the terms of the output of a transposed conv are:
  206. => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
  207. => {D, E, F, G} [grad_out]
  208. Manually computing the gradient for the weights, we see it's
  209. {AD + BE + CF, AE + BF + CG} [grad_weight]
  210. This corresponds to the below conv
  211. {D, E, F, G} [inp (grad_out)]
  212. {A, B, C} [weight (inp)]
  213. => (conv)
  214. {AD + BE + CF, AE + BF + CG} [out (grad_weight)]
  215. For the full backwards formula, there are also some details involving
  216. transpose of the batch/channel dimensions and groups, but I skip those for
  217. the sake of brevity (and they're pretty similar to matmul backwards)
  218. Check [conv backwards decomposition as conv forwards]
  219. """
  220. # grad_inp as conv_transpose(grad_out, weight)
  221. if output_mask[0]:
  222. grad_input_shape = get_shape(out_shape[0])
  223. flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed)
  224. if output_mask[1]:
  225. grad_weight_shape = get_shape(out_shape[1])
  226. if transposed:
  227. # grad_weight of transposed conv as conv(grad_out, inp)
  228. flop_count += conv_flop_count(t(grad_out_shape), t(x_shape), t(grad_weight_shape), transposed=False)
  229. else:
  230. # grad_weight as conv(inp, grad_out)
  231. flop_count += conv_flop_count(t(x_shape), t(grad_out_shape), t(grad_weight_shape), transposed=False)
  232. return flop_count
  233. def sdpa_flop_count(query_shape, key_shape, value_shape):
  234. """
  235. Count flops for self-attention.
  236. NB: We can assume that value_shape == key_shape
  237. """
  238. b, h, s_q, d_q = query_shape
  239. _b2, _h2, s_k, _d2 = key_shape
  240. _b3, _h3, _s3, d_v = value_shape
  241. if not b == _b2 == _b3 or not h == _h2 == _h3 or not d_q == _d2 or not s_k == _s3 or not d_q == _d2:
  242. raise AssertionError("sdpa_flop_count: query/key/value shapes are incompatible")
  243. total_flops = 0
  244. # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
  245. total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
  246. # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
  247. total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
  248. return total_flops
  249. @register_flop_formula([aten._scaled_dot_product_efficient_attention,
  250. aten._scaled_dot_product_flash_attention,
  251. aten._scaled_dot_product_cudnn_attention])
  252. def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
  253. """Count flops for self-attention."""
  254. # NB: We aren't accounting for causal attention here
  255. return sdpa_flop_count(query_shape, key_shape, value_shape)
  256. def _offsets_to_lengths(offsets, max_len):
  257. """
  258. If the offsets tensor is fake, then we don't know the actual lengths.
  259. In that case, we can just assume the worst case; each batch has max length.
  260. """
  261. from torch._subclasses.fake_tensor import FakeTensor
  262. from torch._subclasses.functional_tensor import FunctionalTensor
  263. if not isinstance(offsets, (FakeTensor, FunctionalTensor)) and offsets.device.type != "meta":
  264. return offsets.diff().tolist()
  265. return [max_len] * (offsets.size(0) - 1)
  266. def _unpack_flash_attention_nested_shapes(
  267. *,
  268. query,
  269. key,
  270. value,
  271. grad_out=None,
  272. cum_seq_q,
  273. cum_seq_k,
  274. max_q,
  275. max_k,
  276. ) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...] | None]]:
  277. """
  278. Given inputs to a flash_attention_(forward|backward) kernel, this will handle behavior for
  279. NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
  280. each batch element.
  281. In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
  282. """
  283. if cum_seq_q is not None:
  284. # This means we should be dealing with a Nested Jagged Tensor query.
  285. # The inputs will have shape (sum(sequence len), heads, dimension)
  286. # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
  287. # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
  288. # So the flops calculation in this case is an overestimate of the actual flops.
  289. if len(key.shape) != 3:
  290. raise AssertionError("sdpa_flop_count: expected key.shape to be 3-dimensional")
  291. if len(value.shape) != 3:
  292. raise AssertionError("sdpa_flop_count: expected value.shape to be 3-dimensional")
  293. if grad_out is not None and grad_out.shape != query.shape:
  294. raise AssertionError("sdpa_flop_count: grad_out.shape must match query.shape when provided")
  295. _, h_q, d_q = query.shape
  296. _, h_k, d_k = key.shape
  297. _, h_v, d_v = value.shape
  298. if cum_seq_q is None:
  299. raise AssertionError("sdpa_flop_count: cum_seq_q must not be None")
  300. if cum_seq_k is None:
  301. raise AssertionError("sdpa_flop_count: cum_seq_k must not be None")
  302. if cum_seq_q.shape != cum_seq_k.shape:
  303. raise AssertionError("sdpa_flop_count: cum_seq_q and cum_seq_k must have the same shape")
  304. seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
  305. seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
  306. for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths, strict=True):
  307. new_query_shape = (1, h_q, seq_q_len, d_q)
  308. new_key_shape = (1, h_k, seq_k_len, d_k)
  309. new_value_shape = (1, h_v, seq_k_len, d_v)
  310. new_grad_out_shape = new_query_shape if grad_out is not None else None
  311. yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
  312. return
  313. yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
  314. def _unpack_efficient_attention_nested_shapes(
  315. *,
  316. query,
  317. key,
  318. value,
  319. grad_out=None,
  320. cu_seqlens_q,
  321. cu_seqlens_k,
  322. max_seqlen_q,
  323. max_seqlen_k,
  324. ) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...] | None]]:
  325. """
  326. Given inputs to a efficient_attention_(forward|backward) kernel, this will handle behavior for
  327. NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
  328. each batch element.
  329. In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
  330. """
  331. if cu_seqlens_q is not None:
  332. # Unlike flash_attention_forward, we get a 4D tensor instead of a 3D tensor for efficient attention.
  333. #
  334. # This means we should be dealing with a Nested Jagged Tensor query.
  335. # The inputs will have shape (sum(sequence len), heads, dimension)
  336. # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
  337. # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
  338. # So the flops calculation in this case is an overestimate of the actual flops.
  339. if len(key.shape) != 4:
  340. raise AssertionError("_unpack_efficient_attention_nested_shapes: expected key.shape to be 4-dimensional")
  341. if len(value.shape) != 4:
  342. raise AssertionError("_unpack_efficient_attention_nested_shapes: expected value.shape to be 4-dimensional")
  343. if grad_out is not None and grad_out.shape != query.shape:
  344. raise AssertionError("_unpack_efficient_attention_nested_shapes: grad_out.shape must match query.shape when provided")
  345. _, _, h_q, d_q = query.shape
  346. _, _, h_k, d_k = key.shape
  347. _, _, h_v, d_v = value.shape
  348. if cu_seqlens_q is None:
  349. raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_q must not be None")
  350. if cu_seqlens_k is None:
  351. raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_k must not be None")
  352. if cu_seqlens_q.shape != cu_seqlens_k.shape:
  353. raise AssertionError("_unpack_efficient_attention_nested_shapes: "
  354. "cu_seqlens_q and cu_seqlens_k must have the same shape")
  355. seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
  356. seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
  357. for len_q, len_k in zip(seqlens_q, seqlens_k, strict=True):
  358. new_query_shape = (1, h_q, len_q, d_q)
  359. new_key_shape = (1, h_k, len_k, d_k)
  360. new_value_shape = (1, h_v, len_k, d_v)
  361. new_grad_out_shape = new_query_shape if grad_out is not None else None
  362. yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
  363. return
  364. yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
  365. @register_flop_formula(aten._flash_attention_forward, get_raw=True)
  366. def _flash_attention_forward_flop(
  367. query,
  368. key,
  369. value,
  370. cum_seq_q,
  371. cum_seq_k,
  372. max_q,
  373. max_k,
  374. *args,
  375. out_shape=None,
  376. **kwargs
  377. ) -> int:
  378. """Count flops for self-attention."""
  379. # NB: We aren't accounting for causal attention here
  380. # in case this is a nested tensor, we unpack the individual batch elements
  381. # and then sum the flops per batch element
  382. sizes = _unpack_flash_attention_nested_shapes(
  383. query=query,
  384. key=key,
  385. value=value,
  386. cum_seq_q=cum_seq_q,
  387. cum_seq_k=cum_seq_k,
  388. max_q=max_q,
  389. max_k=max_k,
  390. )
  391. return sum(
  392. sdpa_flop_count(query_shape, key_shape, value_shape)
  393. for query_shape, key_shape, value_shape, _ in sizes
  394. )
  395. @register_flop_formula(aten._efficient_attention_forward, get_raw=True)
  396. def _efficient_attention_forward_flop(
  397. query,
  398. key,
  399. value,
  400. bias,
  401. cu_seqlens_q,
  402. cu_seqlens_k,
  403. max_seqlen_q,
  404. max_seqlen_k,
  405. *args,
  406. **kwargs
  407. ) -> int:
  408. """Count flops for self-attention."""
  409. # NB: We aren't accounting for causal attention here
  410. # in case this is a nested tensor, we unpack the individual batch elements
  411. # and then sum the flops per batch element
  412. sizes = _unpack_efficient_attention_nested_shapes(
  413. query=query,
  414. key=key,
  415. value=value,
  416. cu_seqlens_q=cu_seqlens_q,
  417. cu_seqlens_k=cu_seqlens_k,
  418. max_seqlen_q=max_seqlen_q,
  419. max_seqlen_k=max_seqlen_k,
  420. )
  421. return sum(
  422. sdpa_flop_count(query_shape, key_shape, value_shape)
  423. for query_shape, key_shape, value_shape, _ in sizes
  424. )
  425. def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
  426. total_flops = 0
  427. b, h, s_q, d_q = query_shape
  428. _b2, _h2, s_k, _d2 = key_shape
  429. _b3, _h3, _s3, d_v = value_shape
  430. _b4, _h4, _s4, _d4 = grad_out_shape
  431. if not b == _b2 == _b3 == _b4 or not h == _h2 == _h3 == _h4 or not d_q == _d2:
  432. raise AssertionError("sdpa_backward_flop_count: batch/heads/dimension mismatch among tensors")
  433. if not d_v == _d4 or not s_k == _s3 or not s_q == _s4:
  434. raise AssertionError("sdpa_backward_flop_count: grad_out/value/key/query shapes are incompatible")
  435. total_flops = 0
  436. # Step 1: We recompute the scores matrix.
  437. # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
  438. total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
  439. # Step 2: We propagate the gradients through the score @ v operation.
  440. # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
  441. total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
  442. # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
  443. total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
  444. # Step 3: We propagate th gradients through the k @ v operation
  445. # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
  446. total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
  447. # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
  448. total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
  449. return total_flops
  450. @register_flop_formula([aten._scaled_dot_product_efficient_attention_backward,
  451. aten._scaled_dot_product_flash_attention_backward,
  452. aten._scaled_dot_product_cudnn_attention_backward])
  453. def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
  454. """Count flops for self-attention backward."""
  455. return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
  456. @register_flop_formula(aten._flash_attention_backward, get_raw=True)
  457. def _flash_attention_backward_flop(
  458. grad_out,
  459. query,
  460. key,
  461. value,
  462. out, # named _out_shape to avoid kwarg collision with out_shape created in wrapper
  463. logsumexp,
  464. cum_seq_q,
  465. cum_seq_k,
  466. max_q,
  467. max_k,
  468. *args,
  469. **kwargs,
  470. ) -> int:
  471. # in case this is a nested tensor, we unpack the individual batch elements
  472. # and then sum the flops per batch element
  473. shapes = _unpack_flash_attention_nested_shapes(
  474. query=query,
  475. key=key,
  476. value=value,
  477. grad_out=grad_out,
  478. cum_seq_q=cum_seq_q,
  479. cum_seq_k=cum_seq_k,
  480. max_q=max_q,
  481. max_k=max_k,
  482. )
  483. return sum(
  484. sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
  485. for query_shape, key_shape, value_shape, grad_out_shape in shapes
  486. )
  487. @register_flop_formula(aten._efficient_attention_backward, get_raw=True)
  488. def _efficient_attention_backward_flop(
  489. grad_out,
  490. query,
  491. key,
  492. value,
  493. bias,
  494. out, # named _out to avoid kwarg collision with out created in wrapper
  495. cu_seqlens_q,
  496. cu_seqlens_k,
  497. max_seqlen_q,
  498. max_seqlen_k,
  499. *args,
  500. **kwargs,
  501. ) -> int:
  502. # in case this is a nested tensor, we unpack the individual batch elements
  503. # and then sum the flops per batch element
  504. shapes = _unpack_efficient_attention_nested_shapes(
  505. query=query,
  506. key=key,
  507. value=value,
  508. grad_out=grad_out,
  509. cu_seqlens_q=cu_seqlens_q,
  510. cu_seqlens_k=cu_seqlens_k,
  511. max_seqlen_q=max_seqlen_q,
  512. max_seqlen_k=max_seqlen_k,
  513. )
  514. return sum(
  515. sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
  516. for query_shape, key_shape, value_shape, grad_out_shape in shapes
  517. )
  518. flop_registry = {
  519. aten.mm: mm_flop,
  520. aten.addmm: addmm_flop,
  521. aten.bmm: bmm_flop,
  522. aten.baddbmm: baddbmm_flop,
  523. aten._scaled_mm: _scaled_mm_flop,
  524. aten.convolution: conv_flop,
  525. aten._convolution: conv_flop,
  526. aten.cudnn_convolution: conv_flop,
  527. aten.convolution_overrideable: conv_flop,
  528. aten._slow_conv2d_forward: conv_flop,
  529. aten.convolution_backward: conv_backward_flop,
  530. aten._scaled_dot_product_efficient_attention: sdpa_flop,
  531. aten._scaled_dot_product_flash_attention: sdpa_flop,
  532. aten._scaled_dot_product_cudnn_attention: sdpa_flop,
  533. aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
  534. aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
  535. aten._scaled_dot_product_cudnn_attention_backward: sdpa_backward_flop,
  536. aten._flash_attention_forward: _flash_attention_forward_flop,
  537. aten._efficient_attention_forward: _efficient_attention_forward_flop,
  538. aten._flash_attention_backward: _flash_attention_backward_flop,
  539. aten._efficient_attention_backward: _efficient_attention_backward_flop,
  540. }
  541. def normalize_tuple(x):
  542. if not isinstance(x, tuple):
  543. return (x,)
  544. return x
  545. # Define the suffixes for different orders of magnitude
  546. suffixes = ["", "K", "M", "B", "T"]
  547. # Thanks BingChat!
  548. def get_suffix_str(number):
  549. # Find the index of the appropriate suffix based on the number of digits
  550. # with some additional overflow.
  551. # i.e. 1.01B should be displayed as 1001M, not 1.001B
  552. index = max(0, min(len(suffixes) - 1, (len(str(number)) - 2) // 3))
  553. return suffixes[index]
  554. def convert_num_with_suffix(number, suffix):
  555. index = suffixes.index(suffix)
  556. # Divide the number by 1000^index and format it to two decimal places
  557. value = f"{number / 1000 ** index:.3f}"
  558. # Return the value and the suffix as a string
  559. return value + suffixes[index]
  560. def convert_to_percent_str(num, denom) -> str:
  561. if denom == 0:
  562. return "0%"
  563. return f"{num / denom:.2%}"
  564. def _pytreeify_preserve_structure(f):
  565. @wraps(f)
  566. def nf(args):
  567. flat_args, spec = tree_flatten(args)
  568. out = f(*flat_args)
  569. return tree_unflatten(out, spec)
  570. return nf
  571. class FlopCounterMode:
  572. """
  573. ``FlopCounterMode`` is a context manager that counts the number of flops within its context.
  574. It does this using a ``TorchDispatchMode``.
  575. It also supports hierarchical output by passing a module (or list of
  576. modules) to FlopCounterMode on construction. If you do not need hierarchical
  577. output, you do not need to use it with a module.
  578. Example usage
  579. .. code-block:: python
  580. mod = ...
  581. with FlopCounterMode(mod) as flop_counter:
  582. mod.sum().backward()
  583. """
  584. def __init__(
  585. self,
  586. mods: torch.nn.Module | list[torch.nn.Module] | None = None,
  587. depth: int = 2,
  588. display: bool = True,
  589. custom_mapping: dict[Any, Any] | None = None) -> None:
  590. super().__init__()
  591. self.flop_counts: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int))
  592. self.depth = depth
  593. self.display = display
  594. self.mode: _FlopCounterMode | None = None
  595. if custom_mapping is None:
  596. custom_mapping = {}
  597. if mods is not None:
  598. warnings.warn("mods argument is not needed anymore, you can stop passing it", stacklevel=2)
  599. self.flop_registry = {
  600. **flop_registry,
  601. **{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
  602. }
  603. self.mod_tracker = ModuleTracker()
  604. def get_total_flops(self) -> int:
  605. return sum(self.flop_counts['Global'].values())
  606. def get_flop_counts(self) -> dict[str, dict[Any, int]]:
  607. """Return the flop counts as a dictionary of dictionaries.
  608. The outer
  609. dictionary is keyed by module name, and the inner dictionary is keyed by
  610. operation name.
  611. Returns:
  612. Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
  613. """
  614. return {k: dict(v) for k, v in self.flop_counts.items()}
  615. def get_table(self, depth=None):
  616. if depth is None:
  617. depth = self.depth
  618. if depth is None:
  619. depth = 999999
  620. import tabulate
  621. tabulate.PRESERVE_WHITESPACE = True
  622. header = ["Module", "FLOP", "% Total"]
  623. values = []
  624. global_flops = self.get_total_flops()
  625. global_suffix = get_suffix_str(global_flops)
  626. is_global_subsumed = False
  627. def process_mod(mod_name, depth):
  628. nonlocal is_global_subsumed
  629. total_flops = sum(self.flop_counts[mod_name].values())
  630. is_global_subsumed |= total_flops >= global_flops
  631. padding = " " * depth
  632. values = []
  633. values.append([
  634. padding + mod_name,
  635. convert_num_with_suffix(total_flops, global_suffix),
  636. convert_to_percent_str(total_flops, global_flops)
  637. ])
  638. for k, v in self.flop_counts[mod_name].items():
  639. values.append([
  640. padding + " - " + str(k),
  641. convert_num_with_suffix(v, global_suffix),
  642. convert_to_percent_str(v, global_flops)
  643. ])
  644. return values
  645. for mod in sorted(self.flop_counts.keys()):
  646. if mod == 'Global':
  647. continue
  648. mod_depth = mod.count(".") + 1
  649. if mod_depth > depth:
  650. continue
  651. cur_values = process_mod(mod, mod_depth - 1)
  652. values.extend(cur_values)
  653. # We do a bit of messing around here to only output the "Global" value
  654. # if there are any FLOPs in there that aren't already fully contained by
  655. # a module.
  656. if 'Global' in self.flop_counts and not is_global_subsumed:
  657. for value in values:
  658. value[0] = " " + value[0]
  659. values = process_mod('Global', 0) + values
  660. if len(values) == 0:
  661. values = [["Global", "0", "0%"]]
  662. return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right"))
  663. # NB: This context manager is NOT reentrant
  664. def __enter__(self):
  665. self.flop_counts.clear()
  666. self.mod_tracker.__enter__()
  667. self.mode = _FlopCounterMode(self)
  668. self.mode.__enter__()
  669. return self
  670. def __exit__(self, *args):
  671. if self.mode is None:
  672. raise AssertionError("Internal error: FlopCounter.__exit__ called but mode is None")
  673. b = self.mode.__exit__(*args)
  674. self.mode = None # break cycles
  675. self.mod_tracker.__exit__()
  676. if self.display:
  677. print(self.get_table(self.depth))
  678. return b
  679. def _count_flops(self, func_packet, out, args, kwargs):
  680. if func_packet in self.flop_registry:
  681. flop_count_func = self.flop_registry[func_packet]
  682. flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator]
  683. for par in set(self.mod_tracker.parents):
  684. self.flop_counts[par][func_packet] += flop_count
  685. return out
  686. class _FlopCounterMode(TorchDispatchMode):
  687. supports_higher_order_operators = True
  688. def __init__(self, counter: FlopCounterMode) -> None:
  689. self.counter = counter
  690. def _execute_with_isolated_flop_counting(self, branch_fn, operands):
  691. """Execute a branch function and capture its FLOP counts without
  692. affecting self.counter.flop_counts
  693. Args:
  694. branch_fn: The branch function to execute
  695. operands: Arguments to pass to the branch function
  696. Returns:
  697. Tuple of (result, flop_counts) where result is the branch output
  698. and flop_counts is a copy of the FLOP counts after execution
  699. """
  700. import copy
  701. checkpointed_flop_counts = copy.copy(self.counter.flop_counts)
  702. with self:
  703. result = branch_fn(*operands)
  704. flop_counts = copy.copy(self.counter.flop_counts)
  705. self.counter.flop_counts = checkpointed_flop_counts
  706. return result, flop_counts
  707. def _handle_higher_order_ops(self, func, types, args, kwargs):
  708. is_triton = func in {torch.ops.higher_order.triton_kernel_wrapper_mutation,
  709. torch.ops.higher_order.triton_kernel_wrapper_functional}
  710. if is_triton:
  711. from torch._higher_order_ops.triton_kernel_wrap import get_kernel
  712. # Special case - look in the triton flop registry for the kernel
  713. from triton.runtime.jit import JITFunction
  714. kernel_name = get_kernel(kwargs["kernel_idx"])
  715. # Unwrap heuristics if they are present
  716. while not isinstance(kernel_name, JITFunction):
  717. if hasattr(kernel_name, "fn"):
  718. kernel_name = kernel_name.fn
  719. else:
  720. break
  721. return self.counter._count_flops(kernel_name, None, args, kwargs)
  722. elif func is torch.ops.higher_order.cond:
  723. # The flop counter for cond counts the upper bound of flops.
  724. # For example, if a matmul is executed 2 times in true branch
  725. # but only 1 time in the false branch, the flop counter will
  726. # record the larger number of flops, i.e. 2 times.
  727. pred, true_branch, false_branch, operands = args
  728. # Step 1: Count flops for true branch and false branch separately
  729. true_out, true_flop_counts = self._execute_with_isolated_flop_counting(
  730. true_branch, operands
  731. )
  732. if true_out is NotImplemented:
  733. return NotImplemented
  734. false_out, false_flop_counts = self._execute_with_isolated_flop_counting(
  735. false_branch, operands
  736. )
  737. if false_out is NotImplemented:
  738. return NotImplemented
  739. # Step 2: merge flop counts
  740. all_mod_keys = set(true_flop_counts.keys()) | set(false_flop_counts.keys())
  741. merged_flop_counts = {}
  742. for outer_key in all_mod_keys:
  743. true_func_counts = true_flop_counts[outer_key]
  744. false_func_counts = false_flop_counts[outer_key]
  745. merged_func_counts = {}
  746. all_func_keys = set(true_func_counts.keys()) | set(false_func_counts.keys())
  747. for func_key in all_func_keys:
  748. true_val = true_func_counts.get(func_key, 0)
  749. false_val = false_func_counts.get(func_key, 0)
  750. merged_func_counts[func_key] = max(true_val, false_val)
  751. merged_flop_counts[outer_key] = merged_func_counts
  752. # Step 3: update the counter with merged counts
  753. for outer_key, inner_dict in merged_flop_counts.items():
  754. self.counter.flop_counts[outer_key].update(inner_dict)
  755. # It doesn't matter which one we return since true_fn and false_fn return
  756. # output with the same structure.
  757. return true_out
  758. else:
  759. return NotImplemented
  760. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  761. kwargs = kwargs if kwargs else {}
  762. # Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
  763. if func in {torch.ops.aten.sym_is_contiguous.default,
  764. torch.ops.aten.is_contiguous.default,
  765. torch.ops.aten.is_contiguous.memory_format,
  766. torch.ops.aten.is_strides_like_format.default,
  767. torch.ops.aten.is_non_overlapping_and_dense.default,
  768. torch.ops.aten.size.default,
  769. torch.ops.aten.sym_size.default,
  770. torch.ops.aten.stride.default,
  771. torch.ops.aten.sym_stride.default,
  772. torch.ops.aten.storage_offset.default,
  773. torch.ops.aten.sym_storage_offset.default,
  774. torch.ops.aten.numel.default,
  775. torch.ops.aten.sym_numel.default,
  776. torch.ops.aten.dim.default,
  777. torch.ops.prim.layout.default}:
  778. return NotImplemented
  779. if isinstance(func, torch._ops.HigherOrderOperator):
  780. return self._handle_higher_order_ops(func, types, args, kwargs)
  781. # If we don't have func in flop_registry, see if it can decompose
  782. if func not in self.counter.flop_registry and func is not torch.ops.prim.device.default:
  783. with self:
  784. r = func.decompose(*args, **kwargs)
  785. if r is not NotImplemented:
  786. return r
  787. # no further decomposition; execute & count flops
  788. out = func(*args, **kwargs)
  789. return self.counter._count_flops(func._overloadpacket, out, args, kwargs)