| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608 |
- # Copyright 2025 HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import itertools
- from collections.abc import Callable
- import torch
- import torch.nn.functional as F
- from .cache_utils import Cache
- from .configuration_utils import PreTrainedConfig
- from .utils import is_torch_xpu_available, logging
- from .utils.deprecation import deprecate_kwarg
- from .utils.generic import GeneralInterface, is_flash_attention_requested
- from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_tracing
- if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size
- from torch.nn.attention.flex_attention import BlockMask, create_block_mask
- else:
- # Register a fake type to avoid crashing for annotations and `isinstance` checks
- BlockMask = torch.Tensor
- _is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
- _is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
- _is_torch_xpu_available = is_torch_xpu_available()
- if _is_torch_greater_or_equal_than_2_6:
- from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
- logger = logging.get_logger(__name__)
- def and_masks(*mask_functions: Callable) -> Callable:
- """Returns a mask function that is the intersection of provided mask functions"""
- if not all(callable(arg) for arg in mask_functions):
- raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}")
- def and_mask(batch_idx, head_idx, q_idx, kv_idx):
- result = q_idx.new_ones((), dtype=torch.bool)
- for mask in mask_functions:
- result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
- return result
- return and_mask
- def or_masks(*mask_functions: Callable) -> Callable:
- """Returns a mask function that is the union of provided mask functions"""
- if not all(callable(arg) for arg in mask_functions):
- raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}")
- def or_mask(batch_idx, head_idx, q_idx, kv_idx):
- result = q_idx.new_zeros((), dtype=torch.bool)
- for mask in mask_functions:
- result = result | mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
- return result
- return or_mask
- def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- """
- This creates a basic lower-diagonal causal mask.
- """
- return kv_idx <= q_idx
- def bidirectional_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- """
- This creates a full bidirectional mask.
- NOTE: It is important to keep an index-based version for non-vmap expansion.
- """
- return q_idx >= 0
- def sliding_window_overlay(sliding_window: int) -> Callable:
- """
- This is an overlay depicting a sliding window pattern. Add it on top of a causal mask for a proper sliding
- window mask.
- """
- def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- return kv_idx > q_idx - sliding_window
- return inner_mask
- def chunked_overlay(chunk_size: int, left_padding: torch.Tensor) -> Callable:
- """
- This is an overlay depicting a chunked attention pattern. Add it on top of a causal mask for a proper chunked
- attention mask.
- """
- def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- return (kv_idx - left_padding[batch_idx]) // chunk_size == (q_idx - left_padding[batch_idx]) // chunk_size
- return inner_mask
- def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
- """
- This return the mask_function function to create a sliding window mask.
- """
- return and_masks(sliding_window_overlay(sliding_window), causal_mask_function)
- def sliding_window_bidirectional_overlay(sliding_window: int) -> Callable:
- """
- This is an overlay depicting a bidirectional sliding window pattern.
- """
- def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- """A token can attend to any other token if their absolute distance is within
- the (inclusive) sliding window size (distance <= sliding_window)."""
- return abs(q_idx - kv_idx) <= sliding_window
- return inner_mask
- def sliding_window_bidirectional_mask_function(sliding_window: int) -> Callable:
- """
- This return the mask_function function to create a bidirectional sliding window mask.
- """
- return and_masks(sliding_window_bidirectional_overlay(sliding_window), bidirectional_mask_function)
- def chunked_causal_mask_function(chunk_size: int, left_padding: torch.Tensor) -> Callable:
- """
- This return the mask_function function to create a chunked attention mask.
- """
- return and_masks(chunked_overlay(chunk_size, left_padding), causal_mask_function)
- def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
- """
- This return the mask_function function corresponding to a 2D padding mask.
- """
- def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- # Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
- # we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
- # vectorizable on accelerator devices
- return padding_mask[batch_idx, kv_idx]
- return inner_mask
- def packed_sequence_mask_function(packed_sequence_mask: torch.Tensor) -> Callable:
- """
- This return the mask_function function corresponding to a 2D packed sequence mask.
- """
- def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- return packed_sequence_mask[batch_idx, q_idx] == packed_sequence_mask[batch_idx, kv_idx]
- return inner_mask
- def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable:
- """
- This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
- not start and end indices.
- """
- def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- return mask_function(batch_idx, head_idx, q_idx + q_offset, kv_idx + kv_offset)
- return inner_mask
- def prepare_padding_mask(attention_mask: torch.Tensor | None, kv_length: int, kv_offset: int) -> torch.Tensor | None:
- """
- From the 2D attention mask, prepare the correct padding mask to use by potentially padding it.
- """
- local_padding_mask = attention_mask
- if attention_mask is not None:
- # Pad it if necessary
- if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
- local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length))
- return local_padding_mask
- def _can_skip_causal_mask_xpu(
- padding_mask: torch.Tensor | None,
- query_length: int,
- kv_length: int,
- local_attention_size: int | None,
- ) -> bool:
- """
- XPU-specific logic for determining if we can skip causal mask creation.
- For XPU devices, we have special handling:
- - Single query tokens (query_length == 1) use the same logic as CUDA
- - Multi-query tokens can skip if padding_mask is provided and correctly structured
- The mask must have all True values in the query window and all False after
- """
- if is_tracing(padding_mask):
- return False
- # Check local attention constraint (same as CUDA)
- if local_attention_size is not None and kv_length >= local_attention_size:
- return False
- if padding_mask is None:
- # Without padding mask, can skip if single query token or full causal attention
- return query_length == 1 or kv_length == query_length
- # XPU allows skipping under additional conditions when padding_mask is provided
- if query_length == 1:
- # Single query token: skip only if no padding tokens present
- return padding_mask.all()
- # XPU-specific: check if query window is all True and rest is all False
- # This allows XPU to optimize the 1st token in static cache
- return padding_mask[:, :query_length].all() and not padding_mask[:, query_length:].any()
- def _ignore_causal_mask_sdpa(
- padding_mask: torch.Tensor | None,
- query_length: int,
- kv_length: int,
- kv_offset: int,
- local_attention_size: int | None = None,
- ) -> bool:
- """
- Detects whether the causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
- In case no token is masked in the 2D `padding_mask` argument, if `query_length == 1` or
- `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
- allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
- passed).
- """
- if padding_mask is not None and padding_mask.shape[-1] > kv_length:
- mask_indices = torch.arange(kv_length, device=padding_mask.device)
- mask_indices += kv_offset
- padding_mask = padding_mask[:, mask_indices]
- if _is_torch_xpu_available:
- # XPU devices have special handling for mask skipping:
- # - Single query tokens use the same logic as CUDA
- # - Multi-query tokens can skip if padding_mask is provided and correctly structured
- # (all True in query window, all False after)
- return _can_skip_causal_mask_xpu(padding_mask, query_length, kv_length, local_attention_size)
- # When using `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
- # hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True`
- # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
- # `ignore_causal_mask = True` if we are not tracing
- if (
- not is_tracing(padding_mask)
- # only cases when lower and upper diags are the same, see https://github.com/pytorch/pytorch/issues/108108
- and (query_length == 1 or kv_length == query_length)
- # in this case we need to add special patterns to the mask so cannot be skipped otherwise
- and (local_attention_size is None or kv_length < local_attention_size)
- # In this case, we need to add padding to the mask, so cannot be skipped otherwise
- and (padding_mask is None or padding_mask.all())
- ):
- return True
- return False
- def _can_skip_bidirectional_mask_xpu(
- padding_mask: torch.Tensor | None,
- kv_length: int,
- local_attention_size: int | None,
- ) -> bool:
- """
- XPU-specific logic for determining if we can skip bidirectional mask creation.
- For XPU devices, we have special handling:
- - Skip if no padding and no local attention constraint
- """
- if is_tracing(padding_mask):
- return False
- # Check local attention constraint (same as CUDA)
- if local_attention_size is not None and kv_length >= local_attention_size:
- return False
- if padding_mask is None:
- # Without padding mask, can always skip for full bidirectional attention
- return True
- # Skip only if no padding tokens present
- return padding_mask.all()
- def _ignore_bidirectional_mask_sdpa(
- padding_mask: torch.Tensor | None,
- kv_length: int,
- local_attention_size: int | None = None,
- ) -> bool:
- """
- Detects whether the bidirectional mask can be ignored in case PyTorch's SDPA is used.
- In case no token is masked in the 2D `padding_mask` argument and no local attention constraint applies
- (i.e. `local_attention_size` is None or `kv_length < local_attention_size`), we skip mask creation,
- allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
- passed).
- """
- if _is_torch_xpu_available:
- # XPU devices have special handling for mask skipping:
- # - Skip if no padding and no local attention constraint
- return _can_skip_bidirectional_mask_xpu(padding_mask, kv_length, local_attention_size)
- # When using `torch.export` or `torch.onnx.dynamo_export`, we need to avoid to check the contents of the mask;
- # otherwise, we will encounter dynamic control flows
- if (
- not is_tracing(padding_mask)
- and (padding_mask is None or padding_mask.all())
- # in this case we need to add special patterns to the mask so cannot be skipped otherwise
- and (local_attention_size is None or kv_length < local_attention_size)
- ):
- return True
- return False
- def _vmap_expansion_sdpa(mask_function: Callable) -> Callable:
- """
- Used to vmap our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
- Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
- functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
- """
- # We vmap the function over all 4 dimensions, broadcasting [b_idx, h_idx, q_idx, kv_idx]
- dimensions = [(None, None, None, 0), (None, None, 0, None), (None, 0, None, None), (0, None, None, None)]
- for dims in dimensions:
- mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
- return mask_function
- def _non_vmap_expansion_sdpa(
- batch_indices: torch.Tensor, head_indices: torch.Tensor, q_indices: torch.Tensor, kv_indices: torch.Tensor
- ):
- """
- Used to broadcast our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
- Allows the usage of any index-based mask function without relying on vmap.
- NOTE: This is limited to index based functions only and is not guaranteed to work otherwise.
- Reference:
- - https://github.com/huggingface/optimum-onnx/blob/c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365
- """
- batch_indices = batch_indices[:, None, None, None]
- head_indices = head_indices[None, :, None, None]
- q_indices = q_indices[None, None, :, None]
- kv_indices = kv_indices[None, None, None, :]
- return batch_indices, head_indices, q_indices, kv_indices
- def sdpa_mask(
- batch_size: int,
- q_length: int,
- kv_length: int,
- q_offset: int = 0,
- kv_offset: int = 0,
- mask_function: Callable = causal_mask_function,
- attention_mask: torch.Tensor | None = None,
- local_size: int | None = None,
- allow_is_causal_skip: bool = True,
- allow_is_bidirectional_skip: bool = False,
- allow_torch_fix: bool = True,
- use_vmap: bool = False,
- device: torch.device | str = "cpu",
- **kwargs,
- ) -> torch.Tensor | None:
- """
- Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
- the element should take part in the attention computation, and False that it should not.
- This function can only be used with torch>=2.5, as the context manager is otherwise not available.
- Args:
- batch_size (`int`):
- The batch size of the input sequence.
- q_length (`int`):
- The size that the query states will have during the attention computation.
- kv_length (`int`):
- The size that the key and value states will have during the attention computation.
- kv_offset (`int`, optional):
- An optional offset to indicate at which first position the key and values states will refer to.
- q_offset (`int`, optional):
- An optional offset to indicate at which first position the query states will refer to.
- mask_function (`Callable`):
- The mask factory function describing the mask pattern.
- attention_mask (`torch.Tensor`, optional):
- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
- local_size (`int`, optional):
- The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
- to try to skip mask creation if possible.
- allow_is_causal_skip (`bool`, optional):
- Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
- `torch.sdpa` instead. Default to `True`.
- allow_is_bidirectional_skip (`bool`, optional):
- Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
- i.e. full attention without any padding. Default to `False`.
- allow_torch_fix (`bool`, optional):
- Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
- versions. We need an arg to skip it when using eager. By default `True`.
- use_vmap (`bool`, optional):
- Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
- index-based (for the cost of speed performance). By default `False`.
- device (`torch.device` or `str`, optional):
- An optional device to create the mask on.
- ## Creating a simple causal mask:
- To create the following causal mask:
- 0 ■ ⬚ ⬚ ⬚ ⬚
- 1 ■ ■ ⬚ ⬚ ⬚
- 2 ■ ■ ■ ⬚ ⬚
- 3 ■ ■ ■ ■ ⬚
- 4 ■ ■ ■ ■ ■
- You can do
- ```python
- >>> sdpa_mask(batch_size=1, q_length=5, kv_length=5)
- >>> tensor([[[[ True, False, False, False, False],
- [ True, True, False, False, False],
- [ True, True, True, False, False],
- [ True, True, True, True, False],
- [ True, True, True, True, True]]]])
- ```
- ## Creating a sliding window mask:
- To create the following sliding window mask (`sliding_window=3`):
- 0 ■ ⬚ ⬚ ⬚ ⬚
- 1 ■ ■ ⬚ ⬚ ⬚
- 2 ■ ■ ■ ⬚ ⬚
- 3 ⬚ ■ ■ ■ ⬚
- 4 ⬚ ⬚ ■ ■ ■
- You can do
- ```python
- >>> sdpa_mask(batch_size=1, q_length=5, kv_length=5, mask_function=sliding_window_causal_mask_function(3))
- >>> tensor([[[[ True, False, False, False, False],
- [ True, True, False, False, False],
- [ True, True, True, False, False],
- [False, True, True, True, False],
- [False, False, True, True, True]]]])
- ```
- ## Creating a chunked attention mask
- To create the following chunked attention mask (`chunk_size=3`):
- 0 ■ ⬚ ⬚ ⬚ ⬚
- 1 ■ ■ ⬚ ⬚ ⬚
- 2 ■ ■ ■ ⬚ ⬚
- 3 ⬚ ⬚ ⬚ ■ ⬚
- 4 ⬚ ⬚ ⬚ ■ ■
- You can do
- ```python
- >>> sdpa_mask(batch_size=1, q_length=5, kv_length=5, mask_function=chunked_causal_mask_function(3, torch.zeros(1, dtype=int)))
- >>> tensor([[[[ True, False, False, False, False],
- [ True, True, False, False, False],
- [ True, True, True, False, False],
- [False, False, False, True, False],
- [False, False, False, True, True]]]])
- ```
- """
- # For BC on `cache_positions` that used to be an arg at the position of `q_length`
- if isinstance(q_length, torch.Tensor):
- logger.warning_once(
- "`cache_position` is deprecated as an arg, and will be removed in Transformers v5.6. Please use `q_length` and "
- "`q_offset` instead, similarly to `kv_length` and `kv_offset`"
- )
- q_length, q_offset = q_length.shape[0], q_length[0].to(device)
- # Potentially pad the 2D mask
- padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
- # Under specific conditions, we can avoid materializing the mask
- # 1. Causal masks can rely on the `is_causal` argument
- # 2. Bidirectional do not need any further processing (no bias)
- if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
- return None
- if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask, kv_length, local_size):
- return None
- # Potentially add the padding 2D mask
- if padding_mask is not None:
- mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
- batch_arange = torch.arange(batch_size, device=device)
- head_arange = torch.arange(1, device=device)
- q_arange = torch.arange(q_length, device=device) + q_offset
- kv_arange = torch.arange(kv_length, device=device) + kv_offset
- # Actual mask creation
- # Option 1: Fast non-vmap mask creation (default)
- if not use_vmap:
- # Apply mask function element-wise through broadcasting
- attention_mask = mask_function(*_non_vmap_expansion_sdpa(batch_arange, head_arange, q_arange, kv_arange))
- # Expand the mask to match batch size and query length if they weren't used in the mask function
- attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
- # Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
- elif _is_torch_greater_or_equal_than_2_6:
- # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
- # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
- # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
- with TransformGetItemToIndex():
- attention_mask = _vmap_expansion_sdpa(mask_function)(batch_arange, head_arange, q_arange, kv_arange)
- # Option 3: Error out since it indicates that the user did something custom, which they shouldn't have (torch<2.6)
- else:
- raise ValueError(
- "The vmap functionality for mask creation is only supported from torch>=2.6. "
- "Please update your torch version or use `use_vmap=False` with index-based masks."
- )
- # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any
- # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
- if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
- attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True)
- return attention_mask
- def eager_mask(
- batch_size: int,
- q_length: int,
- kv_length: int,
- q_offset: int = 0,
- kv_offset: int = 0,
- mask_function: Callable = causal_mask_function,
- attention_mask: torch.Tensor | None = None,
- dtype: torch.dtype = torch.float32,
- allow_is_bidirectional_skip: bool = False,
- use_vmap: bool = False,
- device: torch.device | str = "cpu",
- **kwargs,
- ) -> torch.Tensor:
- """
- Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
- the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
- it should not.
- Args:
- batch_size (`int`):
- The batch size of the input sequence.
- q_length (`int`):
- The size that the query states will have during the attention computation.
- kv_length (`int`):
- The size that the key and value states will have during the attention computation.
- q_offset (`int`, optional):
- An optional offset to indicate at which first position the query states will refer to.
- kv_offset (`int`, optional):
- An optional offset to indicate at which first position the key and values states will refer to.
- mask_function (`Callable`):
- The mask factory function describing the mask pattern.
- attention_mask (`torch.Tensor`, optional):
- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
- dtype (`torch.dtype`, optional):
- The dtype to use for the mask. By default, `torch.float32`.
- allow_is_bidirectional_skip (`bool`, optional):
- Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
- i.e. full attention without any padding. Default to `False`.
- use_vmap (`bool`, optional):
- Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
- index-based (for the cost of speed performance). By default `False`.
- device (`torch.device` or `str`, optional):
- An optional device to create the mask on.
- """
- # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
- _ = kwargs.pop("allow_is_causal_skip", None)
- _ = kwargs.pop("allow_torch_fix", None)
- mask = sdpa_mask(
- batch_size=batch_size,
- q_length=q_length,
- kv_length=kv_length,
- q_offset=q_offset,
- kv_offset=kv_offset,
- mask_function=mask_function,
- attention_mask=attention_mask,
- allow_is_causal_skip=False,
- allow_is_bidirectional_skip=allow_is_bidirectional_skip,
- allow_torch_fix=False,
- use_vmap=use_vmap,
- device=device,
- **kwargs,
- )
- # only bidirectional masks can be skipped, otherwise we convert bool -> float
- if mask is not None:
- min_dtype = torch.finfo(dtype).min
- # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
- mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
- return mask
- def flash_attention_mask(
- batch_size: int,
- q_length: int,
- kv_length: int,
- q_offset: int = 0,
- kv_offset: int = 0,
- mask_function: Callable = causal_mask_function,
- attention_mask: torch.Tensor | None = None,
- **kwargs,
- ):
- """
- Create the attention mask necessary to use FA2. Since FA2 is un-padded by definition, here we simply return
- `None` if the mask is fully causal, or we return the 2D mask which will then be used to extract the seq_lens.
- We just slice it in case of sliding window.
- Args:
- batch_size (`int`):
- The batch size of the input sequence.
- q_length (`int`):
- The size that the query states will have during the attention computation.
- kv_length (`int`):
- The size that the key and value states will have during the attention computation.
- q_offset (`int`, optional):
- An optional offset to indicate at which first position the query states will refer to.
- kv_offset (`int`, optional):
- An optional offset to indicate at which first position the key and values states will refer to.
- mask_function (`Callable`):
- The mask factory function describing the mask pattern.
- attention_mask (`torch.Tensor`, optional):
- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
- """
- if attention_mask is not None:
- # Here we need to slice from the right if using sliding or chunked (for full attention, this is equivalent to doing nothing)
- attention_mask = attention_mask[:, -kv_length:]
- # We only return an actual mask if there is at least 1 padding token, otherwise we return `None` and use `is_causal` in FA2
- # (note that the attention_mask is a boolean dtype here)
- if attention_mask.all():
- attention_mask = None
- return attention_mask
- def flex_attention_mask(
- batch_size: int,
- q_length: int,
- kv_length: int,
- q_offset: int = 0,
- kv_offset: int = 0,
- mask_function: Callable = causal_mask_function,
- attention_mask: torch.Tensor | None = None,
- device: torch.device | str = "cpu",
- **kwargs,
- ) -> BlockMask:
- """
- Create a 4D block mask which is a compressed representation of the full 4D block causal mask. BlockMask is essential
- for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/
- Args:
- batch_size (`int`):
- The batch size of the input sequence.
- q_length (`int`):
- The size that the query states will have during the attention computation.
- kv_length (`int`):
- The size that the key and value states will have during the attention computation.
- q_offset (`int`, optional):
- An optional offset to indicate at which first position the query states will refer to.
- kv_offset (`int`, optional):
- An optional offset to indicate at which first position the key and values states will refer to.
- mask_function (`Callable`):
- The mask factory function describing the mask pattern.
- attention_mask (`torch.Tensor`, optional):
- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
- device (`torch.device` or `str`, optional):
- An optional device to create the mask on.
- """
- # For BC on `cache_positions` that used to be an arg at the position of `q_length`
- if isinstance(q_length, torch.Tensor):
- logger.warning_once(
- "`cache_position` is deprecated as an arg, and will be removed in Transformers v5.6. Please use `q_length` and "
- "`q_offset` instead, similarly to `kv_length` and `kv_offset`"
- )
- q_length, q_offset = q_length.shape[0], q_length[0].to(device)
- # Potentially add the padding 2D mask
- if attention_mask is not None:
- # Older torch (2.5.x) cannot handle sequences not in multiples of 128 (default block size)
- # Hence we pad to multiples of this as a minimum to ensure this
- pad_len = ((attention_mask.shape[1] // flex_default_block_size) + 1) * flex_default_block_size
- pad_len = pad_len - attention_mask.shape[1]
- if not _is_torch_greater_or_equal_than_2_6 and pad_len > 0:
- attention_mask = torch.nn.functional.pad(attention_mask, value=0, pad=(0, pad_len))
- padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
- mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
- # Add the offsets on top (because flex interface only allows length, not start and end indices)
- mask_function = add_offsets_to_mask_function(mask_function, q_offset, kv_offset)
- # Finally create the block mask
- block_mask = create_block_mask(
- mask_mod=mask_function,
- B=batch_size,
- H=None,
- Q_LEN=q_length,
- KV_LEN=kv_length,
- device=device,
- _compile=_is_torch_greater_or_equal_than_2_6,
- )
- return block_mask
- class AttentionMaskInterface(GeneralInterface):
- # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
- # a new instance is created (in order to locally override a given function)
- _global_mapping = {
- "sdpa": sdpa_mask,
- "eager": eager_mask,
- "flash_attention_2": flash_attention_mask,
- "flash_attention_3": flash_attention_mask,
- "flash_attention_4": flash_attention_mask,
- "flex_attention": flex_attention_mask,
- }
- # Global AttentionMaskInterface shared by all models which do not need to overwrite any of the existing ones
- ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
- def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor | None:
- """
- Find the indices of the sequence to which each new query token in the sequence belongs when using packed
- tensor format (i.e. several sequences packed in the same batch dimension).
- Args:
- position_ids (`torch.Tensor`)
- A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
- Returns:
- A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
- pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
- If the there is only one sequence in each batch item (and we don't compile), then we return `None` indicating
- no packed sequences. This is the same as [[0, 0, 0, 0, 0, 0]] for the example above.
- """
- # What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
- # taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
- # gives exactly the sequence indices
- # Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
- # cannot be part of the end of the first batch dim and the start of the 2nd one for example
- first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
- position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1)
- packed_sequence_mask = (position_diff != 1).cumsum(-1)
- # Sadly this is a dynamic control flow, so we cannot enable this check on anything compile related
- if not is_tracing(packed_sequence_mask) and (packed_sequence_mask[:, -1] == 0).all():
- return None
- return packed_sequence_mask
- @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
- def _preprocess_mask_arguments(
- config: PreTrainedConfig,
- inputs_embeds: torch.Tensor,
- attention_mask: torch.Tensor | BlockMask | None,
- past_key_values: Cache | None,
- position_ids: torch.Tensor | None,
- layer_idx: int | None,
- encoder_hidden_states: torch.Tensor | None = None,
- ) -> tuple[bool, torch.Tensor | BlockMask | None, int, int]:
- """
- Perform some common pre-processing of the mask arguments we get from the modeling code. Mostly determine the
- key-value length and offsets, and if we should early exit or not.
- Args:
- config (`PreTrainedConfig`):
- The model config.
- inputs_embeds (`torch.Tensor`):
- The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
- batch size, query length and dtype.
- attention_mask (`torch.Tensor`, optional):
- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
- It can also be an already prepared 4D mask, in which case it is returned as-is.
- past_key_values (`Cache`, optional):
- The past key values, if we use a cache.
- position_ids (`torch.Tensor`, optional)
- A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
- layer_idx (`int`, optional):
- If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
- length and offset. Indeed, for hybrid caches, different layers may return different lengths.
- encoder_hidden_states (`torch.Tensor`, optional):
- The input embeddings of shape (batch_size, kv_length, hidden_dim). If provided, it is used instead of
- `inputs_embeds` to infer the kv length.
- Returns:
- early_exit (`bool`):
- Whether we should early exit mask creation, and return the mask as-is.
- attention_mask (`torch.Tensor` or `BlockMask` or `None`):
- The attention mask to either return immediately, or to use in downstream mask creation.
- packed_sequence_mask (`torch.Tensor`, optional):
- In case we detected packed sequence format, this is a tensor where each similar integer indicates that
- the tokens belong to the same sequence.
- q_length (`int`):
- The size that the query states will have during the attention computation.
- kv_length (`int`):
- The size that the key and value states will have during the attention computation.
- q_offset (`int`, optional):
- An optional offset to indicate at which first position the query states will refer to.
- kv_offset (`int`):
- An offset to indicate at which first position the key and values states will refer to.
- """
- # If the mask is already 4D, simply return as-is (it was already prepared, or it is custom)
- if isinstance(attention_mask, (torch.Tensor, BlockMask)) and len(attention_mask.shape) == 4:
- return True, attention_mask, None, None, None, None, None
- # For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
- # Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
- # full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
- # with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
- # according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
- if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping:
- return True, None, None, None, None, None, None
- # Move the mask to correct device, and potentially switch dtype for efficiency
- if attention_mask is not None and attention_mask.ndim == 2:
- attention_mask = attention_mask.to(device=inputs_embeds.device, dtype=torch.bool)
- q_length = inputs_embeds.shape[1]
- # If using a cache, it can give all information about mask sizes based on seen tokens
- if past_key_values is not None:
- q_offset = past_key_values.get_seq_length()
- # To avoid graph breaks, StaticLayer return a tensor instead of int -> this has no impact on the ops, but we
- # need the correct device
- q_offset = q_offset.to(inputs_embeds.device) if isinstance(q_offset, torch.Tensor) else q_offset
- kv_length, kv_offset = past_key_values.get_mask_sizes(q_length, layer_idx)
- # Otherwise, we infer based on our input
- else:
- q_offset = 0
- # 1. Rely on input directly
- if attention_mask is None:
- # For encoder-decoders, use encoder_hidden_states to infer kv_length if provided
- kv_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else q_length
- kv_offset = 0
- # 2. Rely on the mask instead - needed for special cases like prefix tuning in PEFT
- #
- # This is a very unique and special case where an encoder utilizes a cache and expects its length
- # to be accounted for (usually, they should never use a cache). In general, the mask should always
- # match with the input sizes nonetheless (i.e. it does not affect others).
- # Conclusion: "prefix tuning is evil"
- else:
- kv_length, kv_offset = attention_mask.shape[-1], 0
- # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
- # and we don't have past_key_values, i.e. generally a training setup)
- packed_sequence_mask = None
- if position_ids is not None and attention_mask is None and past_key_values is None:
- batch_size = inputs_embeds.shape[0]
- # The position ids are sometimes just unsqueezed, without being expanded
- if batch_size != position_ids.shape[0]:
- position_ids = position_ids.expand(batch_size, -1)
- packed_sequence_mask = find_packed_sequence_indices(position_ids)
- return False, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset
- @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
- def create_causal_mask(
- config: PreTrainedConfig,
- inputs_embeds: torch.Tensor,
- attention_mask: torch.Tensor | None,
- cache_position: torch.Tensor | None = None, # not used anymore but kept for BC
- *,
- past_key_values: Cache | None,
- position_ids: torch.Tensor | None = None,
- or_mask_function: Callable | None = None,
- and_mask_function: Callable | None = None,
- ) -> torch.Tensor | BlockMask | None:
- """
- Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
- has an hybrid cache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
- to what is needed in the `modeling_xxx.py` files).
- Args:
- config (`PreTrainedConfig`):
- The model config.
- inputs_embeds (`torch.Tensor`):
- The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
- batch size, query length and dtype.
- attention_mask (`torch.Tensor`, optional):
- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
- It can also be an already prepared 4D mask, in which case it is returned as-is.
- cache_position (`torch.Tensor`):
- Deprecated and unused.
- past_key_values (`Cache`, optional):
- The past key values, if we use a cache.
- position_ids (`torch.Tensor`, optional)
- A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
- or_mask_function (`Callable`, optional):
- An optional mask function to combine with the causal mask function (by doing the union of both). This is
- useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
- and_mask_function (`Callable`, optional):
- An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
- useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
- """
- # Power feature: if `is_causal` is False, then fallback to bi-directional mask for bi-directional attention.
- # It allows to use decoder-only models with bi-directional attention as well
- if not getattr(config, "is_causal", True):
- return create_bidirectional_mask(
- config,
- inputs_embeds,
- attention_mask,
- past_key_values=past_key_values,
- or_mask_function=or_mask_function,
- and_mask_function=and_mask_function,
- )
- # If we have an hybrid cache structure, here we want to create the mask for the full layers
- if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding:
- layer_idx = past_key_values.is_sliding.index(False)
- else:
- layer_idx = 0
- early_exit, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset = (
- _preprocess_mask_arguments(config, inputs_embeds, attention_mask, past_key_values, position_ids, layer_idx)
- )
- if early_exit:
- return attention_mask
- batch_size, dtype, device = inputs_embeds.shape[0], inputs_embeds.dtype, inputs_embeds.device
- mask_factory_function = causal_mask_function
- mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
- # Defaulting to using non-vmap based mask creations except when detecting
- # users passing custom mask functions (as we cannot guarantee that they
- # are properly index-based as required by our implementation).
- use_vmap = False
- # Do not allow skip if we are compiling (this is to match BC)
- # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
- if _is_torch_xpu_available:
- # Do not allow skip if we are compiling for decoding, but for prefill, we still allow skip to optimization the perf of 1st token generation
- allow_is_causal_skip = not (getattr(past_key_values, "is_compileable", False) and q_length == 1)
- else:
- allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
- # Allow slight deviations from causal mask
- # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
- # padding mask, etc) as the resulting mask may otherwise not be correct!
- if or_mask_function is not None:
- if not _is_torch_greater_or_equal_than_2_6:
- raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
- mask_factory_function = or_masks(mask_factory_function, or_mask_function)
- allow_is_causal_skip = False
- use_vmap = True
- if and_mask_function is not None:
- if not _is_torch_greater_or_equal_than_2_6:
- raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
- mask_factory_function = and_masks(mask_factory_function, and_mask_function)
- allow_is_causal_skip = False
- use_vmap = True
- # If we detected packing format
- if packed_sequence_mask is not None:
- mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
- allow_is_causal_skip = False
- # We now create the mask
- causal_mask = mask_interface(
- batch_size=batch_size,
- q_length=q_length,
- kv_length=kv_length,
- q_offset=q_offset,
- kv_offset=kv_offset,
- mask_function=mask_factory_function,
- attention_mask=attention_mask,
- allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
- dtype=dtype, # Additional kwarg for eager
- config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
- use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
- device=device,
- )
- return causal_mask
- @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
- def create_bidirectional_mask(
- config: PreTrainedConfig,
- inputs_embeds: torch.Tensor,
- attention_mask: torch.Tensor | None,
- encoder_hidden_states: torch.Tensor | None = None,
- past_key_values: Cache | None = None,
- or_mask_function: Callable | None = None,
- and_mask_function: Callable | None = None,
- ) -> torch.Tensor | BlockMask | None:
- """
- Create a standard bidirectional mask based on the attention implementation used (stored in the config).
- Args:
- config (`PreTrainedConfig`):
- The model config.
- inputs_embeds (`torch.Tensor`):
- The input embeddings of shape (batch_size, query_length, hidden_dim). This is only used to infer metadata
- such as the batch size, query length, dtype, and device.
- past_key_values (`Cache`, optional):
- The past key values, if we use a cache.
- attention_mask (`torch.Tensor`, optional):
- The 2D attention mask corresponding to padded tokens of shape (batch_size, kv_length).
- It can also be an already prepared 4D mask of shape (batch_size, 1, query_length, kv_length),
- in which case it is returned as-is.
- encoder_hidden_states (`torch.Tensor`, optional):
- The input embeddings of shape (batch_size, kv_length, hidden_dim). If provided, it is used instead of
- `inputs_embeds` to infer the batch size, kv length and dtype.
- or_mask_function (`Callable`, optional):
- An optional mask function to combine with the base mask function (by doing the union of both). This is
- useful to easily overlay another mask on top, for example for image tokens handling.
- and_mask_function (`Callable`, optional):
- An optional mask function to combine with the base mask function (by doing the intersection of both). This is
- useful to easily overlay another mask on top, for example for image tokens handling.
- """
- # We ignore a few irrelevant arguments at the end as we do not have a (growing) cache here
- early_exit, attention_mask, _, q_length, kv_length, q_offset, kv_offset = _preprocess_mask_arguments(
- config, inputs_embeds, attention_mask, past_key_values, None, 0, encoder_hidden_states
- )
- if early_exit:
- return attention_mask
- embeds = encoder_hidden_states if encoder_hidden_states is not None else inputs_embeds
- batch_size, dtype, device = embeds.shape[0], embeds.dtype, embeds.device
- mask_factory_function = bidirectional_mask_function
- mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
- # Allow skipping the mask creation except we have additional masking operators (and/or masks)
- allow_is_bidirectional_skip = True
- # Defaulting to using non-vmap based mask creations except when detecting
- # users passing custom mask functions (as we cannot guarantee that they
- # are properly index-based as required by our implementation).
- use_vmap = False
- # Allow slight deviations from the base mask
- # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
- # padding mask, etc) as the resulting mask may otherwise not be correct!
- if or_mask_function is not None:
- if not _is_torch_greater_or_equal_than_2_6:
- raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
- mask_factory_function = or_masks(mask_factory_function, or_mask_function)
- allow_is_bidirectional_skip = False
- use_vmap = True
- if and_mask_function is not None:
- if not _is_torch_greater_or_equal_than_2_6:
- raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
- mask_factory_function = and_masks(mask_factory_function, and_mask_function)
- allow_is_bidirectional_skip = False
- use_vmap = True
- # We now create the mask
- attention_mask = mask_interface(
- batch_size=batch_size,
- q_length=q_length,
- kv_length=kv_length,
- q_offset=q_offset,
- kv_offset=kv_offset,
- mask_function=mask_factory_function,
- attention_mask=attention_mask,
- # Additional kwargs for sdpa
- allow_is_causal_skip=False,
- allow_is_bidirectional_skip=allow_is_bidirectional_skip,
- dtype=dtype, # Additional kwarg for eager
- config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
- use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
- device=device,
- )
- return attention_mask
- @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
- def create_sliding_window_causal_mask(
- config: PreTrainedConfig,
- inputs_embeds: torch.Tensor,
- attention_mask: torch.Tensor | None,
- cache_position: torch.Tensor | None = None, # not used anymore but kept for BC
- *,
- past_key_values: Cache | None,
- position_ids: torch.Tensor | None = None,
- or_mask_function: Callable | None = None,
- and_mask_function: Callable | None = None,
- ) -> torch.Tensor | BlockMask | None:
- """
- Create a sliding window causal mask based on the attention implementation used (stored in the config). This type
- of attention pattern was mostly democratized by Mistral. If `past_key_values` has an hybrid cache structure, this
- function will return the mask corresponding to one of the "sliding_attention" layers (to align to what is needed in the
- `modeling_xxx.py` files).
- Args:
- config (`PreTrainedConfig`):
- The model config.
- inputs_embeds (`torch.Tensor`):
- The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
- batch size, query length and dtype.
- attention_mask (`torch.Tensor`, optional):
- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
- It can also be an already prepared 4D mask, in which case it is returned as-is.
- cache_position (`torch.Tensor`):
- Deprecated and unused.
- past_key_values (`Cache`, optional):
- The past key values, if we use a cache.
- position_ids (`torch.Tensor`, optional)
- A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
- or_mask_function (`Callable`, optional):
- An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
- useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
- and_mask_function (`Callable`, optional):
- An optional mask function to combine with the sliding causal mask function (by doing the intersection of both). This is
- useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
- """
- # Power feature: if `is_causal` is False, then fallback to bi-directional mask for bi-directional attention
- # It allows to use decoder-only models with bi-directional attention as well
- if not getattr(config, "is_causal", True):
- return create_bidirectional_sliding_window_mask(
- config,
- inputs_embeds,
- attention_mask,
- past_key_values=past_key_values,
- or_mask_function=or_mask_function,
- and_mask_function=and_mask_function,
- )
- # If we have an hybrid cache structure, here we want to create the mask for the sliding layers
- if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
- layer_idx = past_key_values.is_sliding.index(True)
- else:
- layer_idx = 0
- early_exit, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset = (
- _preprocess_mask_arguments(config, inputs_embeds, attention_mask, past_key_values, position_ids, layer_idx)
- )
- if early_exit:
- return attention_mask
- sliding_window = getattr(config, "sliding_window", None)
- if sliding_window is None:
- raise ValueError("Could not find a `sliding_window` argument in the config, or it is not set")
- batch_size, dtype, device = inputs_embeds.shape[0], inputs_embeds.dtype, inputs_embeds.device
- mask_factory_function = sliding_window_causal_mask_function(sliding_window)
- mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
- # Defaulting to using non-vmap based mask creations except when detecting
- # users passing custom mask functions (as we cannot guarantee that they
- # are properly index-based as required by our implementation).
- use_vmap = False
- # Do not allow skip if we are compiling (this is to match BC)
- # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
- allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
- # Allow slight deviations from causal mask
- # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
- # padding mask, etc) as the resulting mask may otherwise not be correct!
- if or_mask_function is not None:
- if not _is_torch_greater_or_equal_than_2_6:
- raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
- mask_factory_function = or_masks(mask_factory_function, or_mask_function)
- allow_is_causal_skip = False
- use_vmap = True
- if and_mask_function is not None:
- if not _is_torch_greater_or_equal_than_2_6:
- raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
- mask_factory_function = and_masks(mask_factory_function, and_mask_function)
- allow_is_causal_skip = False
- use_vmap = True
- # If we detected packing format
- if packed_sequence_mask is not None:
- mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
- allow_is_causal_skip = False
- # We now create the mask
- causal_mask = mask_interface(
- batch_size=batch_size,
- q_length=q_length,
- kv_length=kv_length,
- q_offset=q_offset,
- kv_offset=kv_offset,
- mask_function=mask_factory_function,
- attention_mask=attention_mask,
- allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
- local_size=sliding_window, # Additional kwarg for sdpa
- dtype=dtype, # Additional kwarg for eager
- config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
- use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
- device=device,
- )
- return causal_mask
- @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
- def create_bidirectional_sliding_window_mask(
- config: PreTrainedConfig,
- inputs_embeds: torch.Tensor,
- attention_mask: torch.Tensor | None,
- past_key_values: Cache | None = None,
- or_mask_function: Callable | None = None,
- and_mask_function: Callable | None = None,
- ) -> torch.Tensor | BlockMask | None:
- """
- Create a standard bidirectional sliding window mask based on the attention implementation used (stored in the config).
- Args:
- config (`PreTrainedConfig`):
- The model config.
- inputs_embeds (`torch.Tensor`):
- The input embeddings of shape (batch_size, query_length, hidden_dim). This is only used to infer metadata
- such as the batch size, query length, dtype, and device.
- past_key_values (`Cache`, optional):
- The past key values, if we use a cache.
- attention_mask (`torch.Tensor`, optional):
- The 2D attention mask corresponding to padded tokens of shape (batch_size, kv_length).
- It can also be an already prepared 4D mask of shape (batch_size, 1, query_length, kv_length),
- in which case it is returned as-is.
- or_mask_function (`Callable`, optional):
- An optional mask function to combine with the base mask function (by doing the union of both). This is
- useful to easily overlay another mask on top, for example for image tokens handling.
- and_mask_function (`Callable`, optional):
- An optional mask function to combine with the base mask function (by doing the intersection of both). This is
- useful to easily overlay another mask on top, for example for image tokens handling.
- """
- # We ignore a few irrelevant arguments at the end as we do not have a (growing) cache here
- early_exit, attention_mask, _, q_length, kv_length, q_offset, kv_offset = _preprocess_mask_arguments(
- config, inputs_embeds, attention_mask, past_key_values, None, 0
- )
- if early_exit:
- return attention_mask
- sliding_window = getattr(config, "sliding_window", None)
- if sliding_window is None:
- raise ValueError("Could not find a `sliding_window` argument in the config, or it is not set")
- batch_size, dtype, device = inputs_embeds.shape[0], inputs_embeds.dtype, inputs_embeds.device
- mask_factory_function = sliding_window_bidirectional_mask_function(sliding_window)
- mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
- use_vmap = False
- allow_is_bidirectional_skip = True
- if or_mask_function is not None:
- if not _is_torch_greater_or_equal_than_2_6:
- raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
- mask_factory_function = or_masks(mask_factory_function, or_mask_function)
- allow_is_bidirectional_skip = False
- use_vmap = True
- if and_mask_function is not None:
- if not _is_torch_greater_or_equal_than_2_6:
- raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
- mask_factory_function = and_masks(mask_factory_function, and_mask_function)
- allow_is_bidirectional_skip = False
- use_vmap = True
- attention_mask = mask_interface(
- batch_size=batch_size,
- q_length=q_length,
- kv_length=kv_length,
- q_offset=q_offset,
- kv_offset=kv_offset,
- mask_function=mask_factory_function,
- attention_mask=attention_mask,
- allow_is_causal_skip=False,
- allow_is_bidirectional_skip=allow_is_bidirectional_skip,
- local_size=sliding_window, # Additional kwarg for sdpa
- dtype=dtype, # Additional kwarg for eager
- config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
- use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
- device=device,
- )
- return attention_mask
- @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
- def create_chunked_causal_mask(
- config: PreTrainedConfig,
- inputs_embeds: torch.Tensor,
- attention_mask: torch.Tensor | None,
- cache_position: torch.Tensor | None = None, # not used anymore but kept for BC
- *,
- past_key_values: Cache | None,
- position_ids: torch.Tensor | None = None,
- or_mask_function: Callable | None = None,
- and_mask_function: Callable | None = None,
- ) -> torch.Tensor | BlockMask | None:
- """
- Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type
- of attention pattern was mostly democratized by Llama4. If `past_key_values` has an hybrid cache structure, this
- function will return the mask corresponding to one of the "chunked_attention" layers (to align to what is needed in the
- `modeling_xxx.py` files).
- Args:
- config (`PreTrainedConfig`):
- The model config.
- inputs_embeds (`torch.Tensor`):
- The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
- batch size, query length and dtype.
- attention_mask (`torch.Tensor`, optional):
- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
- It can also be an already prepared 4D mask, in which case it is returned as-is.
- cache_position (`torch.Tensor`):
- Deprecated and unused.
- past_key_values (`Cache`, optional):
- The past key values, if we use a cache.
- position_ids (`torch.Tensor`, optional)
- A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
- or_mask_function (`Callable`, optional):
- An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
- useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
- and_mask_function (`Callable`, optional):
- An optional mask function to combine with the chunked causal mask function (by doing the intersection of both). This is
- useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
- """
- # If we have an hybrid cache structure, here we want to create the mask for the sliding layers
- if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
- layer_idx = past_key_values.is_sliding.index(True)
- else:
- layer_idx = 0
- early_exit, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset = (
- _preprocess_mask_arguments(config, inputs_embeds, attention_mask, past_key_values, position_ids, layer_idx)
- )
- if early_exit:
- return attention_mask
- chunk_size = getattr(config, "attention_chunk_size", None)
- if chunk_size is None:
- raise ValueError("Could not find an `attention_chunk_size` argument in the config, or it is not set")
- # Raise if using chunked attention on context too large with FA
- if is_flash_attention_requested(config) and kv_length + kv_offset > chunk_size:
- raise ValueError(
- "Flash attention cannot handle chunked attention, and the key-value length is larger than the chunk size so the "
- "chunked pattern cannot be respected. You should use another `attn_implementation` when instantiating the model"
- )
- batch_size, dtype, device = inputs_embeds.shape[0], inputs_embeds.dtype, inputs_embeds.device
- # For chunked attention and batched inputs, we need to take the number of left padding tokens into account
- # to start the chunk from the actual start of the sequence for the padded sequence
- if attention_mask is not None:
- # Only count the left padding tokens, not all of them
- left_padding_tokens = (attention_mask.cumsum(dim=-1) == torch.zeros_like(attention_mask)).sum(dim=-1)
- else:
- left_padding_tokens = torch.zeros(batch_size, device=device, dtype=int)
- mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens)
- mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
- # Defaulting to using non-vmap based mask creations except when detecting
- # users passing custom mask functions (as we cannot guarantee that they
- # are properly index-based as required by our implementation).
- use_vmap = False
- # Do not allow skip if we are compiling (this is to match BC)
- # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
- allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
- # Allow slight deviations from causal mask
- # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
- # padding mask, etc) as the resulting mask may otherwise not be correct!
- if or_mask_function is not None:
- if not _is_torch_greater_or_equal_than_2_6:
- raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
- mask_factory_function = or_masks(mask_factory_function, or_mask_function)
- allow_is_causal_skip = False
- use_vmap = True
- if and_mask_function is not None:
- if not _is_torch_greater_or_equal_than_2_6:
- raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
- mask_factory_function = and_masks(mask_factory_function, and_mask_function)
- allow_is_causal_skip = False
- use_vmap = True
- # If we detected packing format
- if packed_sequence_mask is not None:
- mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
- allow_is_causal_skip = False
- # We now create the mask
- causal_mask = mask_interface(
- batch_size=batch_size,
- q_length=q_length,
- kv_length=kv_length,
- q_offset=q_offset,
- kv_offset=kv_offset,
- mask_function=mask_factory_function,
- attention_mask=attention_mask,
- allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
- local_size=chunk_size, # Additional kwarg for sdpa
- dtype=dtype, # Additional kwarg for eager
- config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
- use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
- device=device,
- )
- return causal_mask
- LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING = {
- "full_attention": create_causal_mask,
- "sliding_attention": create_sliding_window_causal_mask,
- "chunked_attention": create_chunked_causal_mask,
- }
- @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
- def create_masks_for_generate(
- config: PreTrainedConfig,
- inputs_embeds: torch.Tensor,
- attention_mask: torch.Tensor | None,
- past_key_values: Cache | None,
- position_ids: torch.Tensor | None = None,
- or_mask_function: Callable | None = None,
- and_mask_function: Callable | None = None,
- **kwargs,
- ):
- """
- This function mimics how we create the masks in the `modeling_xxx.py` files, and is used in places like `generate`
- in order to easily create the masks in advance, when we compile the forwards with Static caches.
- Args:
- config (`PreTrainedConfig`):
- The model config.
- inputs_embeds (`torch.Tensor`):
- The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
- batch size, query length and dtype.
- attention_mask (`torch.Tensor`, optional):
- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
- It can also be an already prepared 4D mask, in which case it is returned as-is.
- past_key_values (`Cache`, optional):
- The past key values, if we use a cache.
- position_ids (`torch.Tensor`, optional)
- A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
- or_mask_function (`Callable`, optional):
- An optional mask function to combine with the other mask function (by doing the union of both). This is
- useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
- and_mask_function (`Callable`, optional):
- An optional mask function to combine with the other mask function (by doing the intersection of both). This is
- useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
- """
- # The attribute reside in the text config for composite models
- effective_config = config.get_text_config()
- # Prepare the mask args
- mask_kwargs = {
- "config": effective_config,
- "inputs_embeds": inputs_embeds,
- "attention_mask": attention_mask,
- "past_key_values": past_key_values,
- "position_ids": position_ids,
- "or_mask_function": or_mask_function,
- "and_mask_function": and_mask_function,
- }
- # If the attribute exist, we need several masks
- if hasattr(effective_config, "layer_types"):
- causal_masks = {}
- for layer_pattern in set(effective_config.layer_types):
- causal_masks[layer_pattern] = LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING[layer_pattern](**mask_kwargs)
- return causal_masks
- # In this case, all layers are sliding
- elif getattr(effective_config, "sliding_window", None) is not None:
- return create_sliding_window_causal_mask(**mask_kwargs)
- # In this case, all layers are chunked
- elif getattr(effective_config, "attention_chunk_size", None) is not None:
- return create_chunked_causal_mask(**mask_kwargs)
- # All layers use standard causal attention
- return create_causal_mask(**mask_kwargs)
- # Below are utilities to pretty-print the different masks
- # Print the matrix with words as row labels
- GREEN = "\033[92m"
- YELLOW = "\033[93m"
- RESET = "\033[0m"
- BLACK_SQUARE = "■"
- WHITE_SQUARE = "⬚"
- GREY_SQUARE = "∙"
- LOW_TRIANGLE = "⬕"
- UPPER_TRIANGLE = "⬔"
- def get_style(style):
- if style == "majong":
- BLACK_SQUARE = "🀞" # Full block (represents "on" or active)
- BLACK_SQUARE = "🀙" # Full block (represents "on" or active)
- WHITE_SQUARE = "🀆" # "▒" # Light shade (represents "off" or inactive)
- LOW_TRIANGLE = "🀛" # Lower left triangle (stylized indication)
- UPPER_TRIANGLE = "🀛" # Upper left triangle (stylized indication)
- else:
- BLACK_SQUARE = "█" # Full block (represents "on" or active)
- WHITE_SQUARE = "░" # "▒" # Light shade (represents "off" or inactive)
- LOW_TRIANGLE = "▙" # Lower left triangle (stylized indication))
- UPPER_TRIANGLE = "▜" # Upper left triangle (stylized indication)
- return BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE
- # LOW_TRIANGLE = UPPER_TRIANGLE = "⟍" # Upper right triangle (stylized indication)
- YELLOW_SQUARE = f"{YELLOW}{BLACK_SQUARE}{RESET}"
- GREEN_SQUARE = f"{GREEN}{BLACK_SQUARE}{RESET}"
- def tensor_to_mask_visual(original_tensor: torch.Tensor, grid_size=(20, 40), style="majong") -> str:
- BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE = get_style(style)
- h, w = original_tensor.shape
- max_h, max_w = grid_size
- if not (h < max_h and w < max_w):
- # Preserve aspect ratio within max grid size
- aspect_ratio = 2 * w / h
- if aspect_ratio > 1:
- w = max_w
- h = min(max_h, max(1, round(max_w / aspect_ratio)))
- else:
- h = max_h
- w = max(1, round(max_h * aspect_ratio))
- # Step 1: Rescale tensor by average pooling
- tensor = original_tensor.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
- tensor = F.adaptive_avg_pool2d(tensor, output_size=(h, w))[0, 0] # Remove extra dims
- else:
- tensor = original_tensor
- # Step 3: Build the string representation
- result = []
- for i in range(h):
- row = ""
- for j in range(w):
- if tensor[i, j] == 1:
- row += BLACK_SQUARE
- elif tensor[i, j] == 0:
- row += WHITE_SQUARE
- else:
- if j > 0:
- if tensor[i, j - 1] == 1:
- row += LOW_TRIANGLE
- elif tensor[i, j - 1] == 0:
- row += UPPER_TRIANGLE
- else:
- row += BLACK_SQUARE if tensor[i, j] == 1 else WHITE_SQUARE
- else:
- row += (
- BLACK_SQUARE
- if tensor[i, j] == 1
- else (
- WHITE_SQUARE
- if tensor[i, j] == 0
- else (UPPER_TRIANGLE if tensor[i, j + 1] == 1 else LOW_TRIANGLE)
- )
- )
- result.append(row)
- return "\n".join(result)
- class AttentionMask(torch.Tensor):
- def __new__(cls, data, style=None):
- # Create a new instance of AttentionMask as a Tensor
- cls.style = style
- return torch.Tensor._make_subclass(cls, data, require_grad=False)
- def __init__(self, data):
- # You can initialize any additional metadata here if needed
- pass
- def to_string(self, grid_size=(20, 40), limit=4):
- """Returns a string representation of the block mask."""
- dense_mask = self
- *batch_dims, num_rows, num_cols = dense_mask.shape
- total_vis = []
- for idx, batch_idx in enumerate(itertools.product(*[range(i) for i in batch_dims])):
- if idx == limit:
- total_vis.append("...")
- total_vis.append("To print out more, set AttentionMask.to_string(limit=N)")
- total_vis.append("You can also index (AttentionMask[batch, head]) to choose a specific batch or head")
- break
- block_vis = tensor_to_mask_visual(dense_mask[batch_idx], grid_size=grid_size, style=self.style)
- total_vis.append(block_vis)
- total_vis.append(f"torch.Tensor(shape={tuple(self.shape)}, dtype={self.dtype})")
- return "\n".join(total_vis)
- def __repr__(self):
- return self.to_string()
- def __str__(self):
- return self.to_string()
- @classmethod
- def from_tensor(cls, tensor: torch.Tensor, style: str | None = None) -> "AttentionMask":
- res = cls(tensor)
- res.style = style
- return res
|