modeling_cpmant.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779
  1. # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch CPMAnt"""
  15. import math
  16. import torch
  17. import torch.nn.functional as F
  18. from torch import nn
  19. from torch.nn import CrossEntropyLoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import auto_docstring, logging
  27. from .configuration_cpmant import CpmAntConfig
  28. logger = logging.get_logger(__name__)
  29. class CpmAntLayerNorm(nn.Module):
  30. """
  31. We use Root Mean Square (RMS) Layer Normalization, please see https://huggingface.co/papers/1910.07467 for details."
  32. """
  33. def __init__(self, config: CpmAntConfig):
  34. super().__init__()
  35. self.eps = config.eps
  36. self.dim_norm = config.hidden_size
  37. self.weight = nn.Parameter(torch.empty(config.hidden_size))
  38. def forward(self, hidden_states: torch.Tensor):
  39. """
  40. Args:
  41. hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
  42. """
  43. if hidden_states.size(-1) != self.dim_norm:
  44. raise AssertionError("hidden_states.size(-1) != self.dim_norm")
  45. old_dtype = hidden_states.dtype
  46. variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
  47. hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight
  48. return hidden_states
  49. class CpmAntAttention(nn.Module):
  50. def __init__(self, config: CpmAntConfig, layer_idx=None):
  51. super().__init__()
  52. self.dim_model = config.hidden_size
  53. self.num_heads = config.num_attention_heads
  54. self.dim_head = config.dim_head
  55. self.layer_idx = layer_idx
  56. self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
  57. self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
  58. self.project_v = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
  59. self.attention_out = nn.Linear(self.num_heads * self.dim_head, self.dim_model, bias=False)
  60. self.softmax = torch.nn.Softmax(dim=-1)
  61. if config.dropout_p is not None:
  62. self.dropout = torch.nn.Dropout(p=config.dropout_p)
  63. else:
  64. self.dropout = None
  65. def forward(
  66. self,
  67. hidden_q: torch.Tensor,
  68. hidden_kv: torch.Tensor,
  69. attention_mask: torch.BoolTensor,
  70. position_bias: torch.Tensor,
  71. output_attentions: bool | None = False,
  72. past_key_values: Cache | None = None,
  73. use_cache: bool | None = None,
  74. **kwargs,
  75. ):
  76. """
  77. Args:
  78. hidden_q (`torch.Tensor`):
  79. Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
  80. hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
  81. Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
  82. attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
  83. Avoid invalid areas to participate in the calculation of self-attention.
  84. position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
  85. Provide positional information to self-attention block.
  86. output_attentions (`bool`, *optional*):
  87. Whether or not to return the attentions tensors of all attention layers.
  88. past_key_values (`Cache`, *optional*):
  89. Cached past key and value projection states.
  90. use_cache (`bool`, *optional*):
  91. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  92. (see `past_key_values`).
  93. """
  94. batch_size = hidden_q.size(0)
  95. len_q = hidden_q.size(1)
  96. len_k = hidden_kv.size(1)
  97. query = self.project_q(hidden_q)
  98. key = self.project_k(hidden_kv)
  99. value = self.project_v(hidden_kv)
  100. query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
  101. key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
  102. value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
  103. if past_key_values is not None:
  104. key, value = past_key_values.update(key, value, self.layer_idx)
  105. len_k = key.size(-2)
  106. # (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k)
  107. score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
  108. score = score + position_bias
  109. score = torch.masked_fill(
  110. score,
  111. attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
  112. torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
  113. )
  114. score = self.softmax(score)
  115. score = torch.masked_fill(
  116. score,
  117. attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
  118. torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
  119. )
  120. if output_attentions:
  121. attn_weights = score
  122. else:
  123. attn_weights = None
  124. if self.dropout is not None:
  125. score = self.dropout(score)
  126. # (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
  127. score = torch.matmul(score, value)
  128. score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
  129. score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
  130. score = self.attention_out(score)
  131. return score, attn_weights
  132. class CpmAntSelfAttentionBlock(nn.Module):
  133. def __init__(self, config: CpmAntConfig, layer_idx=None):
  134. super().__init__()
  135. self.layernorm_before_attention = CpmAntLayerNorm(config)
  136. self.self_attention = CpmAntAttention(config, layer_idx=layer_idx)
  137. if config.dropout_p:
  138. self.dropout = torch.nn.Dropout(config.dropout_p)
  139. else:
  140. self.dropout = None
  141. def forward(
  142. self,
  143. hidden_states: torch.Tensor,
  144. attention_mask: torch.Tensor,
  145. position_bias: torch.Tensor | None = None,
  146. output_attentions: bool | None = False,
  147. past_key_values: Cache | None = None,
  148. use_cache: bool | None = None,
  149. **kwargs,
  150. ):
  151. """
  152. Args:
  153. hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
  154. Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
  155. attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
  156. Avoid invalid areas to participate in the calculation of self-attention.
  157. position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
  158. Provide positional information to self-attention block.
  159. output_attentions (`bool`, *optional*):
  160. Whether or not to return the attentions tensors of all attention layers.
  161. past_key_values (`Cache`, *optional*):
  162. Cached past key and value projection states.
  163. use_cache (`bool`, *optional*):
  164. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  165. (see `past_key_values`).
  166. """
  167. outputs = self.layernorm_before_attention(hidden_states)
  168. outputs, attn_weights = self.self_attention(
  169. outputs,
  170. outputs,
  171. attention_mask,
  172. position_bias,
  173. output_attentions,
  174. past_key_values,
  175. use_cache,
  176. )
  177. if self.dropout is not None:
  178. outputs = self.dropout(outputs)
  179. hidden_states = hidden_states + outputs
  180. return hidden_states, attn_weights
  181. class CpmAntDenseGatedACT(nn.Module):
  182. def __init__(self, config: CpmAntConfig):
  183. super().__init__()
  184. self.w_0 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
  185. self.w_1 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
  186. self.act = torch.nn.GELU()
  187. def forward(self, hidden_states: torch.Tensor):
  188. """Transform an input tensor from one feature space to another via a nonlinear operation
  189. Args:
  190. hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
  191. """
  192. gate_score = self.act(self.w_0(hidden_states))
  193. hidden_states = self.w_1(hidden_states)
  194. hidden_states = gate_score * hidden_states
  195. return hidden_states
  196. class CpmAntFeedForward(nn.Module):
  197. def __init__(self, config: CpmAntConfig):
  198. super().__init__()
  199. self.w_in = CpmAntDenseGatedACT(config)
  200. if config.dropout_p is not None:
  201. self.dropout = torch.nn.Dropout(config.dropout_p)
  202. else:
  203. self.dropout = None
  204. self.w_out = nn.Linear(config.dim_ff, config.hidden_size, bias=False)
  205. def forward(self, hidden_states: torch.Tensor):
  206. """
  207. Args:
  208. hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
  209. """
  210. hidden_states = self.w_in(hidden_states)
  211. if self.dropout is not None:
  212. hidden_states = self.dropout(hidden_states)
  213. hidden_states = self.w_out(hidden_states)
  214. return hidden_states
  215. class CpmAntFFNBlock(nn.Module):
  216. def __init__(self, config: CpmAntConfig):
  217. super().__init__()
  218. self.layernorm_before_ffn = CpmAntLayerNorm(config)
  219. self.ffn = CpmAntFeedForward(config)
  220. if config.dropout_p:
  221. self.dropout = torch.nn.Dropout(config.dropout_p)
  222. else:
  223. self.dropout = None
  224. def forward(
  225. self,
  226. hidden_states: torch.Tensor,
  227. ):
  228. """
  229. Args:
  230. hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
  231. Hidden states before feed forward layer.
  232. """
  233. ln_outputs = self.layernorm_before_ffn(hidden_states)
  234. outputs = self.ffn(ln_outputs)
  235. if self.dropout is not None:
  236. outputs = self.dropout(outputs)
  237. hidden_states = hidden_states + outputs
  238. return hidden_states
  239. class CpmAntTransformerBlock(nn.Module):
  240. def __init__(self, config: CpmAntConfig, layer_idx=None):
  241. super().__init__()
  242. self.self_att = CpmAntSelfAttentionBlock(config, layer_idx=layer_idx)
  243. self.ffn = CpmAntFFNBlock(config)
  244. def forward(
  245. self,
  246. hidden_states: torch.Tensor,
  247. attention_mask: torch.Tensor,
  248. position_bias: torch.Tensor | None = None,
  249. output_attentions: bool | None = False,
  250. past_key_values: Cache | None = None,
  251. use_cache: bool | None = None,
  252. **kwargs,
  253. ):
  254. """
  255. Args:
  256. hidden_states (`torch.Tensor`):
  257. Input to the layer of shape `(batch, seq_len, dim_model)`
  258. attention_mask (`torch.Tensor`):
  259. Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
  260. position_bias (`torch.Tensor`):
  261. Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
  262. output_attentions (`bool`, *optional*):
  263. Whether or not to return the attentions tensors of all attention layers.
  264. past_key_values (`Cache`, *optional*):
  265. Cached past key and value projection states
  266. use_cache (`bool`, *optional*):
  267. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  268. (see `past_key_values`).
  269. """
  270. hidden_states, attn_weights = self.self_att(
  271. hidden_states,
  272. attention_mask=attention_mask,
  273. position_bias=position_bias,
  274. output_attentions=output_attentions,
  275. past_key_values=past_key_values,
  276. use_cache=use_cache,
  277. )
  278. hidden_states = self.ffn(hidden_states)
  279. return hidden_states, attn_weights
  280. class CpmAntEncoder(nn.Module):
  281. def __init__(self, config: CpmAntConfig):
  282. super().__init__()
  283. self.num_layers = config.num_hidden_layers
  284. self.layers = nn.ModuleList([CpmAntTransformerBlock(config, layer_idx=i) for i in range(self.num_layers)])
  285. self.output_layernorm = CpmAntLayerNorm(config)
  286. def forward(
  287. self,
  288. hidden_states: torch.Tensor,
  289. attention_mask: torch.Tensor,
  290. position_bias: torch.Tensor,
  291. output_attentions: bool | None = None,
  292. output_hidden_states: bool | None = None,
  293. past_key_values: Cache | None = None,
  294. use_cache: bool | None = None,
  295. **kwargs,
  296. ):
  297. """
  298. Args:
  299. hidden_states (`torch.Tensor`):
  300. Input to the layer of shape `(batch, seq_len, dim_model)`
  301. attention_mask (`torch.Tensor`):
  302. Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
  303. position_bias (`torch.Tensor`):
  304. Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
  305. output_attentions (`bool`, *optional*):
  306. Whether or not to return the attentions tensors of all attention layers.
  307. output_hidden_states (`bool`, *optional*):
  308. Whether or not to return the hidden states of all layers.
  309. past_key_values (`Cache`, *optional*):
  310. Cached past key and value projection states
  311. use_cache (`bool`, *optional*):
  312. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  313. (see `past_key_values`).
  314. """
  315. all_hidden_states = () if output_hidden_states else None
  316. all_self_attns = () if output_attentions else None
  317. for i, layer in enumerate(self.layers):
  318. if output_hidden_states:
  319. all_hidden_states += (hidden_states,)
  320. layer_outputs = layer(
  321. hidden_states,
  322. attention_mask,
  323. position_bias,
  324. output_attentions=output_attentions,
  325. past_key_values=past_key_values,
  326. use_cache=use_cache,
  327. )
  328. hidden_states, attn_weights = layer_outputs
  329. if output_attentions:
  330. all_self_attns += (attn_weights,)
  331. hidden_states = self.output_layernorm(hidden_states)
  332. if output_hidden_states:
  333. all_hidden_states += (hidden_states,)
  334. return hidden_states, all_hidden_states, all_self_attns
  335. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt
  336. class CpmAntIntermediate(nn.Module):
  337. def __init__(self, config):
  338. super().__init__()
  339. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  340. if isinstance(config.hidden_act, str):
  341. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  342. else:
  343. self.intermediate_act_fn = config.hidden_act
  344. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  345. hidden_states = self.dense(hidden_states)
  346. hidden_states = self.intermediate_act_fn(hidden_states)
  347. return hidden_states
  348. class CpmAntSegmentPositionEmbedding(nn.Module):
  349. def __init__(self, config: CpmAntConfig):
  350. super().__init__()
  351. self.num_heads = config.num_attention_heads
  352. self.num_buckets = config.position_bias_num_buckets
  353. self.max_distance = config.position_bias_max_distance
  354. self.num_segments = config.segment_types
  355. self.relative_attention_bias = nn.Parameter(
  356. torch.empty(
  357. config.segment_types * config.segment_types + config.position_bias_num_buckets,
  358. config.num_attention_heads,
  359. )
  360. )
  361. def forward(
  362. self,
  363. key_pos: torch.Tensor,
  364. query_pos: torch.Tensor,
  365. key_segment: torch.Tensor,
  366. query_segment: torch.Tensor,
  367. ):
  368. with torch.no_grad():
  369. batch = key_pos.size(0)
  370. keylen = key_pos.size(1)
  371. querylen = query_pos.size(1)
  372. if key_pos.size(0) != query_pos.size(0):
  373. raise AssertionError(
  374. f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!"
  375. )
  376. if keylen != key_segment.size(1) or querylen != query_segment.size(1):
  377. raise AssertionError(
  378. f"keylen should be equal to key_segment.size(1), but got {keylen} and {key_segment.size(1)}!"
  379. )
  380. if querylen != query_segment.size(1):
  381. raise AssertionError(
  382. f"querylen should be equal to query_segment.size(1), but got {querylen} and {query_segment.size(1)}!"
  383. )
  384. key_pos = key_pos.view(batch, -1, keylen)
  385. query_pos = query_pos.view(batch, querylen, -1)
  386. key_segment = key_segment.view(batch, -1, keylen)
  387. query_segment = query_segment.view(batch, querylen, -1)
  388. relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
  389. relative_position_bucket = relative_position_bucket + self.num_buckets
  390. # (batch, len_q, len_k)
  391. absolute_position_bucket = self._position_bucket(
  392. torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
  393. - torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
  394. num_buckets=self.num_buckets,
  395. max_distance=self.max_distance,
  396. )
  397. relative_position_bucket = torch.where(
  398. (key_segment == query_segment),
  399. absolute_position_bucket[None, :, :],
  400. relative_position_bucket,
  401. )
  402. # (batch, len_q, len_k, num_heads)
  403. embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
  404. # (batch, num_heads, len_q, len_k)
  405. embeds = embeds.permute(0, 3, 1, 2).contiguous()
  406. return embeds
  407. def _segment_relative_position_bucket(self, query_segment, key_segment):
  408. return query_segment * self.num_segments + key_segment
  409. def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
  410. relative_buckets = 0
  411. # always bidirectional in CPMAnt
  412. num_buckets //= 2
  413. relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
  414. relative_position = torch.abs(relative_position)
  415. max_exact = num_buckets // 2
  416. is_small = relative_position < max_exact
  417. relative_position_if_large = max_exact + (
  418. torch.log(relative_position.float() / max_exact)
  419. / math.log(max_distance / max_exact)
  420. * (num_buckets - max_exact)
  421. ).to(torch.int32)
  422. relative_position_if_large = torch.min(
  423. relative_position_if_large,
  424. torch.full_like(relative_position_if_large, num_buckets - 1),
  425. )
  426. relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_position_if_large)
  427. return relative_buckets
  428. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMAnt
  429. class CpmAntOutput(nn.Module):
  430. def __init__(self, config):
  431. super().__init__()
  432. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  433. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  434. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  435. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  436. hidden_states = self.dense(hidden_states)
  437. hidden_states = self.dropout(hidden_states)
  438. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  439. return hidden_states
  440. @auto_docstring
  441. class CpmAntPreTrainedModel(PreTrainedModel):
  442. config: CpmAntConfig
  443. base_model_prefix = "cpmant"
  444. @torch.no_grad()
  445. def _init_weights(self, module):
  446. """Initialize the weights"""
  447. super()._init_weights(module)
  448. if isinstance(module, CpmAntLayerNorm):
  449. init.ones_(module.weight)
  450. elif isinstance(module, CpmAntSegmentPositionEmbedding):
  451. init.normal_(module.relative_attention_bias, mean=0.0, std=self.config.init_std)
  452. @auto_docstring
  453. class CpmAntModel(CpmAntPreTrainedModel):
  454. def __init__(self, config: CpmAntConfig):
  455. super().__init__(config)
  456. self.encoder = CpmAntEncoder(config)
  457. self.segment_embedding = nn.Embedding(config.segment_types, config.hidden_size)
  458. self.input_embedding = nn.Embedding(
  459. config.vocab_size + config.prompt_types * config.prompt_length, config.hidden_size
  460. )
  461. self.position_bias = CpmAntSegmentPositionEmbedding(config)
  462. self.prompt_length = config.prompt_length
  463. self.vocab_size = config.vocab_size
  464. self.post_init()
  465. def get_input_embeddings(self):
  466. return self.input_embedding
  467. def set_input_embeddings(self, embeddings, **kwargs):
  468. self.input_embedding = embeddings
  469. def _prepare_attention_mask(self, input_ids, span, context, length):
  470. batch = input_ids.size(0)
  471. seqlen = input_ids.size(1)
  472. device = input_ids.device
  473. directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
  474. attention_mask = context[:, None, :] | (
  475. context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
  476. )
  477. attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
  478. # mask for left padding
  479. mask_1d = (
  480. torch.tensor(list(range(seqlen - self.prompt_length))[::-1], device=device)[None, :].repeat(batch, 1)
  481. < length[:, None]
  482. )
  483. mask_1d = torch.cat((torch.ones(batch, self.prompt_length, device=device).bool(), mask_1d), dim=1)
  484. attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
  485. return attention_mask
  486. @auto_docstring
  487. def forward(
  488. self,
  489. input_ids: torch.Tensor | None = None,
  490. output_attentions: bool | None = None,
  491. output_hidden_states: bool | None = None,
  492. past_key_values: Cache | None = None,
  493. use_cache: bool | None = None,
  494. return_dict: bool | None = None,
  495. **kwargs,
  496. ) -> tuple[torch.Tensor] | BaseModelOutputWithPast:
  497. r"""
  498. input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
  499. Indices of input sequence tokens in the vocabulary.
  500. Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  501. [`PreTrainedTokenizer.__call__`] for details.
  502. [What are input IDs?](../glossary#input-ids)
  503. """
  504. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  505. output_hidden_states = (
  506. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  507. )
  508. return_dict = return_dict if return_dict is not None else self.config.return_dict
  509. use_cache = use_cache if use_cache is not None else self.config.use_cache
  510. # add prompts ahead
  511. if input_ids.dtype != torch.int32:
  512. input_ids = input_ids.to(torch.int32)
  513. dtype, device = input_ids.dtype, input_ids.device
  514. segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device)
  515. length = (segment != 0).sum(-1).to(dtype=dtype, device=device)
  516. input_ids = torch.cat(
  517. (
  518. torch.arange(
  519. self.prompt_length * 2 + self.vocab_size,
  520. self.prompt_length * 3 + self.vocab_size,
  521. dtype=dtype,
  522. device=device,
  523. ).repeat(input_ids.size(0), 1),
  524. input_ids,
  525. ),
  526. dim=1,
  527. )
  528. batch, seq_length = input_ids.size()
  529. segment = torch.cat((torch.zeros(batch, self.prompt_length, dtype=dtype, device=device), segment), dim=1)
  530. context = torch.full((batch, seq_length), 1, dtype=dtype, device=device)
  531. position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
  532. span = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
  533. if use_cache and past_key_values is None:
  534. past_key_values = DynamicCache(config=self.config)
  535. past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  536. input_ids = input_ids.contiguous()
  537. hidden_states = self.input_embedding(input_ids)
  538. segment_states = self.segment_embedding(segment)
  539. if past_length != 0:
  540. segment_states = segment_states[:, -1:, :]
  541. hidden_states = hidden_states + segment_states
  542. attention_mask = self._prepare_attention_mask(input_ids, span, context, length)
  543. position_bias = self.position_bias(position, position, segment, segment)
  544. attention_mask = attention_mask[:, past_length:, :]
  545. position_bias = position_bias[:, :, past_length:, :]
  546. hidden_states = hidden_states[:, past_length:, :]
  547. hidden_states, all_hidden_states, all_attentions = self.encoder(
  548. hidden_states,
  549. attention_mask,
  550. position_bias,
  551. output_attentions,
  552. output_hidden_states,
  553. past_key_values,
  554. use_cache,
  555. )
  556. if past_length == 0:
  557. hidden_states = hidden_states[:, self.prompt_length :, :]
  558. # drop the prompt
  559. if all_attentions is not None:
  560. new_attentions = ()
  561. for attention in all_attentions:
  562. new_attentions += (attention[:, :, self.prompt_length :, self.prompt_length :],)
  563. all_attentions = new_attentions
  564. if all_hidden_states is not None:
  565. new_hidden_states = ()
  566. for hidden_state in all_hidden_states:
  567. new_hidden_states += (hidden_state[:, self.prompt_length :, :],)
  568. all_hidden_states = new_hidden_states
  569. if not return_dict:
  570. return tuple(
  571. v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None
  572. )
  573. return BaseModelOutputWithPast(
  574. last_hidden_state=hidden_states,
  575. past_key_values=past_key_values,
  576. hidden_states=all_hidden_states,
  577. attentions=all_attentions,
  578. )
  579. @auto_docstring(
  580. custom_intro="""
  581. The CPMAnt Model with a language modeling head on top (linear layer with weights tied to the input embeddings).
  582. """
  583. )
  584. class CpmAntForCausalLM(CpmAntPreTrainedModel, GenerationMixin):
  585. _tied_weights_keys = {"lm_head.weight": "cpmant.input_embedding.weight"}
  586. def __init__(self, config: CpmAntConfig):
  587. super().__init__(config)
  588. self.cpmant = CpmAntModel(config)
  589. # lm_head.weight is tied to cpmant.input_embedding.weight
  590. self.lm_head = nn.Linear(
  591. config.hidden_size, config.vocab_size + config.prompt_types * config.prompt_length, bias=False
  592. )
  593. self.post_init()
  594. @auto_docstring
  595. def forward(
  596. self,
  597. input_ids: torch.Tensor | None = None,
  598. past_key_values: Cache | None = None,
  599. use_cache: bool | None = None,
  600. output_attentions: bool | None = None,
  601. output_hidden_states: bool | None = None,
  602. labels: torch.Tensor | None = None,
  603. return_dict: bool | None = None,
  604. attention_mask: torch.Tensor | None = None, # dummy parameter for text-generation pipeline
  605. logits_to_keep: int | torch.Tensor = 0,
  606. **kwargs,
  607. ) -> tuple | CausalLMOutputWithPast:
  608. r"""
  609. input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
  610. Indices of input sequence tokens in the vocabulary.
  611. Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  612. [`PreTrainedTokenizer.__call__`] for details.
  613. [What are input IDs?](../glossary#input-ids)
  614. labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  615. Labels for computing the masked language modeling loss.
  616. Example:
  617. Text Generation with CpmAntForCausalLM.
  618. ```python
  619. >>> from transformers import CPMAntTokenizer, CpmAntForCausalLM
  620. >>> texts = "今天天气不错,"
  621. >>> model = CpmAntForCausalLM.from_pretrained("openbmb/cpm-ant-10b")
  622. >>> tokenizer = CPMAntTokenizer.from_pretrained("openbmb/cpm-ant-10b")
  623. >>> input_ids = tokenizer(texts, return_tensors="pt")
  624. >>> outputs = model.generate(**input_ids)
  625. >>> output_texts = tokenizer.batch_decode(outputs)
  626. >>> print(output_texts)
  627. ['今天天气不错,阳光明媚,我和妈妈一起去超市买东西。\n在超市里,我看到了一个很好玩的玩具,它的名字叫“机器人”。它有一个圆圆的脑袋,两只圆圆的眼睛,还有一个圆圆的']
  628. ```
  629. """
  630. return_dict = return_dict if return_dict is not None else self.config.return_dict
  631. model_output = self.cpmant(
  632. input_ids,
  633. output_attentions,
  634. output_hidden_states,
  635. past_key_values,
  636. use_cache,
  637. return_dict,
  638. )
  639. hidden_states = model_output.last_hidden_state if return_dict else model_output[0]
  640. # Only compute necessary logits
  641. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  642. logits = self.lm_head(hidden_states[:, slice_indices, :])
  643. loss = None
  644. if labels is not None:
  645. loss_func = CrossEntropyLoss()
  646. loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1))
  647. if not return_dict:
  648. output = (logits,) + model_output[1:]
  649. return ((loss,) + output) if loss is not None else output
  650. return CausalLMOutputWithPast(
  651. loss=loss,
  652. logits=logits,
  653. past_key_values=model_output.past_key_values,
  654. hidden_states=model_output.hidden_states,
  655. attentions=model_output.attentions,
  656. )
  657. def get_input_embeddings(self):
  658. return self.cpmant.input_embedding
  659. def set_input_embeddings(self, embeddings):
  660. self.cpmant.input_embedding = embeddings
  661. __all__ = ["CpmAntForCausalLM", "CpmAntModel", "CpmAntPreTrainedModel"]