| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921 |
- # mypy: allow-untyped-defs
- from types import NoneType
- import logging
- import torch
- from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
- from .module_tracker import ModuleTracker
- from typing import Any, TypeVar
- from collections.abc import Callable
- from collections.abc import Iterator
- from typing_extensions import ParamSpec
- from collections import defaultdict
- from torch.utils._python_dispatch import TorchDispatchMode
- from math import prod
- from functools import wraps
- import warnings
- __all__ = ["FlopCounterMode", "register_flop_formula"]
- _T = TypeVar("_T")
- _P = ParamSpec("_P")
- log = logging.getLogger(__name__)
- try:
- from triton.runtime.jit import JITFunction as _JITFunction
- except ImportError:
- if any(getattr(torch.version, attr, None) is not None for attr in ["cuda", "hip", "xpu"]):
- log.warning("triton not found; flop counting will not work for triton kernels")
- _JITFunction = NoneType
- aten = torch.ops.aten
- def get_shape(i):
- if isinstance(i, torch.Tensor):
- return i.shape
- return i
- flop_registry: dict[Any, Any] = {}
- def shape_wrapper(f):
- @wraps(f)
- def nf(*args, out_val=None, **kwargs):
- args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val))
- return f(*args, out_shape=out_shape, **kwargs)
- return nf
- def register_flop_formula(targets, get_raw=False) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
- def register_fun(flop_formula: Callable[_P, _T]) -> Callable[_P, _T]:
- if not get_raw:
- flop_formula = shape_wrapper(flop_formula)
- def register(target) -> None:
- if not (isinstance(target, (torch._ops.OpOverloadPacket, _JITFunction))):
- raise ValueError(
- f"register_flop_formula(targets): expected each target to be "
- f"OpOverloadPacket (i.e. torch.ops.mylib.foo), or JitFunction"
- f", got {target} which is of type {type(target)}")
- if target in flop_registry:
- raise RuntimeError(f"duplicate registrations for {target}")
- flop_registry[target] = flop_formula
- # To handle allowing multiple aten_ops at once
- torch.utils._pytree.tree_map_(register, targets)
- return flop_formula
- return register_fun
- @register_flop_formula(aten.mm)
- def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
- """Count flops for matmul."""
- # Inputs should be a list of length 2.
- # Inputs contains the shapes of two matrices.
- m, k = a_shape
- k2, n = b_shape
- if k != k2:
- raise AssertionError(f"matmul: inner dimensions must match (k == k2), got {k} and {k2}")
- # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
- return m * n * 2 * k
- @register_flop_formula(aten.addmm)
- def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
- """Count flops for addmm."""
- return mm_flop(a_shape, b_shape)
- @register_flop_formula(aten.bmm)
- def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
- """Count flops for the bmm operation."""
- # Inputs should be a list of length 2.
- # Inputs contains the shapes of two tensor.
- b, m, k = a_shape
- b2, k2, n = b_shape
- if b != b2:
- raise AssertionError(f"bmm: batch dimensions must match (b == b2), got {b} and {b2}")
- if k != k2:
- raise AssertionError(f"bmm: inner dimensions must match (k == k2), got {k} and {k2}")
- # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
- flop = b * m * n * 2 * k
- return flop
- @register_flop_formula(aten.baddbmm)
- def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
- """Count flops for the baddbmm operation."""
- # Inputs should be a list of length 3.
- # Inputs contains the shapes of three tensors.
- return bmm_flop(a_shape, b_shape)
- @register_flop_formula(aten._scaled_mm)
- def _scaled_mm_flop(
- a_shape,
- b_shape,
- scale_a_shape,
- scale_b_shape,
- bias_shape=None,
- scale_result_shape=None,
- out_dtype=None,
- use_fast_accum=False,
- out_shape=None,
- **kwargs,
- ) -> int:
- """Count flops for _scaled_mm."""
- return mm_flop(a_shape, b_shape)
- def conv_flop_count(
- x_shape: list[int],
- w_shape: list[int],
- out_shape: list[int],
- transposed: bool = False,
- ) -> int:
- """Count flops for convolution.
- Note only multiplication is
- counted. Computation for bias are ignored.
- Flops for a transposed convolution are calculated as
- flops = (x_shape[2:] * prod(w_shape) * batch_size).
- Args:
- x_shape (list(int)): The input shape before convolution.
- w_shape (list(int)): The filter shape.
- out_shape (list(int)): The output shape after convolution.
- transposed (bool): is the convolution transposed
- Returns:
- int: the number of flops
- """
- batch_size = x_shape[0]
- conv_shape = (x_shape if transposed else out_shape)[2:]
- c_out, c_in, *filter_size = w_shape
- """
- General idea here is that for a regular conv, for each point in the output
- spatial dimension we convolve the filter with something (hence
- `prod(conv_shape) * prod(filter_size)` ops). Then, this gets multiplied by
- 1. batch_size, 2. the cross product of input and weight channels.
- For the transpose, it's not each point in the *output* spatial dimension but
- each point in the *input* spatial dimension.
- """
- # NB(chilli): I don't think this properly accounts for padding :think:
- # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
- flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2
- return flop
- @register_flop_formula([aten.convolution,
- aten._convolution,
- aten.cudnn_convolution,
- aten._slow_conv2d_forward,
- aten.convolution_overrideable])
- def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
- """Count flops for convolution."""
- # pyrefly: ignore [bad-argument-type]
- return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
- @register_flop_formula(aten.convolution_backward)
- def conv_backward_flop(
- grad_out_shape,
- x_shape,
- w_shape,
- _bias,
- _stride,
- _padding,
- _dilation,
- transposed,
- _output_padding,
- _groups,
- output_mask,
- out_shape) -> int:
- def t(shape):
- return [shape[1], shape[0]] + list(shape[2:])
- flop_count = 0
- """
- Let's say we have a regular 1D conv
- {A, B, C} [inp]
- {i, j} [weight]
- => (conv)
- {Ai + Bj, Bi + Cj} [out]
- And as a reminder, the transposed conv of the above is
- => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
- For the backwards of conv, we now have
- {D, E} [grad_out]
- {A, B, C} [inp]
- {i, j} [weight]
- # grad_inp as conv_transpose(grad_out, weight)
- Let's first compute grad_inp. To do so, we can simply look at all the
- multiplications that each element of inp is involved in. For example, A is
- only involved in the first element of the output (and thus only depends upon
- D in grad_out), and C is only involved in the last element of the output
- (and thus only depends upon E in grad_out)
- {Di, Dj + Ei, Ej} [grad_inp]
- Note that this corresponds to the below conv_transpose. This gives us the
- output_mask[0] branch, which is grad_inp.
- {D, E} [inp (grad_out)]
- {i, j} [weight]
- => (conv_transpose)
- {Di, Dj + Ei, Ej} [out (grad_inp)]
- I leave the fact that grad_inp for a transposed conv is just conv(grad_out,
- weight) as an exercise for the reader.
- # grad_weight as conv(inp, grad_out)
- To compute grad_weight, we again look at the terms in the output, which as
- a reminder is:
- => {Ai + Bj, Bi + Cj} [out]
- => {D, E} [grad_out]
- If we manually compute the gradient for the weights, we see it's
- {AD + BE, BD + CE} [grad_weight]
- This corresponds to the below conv
- {A, B, C} [inp]
- {D, E} [weight (grad_out)]
- => (conv)
- {AD + BE, BD + CE} [out (grad_weight)]
- # grad_weight of transposed conv as conv(grad_out, inp)
- As a reminder, the terms of the output of a transposed conv are:
- => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
- => {D, E, F, G} [grad_out]
- Manually computing the gradient for the weights, we see it's
- {AD + BE + CF, AE + BF + CG} [grad_weight]
- This corresponds to the below conv
- {D, E, F, G} [inp (grad_out)]
- {A, B, C} [weight (inp)]
- => (conv)
- {AD + BE + CF, AE + BF + CG} [out (grad_weight)]
- For the full backwards formula, there are also some details involving
- transpose of the batch/channel dimensions and groups, but I skip those for
- the sake of brevity (and they're pretty similar to matmul backwards)
- Check [conv backwards decomposition as conv forwards]
- """
- # grad_inp as conv_transpose(grad_out, weight)
- if output_mask[0]:
- grad_input_shape = get_shape(out_shape[0])
- flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed)
- if output_mask[1]:
- grad_weight_shape = get_shape(out_shape[1])
- if transposed:
- # grad_weight of transposed conv as conv(grad_out, inp)
- flop_count += conv_flop_count(t(grad_out_shape), t(x_shape), t(grad_weight_shape), transposed=False)
- else:
- # grad_weight as conv(inp, grad_out)
- flop_count += conv_flop_count(t(x_shape), t(grad_out_shape), t(grad_weight_shape), transposed=False)
- return flop_count
- def sdpa_flop_count(query_shape, key_shape, value_shape):
- """
- Count flops for self-attention.
- NB: We can assume that value_shape == key_shape
- """
- b, h, s_q, d_q = query_shape
- _b2, _h2, s_k, _d2 = key_shape
- _b3, _h3, _s3, d_v = value_shape
- if not b == _b2 == _b3 or not h == _h2 == _h3 or not d_q == _d2 or not s_k == _s3 or not d_q == _d2:
- raise AssertionError("sdpa_flop_count: query/key/value shapes are incompatible")
- total_flops = 0
- # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
- total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
- # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
- total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
- return total_flops
- @register_flop_formula([aten._scaled_dot_product_efficient_attention,
- aten._scaled_dot_product_flash_attention,
- aten._scaled_dot_product_cudnn_attention])
- def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
- """Count flops for self-attention."""
- # NB: We aren't accounting for causal attention here
- return sdpa_flop_count(query_shape, key_shape, value_shape)
- def _offsets_to_lengths(offsets, max_len):
- """
- If the offsets tensor is fake, then we don't know the actual lengths.
- In that case, we can just assume the worst case; each batch has max length.
- """
- from torch._subclasses.fake_tensor import FakeTensor
- from torch._subclasses.functional_tensor import FunctionalTensor
- if not isinstance(offsets, (FakeTensor, FunctionalTensor)) and offsets.device.type != "meta":
- return offsets.diff().tolist()
- return [max_len] * (offsets.size(0) - 1)
- def _unpack_flash_attention_nested_shapes(
- *,
- query,
- key,
- value,
- grad_out=None,
- cum_seq_q,
- cum_seq_k,
- max_q,
- max_k,
- ) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...] | None]]:
- """
- Given inputs to a flash_attention_(forward|backward) kernel, this will handle behavior for
- NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
- each batch element.
- In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
- """
- if cum_seq_q is not None:
- # This means we should be dealing with a Nested Jagged Tensor query.
- # The inputs will have shape (sum(sequence len), heads, dimension)
- # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
- # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
- # So the flops calculation in this case is an overestimate of the actual flops.
- if len(key.shape) != 3:
- raise AssertionError("sdpa_flop_count: expected key.shape to be 3-dimensional")
- if len(value.shape) != 3:
- raise AssertionError("sdpa_flop_count: expected value.shape to be 3-dimensional")
- if grad_out is not None and grad_out.shape != query.shape:
- raise AssertionError("sdpa_flop_count: grad_out.shape must match query.shape when provided")
- _, h_q, d_q = query.shape
- _, h_k, d_k = key.shape
- _, h_v, d_v = value.shape
- if cum_seq_q is None:
- raise AssertionError("sdpa_flop_count: cum_seq_q must not be None")
- if cum_seq_k is None:
- raise AssertionError("sdpa_flop_count: cum_seq_k must not be None")
- if cum_seq_q.shape != cum_seq_k.shape:
- raise AssertionError("sdpa_flop_count: cum_seq_q and cum_seq_k must have the same shape")
- seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
- seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
- for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths, strict=True):
- new_query_shape = (1, h_q, seq_q_len, d_q)
- new_key_shape = (1, h_k, seq_k_len, d_k)
- new_value_shape = (1, h_v, seq_k_len, d_v)
- new_grad_out_shape = new_query_shape if grad_out is not None else None
- yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
- return
- yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
- def _unpack_efficient_attention_nested_shapes(
- *,
- query,
- key,
- value,
- grad_out=None,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- ) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...] | None]]:
- """
- Given inputs to a efficient_attention_(forward|backward) kernel, this will handle behavior for
- NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
- each batch element.
- In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
- """
- if cu_seqlens_q is not None:
- # Unlike flash_attention_forward, we get a 4D tensor instead of a 3D tensor for efficient attention.
- #
- # This means we should be dealing with a Nested Jagged Tensor query.
- # The inputs will have shape (sum(sequence len), heads, dimension)
- # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
- # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
- # So the flops calculation in this case is an overestimate of the actual flops.
- if len(key.shape) != 4:
- raise AssertionError("_unpack_efficient_attention_nested_shapes: expected key.shape to be 4-dimensional")
- if len(value.shape) != 4:
- raise AssertionError("_unpack_efficient_attention_nested_shapes: expected value.shape to be 4-dimensional")
- if grad_out is not None and grad_out.shape != query.shape:
- raise AssertionError("_unpack_efficient_attention_nested_shapes: grad_out.shape must match query.shape when provided")
- _, _, h_q, d_q = query.shape
- _, _, h_k, d_k = key.shape
- _, _, h_v, d_v = value.shape
- if cu_seqlens_q is None:
- raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_q must not be None")
- if cu_seqlens_k is None:
- raise AssertionError("_unpack_efficient_attention_nested_shapes: cu_seqlens_k must not be None")
- if cu_seqlens_q.shape != cu_seqlens_k.shape:
- raise AssertionError("_unpack_efficient_attention_nested_shapes: "
- "cu_seqlens_q and cu_seqlens_k must have the same shape")
- seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
- seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
- for len_q, len_k in zip(seqlens_q, seqlens_k, strict=True):
- new_query_shape = (1, h_q, len_q, d_q)
- new_key_shape = (1, h_k, len_k, d_k)
- new_value_shape = (1, h_v, len_k, d_v)
- new_grad_out_shape = new_query_shape if grad_out is not None else None
- yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
- return
- yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
- @register_flop_formula(aten._flash_attention_forward, get_raw=True)
- def _flash_attention_forward_flop(
- query,
- key,
- value,
- cum_seq_q,
- cum_seq_k,
- max_q,
- max_k,
- *args,
- out_shape=None,
- **kwargs
- ) -> int:
- """Count flops for self-attention."""
- # NB: We aren't accounting for causal attention here
- # in case this is a nested tensor, we unpack the individual batch elements
- # and then sum the flops per batch element
- sizes = _unpack_flash_attention_nested_shapes(
- query=query,
- key=key,
- value=value,
- cum_seq_q=cum_seq_q,
- cum_seq_k=cum_seq_k,
- max_q=max_q,
- max_k=max_k,
- )
- return sum(
- sdpa_flop_count(query_shape, key_shape, value_shape)
- for query_shape, key_shape, value_shape, _ in sizes
- )
- @register_flop_formula(aten._efficient_attention_forward, get_raw=True)
- def _efficient_attention_forward_flop(
- query,
- key,
- value,
- bias,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- *args,
- **kwargs
- ) -> int:
- """Count flops for self-attention."""
- # NB: We aren't accounting for causal attention here
- # in case this is a nested tensor, we unpack the individual batch elements
- # and then sum the flops per batch element
- sizes = _unpack_efficient_attention_nested_shapes(
- query=query,
- key=key,
- value=value,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_q,
- max_seqlen_k=max_seqlen_k,
- )
- return sum(
- sdpa_flop_count(query_shape, key_shape, value_shape)
- for query_shape, key_shape, value_shape, _ in sizes
- )
- def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
- total_flops = 0
- b, h, s_q, d_q = query_shape
- _b2, _h2, s_k, _d2 = key_shape
- _b3, _h3, _s3, d_v = value_shape
- _b4, _h4, _s4, _d4 = grad_out_shape
- if not b == _b2 == _b3 == _b4 or not h == _h2 == _h3 == _h4 or not d_q == _d2:
- raise AssertionError("sdpa_backward_flop_count: batch/heads/dimension mismatch among tensors")
- if not d_v == _d4 or not s_k == _s3 or not s_q == _s4:
- raise AssertionError("sdpa_backward_flop_count: grad_out/value/key/query shapes are incompatible")
- total_flops = 0
- # Step 1: We recompute the scores matrix.
- # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
- total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
- # Step 2: We propagate the gradients through the score @ v operation.
- # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
- total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
- # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
- total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
- # Step 3: We propagate th gradients through the k @ v operation
- # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
- total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
- # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
- total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
- return total_flops
- @register_flop_formula([aten._scaled_dot_product_efficient_attention_backward,
- aten._scaled_dot_product_flash_attention_backward,
- aten._scaled_dot_product_cudnn_attention_backward])
- def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
- """Count flops for self-attention backward."""
- return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
- @register_flop_formula(aten._flash_attention_backward, get_raw=True)
- def _flash_attention_backward_flop(
- grad_out,
- query,
- key,
- value,
- out, # named _out_shape to avoid kwarg collision with out_shape created in wrapper
- logsumexp,
- cum_seq_q,
- cum_seq_k,
- max_q,
- max_k,
- *args,
- **kwargs,
- ) -> int:
- # in case this is a nested tensor, we unpack the individual batch elements
- # and then sum the flops per batch element
- shapes = _unpack_flash_attention_nested_shapes(
- query=query,
- key=key,
- value=value,
- grad_out=grad_out,
- cum_seq_q=cum_seq_q,
- cum_seq_k=cum_seq_k,
- max_q=max_q,
- max_k=max_k,
- )
- return sum(
- sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
- for query_shape, key_shape, value_shape, grad_out_shape in shapes
- )
- @register_flop_formula(aten._efficient_attention_backward, get_raw=True)
- def _efficient_attention_backward_flop(
- grad_out,
- query,
- key,
- value,
- bias,
- out, # named _out to avoid kwarg collision with out created in wrapper
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- *args,
- **kwargs,
- ) -> int:
- # in case this is a nested tensor, we unpack the individual batch elements
- # and then sum the flops per batch element
- shapes = _unpack_efficient_attention_nested_shapes(
- query=query,
- key=key,
- value=value,
- grad_out=grad_out,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_q,
- max_seqlen_k=max_seqlen_k,
- )
- return sum(
- sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
- for query_shape, key_shape, value_shape, grad_out_shape in shapes
- )
- flop_registry = {
- aten.mm: mm_flop,
- aten.addmm: addmm_flop,
- aten.bmm: bmm_flop,
- aten.baddbmm: baddbmm_flop,
- aten._scaled_mm: _scaled_mm_flop,
- aten.convolution: conv_flop,
- aten._convolution: conv_flop,
- aten.cudnn_convolution: conv_flop,
- aten.convolution_overrideable: conv_flop,
- aten._slow_conv2d_forward: conv_flop,
- aten.convolution_backward: conv_backward_flop,
- aten._scaled_dot_product_efficient_attention: sdpa_flop,
- aten._scaled_dot_product_flash_attention: sdpa_flop,
- aten._scaled_dot_product_cudnn_attention: sdpa_flop,
- aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
- aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
- aten._scaled_dot_product_cudnn_attention_backward: sdpa_backward_flop,
- aten._flash_attention_forward: _flash_attention_forward_flop,
- aten._efficient_attention_forward: _efficient_attention_forward_flop,
- aten._flash_attention_backward: _flash_attention_backward_flop,
- aten._efficient_attention_backward: _efficient_attention_backward_flop,
- }
- def normalize_tuple(x):
- if not isinstance(x, tuple):
- return (x,)
- return x
- # Define the suffixes for different orders of magnitude
- suffixes = ["", "K", "M", "B", "T"]
- # Thanks BingChat!
- def get_suffix_str(number):
- # Find the index of the appropriate suffix based on the number of digits
- # with some additional overflow.
- # i.e. 1.01B should be displayed as 1001M, not 1.001B
- index = max(0, min(len(suffixes) - 1, (len(str(number)) - 2) // 3))
- return suffixes[index]
- def convert_num_with_suffix(number, suffix):
- index = suffixes.index(suffix)
- # Divide the number by 1000^index and format it to two decimal places
- value = f"{number / 1000 ** index:.3f}"
- # Return the value and the suffix as a string
- return value + suffixes[index]
- def convert_to_percent_str(num, denom) -> str:
- if denom == 0:
- return "0%"
- return f"{num / denom:.2%}"
- def _pytreeify_preserve_structure(f):
- @wraps(f)
- def nf(args):
- flat_args, spec = tree_flatten(args)
- out = f(*flat_args)
- return tree_unflatten(out, spec)
- return nf
- class FlopCounterMode:
- """
- ``FlopCounterMode`` is a context manager that counts the number of flops within its context.
- It does this using a ``TorchDispatchMode``.
- It also supports hierarchical output by passing a module (or list of
- modules) to FlopCounterMode on construction. If you do not need hierarchical
- output, you do not need to use it with a module.
- Example usage
- .. code-block:: python
- mod = ...
- with FlopCounterMode(mod) as flop_counter:
- mod.sum().backward()
- """
- def __init__(
- self,
- mods: torch.nn.Module | list[torch.nn.Module] | None = None,
- depth: int = 2,
- display: bool = True,
- custom_mapping: dict[Any, Any] | None = None) -> None:
- super().__init__()
- self.flop_counts: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int))
- self.depth = depth
- self.display = display
- self.mode: _FlopCounterMode | None = None
- if custom_mapping is None:
- custom_mapping = {}
- if mods is not None:
- warnings.warn("mods argument is not needed anymore, you can stop passing it", stacklevel=2)
- self.flop_registry = {
- **flop_registry,
- **{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
- }
- self.mod_tracker = ModuleTracker()
- def get_total_flops(self) -> int:
- return sum(self.flop_counts['Global'].values())
- def get_flop_counts(self) -> dict[str, dict[Any, int]]:
- """Return the flop counts as a dictionary of dictionaries.
- The outer
- dictionary is keyed by module name, and the inner dictionary is keyed by
- operation name.
- Returns:
- Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
- """
- return {k: dict(v) for k, v in self.flop_counts.items()}
- def get_table(self, depth=None):
- if depth is None:
- depth = self.depth
- if depth is None:
- depth = 999999
- import tabulate
- tabulate.PRESERVE_WHITESPACE = True
- header = ["Module", "FLOP", "% Total"]
- values = []
- global_flops = self.get_total_flops()
- global_suffix = get_suffix_str(global_flops)
- is_global_subsumed = False
- def process_mod(mod_name, depth):
- nonlocal is_global_subsumed
- total_flops = sum(self.flop_counts[mod_name].values())
- is_global_subsumed |= total_flops >= global_flops
- padding = " " * depth
- values = []
- values.append([
- padding + mod_name,
- convert_num_with_suffix(total_flops, global_suffix),
- convert_to_percent_str(total_flops, global_flops)
- ])
- for k, v in self.flop_counts[mod_name].items():
- values.append([
- padding + " - " + str(k),
- convert_num_with_suffix(v, global_suffix),
- convert_to_percent_str(v, global_flops)
- ])
- return values
- for mod in sorted(self.flop_counts.keys()):
- if mod == 'Global':
- continue
- mod_depth = mod.count(".") + 1
- if mod_depth > depth:
- continue
- cur_values = process_mod(mod, mod_depth - 1)
- values.extend(cur_values)
- # We do a bit of messing around here to only output the "Global" value
- # if there are any FLOPs in there that aren't already fully contained by
- # a module.
- if 'Global' in self.flop_counts and not is_global_subsumed:
- for value in values:
- value[0] = " " + value[0]
- values = process_mod('Global', 0) + values
- if len(values) == 0:
- values = [["Global", "0", "0%"]]
- return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right"))
- # NB: This context manager is NOT reentrant
- def __enter__(self):
- self.flop_counts.clear()
- self.mod_tracker.__enter__()
- self.mode = _FlopCounterMode(self)
- self.mode.__enter__()
- return self
- def __exit__(self, *args):
- if self.mode is None:
- raise AssertionError("Internal error: FlopCounter.__exit__ called but mode is None")
- b = self.mode.__exit__(*args)
- self.mode = None # break cycles
- self.mod_tracker.__exit__()
- if self.display:
- print(self.get_table(self.depth))
- return b
- def _count_flops(self, func_packet, out, args, kwargs):
- if func_packet in self.flop_registry:
- flop_count_func = self.flop_registry[func_packet]
- flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator]
- for par in set(self.mod_tracker.parents):
- self.flop_counts[par][func_packet] += flop_count
- return out
- class _FlopCounterMode(TorchDispatchMode):
- supports_higher_order_operators = True
- def __init__(self, counter: FlopCounterMode) -> None:
- self.counter = counter
- def _execute_with_isolated_flop_counting(self, branch_fn, operands):
- """Execute a branch function and capture its FLOP counts without
- affecting self.counter.flop_counts
- Args:
- branch_fn: The branch function to execute
- operands: Arguments to pass to the branch function
- Returns:
- Tuple of (result, flop_counts) where result is the branch output
- and flop_counts is a copy of the FLOP counts after execution
- """
- import copy
- checkpointed_flop_counts = copy.copy(self.counter.flop_counts)
- with self:
- result = branch_fn(*operands)
- flop_counts = copy.copy(self.counter.flop_counts)
- self.counter.flop_counts = checkpointed_flop_counts
- return result, flop_counts
- def _handle_higher_order_ops(self, func, types, args, kwargs):
- is_triton = func in {torch.ops.higher_order.triton_kernel_wrapper_mutation,
- torch.ops.higher_order.triton_kernel_wrapper_functional}
- if is_triton:
- from torch._higher_order_ops.triton_kernel_wrap import get_kernel
- # Special case - look in the triton flop registry for the kernel
- from triton.runtime.jit import JITFunction
- kernel_name = get_kernel(kwargs["kernel_idx"])
- # Unwrap heuristics if they are present
- while not isinstance(kernel_name, JITFunction):
- if hasattr(kernel_name, "fn"):
- kernel_name = kernel_name.fn
- else:
- break
- return self.counter._count_flops(kernel_name, None, args, kwargs)
- elif func is torch.ops.higher_order.cond:
- # The flop counter for cond counts the upper bound of flops.
- # For example, if a matmul is executed 2 times in true branch
- # but only 1 time in the false branch, the flop counter will
- # record the larger number of flops, i.e. 2 times.
- pred, true_branch, false_branch, operands = args
- # Step 1: Count flops for true branch and false branch separately
- true_out, true_flop_counts = self._execute_with_isolated_flop_counting(
- true_branch, operands
- )
- if true_out is NotImplemented:
- return NotImplemented
- false_out, false_flop_counts = self._execute_with_isolated_flop_counting(
- false_branch, operands
- )
- if false_out is NotImplemented:
- return NotImplemented
- # Step 2: merge flop counts
- all_mod_keys = set(true_flop_counts.keys()) | set(false_flop_counts.keys())
- merged_flop_counts = {}
- for outer_key in all_mod_keys:
- true_func_counts = true_flop_counts[outer_key]
- false_func_counts = false_flop_counts[outer_key]
- merged_func_counts = {}
- all_func_keys = set(true_func_counts.keys()) | set(false_func_counts.keys())
- for func_key in all_func_keys:
- true_val = true_func_counts.get(func_key, 0)
- false_val = false_func_counts.get(func_key, 0)
- merged_func_counts[func_key] = max(true_val, false_val)
- merged_flop_counts[outer_key] = merged_func_counts
- # Step 3: update the counter with merged counts
- for outer_key, inner_dict in merged_flop_counts.items():
- self.counter.flop_counts[outer_key].update(inner_dict)
- # It doesn't matter which one we return since true_fn and false_fn return
- # output with the same structure.
- return true_out
- else:
- return NotImplemented
- def __torch_dispatch__(self, func, types, args=(), kwargs=None):
- kwargs = kwargs if kwargs else {}
- # Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
- if func in {torch.ops.aten.sym_is_contiguous.default,
- torch.ops.aten.is_contiguous.default,
- torch.ops.aten.is_contiguous.memory_format,
- torch.ops.aten.is_strides_like_format.default,
- torch.ops.aten.is_non_overlapping_and_dense.default,
- torch.ops.aten.size.default,
- torch.ops.aten.sym_size.default,
- torch.ops.aten.stride.default,
- torch.ops.aten.sym_stride.default,
- torch.ops.aten.storage_offset.default,
- torch.ops.aten.sym_storage_offset.default,
- torch.ops.aten.numel.default,
- torch.ops.aten.sym_numel.default,
- torch.ops.aten.dim.default,
- torch.ops.prim.layout.default}:
- return NotImplemented
- if isinstance(func, torch._ops.HigherOrderOperator):
- return self._handle_higher_order_ops(func, types, args, kwargs)
- # If we don't have func in flop_registry, see if it can decompose
- if func not in self.counter.flop_registry and func is not torch.ops.prim.device.default:
- with self:
- r = func.decompose(*args, **kwargs)
- if r is not NotImplemented:
- return r
- # no further decomposition; execute & count flops
- out = func(*args, **kwargs)
- return self.counter._count_flops(func._overloadpacket, out, args, kwargs)
|