| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779 |
- # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch CPMAnt"""
- import math
- import torch
- import torch.nn.functional as F
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache
- from ...generation import GenerationMixin
- from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
- from ...modeling_utils import PreTrainedModel
- from ...utils import auto_docstring, logging
- from .configuration_cpmant import CpmAntConfig
- logger = logging.get_logger(__name__)
- class CpmAntLayerNorm(nn.Module):
- """
- We use Root Mean Square (RMS) Layer Normalization, please see https://huggingface.co/papers/1910.07467 for details."
- """
- def __init__(self, config: CpmAntConfig):
- super().__init__()
- self.eps = config.eps
- self.dim_norm = config.hidden_size
- self.weight = nn.Parameter(torch.empty(config.hidden_size))
- def forward(self, hidden_states: torch.Tensor):
- """
- Args:
- hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
- """
- if hidden_states.size(-1) != self.dim_norm:
- raise AssertionError("hidden_states.size(-1) != self.dim_norm")
- old_dtype = hidden_states.dtype
- variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
- hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight
- return hidden_states
- class CpmAntAttention(nn.Module):
- def __init__(self, config: CpmAntConfig, layer_idx=None):
- super().__init__()
- self.dim_model = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.dim_head = config.dim_head
- self.layer_idx = layer_idx
- self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
- self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
- self.project_v = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
- self.attention_out = nn.Linear(self.num_heads * self.dim_head, self.dim_model, bias=False)
- self.softmax = torch.nn.Softmax(dim=-1)
- if config.dropout_p is not None:
- self.dropout = torch.nn.Dropout(p=config.dropout_p)
- else:
- self.dropout = None
- def forward(
- self,
- hidden_q: torch.Tensor,
- hidden_kv: torch.Tensor,
- attention_mask: torch.BoolTensor,
- position_bias: torch.Tensor,
- output_attentions: bool | None = False,
- past_key_values: Cache | None = None,
- use_cache: bool | None = None,
- **kwargs,
- ):
- """
- Args:
- hidden_q (`torch.Tensor`):
- Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
- hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
- Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
- attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
- Avoid invalid areas to participate in the calculation of self-attention.
- position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
- Provide positional information to self-attention block.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers.
- past_key_values (`Cache`, *optional*):
- Cached past key and value projection states.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- """
- batch_size = hidden_q.size(0)
- len_q = hidden_q.size(1)
- len_k = hidden_kv.size(1)
- query = self.project_q(hidden_q)
- key = self.project_k(hidden_kv)
- value = self.project_v(hidden_kv)
- query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
- key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
- value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
- if past_key_values is not None:
- key, value = past_key_values.update(key, value, self.layer_idx)
- len_k = key.size(-2)
- # (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k)
- score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
- score = score + position_bias
- score = torch.masked_fill(
- score,
- attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
- torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
- )
- score = self.softmax(score)
- score = torch.masked_fill(
- score,
- attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
- torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
- )
- if output_attentions:
- attn_weights = score
- else:
- attn_weights = None
- if self.dropout is not None:
- score = self.dropout(score)
- # (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
- score = torch.matmul(score, value)
- score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
- score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
- score = self.attention_out(score)
- return score, attn_weights
- class CpmAntSelfAttentionBlock(nn.Module):
- def __init__(self, config: CpmAntConfig, layer_idx=None):
- super().__init__()
- self.layernorm_before_attention = CpmAntLayerNorm(config)
- self.self_attention = CpmAntAttention(config, layer_idx=layer_idx)
- if config.dropout_p:
- self.dropout = torch.nn.Dropout(config.dropout_p)
- else:
- self.dropout = None
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- position_bias: torch.Tensor | None = None,
- output_attentions: bool | None = False,
- past_key_values: Cache | None = None,
- use_cache: bool | None = None,
- **kwargs,
- ):
- """
- Args:
- hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
- Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
- attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
- Avoid invalid areas to participate in the calculation of self-attention.
- position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
- Provide positional information to self-attention block.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers.
- past_key_values (`Cache`, *optional*):
- Cached past key and value projection states.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- """
- outputs = self.layernorm_before_attention(hidden_states)
- outputs, attn_weights = self.self_attention(
- outputs,
- outputs,
- attention_mask,
- position_bias,
- output_attentions,
- past_key_values,
- use_cache,
- )
- if self.dropout is not None:
- outputs = self.dropout(outputs)
- hidden_states = hidden_states + outputs
- return hidden_states, attn_weights
- class CpmAntDenseGatedACT(nn.Module):
- def __init__(self, config: CpmAntConfig):
- super().__init__()
- self.w_0 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
- self.w_1 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
- self.act = torch.nn.GELU()
- def forward(self, hidden_states: torch.Tensor):
- """Transform an input tensor from one feature space to another via a nonlinear operation
- Args:
- hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
- """
- gate_score = self.act(self.w_0(hidden_states))
- hidden_states = self.w_1(hidden_states)
- hidden_states = gate_score * hidden_states
- return hidden_states
- class CpmAntFeedForward(nn.Module):
- def __init__(self, config: CpmAntConfig):
- super().__init__()
- self.w_in = CpmAntDenseGatedACT(config)
- if config.dropout_p is not None:
- self.dropout = torch.nn.Dropout(config.dropout_p)
- else:
- self.dropout = None
- self.w_out = nn.Linear(config.dim_ff, config.hidden_size, bias=False)
- def forward(self, hidden_states: torch.Tensor):
- """
- Args:
- hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
- """
- hidden_states = self.w_in(hidden_states)
- if self.dropout is not None:
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.w_out(hidden_states)
- return hidden_states
- class CpmAntFFNBlock(nn.Module):
- def __init__(self, config: CpmAntConfig):
- super().__init__()
- self.layernorm_before_ffn = CpmAntLayerNorm(config)
- self.ffn = CpmAntFeedForward(config)
- if config.dropout_p:
- self.dropout = torch.nn.Dropout(config.dropout_p)
- else:
- self.dropout = None
- def forward(
- self,
- hidden_states: torch.Tensor,
- ):
- """
- Args:
- hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
- Hidden states before feed forward layer.
- """
- ln_outputs = self.layernorm_before_ffn(hidden_states)
- outputs = self.ffn(ln_outputs)
- if self.dropout is not None:
- outputs = self.dropout(outputs)
- hidden_states = hidden_states + outputs
- return hidden_states
- class CpmAntTransformerBlock(nn.Module):
- def __init__(self, config: CpmAntConfig, layer_idx=None):
- super().__init__()
- self.self_att = CpmAntSelfAttentionBlock(config, layer_idx=layer_idx)
- self.ffn = CpmAntFFNBlock(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- position_bias: torch.Tensor | None = None,
- output_attentions: bool | None = False,
- past_key_values: Cache | None = None,
- use_cache: bool | None = None,
- **kwargs,
- ):
- """
- Args:
- hidden_states (`torch.Tensor`):
- Input to the layer of shape `(batch, seq_len, dim_model)`
- attention_mask (`torch.Tensor`):
- Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
- position_bias (`torch.Tensor`):
- Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers.
- past_key_values (`Cache`, *optional*):
- Cached past key and value projection states
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- """
- hidden_states, attn_weights = self.self_att(
- hidden_states,
- attention_mask=attention_mask,
- position_bias=position_bias,
- output_attentions=output_attentions,
- past_key_values=past_key_values,
- use_cache=use_cache,
- )
- hidden_states = self.ffn(hidden_states)
- return hidden_states, attn_weights
- class CpmAntEncoder(nn.Module):
- def __init__(self, config: CpmAntConfig):
- super().__init__()
- self.num_layers = config.num_hidden_layers
- self.layers = nn.ModuleList([CpmAntTransformerBlock(config, layer_idx=i) for i in range(self.num_layers)])
- self.output_layernorm = CpmAntLayerNorm(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- position_bias: torch.Tensor,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = None,
- **kwargs,
- ):
- """
- Args:
- hidden_states (`torch.Tensor`):
- Input to the layer of shape `(batch, seq_len, dim_model)`
- attention_mask (`torch.Tensor`):
- Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
- position_bias (`torch.Tensor`):
- Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers.
- past_key_values (`Cache`, *optional*):
- Cached past key and value projection states
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- """
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- for i, layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- layer_outputs = layer(
- hidden_states,
- attention_mask,
- position_bias,
- output_attentions=output_attentions,
- past_key_values=past_key_values,
- use_cache=use_cache,
- )
- hidden_states, attn_weights = layer_outputs
- if output_attentions:
- all_self_attns += (attn_weights,)
- hidden_states = self.output_layernorm(hidden_states)
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- return hidden_states, all_hidden_states, all_self_attns
- # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt
- class CpmAntIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- class CpmAntSegmentPositionEmbedding(nn.Module):
- def __init__(self, config: CpmAntConfig):
- super().__init__()
- self.num_heads = config.num_attention_heads
- self.num_buckets = config.position_bias_num_buckets
- self.max_distance = config.position_bias_max_distance
- self.num_segments = config.segment_types
- self.relative_attention_bias = nn.Parameter(
- torch.empty(
- config.segment_types * config.segment_types + config.position_bias_num_buckets,
- config.num_attention_heads,
- )
- )
- def forward(
- self,
- key_pos: torch.Tensor,
- query_pos: torch.Tensor,
- key_segment: torch.Tensor,
- query_segment: torch.Tensor,
- ):
- with torch.no_grad():
- batch = key_pos.size(0)
- keylen = key_pos.size(1)
- querylen = query_pos.size(1)
- if key_pos.size(0) != query_pos.size(0):
- raise AssertionError(
- f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!"
- )
- if keylen != key_segment.size(1) or querylen != query_segment.size(1):
- raise AssertionError(
- f"keylen should be equal to key_segment.size(1), but got {keylen} and {key_segment.size(1)}!"
- )
- if querylen != query_segment.size(1):
- raise AssertionError(
- f"querylen should be equal to query_segment.size(1), but got {querylen} and {query_segment.size(1)}!"
- )
- key_pos = key_pos.view(batch, -1, keylen)
- query_pos = query_pos.view(batch, querylen, -1)
- key_segment = key_segment.view(batch, -1, keylen)
- query_segment = query_segment.view(batch, querylen, -1)
- relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
- relative_position_bucket = relative_position_bucket + self.num_buckets
- # (batch, len_q, len_k)
- absolute_position_bucket = self._position_bucket(
- torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
- - torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
- num_buckets=self.num_buckets,
- max_distance=self.max_distance,
- )
- relative_position_bucket = torch.where(
- (key_segment == query_segment),
- absolute_position_bucket[None, :, :],
- relative_position_bucket,
- )
- # (batch, len_q, len_k, num_heads)
- embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
- # (batch, num_heads, len_q, len_k)
- embeds = embeds.permute(0, 3, 1, 2).contiguous()
- return embeds
- def _segment_relative_position_bucket(self, query_segment, key_segment):
- return query_segment * self.num_segments + key_segment
- def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
- relative_buckets = 0
- # always bidirectional in CPMAnt
- num_buckets //= 2
- relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
- relative_position = torch.abs(relative_position)
- max_exact = num_buckets // 2
- is_small = relative_position < max_exact
- relative_position_if_large = max_exact + (
- torch.log(relative_position.float() / max_exact)
- / math.log(max_distance / max_exact)
- * (num_buckets - max_exact)
- ).to(torch.int32)
- relative_position_if_large = torch.min(
- relative_position_if_large,
- torch.full_like(relative_position_if_large, num_buckets - 1),
- )
- relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_position_if_large)
- return relative_buckets
- # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMAnt
- class CpmAntOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- @auto_docstring
- class CpmAntPreTrainedModel(PreTrainedModel):
- config: CpmAntConfig
- base_model_prefix = "cpmant"
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- super()._init_weights(module)
- if isinstance(module, CpmAntLayerNorm):
- init.ones_(module.weight)
- elif isinstance(module, CpmAntSegmentPositionEmbedding):
- init.normal_(module.relative_attention_bias, mean=0.0, std=self.config.init_std)
- @auto_docstring
- class CpmAntModel(CpmAntPreTrainedModel):
- def __init__(self, config: CpmAntConfig):
- super().__init__(config)
- self.encoder = CpmAntEncoder(config)
- self.segment_embedding = nn.Embedding(config.segment_types, config.hidden_size)
- self.input_embedding = nn.Embedding(
- config.vocab_size + config.prompt_types * config.prompt_length, config.hidden_size
- )
- self.position_bias = CpmAntSegmentPositionEmbedding(config)
- self.prompt_length = config.prompt_length
- self.vocab_size = config.vocab_size
- self.post_init()
- def get_input_embeddings(self):
- return self.input_embedding
- def set_input_embeddings(self, embeddings, **kwargs):
- self.input_embedding = embeddings
- def _prepare_attention_mask(self, input_ids, span, context, length):
- batch = input_ids.size(0)
- seqlen = input_ids.size(1)
- device = input_ids.device
- directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
- attention_mask = context[:, None, :] | (
- context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
- )
- attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
- # mask for left padding
- mask_1d = (
- torch.tensor(list(range(seqlen - self.prompt_length))[::-1], device=device)[None, :].repeat(batch, 1)
- < length[:, None]
- )
- mask_1d = torch.cat((torch.ones(batch, self.prompt_length, device=device).bool(), mask_1d), dim=1)
- attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
- return attention_mask
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor] | BaseModelOutputWithPast:
- r"""
- input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- # add prompts ahead
- if input_ids.dtype != torch.int32:
- input_ids = input_ids.to(torch.int32)
- dtype, device = input_ids.dtype, input_ids.device
- segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device)
- length = (segment != 0).sum(-1).to(dtype=dtype, device=device)
- input_ids = torch.cat(
- (
- torch.arange(
- self.prompt_length * 2 + self.vocab_size,
- self.prompt_length * 3 + self.vocab_size,
- dtype=dtype,
- device=device,
- ).repeat(input_ids.size(0), 1),
- input_ids,
- ),
- dim=1,
- )
- batch, seq_length = input_ids.size()
- segment = torch.cat((torch.zeros(batch, self.prompt_length, dtype=dtype, device=device), segment), dim=1)
- context = torch.full((batch, seq_length), 1, dtype=dtype, device=device)
- position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
- span = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
- input_ids = input_ids.contiguous()
- hidden_states = self.input_embedding(input_ids)
- segment_states = self.segment_embedding(segment)
- if past_length != 0:
- segment_states = segment_states[:, -1:, :]
- hidden_states = hidden_states + segment_states
- attention_mask = self._prepare_attention_mask(input_ids, span, context, length)
- position_bias = self.position_bias(position, position, segment, segment)
- attention_mask = attention_mask[:, past_length:, :]
- position_bias = position_bias[:, :, past_length:, :]
- hidden_states = hidden_states[:, past_length:, :]
- hidden_states, all_hidden_states, all_attentions = self.encoder(
- hidden_states,
- attention_mask,
- position_bias,
- output_attentions,
- output_hidden_states,
- past_key_values,
- use_cache,
- )
- if past_length == 0:
- hidden_states = hidden_states[:, self.prompt_length :, :]
- # drop the prompt
- if all_attentions is not None:
- new_attentions = ()
- for attention in all_attentions:
- new_attentions += (attention[:, :, self.prompt_length :, self.prompt_length :],)
- all_attentions = new_attentions
- if all_hidden_states is not None:
- new_hidden_states = ()
- for hidden_state in all_hidden_states:
- new_hidden_states += (hidden_state[:, self.prompt_length :, :],)
- all_hidden_states = new_hidden_states
- if not return_dict:
- return tuple(
- v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None
- )
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- hidden_states=all_hidden_states,
- attentions=all_attentions,
- )
- @auto_docstring(
- custom_intro="""
- The CPMAnt Model with a language modeling head on top (linear layer with weights tied to the input embeddings).
- """
- )
- class CpmAntForCausalLM(CpmAntPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_head.weight": "cpmant.input_embedding.weight"}
- def __init__(self, config: CpmAntConfig):
- super().__init__(config)
- self.cpmant = CpmAntModel(config)
- # lm_head.weight is tied to cpmant.input_embedding.weight
- self.lm_head = nn.Linear(
- config.hidden_size, config.vocab_size + config.prompt_types * config.prompt_length, bias=False
- )
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- labels: torch.Tensor | None = None,
- return_dict: bool | None = None,
- attention_mask: torch.Tensor | None = None, # dummy parameter for text-generation pipeline
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs,
- ) -> tuple | CausalLMOutputWithPast:
- r"""
- input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss.
- Example:
- Text Generation with CpmAntForCausalLM.
- ```python
- >>> from transformers import CPMAntTokenizer, CpmAntForCausalLM
- >>> texts = "今天天气不错,"
- >>> model = CpmAntForCausalLM.from_pretrained("openbmb/cpm-ant-10b")
- >>> tokenizer = CPMAntTokenizer.from_pretrained("openbmb/cpm-ant-10b")
- >>> input_ids = tokenizer(texts, return_tensors="pt")
- >>> outputs = model.generate(**input_ids)
- >>> output_texts = tokenizer.batch_decode(outputs)
- >>> print(output_texts)
- ['今天天气不错,阳光明媚,我和妈妈一起去超市买东西。\n在超市里,我看到了一个很好玩的玩具,它的名字叫“机器人”。它有一个圆圆的脑袋,两只圆圆的眼睛,还有一个圆圆的']
- ```
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- model_output = self.cpmant(
- input_ids,
- output_attentions,
- output_hidden_states,
- past_key_values,
- use_cache,
- return_dict,
- )
- hidden_states = model_output.last_hidden_state if return_dict else model_output[0]
- # Only compute necessary logits
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss_func = CrossEntropyLoss()
- loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1))
- if not return_dict:
- output = (logits,) + model_output[1:]
- return ((loss,) + output) if loss is not None else output
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=model_output.past_key_values,
- hidden_states=model_output.hidden_states,
- attentions=model_output.attentions,
- )
- def get_input_embeddings(self):
- return self.cpmant.input_embedding
- def set_input_embeddings(self, embeddings):
- self.cpmant.input_embedding = embeddings
- __all__ = ["CpmAntForCausalLM", "CpmAntModel", "CpmAntPreTrainedModel"]
|