mingpt.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # @OldAPIStack
  2. # LICENSE: MIT
  3. """
  4. Adapted from https://github.com/karpathy/minGPT
  5. Full definition of a GPT Language Model, all of it in this single file.
  6. References:
  7. 1) the official GPT-2 TensorFlow implementation released by OpenAI:
  8. https://github.com/openai/gpt-2/blob/master/src/model.py
  9. 2) huggingface/transformers PyTorch implementation:
  10. https://github.com/huggingface/transformers/blob/main/src/transformers
  11. /models/gpt2/modeling_gpt2.py
  12. """
  13. import math
  14. from dataclasses import dataclass
  15. from typing import Tuple
  16. import torch
  17. import torch.nn as nn
  18. from torch.nn import functional as F
  19. from ray._common.deprecation import Deprecated
  20. from ray.rllib.utils.annotations import DeveloperAPI
  21. @DeveloperAPI
  22. @dataclass
  23. class GPTConfig:
  24. # block size must be provided
  25. block_size: int
  26. # transformer config
  27. n_layer: int = 12
  28. n_head: int = 12
  29. n_embed: int = 768
  30. # dropout config
  31. embed_pdrop: float = 0.1
  32. resid_pdrop: float = 0.1
  33. attn_pdrop: float = 0.1
  34. @Deprecated(error=False)
  35. class NewGELU(nn.Module):
  36. """
  37. Implementation of the GELU activation function currently in Google BERT
  38. repo (identical to OpenAI GPT).
  39. Reference: Gaussian Error Linear Units (GELU) paper:
  40. https://arxiv.org/abs/1606.08415
  41. """
  42. def forward(self, x):
  43. return (
  44. 0.5
  45. * x
  46. * (
  47. 1.0
  48. + torch.tanh(
  49. math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
  50. )
  51. )
  52. )
  53. @Deprecated(error=False)
  54. class CausalSelfAttention(nn.Module):
  55. """
  56. Vanilla multi-head masked self-attention layer with a projection at the end.
  57. It is possible to use torch.nn.MultiheadAttention here but I am including an
  58. explicit implementation here to show that there is nothing too scary here.
  59. """
  60. def __init__(self, config: GPTConfig):
  61. super().__init__()
  62. assert config.n_embed % config.n_head == 0
  63. # key, query, value projections for all heads, but in a batch
  64. self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
  65. # output projection
  66. self.c_proj = nn.Linear(config.n_embed, config.n_embed)
  67. # regularization
  68. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  69. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  70. # causal mask to ensure that attention is only applied to the left
  71. # in the input sequence
  72. self.register_buffer(
  73. "bias",
  74. torch.tril(torch.ones(config.block_size, config.block_size)).view(
  75. 1, 1, config.block_size, config.block_size
  76. ),
  77. )
  78. self.n_head = config.n_head
  79. self.n_embed = config.n_embed
  80. def forward(self, x, attention_masks=None):
  81. # batch size, sequence length, embedding dimensionality (n_embed)
  82. B, T, C = x.size()
  83. # calculate query, key, values for all heads in batch and move head
  84. # forward to be the batch dim
  85. q, k, v = self.c_attn(x).split(self.n_embed, dim=2)
  86. # (B, nh, T, hs)
  87. k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
  88. # (B, nh, T, hs)
  89. q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
  90. # (B, nh, T, hs)
  91. v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
  92. # causal self-attention; Self-attend:
  93. # (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
  94. att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
  95. att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
  96. if attention_masks is not None:
  97. att = att + attention_masks
  98. att = F.softmax(att, dim=-1)
  99. att = self.attn_dropout(att)
  100. y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
  101. # re-assemble all head outputs side by side
  102. y = y.transpose(1, 2).contiguous().view(B, T, C)
  103. # output projection
  104. y = self.resid_dropout(self.c_proj(y))
  105. return y, att
  106. @Deprecated(error=False)
  107. class Block(nn.Module):
  108. """an unassuming Transformer block"""
  109. def __init__(self, config: GPTConfig):
  110. super().__init__()
  111. self.ln_1 = nn.LayerNorm(config.n_embed)
  112. self.attn = CausalSelfAttention(config)
  113. self.ln_2 = nn.LayerNorm(config.n_embed)
  114. self.mlp = nn.ModuleDict(
  115. dict(
  116. c_fc=nn.Linear(config.n_embed, 4 * config.n_embed),
  117. c_proj=nn.Linear(4 * config.n_embed, config.n_embed),
  118. act=NewGELU(),
  119. dropout=nn.Dropout(config.resid_pdrop),
  120. )
  121. )
  122. def forward(self, x, attention_masks=None):
  123. # Multi-head attention sub-layer.
  124. x_att, att = self.attn(self.ln_1(x), attention_masks=attention_masks)
  125. # Residual of multi-head attention sub-layer.
  126. x = x + x_att
  127. # Position-wise FFN sub-layer: fc + activation + fc + dropout
  128. x_ffn = self.mlp.dropout(self.mlp.c_proj(self.mlp.act(self.mlp.c_fc(x))))
  129. # Residual of position-wise FFN sub-layer.
  130. x = x + x_ffn
  131. return x, att
  132. @Deprecated(error=False)
  133. def configure_gpt_optimizer(
  134. model: nn.Module,
  135. learning_rate: float,
  136. weight_decay: float,
  137. betas: Tuple[float, float] = (0.9, 0.95),
  138. **kwargs,
  139. ) -> torch.optim.Optimizer:
  140. """
  141. This long function is unfortunately doing something very simple and is
  142. being very defensive: We are separating out all parameters of the model
  143. into two buckets: those that will experience weight decay for regularization
  144. and those that won't (biases, and layernorm/embedding weights). We are then
  145. returning the PyTorch optimizer object.
  146. """
  147. # separate out all parameters to those that will and won't experience
  148. # regularizing weight decay
  149. decay = set()
  150. no_decay = set()
  151. whitelist_w_modules = (torch.nn.Linear,)
  152. blacklist_w_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
  153. for mn, m in model.named_modules():
  154. for pn, p in m.named_parameters():
  155. fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
  156. # random note: because named_modules and named_parameters are
  157. # recursive we will see the same tensors p many many times. but
  158. # doing it this way allows us to know which parent module any
  159. # tensor p belongs to...
  160. if pn.endswith("bias"):
  161. # all biases will not be decayed
  162. no_decay.add(fpn)
  163. elif pn.endswith("weight") and isinstance(m, whitelist_w_modules):
  164. # weights of whitelist modules will be weight decayed
  165. decay.add(fpn)
  166. elif pn.endswith("weight") and isinstance(m, blacklist_w_modules):
  167. # weights of blacklist modules will NOT be weight decayed
  168. no_decay.add(fpn)
  169. # validate that we considered every parameter
  170. param_dict = dict(model.named_parameters())
  171. inter_params = decay & no_decay
  172. union_params = decay | no_decay
  173. assert (
  174. len(inter_params) == 0
  175. ), f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
  176. assert len(param_dict.keys() - union_params) == 0, (
  177. f"parameters {str(param_dict.keys() - union_params)} were not "
  178. f"separated into either decay/no_decay set!"
  179. )
  180. # create the pytorch optimizer object
  181. optim_groups = [
  182. {
  183. "params": [param_dict[pn] for pn in sorted(decay)],
  184. "weight_decay": weight_decay,
  185. },
  186. {
  187. "params": [param_dict[pn] for pn in sorted(no_decay)],
  188. "weight_decay": 0.0,
  189. },
  190. ]
  191. optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **kwargs)
  192. return optimizer
  193. @Deprecated(error=False)
  194. class GPT(nn.Module):
  195. """GPT Transformer Model"""
  196. def __init__(self, config: GPTConfig):
  197. super().__init__()
  198. assert config.block_size is not None
  199. self.block_size = config.block_size
  200. self.transformer = nn.ModuleDict(
  201. dict(
  202. drop=nn.Dropout(config.embed_pdrop),
  203. h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
  204. ln_f=nn.LayerNorm(config.n_embed),
  205. )
  206. )
  207. # init all weights, and apply a special scaled init to the residual
  208. # projections, per GPT-2 paper
  209. self.apply(self._init_weights)
  210. for pn, p in self.named_parameters():
  211. if pn.endswith("c_proj.weight"):
  212. torch.nn.init.normal_(
  213. p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
  214. )
  215. def _init_weights(self, module):
  216. if isinstance(module, nn.Linear):
  217. torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
  218. if module.bias is not None:
  219. torch.nn.init.zeros_(module.bias)
  220. elif isinstance(module, nn.Embedding):
  221. torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
  222. elif isinstance(module, nn.LayerNorm):
  223. torch.nn.init.zeros_(module.bias)
  224. torch.nn.init.ones_(module.weight)
  225. def forward(self, input_embeds, attention_masks=None, return_attentions=False):
  226. """
  227. input_embeds: [batch_size x seq_len x n_embed]
  228. attention_masks: [batch_size x seq_len], 0 don't attend, 1 attend
  229. """
  230. B, T, C = input_embeds.size()
  231. assert T <= self.block_size, (
  232. f"Cannot forward sequence of length {T}, "
  233. f"block size is only {self.block_size}"
  234. )
  235. if attention_masks is not None:
  236. _B, _T = attention_masks.size()
  237. assert _B == B and _T == T
  238. # We create a 3D attention mask from a 2D tensor mask.
  239. # Sizes are [batch_size, 1, 1, to_seq_len]
  240. # So we can broadcast to
  241. # [batch_size, num_heads, from_seq_length, to_seq_length]
  242. # this attention mask is more simple than the triangular
  243. # masking of causal attention used in OpenAI GPT, we just need
  244. # to prepare the broadcast dimension here.
  245. attention_masks = attention_masks[:, None, None, :]
  246. # Since attention_mask is 1.0 for positions we want to attend
  247. # and 0.0 for masked positions, this operation will create a
  248. # tensor which is 0.0 for positions we want to attend and -inf
  249. # for masked positions. Since we are adding it to the raw scores
  250. # before the softmax, this is effectively the same as removing
  251. # these entirely.
  252. attention_masks = attention_masks.to(dtype=input_embeds.dtype)
  253. attention_masks = (1.0 - attention_masks) * -1e9
  254. # forward the GPT model itself
  255. x = self.transformer.drop(input_embeds)
  256. atts = []
  257. for block in self.transformer.h:
  258. x, att = block(x, attention_masks=attention_masks)
  259. atts.append(att)
  260. x = self.transformer.ln_f(x)
  261. if return_attentions:
  262. return x, atts
  263. else:
  264. return x