modeling_blenderbot.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021
  1. # Copyright 2021 The Facebook, Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Blenderbot model."""
  15. import math
  16. from collections.abc import Callable
  17. import torch
  18. from torch import nn
  19. from torch.nn import CrossEntropyLoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  25. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. BaseModelOutputWithPastAndCrossAttentions,
  30. CausalLMOutputWithCrossAttentions,
  31. Seq2SeqLMOutput,
  32. Seq2SeqModelOutput,
  33. )
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
  37. from ...utils.generic import merge_with_config_defaults
  38. from ...utils.output_capturing import OutputRecorder, capture_outputs
  39. from .configuration_blenderbot import BlenderbotConfig
  40. logger = logging.get_logger(__name__)
  41. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  42. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  43. """
  44. Shift input ids one token to the right.
  45. """
  46. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  47. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  48. shifted_input_ids[:, 0] = decoder_start_token_id
  49. if pad_token_id is None:
  50. raise ValueError("self.model.config.pad_token_id has to be defined.")
  51. # replace possible -100 values in labels by `pad_token_id`
  52. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  53. return shifted_input_ids
  54. class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
  55. """
  56. This module learns positional embeddings up to a fixed maximum size.
  57. """
  58. def __init__(self, num_embeddings: int, embedding_dim: int):
  59. super().__init__(num_embeddings, embedding_dim)
  60. def forward(
  61. self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None
  62. ):
  63. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  64. if position_ids is None:
  65. bsz, seq_len = input_ids_shape[:2]
  66. position_ids = torch.arange(
  67. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  68. )
  69. return super().forward(position_ids)
  70. # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->Blenderbot
  71. class BlenderbotScaledWordEmbedding(nn.Embedding):
  72. """
  73. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  74. """
  75. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float | None = 1.0):
  76. super().__init__(num_embeddings, embedding_dim, padding_idx)
  77. self.embed_scale = embed_scale
  78. def forward(self, input_ids: torch.Tensor):
  79. return super().forward(input_ids) * self.embed_scale
  80. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  81. def eager_attention_forward(
  82. module: nn.Module,
  83. query: torch.Tensor,
  84. key: torch.Tensor,
  85. value: torch.Tensor,
  86. attention_mask: torch.Tensor | None,
  87. scaling: float | None = None,
  88. dropout: float = 0.0,
  89. **kwargs: Unpack[TransformersKwargs],
  90. ):
  91. if scaling is None:
  92. scaling = query.size(-1) ** -0.5
  93. # Take the dot product between "query" and "key" to get the raw attention scores.
  94. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  95. if attention_mask is not None:
  96. attn_weights = attn_weights + attention_mask
  97. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  98. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  99. attn_output = torch.matmul(attn_weights, value)
  100. attn_output = attn_output.transpose(1, 2).contiguous()
  101. return attn_output, attn_weights
  102. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot
  103. class BlenderbotAttention(nn.Module):
  104. """Multi-headed attention from 'Attention Is All You Need' paper"""
  105. def __init__(
  106. self,
  107. embed_dim: int,
  108. num_heads: int,
  109. dropout: float = 0.0,
  110. is_decoder: bool = False,
  111. bias: bool = True,
  112. is_causal: bool = False,
  113. config: BlenderbotConfig | None = None,
  114. layer_idx: int | None = None,
  115. ):
  116. super().__init__()
  117. self.embed_dim = embed_dim
  118. self.num_heads = num_heads
  119. self.dropout = dropout
  120. self.head_dim = embed_dim // num_heads
  121. self.config = config
  122. if (self.head_dim * num_heads) != self.embed_dim:
  123. raise ValueError(
  124. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  125. f" and `num_heads`: {num_heads})."
  126. )
  127. self.scaling = self.head_dim**-0.5
  128. self.is_decoder = is_decoder
  129. self.is_causal = is_causal
  130. self.layer_idx = layer_idx
  131. if layer_idx is None and self.is_decoder:
  132. logger.warning_once(
  133. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  134. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  135. "when creating this class."
  136. )
  137. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  138. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  139. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  140. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  141. def forward(
  142. self,
  143. hidden_states: torch.Tensor,
  144. key_value_states: torch.Tensor | None = None,
  145. past_key_values: Cache | None = None,
  146. attention_mask: torch.Tensor | None = None,
  147. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  148. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  149. **kwargs: Unpack[FlashAttentionKwargs],
  150. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  151. """Input shape: Batch x Time x Channel"""
  152. # if key_value_states are provided this layer is used as a cross-attention layer
  153. # for the decoder
  154. is_cross_attention = key_value_states is not None
  155. # determine input shapes
  156. input_shape = hidden_states.shape[:-1]
  157. hidden_shape = (*input_shape, -1, self.head_dim)
  158. # get query proj
  159. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  160. is_updated = False
  161. if past_key_values is not None:
  162. if isinstance(past_key_values, EncoderDecoderCache):
  163. is_updated = past_key_values.is_updated.get(self.layer_idx)
  164. if is_cross_attention:
  165. # after the first generated id, we can subsequently re-use all key/value_states from cache
  166. curr_past_key_values = past_key_values.cross_attention_cache
  167. else:
  168. curr_past_key_values = past_key_values.self_attention_cache
  169. else:
  170. curr_past_key_values = past_key_values
  171. current_states = key_value_states if is_cross_attention else hidden_states
  172. if is_cross_attention and past_key_values is not None and is_updated:
  173. # reuse k,v, cross_attentions
  174. key_states = curr_past_key_values.layers[self.layer_idx].keys
  175. value_states = curr_past_key_values.layers[self.layer_idx].values
  176. else:
  177. key_states = self.k_proj(current_states)
  178. value_states = self.v_proj(current_states)
  179. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  180. key_states = key_states.view(kv_shape).transpose(1, 2)
  181. value_states = value_states.view(kv_shape).transpose(1, 2)
  182. if past_key_values is not None:
  183. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  184. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  185. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  186. past_key_values.is_updated[self.layer_idx] = True
  187. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  188. self.config._attn_implementation, eager_attention_forward
  189. )
  190. attn_output, attn_weights = attention_interface(
  191. self,
  192. query_states,
  193. key_states,
  194. value_states,
  195. attention_mask,
  196. dropout=0.0 if not self.training else self.dropout,
  197. scaling=self.scaling,
  198. **kwargs,
  199. )
  200. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  201. attn_output = self.out_proj(attn_output)
  202. return attn_output, attn_weights
  203. # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
  204. class BlenderbotEncoderLayer(GradientCheckpointingLayer):
  205. def __init__(self, config: BlenderbotConfig):
  206. super().__init__()
  207. self.embed_dim = config.d_model
  208. self.self_attn = BlenderbotAttention(
  209. embed_dim=self.embed_dim,
  210. num_heads=config.encoder_attention_heads,
  211. dropout=config.attention_dropout,
  212. config=config,
  213. )
  214. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  215. self.dropout = config.dropout
  216. self.activation_fn = ACT2FN[config.activation_function]
  217. self.activation_dropout = config.activation_dropout
  218. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  219. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  220. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  221. def forward(
  222. self,
  223. hidden_states: torch.Tensor,
  224. attention_mask: torch.Tensor,
  225. **kwargs: Unpack[TransformersKwargs],
  226. ) -> torch.Tensor:
  227. """
  228. Args:
  229. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  230. attention_mask (`torch.FloatTensor`): attention mask of size
  231. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  232. """
  233. residual = hidden_states
  234. hidden_states = self.self_attn_layer_norm(hidden_states)
  235. hidden_states, _ = self.self_attn(
  236. hidden_states=hidden_states,
  237. attention_mask=attention_mask,
  238. **kwargs,
  239. )
  240. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  241. hidden_states = residual + hidden_states
  242. residual = hidden_states
  243. hidden_states = self.final_layer_norm(hidden_states)
  244. hidden_states = self.activation_fn(self.fc1(hidden_states))
  245. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  246. hidden_states = self.fc2(hidden_states)
  247. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  248. hidden_states = residual + hidden_states
  249. if hidden_states.dtype == torch.float16:
  250. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  251. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  252. return hidden_states
  253. # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
  254. class BlenderbotDecoderLayer(GradientCheckpointingLayer):
  255. def __init__(self, config: BlenderbotConfig, layer_idx: int | None = None):
  256. super().__init__()
  257. self.embed_dim = config.d_model
  258. self.self_attn = BlenderbotAttention(
  259. embed_dim=self.embed_dim,
  260. num_heads=config.decoder_attention_heads,
  261. dropout=config.attention_dropout,
  262. is_decoder=True,
  263. is_causal=True,
  264. config=config,
  265. layer_idx=layer_idx,
  266. )
  267. self.dropout = config.dropout
  268. self.activation_fn = ACT2FN[config.activation_function]
  269. self.activation_dropout = config.activation_dropout
  270. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  271. self.encoder_attn = BlenderbotAttention(
  272. self.embed_dim,
  273. config.decoder_attention_heads,
  274. dropout=config.attention_dropout,
  275. is_decoder=True,
  276. config=config,
  277. layer_idx=layer_idx,
  278. )
  279. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  280. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  281. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  282. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  283. def forward(
  284. self,
  285. hidden_states: torch.Tensor,
  286. attention_mask: torch.Tensor | None = None,
  287. encoder_hidden_states: torch.Tensor | None = None,
  288. encoder_attention_mask: torch.Tensor | None = None,
  289. past_key_values: Cache | None = None,
  290. use_cache: bool | None = True,
  291. **kwargs: Unpack[TransformersKwargs],
  292. ) -> torch.Tensor:
  293. """
  294. Args:
  295. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  296. attention_mask (`torch.FloatTensor`): attention mask of size
  297. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  298. encoder_hidden_states (`torch.FloatTensor`):
  299. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  300. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  301. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  302. past_key_values (`Cache`): cached past key and value projection states
  303. """
  304. residual = hidden_states
  305. hidden_states = self.self_attn_layer_norm(hidden_states)
  306. # Self Attention
  307. hidden_states, _ = self.self_attn(
  308. hidden_states=hidden_states,
  309. past_key_values=past_key_values,
  310. attention_mask=attention_mask,
  311. **kwargs,
  312. )
  313. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  314. hidden_states = residual + hidden_states
  315. # Cross-Attention Block
  316. if encoder_hidden_states is not None:
  317. residual = hidden_states
  318. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  319. hidden_states, _ = self.encoder_attn(
  320. hidden_states=hidden_states,
  321. key_value_states=encoder_hidden_states,
  322. attention_mask=encoder_attention_mask,
  323. past_key_values=past_key_values,
  324. **kwargs,
  325. )
  326. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  327. hidden_states = residual + hidden_states
  328. # Fully Connected
  329. residual = hidden_states
  330. hidden_states = self.final_layer_norm(hidden_states)
  331. hidden_states = self.activation_fn(self.fc1(hidden_states))
  332. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  333. hidden_states = self.fc2(hidden_states)
  334. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  335. hidden_states = residual + hidden_states
  336. return hidden_states
  337. @auto_docstring
  338. class BlenderbotPreTrainedModel(PreTrainedModel):
  339. config: BlenderbotConfig
  340. base_model_prefix = "model"
  341. supports_gradient_checkpointing = True
  342. _supports_flash_attn = True
  343. _supports_sdpa = True
  344. _supports_flex_attn = True
  345. _can_compile_fullgraph = True
  346. def _init_weights(self, module):
  347. super()._init_weights(module)
  348. if isinstance(module, BlenderbotForConditionalGeneration):
  349. init.zeros_(module.final_logits_bias)
  350. @property
  351. def dummy_inputs(self):
  352. pad_token = self.config.pad_token_id
  353. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  354. dummy_inputs = {
  355. "attention_mask": input_ids.ne(pad_token),
  356. "input_ids": input_ids,
  357. "decoder_input_ids": input_ids,
  358. }
  359. return dummy_inputs
  360. class BlenderbotEncoder(BlenderbotPreTrainedModel):
  361. """
  362. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  363. [`BlenderbotEncoderLayer`].
  364. Args:
  365. config: BlenderbotConfig
  366. embed_tokens (nn.Embedding): output embedding
  367. """
  368. _can_record_outputs = {
  369. "hidden_states": BlenderbotEncoderLayer,
  370. "attentions": BlenderbotAttention,
  371. }
  372. def __init__(self, config: BlenderbotConfig):
  373. super().__init__(config)
  374. self.dropout = config.dropout
  375. self.layerdrop = config.encoder_layerdrop
  376. embed_dim = config.d_model
  377. self.padding_idx = config.pad_token_id
  378. self.max_source_positions = config.max_position_embeddings
  379. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  380. self.embed_tokens = BlenderbotScaledWordEmbedding(
  381. config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
  382. )
  383. self.embed_positions = BlenderbotLearnedPositionalEmbedding(
  384. config.max_position_embeddings,
  385. embed_dim,
  386. )
  387. self.layers = nn.ModuleList([BlenderbotEncoderLayer(config) for _ in range(config.encoder_layers)])
  388. self.layer_norm = nn.LayerNorm(config.d_model)
  389. self.gradient_checkpointing = False
  390. # Initialize weights and apply final processing
  391. self.post_init()
  392. @merge_with_config_defaults
  393. @capture_outputs
  394. @auto_docstring
  395. def forward(
  396. self,
  397. input_ids: torch.LongTensor | None = None,
  398. attention_mask: torch.Tensor | None = None,
  399. inputs_embeds: torch.FloatTensor | None = None,
  400. **kwargs: Unpack[TransformersKwargs],
  401. ) -> BaseModelOutput:
  402. if (input_ids is None) ^ (inputs_embeds is not None):
  403. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  404. if inputs_embeds is None:
  405. inputs_embeds = self.embed_tokens(input_ids)
  406. input_shape = inputs_embeds.size()[:-1]
  407. embed_pos = self.embed_positions(input_shape)
  408. hidden_states = inputs_embeds + embed_pos
  409. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  410. attention_mask = create_bidirectional_mask(
  411. config=self.config,
  412. inputs_embeds=inputs_embeds,
  413. attention_mask=attention_mask,
  414. )
  415. for idx, encoder_layer in enumerate(self.layers):
  416. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  417. to_drop = False
  418. if self.training:
  419. dropout_probability = torch.rand([])
  420. if dropout_probability < self.layerdrop: # skip the layer
  421. to_drop = True
  422. if not to_drop:
  423. hidden_states = encoder_layer(
  424. hidden_states,
  425. attention_mask,
  426. **kwargs,
  427. )
  428. # add final layer norm
  429. hidden_states = self.layer_norm(hidden_states)
  430. return BaseModelOutput(
  431. last_hidden_state=hidden_states,
  432. )
  433. class BlenderbotDecoder(BlenderbotPreTrainedModel):
  434. """
  435. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BlenderbotDecoderLayer`]
  436. Args:
  437. config: BlenderbotConfig
  438. embed_tokens (nn.Embedding): output embedding
  439. """
  440. _can_record_outputs = {
  441. "hidden_states": BlenderbotDecoderLayer,
  442. "attentions": OutputRecorder(BlenderbotAttention, index=1, layer_name="self_attn"),
  443. "cross_attentions": OutputRecorder(BlenderbotAttention, index=1, layer_name="encoder_attn"),
  444. }
  445. def __init__(self, config: BlenderbotConfig):
  446. super().__init__(config)
  447. self.dropout = config.dropout
  448. self.layerdrop = config.decoder_layerdrop
  449. self.padding_idx = config.pad_token_id
  450. self.max_target_positions = config.max_position_embeddings
  451. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  452. self.embed_tokens = BlenderbotScaledWordEmbedding(
  453. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  454. )
  455. self.embed_positions = BlenderbotLearnedPositionalEmbedding(
  456. config.max_position_embeddings,
  457. config.d_model,
  458. )
  459. self.layers = nn.ModuleList(
  460. [BlenderbotDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]
  461. )
  462. self.layer_norm = nn.LayerNorm(config.d_model)
  463. self.gradient_checkpointing = False
  464. # Initialize weights and apply final processing
  465. self.post_init()
  466. @merge_with_config_defaults
  467. @capture_outputs
  468. @auto_docstring
  469. def forward(
  470. self,
  471. input_ids: torch.LongTensor | None = None,
  472. attention_mask: torch.Tensor | None = None,
  473. encoder_hidden_states: torch.FloatTensor | None = None,
  474. encoder_attention_mask: torch.LongTensor | None = None,
  475. past_key_values: Cache | None = None,
  476. inputs_embeds: torch.FloatTensor | None = None,
  477. use_cache: bool | None = None,
  478. **kwargs: Unpack[TransformersKwargs],
  479. ) -> BaseModelOutputWithPastAndCrossAttentions:
  480. if (input_ids is None) ^ (inputs_embeds is not None):
  481. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  482. if inputs_embeds is None:
  483. inputs_embeds = self.embed_tokens(input_ids)
  484. # initialize `past_key_values`
  485. if use_cache and past_key_values is None:
  486. past_key_values = (
  487. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  488. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  489. else DynamicCache(config=self.config)
  490. )
  491. batch_size, seq_length = inputs_embeds.size()[:-1]
  492. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  493. position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length
  494. if attention_mask is None and not is_torchdynamo_compiling():
  495. # required mask seq length can be calculated via length of past cache
  496. mask_seq_length = past_key_values_length + seq_length
  497. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  498. self_attn_cache = (
  499. past_key_values.self_attention_cache
  500. if isinstance(past_key_values, EncoderDecoderCache)
  501. else past_key_values
  502. )
  503. causal_mask = create_causal_mask(
  504. config=self.config,
  505. inputs_embeds=inputs_embeds,
  506. attention_mask=attention_mask,
  507. past_key_values=self_attn_cache,
  508. )
  509. encoder_attention_mask = create_bidirectional_mask(
  510. config=self.config,
  511. inputs_embeds=inputs_embeds,
  512. attention_mask=encoder_attention_mask,
  513. encoder_hidden_states=encoder_hidden_states,
  514. )
  515. # embed positions
  516. position_ids = self.embed_positions(
  517. (batch_size, seq_length), past_key_values_length, position_ids=position_ids
  518. )
  519. hidden_states = inputs_embeds + position_ids
  520. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  521. for idx, decoder_layer in enumerate(self.layers):
  522. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  523. if self.training:
  524. dropout_probability = torch.rand([])
  525. if dropout_probability < self.layerdrop:
  526. continue
  527. layer_outputs = decoder_layer(
  528. hidden_states,
  529. causal_mask,
  530. encoder_hidden_states, # as a positional argument for gradient checkpointing
  531. encoder_attention_mask=encoder_attention_mask,
  532. past_key_values=past_key_values,
  533. use_cache=use_cache,
  534. **kwargs,
  535. )
  536. hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs
  537. # add final layer norm
  538. hidden_states = self.layer_norm(hidden_states)
  539. return BaseModelOutputWithPastAndCrossAttentions(
  540. last_hidden_state=hidden_states,
  541. past_key_values=past_key_values,
  542. )
  543. @auto_docstring
  544. class BlenderbotModel(BlenderbotPreTrainedModel):
  545. _tied_weights_keys = {
  546. "encoder.embed_tokens.weight": "shared.weight",
  547. "decoder.embed_tokens.weight": "shared.weight",
  548. }
  549. def __init__(self, config: BlenderbotConfig):
  550. super().__init__(config)
  551. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  552. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  553. self.shared = BlenderbotScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  554. self.encoder = BlenderbotEncoder(config)
  555. self.decoder = BlenderbotDecoder(config)
  556. # Initialize weights and apply final processing
  557. self.post_init()
  558. def get_input_embeddings(self):
  559. return self.shared
  560. def set_input_embeddings(self, value):
  561. self.shared = value
  562. self.encoder.embed_tokens = self.shared
  563. self.decoder.embed_tokens = self.shared
  564. @can_return_tuple
  565. @auto_docstring
  566. def forward(
  567. self,
  568. input_ids: torch.LongTensor | None = None,
  569. attention_mask: torch.Tensor | None = None,
  570. decoder_input_ids: torch.LongTensor | None = None,
  571. decoder_attention_mask: torch.LongTensor | None = None,
  572. encoder_outputs: BaseModelOutput | None = None,
  573. past_key_values: Cache | None = None,
  574. inputs_embeds: torch.FloatTensor | None = None,
  575. decoder_inputs_embeds: torch.FloatTensor | None = None,
  576. use_cache: bool | None = None,
  577. **kwargs: Unpack[TransformersKwargs],
  578. ) -> Seq2SeqModelOutput:
  579. r"""
  580. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  581. Indices of decoder input sequence tokens in the vocabulary.
  582. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  583. [`PreTrainedTokenizer.__call__`] for details.
  584. [What are decoder input IDs?](../glossary#decoder-input-ids)
  585. Blenderbot uses the `bos_token_id` as the starting token for `decoder_input_ids` generation. If
  586. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  587. `past_key_values`).
  588. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  589. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  590. be used by default.
  591. Example:
  592. ```python
  593. >>> from transformers import AutoTokenizer, BlenderbotModel
  594. >>> model = BlenderbotModel.from_pretrained("facebook/blenderbot-400M-distill")
  595. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
  596. >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
  597. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  598. >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_input_ids)
  599. >>> last_hidden_states = outputs.last_hidden_state
  600. >>> list(last_hidden_states.shape)
  601. [1, 6, 1280]
  602. ```"""
  603. if encoder_outputs is None:
  604. encoder_outputs: BaseModelOutput = self.encoder(
  605. input_ids=input_ids,
  606. attention_mask=attention_mask,
  607. inputs_embeds=inputs_embeds,
  608. **kwargs,
  609. )
  610. elif not isinstance(encoder_outputs, BaseModelOutput):
  611. encoder_outputs = BaseModelOutput(
  612. last_hidden_state=encoder_outputs[0],
  613. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  614. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  615. )
  616. decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
  617. input_ids=decoder_input_ids,
  618. attention_mask=decoder_attention_mask,
  619. encoder_hidden_states=encoder_outputs[0],
  620. encoder_attention_mask=attention_mask,
  621. past_key_values=past_key_values,
  622. inputs_embeds=decoder_inputs_embeds,
  623. use_cache=use_cache,
  624. **kwargs,
  625. )
  626. return Seq2SeqModelOutput(
  627. last_hidden_state=decoder_outputs.last_hidden_state,
  628. past_key_values=decoder_outputs.past_key_values,
  629. decoder_hidden_states=decoder_outputs.hidden_states,
  630. decoder_attentions=decoder_outputs.attentions,
  631. cross_attentions=decoder_outputs.cross_attentions,
  632. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  633. encoder_hidden_states=encoder_outputs.hidden_states,
  634. encoder_attentions=encoder_outputs.attentions,
  635. )
  636. @auto_docstring(
  637. custom_intro="""
  638. The Blenderbot Model with a language modeling head. Can be used for summarization.
  639. """
  640. )
  641. class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin):
  642. base_model_prefix = "model"
  643. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  644. _tied_weights_keys = {
  645. "lm_head.weight": "model.shared.weight",
  646. }
  647. def __init__(self, config: BlenderbotConfig):
  648. super().__init__(config)
  649. self.model = BlenderbotModel(config)
  650. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  651. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  652. # Initialize weights and apply final processing
  653. self.post_init()
  654. def resize_token_embeddings(
  655. self, new_num_tokens: int, pad_to_multiple_of: int | None = None, mean_resizing: bool = True
  656. ) -> nn.Embedding:
  657. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  658. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  659. return new_embeddings
  660. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  661. old_num_tokens = self.final_logits_bias.shape[-1]
  662. if new_num_tokens <= old_num_tokens:
  663. new_bias = self.final_logits_bias[:, :new_num_tokens]
  664. else:
  665. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  666. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  667. self.register_buffer("final_logits_bias", new_bias)
  668. @can_return_tuple
  669. @auto_docstring
  670. def forward(
  671. self,
  672. input_ids: torch.LongTensor | None = None,
  673. attention_mask: torch.Tensor | None = None,
  674. decoder_input_ids: torch.LongTensor | None = None,
  675. decoder_attention_mask: torch.LongTensor | None = None,
  676. encoder_outputs: BaseModelOutput | None = None,
  677. past_key_values: Cache | None = None,
  678. inputs_embeds: torch.FloatTensor | None = None,
  679. decoder_inputs_embeds: torch.FloatTensor | None = None,
  680. labels: torch.LongTensor | None = None,
  681. use_cache: bool | None = None,
  682. **kwargs: Unpack[TransformersKwargs],
  683. ) -> Seq2SeqLMOutput:
  684. r"""
  685. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  686. Indices of decoder input sequence tokens in the vocabulary.
  687. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  688. [`PreTrainedTokenizer.__call__`] for details.
  689. [What are decoder input IDs?](../glossary#decoder-input-ids)
  690. Blenderbot uses the `bos_token_id` as the starting token for `decoder_input_ids` generation. If
  691. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  692. `past_key_values`).
  693. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  694. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  695. be used by default.
  696. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  697. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  698. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  699. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  700. Example conversation:
  701. ```python
  702. >>> from transformers import AutoTokenizer, BlenderbotForConditionalGeneration
  703. >>> mname = "facebook/blenderbot-400M-distill"
  704. >>> model = BlenderbotForConditionalGeneration.from_pretrained(mname)
  705. >>> tokenizer = AutoTokenizer.from_pretrained(mname)
  706. >>> UTTERANCE = "My friends are cool but they eat too many carbs."
  707. >>> print("Human: ", UTTERANCE)
  708. Human: My friends are cool but they eat too many carbs.
  709. >>> inputs = tokenizer([UTTERANCE], return_tensors="pt")
  710. >>> reply_ids = model.generate(**inputs)
  711. >>> print("Bot: ", tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0])
  712. Bot: That's unfortunate. Are they trying to lose weight or are they just trying to be healthier?
  713. >>> REPLY = "I'm not sure"
  714. >>> print("Human: ", REPLY)
  715. Human: I'm not sure
  716. >>> NEXT_UTTERANCE = (
  717. ... "My friends are cool but they eat too many carbs.</s> <s>That's unfortunate. "
  718. ... "Are they trying to lose weight or are they just trying to be healthier?</s> "
  719. ... "<s> I'm not sure."
  720. ... )
  721. >>> inputs = tokenizer([NEXT_UTTERANCE], return_tensors="pt")
  722. >>> next_reply_ids = model.generate(**inputs)
  723. >>> print("Bot: ", tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0])
  724. Bot: I see. Well, it's good that they're trying to change their eating habits.
  725. ```
  726. """
  727. if labels is not None:
  728. if use_cache:
  729. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  730. use_cache = False
  731. if decoder_input_ids is None and decoder_inputs_embeds is None:
  732. decoder_input_ids = shift_tokens_right(
  733. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  734. )
  735. outputs: Seq2SeqModelOutput = self.model(
  736. input_ids,
  737. attention_mask=attention_mask,
  738. decoder_input_ids=decoder_input_ids,
  739. encoder_outputs=encoder_outputs,
  740. decoder_attention_mask=decoder_attention_mask,
  741. past_key_values=past_key_values,
  742. inputs_embeds=inputs_embeds,
  743. decoder_inputs_embeds=decoder_inputs_embeds,
  744. use_cache=use_cache,
  745. **kwargs,
  746. )
  747. lm_logits = self.lm_head(outputs[0])
  748. lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
  749. masked_lm_loss = None
  750. if labels is not None:
  751. labels = labels.to(lm_logits.device)
  752. loss_fct = CrossEntropyLoss()
  753. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  754. return Seq2SeqLMOutput(
  755. loss=masked_lm_loss,
  756. logits=lm_logits,
  757. past_key_values=outputs.past_key_values,
  758. decoder_hidden_states=outputs.decoder_hidden_states,
  759. decoder_attentions=outputs.decoder_attentions,
  760. cross_attentions=outputs.cross_attentions,
  761. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  762. encoder_hidden_states=outputs.encoder_hidden_states,
  763. encoder_attentions=outputs.encoder_attentions,
  764. )
  765. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Blenderbot
  766. class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
  767. """
  768. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  769. used in combination with the [`EncoderDecoderModel`] framework.
  770. """
  771. def __init__(self, config):
  772. super().__init__(config)
  773. self.decoder = BlenderbotDecoder(config)
  774. self.post_init()
  775. def forward(self, *args, **kwargs):
  776. return self.decoder(*args, **kwargs)
  777. # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
  778. class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin):
  779. _tied_weights_keys = {
  780. "lm_head.weight": "model.decoder.embed_tokens.weight",
  781. }
  782. def __init__(self, config):
  783. config.is_decoder = True
  784. config.is_encoder_decoder = False
  785. super().__init__(config)
  786. self.model = BlenderbotDecoderWrapper(config)
  787. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  788. # Initialize weights and apply final processing
  789. self.post_init()
  790. def get_input_embeddings(self):
  791. return self.model.decoder.embed_tokens
  792. def set_input_embeddings(self, value):
  793. self.model.decoder.embed_tokens = value
  794. @can_return_tuple
  795. @auto_docstring
  796. def forward(
  797. self,
  798. input_ids: torch.LongTensor | None = None,
  799. attention_mask: torch.Tensor | None = None,
  800. encoder_hidden_states: torch.FloatTensor | None = None,
  801. encoder_attention_mask: torch.FloatTensor | None = None,
  802. past_key_values: Cache | None = None,
  803. inputs_embeds: torch.FloatTensor | None = None,
  804. labels: torch.LongTensor | None = None,
  805. use_cache: bool | None = None,
  806. logits_to_keep: int | torch.Tensor = 0,
  807. **kwargs: Unpack[TransformersKwargs],
  808. ) -> tuple | CausalLMOutputWithCrossAttentions:
  809. r"""
  810. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  811. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  812. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  813. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  814. Example:
  815. ```python
  816. >>> from transformers import AutoTokenizer, BlenderbotForCausalLM
  817. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
  818. >>> model = BlenderbotForCausalLM.from_pretrained("facebook/blenderbot-400M-distill")
  819. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  820. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  821. >>> outputs = model(**inputs)
  822. >>> logits = outputs.logits
  823. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  824. >>> list(logits.shape) == expected_shape
  825. True
  826. ```"""
  827. outputs: BaseModelOutputWithPastAndCrossAttentions = self.model.decoder(
  828. input_ids=input_ids,
  829. attention_mask=attention_mask,
  830. encoder_hidden_states=encoder_hidden_states,
  831. encoder_attention_mask=encoder_attention_mask,
  832. past_key_values=past_key_values,
  833. inputs_embeds=inputs_embeds,
  834. use_cache=use_cache,
  835. **kwargs,
  836. )
  837. hidden_states = outputs[0]
  838. # Only compute necessary logits
  839. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  840. logits = self.lm_head(hidden_states[:, slice_indices, :])
  841. loss = None
  842. if labels is not None:
  843. labels = labels.to(logits.device)
  844. loss_fct = CrossEntropyLoss()
  845. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  846. return CausalLMOutputWithCrossAttentions(
  847. loss=loss,
  848. logits=logits,
  849. past_key_values=outputs.past_key_values,
  850. hidden_states=outputs.hidden_states,
  851. attentions=outputs.attentions,
  852. cross_attentions=outputs.cross_attentions,
  853. )
  854. __all__ = [
  855. "BlenderbotForCausalLM",
  856. "BlenderbotForConditionalGeneration",
  857. "BlenderbotModel",
  858. "BlenderbotPreTrainedModel",
  859. ]