modeling_codegen.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. # Copyright 2022 Salesforce authors, The EleutherAI, and HuggingFace Teams. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch CodeGen model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import (
  27. auto_docstring,
  28. logging,
  29. )
  30. from .configuration_codegen import CodeGenConfig
  31. logger = logging.get_logger(__name__)
  32. # Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
  33. def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
  34. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
  35. sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
  36. return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
  37. # Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
  38. def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
  39. x1 = x[:, :, :, ::2]
  40. x2 = x[:, :, :, 1::2]
  41. x = torch.stack((-x2, x1), dim=-1)
  42. return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
  43. # Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
  44. def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
  45. sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
  46. cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
  47. return (tensor * cos) + (rotate_every_two(tensor) * sin)
  48. class CodeGenAttention(nn.Module):
  49. def __init__(self, config, layer_idx=None):
  50. super().__init__()
  51. self.max_positions = config.max_position_embeddings
  52. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  53. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  54. self.layer_idx = layer_idx
  55. if layer_idx is None:
  56. logger.warning_once(
  57. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  58. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  59. "when creating this class."
  60. )
  61. self.embed_dim = config.hidden_size
  62. self.num_attention_heads = config.num_attention_heads
  63. self.head_dim = self.embed_dim // self.num_attention_heads
  64. if self.head_dim * self.num_attention_heads != self.embed_dim:
  65. raise ValueError(
  66. f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
  67. f" `num_attention_heads`: {self.num_attention_heads})."
  68. )
  69. self.scale_attn = math.sqrt(self.head_dim)
  70. self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
  71. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  72. self.rotary_dim = config.rotary_dim
  73. self.pos_embd_dim = self.rotary_dim or self.embed_dim
  74. self.register_buffer(
  75. "embed_positions", create_sinusoidal_positions(self.max_positions, self.pos_embd_dim), persistent=False
  76. )
  77. def _split_heads(self, x, n_head, dim_head, mp_num):
  78. reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
  79. reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
  80. return reshaped
  81. def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
  82. """
  83. Merges attn_head_size dim and num_attn_heads dim into n_ctx
  84. """
  85. if len(tensor.shape) == 5:
  86. tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
  87. elif len(tensor.shape) == 4:
  88. tensor = tensor.permute(0, 2, 1, 3).contiguous()
  89. else:
  90. raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
  91. new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
  92. return tensor.view(new_shape)
  93. def _attn(
  94. self,
  95. query,
  96. key,
  97. value,
  98. attention_mask=None,
  99. ):
  100. # Keep the attention weights computation in fp32 to avoid overflow issues
  101. query = query.to(torch.float32)
  102. key = key.to(torch.float32)
  103. attn_weights = torch.matmul(query, key.transpose(-1, -2))
  104. if attention_mask is not None:
  105. attn_weights = attn_weights + attention_mask
  106. attn_weights = attn_weights / self.scale_attn
  107. attn_weights = nn.Softmax(dim=-1)(attn_weights)
  108. attn_weights = attn_weights.to(value.dtype)
  109. attn_weights = self.attn_dropout(attn_weights)
  110. attn_output = torch.matmul(attn_weights, value)
  111. return attn_output, attn_weights
  112. def forward(
  113. self,
  114. hidden_states: torch.FloatTensor | None,
  115. layer_past: Cache | None = None,
  116. attention_mask: torch.FloatTensor | None = None,
  117. position_ids: torch.LongTensor | None = None,
  118. use_cache: bool | None = False,
  119. output_attentions: bool | None = False,
  120. ) -> (
  121. tuple[torch.Tensor, tuple[torch.Tensor]]
  122. | tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]]
  123. | None
  124. ):
  125. qkv = self.qkv_proj(hidden_states)
  126. # TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
  127. mp_num = 4
  128. qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
  129. local_dim = self.head_dim * self.num_attention_heads // mp_num
  130. query, value, key = torch.split(qkv_split, local_dim, dim=-1)
  131. query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
  132. key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
  133. value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
  134. value = value.permute(0, 2, 1, 3)
  135. embed_positions = self.embed_positions
  136. if embed_positions.device != position_ids.device:
  137. embed_positions = embed_positions.to(position_ids.device)
  138. self.embed_positions = embed_positions
  139. sincos = embed_positions[position_ids]
  140. sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
  141. if self.rotary_dim is not None:
  142. k_rot = key[:, :, :, : self.rotary_dim]
  143. k_pass = key[:, :, :, self.rotary_dim :]
  144. q_rot = query[:, :, :, : self.rotary_dim]
  145. q_pass = query[:, :, :, self.rotary_dim :]
  146. k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
  147. q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
  148. key = torch.cat([k_rot, k_pass], dim=-1)
  149. query = torch.cat([q_rot, q_pass], dim=-1)
  150. else:
  151. key = apply_rotary_pos_emb(key, sin, cos)
  152. query = apply_rotary_pos_emb(query, sin, cos)
  153. key = key.permute(0, 2, 1, 3)
  154. query = query.permute(0, 2, 1, 3)
  155. # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
  156. # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
  157. if layer_past is not None:
  158. key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx)
  159. # compute self-attention: V x Softmax(QK^T)
  160. attn_output, attn_weights = self._attn(query, key, value, attention_mask)
  161. attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
  162. attn_output = self.out_proj(attn_output)
  163. attn_output = self.resid_dropout(attn_output)
  164. return attn_output, attn_weights
  165. # Copied from transformers.models.gptj.modeling_gptj.GPTJMLP with GPTJ->CodeGen
  166. class CodeGenMLP(nn.Module):
  167. def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
  168. super().__init__()
  169. embed_dim = config.n_embd
  170. self.fc_in = nn.Linear(embed_dim, intermediate_size)
  171. self.fc_out = nn.Linear(intermediate_size, embed_dim)
  172. self.act = ACT2FN[config.activation_function]
  173. self.dropout = nn.Dropout(config.resid_pdrop)
  174. def forward(self, hidden_states: torch.FloatTensor | None) -> torch.FloatTensor:
  175. hidden_states = self.fc_in(hidden_states)
  176. hidden_states = self.act(hidden_states)
  177. hidden_states = self.fc_out(hidden_states)
  178. hidden_states = self.dropout(hidden_states)
  179. return hidden_states
  180. # Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
  181. class CodeGenBlock(GradientCheckpointingLayer):
  182. # Ignore copy
  183. def __init__(self, config, layer_idx=None):
  184. super().__init__()
  185. inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
  186. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  187. self.attn = CodeGenAttention(config, layer_idx)
  188. self.mlp = CodeGenMLP(inner_dim, config)
  189. def forward(
  190. self,
  191. hidden_states: torch.FloatTensor | None,
  192. layer_past: Cache | None = None,
  193. attention_mask: torch.FloatTensor | None = None,
  194. position_ids: torch.LongTensor | None = None,
  195. use_cache: bool | None = False,
  196. output_attentions: bool | None = False,
  197. **kwargs,
  198. ) -> tuple[torch.Tensor] | tuple[torch.Tensor, tuple[torch.FloatTensor, ...]] | None:
  199. residual = hidden_states
  200. hidden_states = self.ln_1(hidden_states)
  201. attn_outputs, attn_weights = self.attn(
  202. hidden_states=hidden_states,
  203. layer_past=layer_past,
  204. attention_mask=attention_mask,
  205. position_ids=position_ids,
  206. use_cache=use_cache,
  207. output_attentions=output_attentions,
  208. )
  209. feed_forward_hidden_states = self.mlp(hidden_states)
  210. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  211. return hidden_states, attn_weights
  212. @auto_docstring
  213. class CodeGenPreTrainedModel(PreTrainedModel):
  214. config: CodeGenConfig
  215. base_model_prefix = "transformer"
  216. supports_gradient_checkpointing = True
  217. _no_split_modules = ["CodeGenBlock"]
  218. _skip_keys_device_placement = "past_key_values"
  219. _can_compile_fullgraph = True
  220. def _init_weights(self, module):
  221. super()._init_weights(module)
  222. if isinstance(module, CodeGenAttention):
  223. init.copy_(module.embed_positions, create_sinusoidal_positions(module.max_positions, module.pos_embd_dim))
  224. @auto_docstring
  225. class CodeGenModel(CodeGenPreTrainedModel):
  226. def __init__(self, config):
  227. super().__init__(config)
  228. self.embed_dim = config.n_embd
  229. self.vocab_size = config.vocab_size
  230. self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
  231. self.drop = nn.Dropout(config.embd_pdrop)
  232. self.h = nn.ModuleList([CodeGenBlock(config, layer_idx=i) for i in range(config.n_layer)])
  233. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  234. self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
  235. self.gradient_checkpointing = False
  236. # Initialize weights and apply final processing
  237. self.post_init()
  238. def get_input_embeddings(self):
  239. return self.wte
  240. def set_input_embeddings(self, new_embeddings):
  241. self.wte = new_embeddings
  242. @auto_docstring
  243. def forward(
  244. self,
  245. input_ids: torch.LongTensor | None = None,
  246. past_key_values: Cache | None = None,
  247. attention_mask: torch.FloatTensor | None = None,
  248. token_type_ids: torch.LongTensor | None = None,
  249. position_ids: torch.LongTensor | None = None,
  250. inputs_embeds: torch.FloatTensor | None = None,
  251. use_cache: bool | None = None,
  252. output_attentions: bool | None = None,
  253. output_hidden_states: bool | None = None,
  254. return_dict: bool | None = None,
  255. **kwargs, # NOOP kwargs, for now
  256. ) -> tuple | BaseModelOutputWithPast:
  257. r"""
  258. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  259. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  260. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  261. model's internal embedding lookup matrix.
  262. """
  263. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  264. output_hidden_states = (
  265. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  266. )
  267. use_cache = use_cache if use_cache is not None else self.config.use_cache
  268. return_dict = return_dict if return_dict is not None else self.config.return_dict
  269. if (input_ids is None) ^ (inputs_embeds is not None):
  270. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  271. if self.gradient_checkpointing and self.training:
  272. if use_cache:
  273. logger.warning_once(
  274. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  275. )
  276. use_cache = False
  277. if inputs_embeds is None:
  278. inputs_embeds = self.wte(input_ids)
  279. if use_cache and past_key_values is None:
  280. past_key_values = DynamicCache(config=self.config)
  281. seq_length = inputs_embeds.shape[1]
  282. if position_ids is None:
  283. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  284. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  285. position_ids = position_ids.unsqueeze(0)
  286. causal_mask = create_causal_mask(
  287. config=self.config,
  288. inputs_embeds=inputs_embeds,
  289. attention_mask=attention_mask,
  290. past_key_values=past_key_values,
  291. position_ids=position_ids,
  292. )
  293. hidden_states = inputs_embeds
  294. if token_type_ids is not None:
  295. token_type_ids = token_type_ids.view(-1, seq_length)
  296. token_type_embeds = self.wte(token_type_ids)
  297. hidden_states = hidden_states + token_type_embeds
  298. hidden_states = self.drop(hidden_states)
  299. output_shape = (-1, seq_length, hidden_states.size(-1))
  300. all_self_attentions = () if output_attentions else None
  301. all_hidden_states = () if output_hidden_states else None
  302. for i, block in enumerate(self.h):
  303. if output_hidden_states:
  304. all_hidden_states = all_hidden_states + (hidden_states,)
  305. outputs = block(
  306. hidden_states,
  307. layer_past=past_key_values,
  308. attention_mask=causal_mask,
  309. position_ids=position_ids,
  310. use_cache=use_cache,
  311. output_attentions=output_attentions,
  312. )
  313. hidden_states = outputs[0]
  314. if output_attentions:
  315. all_self_attentions = all_self_attentions + (outputs[1],)
  316. hidden_states = self.ln_f(hidden_states)
  317. hidden_states = hidden_states.view(output_shape)
  318. # Add last hidden state
  319. if output_hidden_states:
  320. all_hidden_states = all_hidden_states + (hidden_states,)
  321. if not return_dict:
  322. return tuple(
  323. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
  324. )
  325. return BaseModelOutputWithPast(
  326. last_hidden_state=hidden_states,
  327. past_key_values=past_key_values,
  328. hidden_states=all_hidden_states,
  329. attentions=all_self_attentions,
  330. )
  331. @auto_docstring(
  332. custom_intro="""
  333. The CodeGen Model transformer with a language modeling head on top.
  334. """
  335. )
  336. class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin):
  337. _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
  338. def __init__(self, config):
  339. super().__init__(config)
  340. self.transformer = CodeGenModel(config)
  341. self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
  342. # Initialize weights and apply final processing
  343. self.post_init()
  344. @auto_docstring
  345. def forward(
  346. self,
  347. input_ids: torch.LongTensor | None = None,
  348. past_key_values: Cache | None = None,
  349. attention_mask: torch.FloatTensor | None = None,
  350. token_type_ids: torch.LongTensor | None = None,
  351. position_ids: torch.LongTensor | None = None,
  352. inputs_embeds: torch.FloatTensor | None = None,
  353. labels: torch.LongTensor | None = None,
  354. use_cache: bool | None = None,
  355. output_attentions: bool | None = None,
  356. output_hidden_states: bool | None = None,
  357. return_dict: bool | None = None,
  358. logits_to_keep: int | torch.Tensor = 0,
  359. **kwargs,
  360. ) -> tuple | CausalLMOutputWithPast:
  361. r"""
  362. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  363. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  364. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  365. model's internal embedding lookup matrix.
  366. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  367. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  368. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  369. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  370. """
  371. return_dict = return_dict if return_dict is not None else self.config.return_dict
  372. transformer_outputs = self.transformer(
  373. input_ids,
  374. past_key_values=past_key_values,
  375. attention_mask=attention_mask,
  376. token_type_ids=token_type_ids,
  377. position_ids=position_ids,
  378. inputs_embeds=inputs_embeds,
  379. use_cache=use_cache,
  380. output_attentions=output_attentions,
  381. output_hidden_states=output_hidden_states,
  382. return_dict=return_dict,
  383. )
  384. hidden_states = transformer_outputs[0]
  385. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  386. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  387. logits = self.lm_head(hidden_states[:, slice_indices, :])
  388. loss = None
  389. if labels is not None:
  390. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  391. if not return_dict:
  392. output = (logits,) + transformer_outputs[1:]
  393. return ((loss,) + output) if loss is not None else output
  394. return CausalLMOutputWithPast(
  395. loss=loss,
  396. logits=logits,
  397. past_key_values=transformer_outputs.past_key_values,
  398. hidden_states=transformer_outputs.hidden_states,
  399. attentions=transformer_outputs.attentions,
  400. )
  401. __all__ = ["CodeGenForCausalLM", "CodeGenModel", "CodeGenPreTrainedModel"]