modeling_attn_mask_utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general
  16. `masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now,
  17. and will be removed in the future.
  18. """
  19. import warnings
  20. from dataclasses import dataclass
  21. from typing import Union
  22. import torch
  23. from .utils.import_utils import is_torchdynamo_compiling, is_tracing
  24. DEPRECATION_MESSAGE = (
  25. "The attention mask API under `transformers.modeling_attn_mask_utils` (`AttentionMaskConverter`) "
  26. "is deprecated and will be removed in Transformers v5.10. Please use the new API in `transformers.masking_utils`."
  27. )
  28. @dataclass
  29. class AttentionMaskConverter:
  30. """
  31. A utility attention mask class that allows one to:
  32. - Create a causal 4d mask
  33. - Create a causal 4d mask with slided window
  34. - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
  35. key_value_length) that can be multiplied with attention scores
  36. Examples:
  37. ```python
  38. >>> import torch
  39. >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
  40. >>> converter = AttentionMaskConverter(True)
  41. >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
  42. tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
  43. [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
  44. [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
  45. [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
  46. [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
  47. ```
  48. Parameters:
  49. is_causal (`bool`):
  50. Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
  51. sliding_window (`int`, *optional*):
  52. Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
  53. """
  54. is_causal: bool
  55. sliding_window: int
  56. def __init__(self, is_causal: bool, sliding_window: int | None = None):
  57. warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
  58. self.is_causal = is_causal
  59. self.sliding_window = sliding_window
  60. if self.sliding_window is not None and self.sliding_window <= 0:
  61. raise ValueError(
  62. f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
  63. )
  64. def to_causal_4d(
  65. self,
  66. batch_size: int,
  67. query_length: int,
  68. key_value_length: int,
  69. dtype: torch.dtype,
  70. device: Union[torch.device, "str"] = "cpu",
  71. ) -> torch.Tensor | None:
  72. """
  73. Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
  74. bias to upper right hand triangular matrix (causal mask).
  75. """
  76. if not self.is_causal:
  77. raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
  78. # If shape is not cached, create a new causal mask and cache it
  79. input_shape = (batch_size, query_length)
  80. past_key_values_length = key_value_length - query_length
  81. # create causal mask
  82. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  83. causal_4d_mask = None
  84. if input_shape[-1] > 1 or self.sliding_window is not None:
  85. causal_4d_mask = self._make_causal_mask(
  86. input_shape,
  87. dtype,
  88. device=device,
  89. past_key_values_length=past_key_values_length,
  90. sliding_window=self.sliding_window,
  91. )
  92. return causal_4d_mask
  93. def to_4d(
  94. self,
  95. attention_mask_2d: torch.Tensor,
  96. query_length: int,
  97. dtype: torch.dtype,
  98. key_value_length: int | None = None,
  99. ) -> torch.Tensor:
  100. """
  101. Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
  102. key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
  103. causal, a causal mask will be added.
  104. """
  105. input_shape = (attention_mask_2d.shape[0], query_length)
  106. # create causal mask
  107. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  108. causal_4d_mask = None
  109. if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
  110. if key_value_length is None:
  111. raise ValueError(
  112. "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
  113. )
  114. past_key_values_length = key_value_length - query_length
  115. causal_4d_mask = self._make_causal_mask(
  116. input_shape,
  117. dtype,
  118. device=attention_mask_2d.device,
  119. past_key_values_length=past_key_values_length,
  120. sliding_window=self.sliding_window,
  121. )
  122. elif self.sliding_window is not None:
  123. raise NotImplementedError("Sliding window is currently only implemented for causal masking")
  124. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  125. expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
  126. attention_mask_2d.device
  127. )
  128. if causal_4d_mask is not None:
  129. expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
  130. # expanded_attn_mask + causal_4d_mask can cause some overflow
  131. expanded_4d_mask = expanded_attn_mask
  132. return expanded_4d_mask
  133. @staticmethod
  134. def _make_causal_mask(
  135. input_ids_shape: torch.Size,
  136. dtype: torch.dtype,
  137. device: torch.device,
  138. past_key_values_length: int = 0,
  139. sliding_window: int | None = None,
  140. ):
  141. """
  142. Make causal mask used for bi-directional self-attention.
  143. """
  144. warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
  145. bsz, tgt_len = input_ids_shape
  146. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  147. mask_cond = torch.arange(mask.size(-1), device=device)
  148. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  149. mask = mask.to(dtype)
  150. if past_key_values_length > 0:
  151. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  152. # add lower triangular sliding window mask if necessary
  153. if sliding_window is not None:
  154. diagonal = past_key_values_length - sliding_window - 1
  155. context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
  156. # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
  157. # See https://github.com/pytorch/pytorch/issues/127571
  158. if is_torchdynamo_compiling():
  159. mask = mask.clone()
  160. mask.masked_fill_(context_mask, torch.finfo(dtype).min)
  161. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  162. @staticmethod
  163. def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
  164. """
  165. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  166. """
  167. warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
  168. bsz, src_len = mask.size()
  169. tgt_len = tgt_len if tgt_len is not None else src_len
  170. expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
  171. inverted_mask = torch.tensor(1.0, dtype=dtype) - expanded_mask
  172. return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
  173. @staticmethod
  174. def _unmask_unattended(
  175. expanded_mask: torch.FloatTensor,
  176. min_dtype: float,
  177. ):
  178. # fmt: off
  179. """
  180. Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
  181. using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  182. Details: https://github.com/pytorch/pytorch/issues/110213
  183. `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
  184. `attention_mask` is [bsz, src_seq_len].
  185. The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
  186. For example, if `expanded_mask` is (e.g. here left-padding case)
  187. ```
  188. [[[[0, 0, 0],
  189. [0, 0, 0],
  190. [0, 0, 1]]],
  191. [[[1, 0, 0],
  192. [1, 1, 0],
  193. [1, 1, 1]]],
  194. [[[0, 0, 0],
  195. [0, 1, 0],
  196. [0, 1, 1]]]]
  197. ```
  198. then the modified `expanded_mask` will be
  199. ```
  200. [[[[1, 1, 1], <-- modified
  201. [1, 1, 1], <-- modified
  202. [0, 0, 1]]],
  203. [[[1, 0, 0],
  204. [1, 1, 0],
  205. [1, 1, 1]]],
  206. [[[1, 1, 1], <-- modified
  207. [0, 1, 0],
  208. [0, 1, 1]]]]
  209. ```
  210. """
  211. warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
  212. # fmt: on
  213. if expanded_mask.dtype == torch.bool:
  214. raise ValueError(
  215. "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
  216. )
  217. return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
  218. @staticmethod
  219. def _ignore_causal_mask_sdpa(
  220. attention_mask: torch.Tensor | None,
  221. inputs_embeds: torch.Tensor,
  222. past_key_values_length: int,
  223. sliding_window: int | None = None,
  224. is_training: bool = False,
  225. ) -> bool:
  226. """
  227. Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
  228. ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
  229. In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
  230. `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
  231. allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
  232. passed).
  233. """
  234. warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
  235. _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
  236. key_value_length = query_length + past_key_values_length
  237. is_tracing_ = is_tracing(inputs_embeds)
  238. ignore_causal_mask = False
  239. if attention_mask is None:
  240. # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
  241. # shape, thus SDPA's `is_causal` argument is rightfully updated
  242. # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
  243. # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
  244. # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
  245. # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
  246. # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
  247. #
  248. # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
  249. # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
  250. if (
  251. (is_training or not is_tracing_)
  252. and (query_length == 1 or key_value_length == query_length)
  253. and (sliding_window is None or key_value_length < sliding_window)
  254. ):
  255. ignore_causal_mask = True
  256. elif sliding_window is None or key_value_length < sliding_window:
  257. if len(attention_mask.shape) == 4:
  258. return False
  259. elif not is_tracing_ and torch.all(attention_mask == 1):
  260. if query_length == 1 or key_value_length == query_length:
  261. # For query_length == 1, causal attention and bi-directional attention are the same.
  262. ignore_causal_mask = True
  263. # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
  264. # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
  265. # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
  266. # Reference: https://github.com/pytorch/pytorch/issues/108108
  267. # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
  268. return ignore_causal_mask
  269. def _prepare_4d_causal_attention_mask(
  270. attention_mask: torch.Tensor | None,
  271. input_shape: torch.Size | tuple | list,
  272. inputs_embeds: torch.Tensor,
  273. past_key_values_length: int,
  274. sliding_window: int | None = None,
  275. ):
  276. """
  277. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  278. `(batch_size, key_value_length)`
  279. Args:
  280. attention_mask (`torch.Tensor` or `None`):
  281. A 2D attention mask of shape `(batch_size, key_value_length)`
  282. input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
  283. The input shape should be a tuple that defines `(batch_size, query_length)`.
  284. inputs_embeds (`torch.Tensor`):
  285. The embedded inputs as a torch Tensor.
  286. past_key_values_length (`int`):
  287. The length of the key value cache.
  288. sliding_window (`int`, *optional*):
  289. If the model uses windowed attention, a sliding window should be passed.
  290. """
  291. attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
  292. key_value_length = input_shape[-1] + past_key_values_length
  293. # 4d mask is passed through the layers
  294. if attention_mask is not None and len(attention_mask.shape) == 2:
  295. attention_mask = attn_mask_converter.to_4d(
  296. attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
  297. )
  298. elif attention_mask is not None and len(attention_mask.shape) == 4:
  299. expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
  300. if tuple(attention_mask.shape) != expected_shape:
  301. raise ValueError(
  302. f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
  303. )
  304. else:
  305. # if the 4D mask has correct shape - invert it and fill with negative infinity
  306. inverted_mask = 1.0 - attention_mask
  307. attention_mask = inverted_mask.masked_fill(
  308. inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
  309. )
  310. else:
  311. attention_mask = attn_mask_converter.to_causal_4d(
  312. input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
  313. )
  314. return attention_mask
  315. # Adapted from _prepare_4d_causal_attention_mask
  316. def _prepare_4d_causal_attention_mask_for_sdpa(
  317. attention_mask: torch.Tensor | None,
  318. input_shape: torch.Size | tuple | list,
  319. inputs_embeds: torch.Tensor,
  320. past_key_values_length: int,
  321. sliding_window: int | None = None,
  322. ):
  323. """
  324. Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
  325. In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
  326. `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
  327. allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
  328. """
  329. attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
  330. key_value_length = input_shape[-1] + past_key_values_length
  331. # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
  332. # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
  333. # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
  334. is_tracing_ = is_tracing(inputs_embeds)
  335. ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
  336. attention_mask=attention_mask,
  337. inputs_embeds=inputs_embeds,
  338. past_key_values_length=past_key_values_length,
  339. sliding_window=sliding_window,
  340. )
  341. if ignore_causal_mask:
  342. expanded_4d_mask = None
  343. elif attention_mask is None:
  344. expanded_4d_mask = attn_mask_converter.to_causal_4d(
  345. input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
  346. )
  347. else:
  348. if attention_mask.dim() == 4:
  349. expanded_4d_mask = attention_mask
  350. else:
  351. expanded_4d_mask = attn_mask_converter.to_4d(
  352. attention_mask,
  353. input_shape[-1],
  354. dtype=inputs_embeds.dtype,
  355. key_value_length=key_value_length,
  356. )
  357. # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
  358. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  359. # Details: https://github.com/pytorch/pytorch/issues/110213
  360. if not is_tracing_ and expanded_4d_mask.device.type in ["cuda", "xpu"]:
  361. expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
  362. expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
  363. )
  364. return expanded_4d_mask
  365. def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
  366. """
  367. Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  368. `(batch_size, key_value_length)`
  369. Args:
  370. mask (`torch.Tensor`):
  371. A 2D attention mask of shape `(batch_size, key_value_length)`
  372. dtype (`torch.dtype`):
  373. The torch dtype the created mask shall have.
  374. tgt_len (`int`):
  375. The target length or query length the created mask shall have.
  376. """
  377. return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
  378. def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
  379. """
  380. Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  381. `(batch_size, key_value_length)`
  382. Args:
  383. mask (`torch.Tensor`):
  384. A 2D attention mask of shape `(batch_size, key_value_length)`
  385. dtype (`torch.dtype`):
  386. The torch dtype the created mask shall have.
  387. tgt_len (`int`):
  388. The target length or query length the created mask shall have.
  389. """
  390. warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
  391. _, key_value_length = mask.shape
  392. tgt_len = tgt_len if tgt_len is not None else key_value_length
  393. # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
  394. if not is_tracing(mask) and torch.all(mask == 1):
  395. return None
  396. else:
  397. return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
  398. def _create_4d_causal_attention_mask(
  399. input_shape: torch.Size | tuple | list,
  400. dtype: torch.dtype,
  401. device: torch.device,
  402. past_key_values_length: int = 0,
  403. sliding_window: int | None = None,
  404. ) -> torch.Tensor | None:
  405. """
  406. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
  407. Args:
  408. input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
  409. The input shape should be a tuple that defines `(batch_size, query_length)`.
  410. dtype (`torch.dtype`):
  411. The torch dtype the created mask shall have.
  412. device (`int`):
  413. The torch device the created mask shall have.
  414. sliding_window (`int`, *optional*):
  415. If the model uses windowed attention, a sliding window should be passed.
  416. """
  417. attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
  418. key_value_length = past_key_values_length + input_shape[-1]
  419. attention_mask = attn_mask_converter.to_causal_4d(
  420. input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
  421. )
  422. return attention_mask