modeling_ctrl.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. # Copyright 2018 Salesforce and HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch CTRL model."""
  16. import numpy as np
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...generation import GenerationMixin
  23. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
  24. from ...modeling_utils import PreTrainedModel
  25. from ...utils import (
  26. auto_docstring,
  27. logging,
  28. )
  29. from .configuration_ctrl import CTRLConfig
  30. logger = logging.get_logger(__name__)
  31. def angle_defn(pos, i, d_model_size):
  32. angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / d_model_size)
  33. return pos * angle_rates
  34. def positional_encoding(position, d_model_size, dtype):
  35. # create the sinusoidal pattern for the positional encoding
  36. angle_rads = angle_defn(
  37. torch.arange(position, dtype=torch.int64).to(dtype).unsqueeze(1),
  38. torch.arange(d_model_size, dtype=torch.int64).to(dtype).unsqueeze(0),
  39. d_model_size,
  40. )
  41. sines = torch.sin(angle_rads[:, 0::2])
  42. cosines = torch.cos(angle_rads[:, 1::2])
  43. pos_encoding = torch.cat([sines, cosines], dim=-1)
  44. return pos_encoding
  45. def scaled_dot_product_attention(q, k, v, mask, attention_mask=None):
  46. # calculate attention
  47. matmul_qk = torch.matmul(q, k.permute(0, 1, 3, 2))
  48. dk = k.shape[-1]
  49. scaled_attention_logits = matmul_qk / np.sqrt(dk)
  50. if mask is not None:
  51. nd, ns = scaled_attention_logits.size(-2), scaled_attention_logits.size(-1)
  52. scaled_attention_logits += mask[ns - nd : ns, :ns] * -1e4
  53. if attention_mask is not None:
  54. # Apply the attention mask
  55. scaled_attention_logits = scaled_attention_logits + attention_mask
  56. attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
  57. output = torch.matmul(attention_weights, v)
  58. return output, attention_weights
  59. class MultiHeadAttention(nn.Module):
  60. def __init__(self, d_model_size, num_heads, layer_idx=None):
  61. super().__init__()
  62. self.num_heads = num_heads
  63. self.d_model_size = d_model_size
  64. self.layer_idx = layer_idx
  65. self.depth = int(d_model_size / self.num_heads)
  66. self.Wq = nn.Linear(d_model_size, d_model_size)
  67. self.Wk = nn.Linear(d_model_size, d_model_size)
  68. self.Wv = nn.Linear(d_model_size, d_model_size)
  69. self.dense = nn.Linear(d_model_size, d_model_size)
  70. def split_into_heads(self, x, batch_size):
  71. x = x.reshape(batch_size, -1, self.num_heads, self.depth)
  72. return x.permute([0, 2, 1, 3])
  73. def forward(
  74. self,
  75. v,
  76. k,
  77. q,
  78. mask,
  79. layer_past=None,
  80. attention_mask=None,
  81. use_cache=False,
  82. output_attentions=False,
  83. **kwargs,
  84. ):
  85. batch_size = q.shape[0]
  86. q = self.Wq(q)
  87. k = self.Wk(k)
  88. v = self.Wv(v)
  89. q = self.split_into_heads(q, batch_size)
  90. k = self.split_into_heads(k, batch_size)
  91. v = self.split_into_heads(v, batch_size)
  92. if layer_past is not None:
  93. k, v = layer_past.update(k, v, self.layer_idx)
  94. output = scaled_dot_product_attention(q, k, v, mask, attention_mask)
  95. scaled_attention = output[0].permute([0, 2, 1, 3])
  96. attn = output[1]
  97. original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
  98. output = self.dense(original_size_attention)
  99. return output, attn
  100. def point_wise_feed_forward_network(d_model_size, dff):
  101. return nn.Sequential(nn.Linear(d_model_size, dff), nn.ReLU(), nn.Linear(dff, d_model_size))
  102. class EncoderLayer(nn.Module):
  103. def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_idx=None):
  104. super().__init__()
  105. self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, layer_idx=layer_idx)
  106. self.ffn = point_wise_feed_forward_network(d_model_size, dff)
  107. self.layernorm1 = nn.LayerNorm(d_model_size, eps=1e-6)
  108. self.layernorm2 = nn.LayerNorm(d_model_size, eps=1e-6)
  109. self.dropout1 = nn.Dropout(rate)
  110. self.dropout2 = nn.Dropout(rate)
  111. def forward(
  112. self,
  113. x,
  114. mask,
  115. layer_past=None,
  116. attention_mask=None,
  117. use_cache=False,
  118. output_attentions=False,
  119. **kwargs,
  120. ):
  121. normed = self.layernorm1(x)
  122. attn_outputs = self.multi_head_attention(
  123. normed,
  124. normed,
  125. normed,
  126. mask,
  127. layer_past=layer_past,
  128. attention_mask=attention_mask,
  129. use_cache=use_cache,
  130. output_attentions=output_attentions,
  131. )
  132. attn_output = attn_outputs[0]
  133. attn_output = self.dropout1(attn_output)
  134. out1 = x + attn_output
  135. out2 = self.layernorm2(out1)
  136. ffn_output = self.ffn(out2)
  137. ffn_output = self.dropout2(ffn_output)
  138. out2 = out1 + ffn_output
  139. outputs = (out2,) + attn_outputs[1:]
  140. return outputs
  141. @auto_docstring
  142. class CTRLPreTrainedModel(PreTrainedModel):
  143. config: CTRLConfig
  144. base_model_prefix = "transformer"
  145. def _init_weights(self, module):
  146. super()._init_weights(module)
  147. if isinstance(module, CTRLModel):
  148. init.copy_(
  149. module.pos_encoding, positional_encoding(module.config.n_positions, module.d_model_size, torch.float)
  150. )
  151. @auto_docstring
  152. class CTRLModel(CTRLPreTrainedModel):
  153. def __init__(self, config):
  154. super().__init__(config)
  155. self.d_model_size = config.n_embd
  156. self.num_layers = config.n_layer
  157. self.w = nn.Embedding(config.vocab_size, config.n_embd)
  158. self.dropout = nn.Dropout(config.embd_pdrop)
  159. self.h = nn.ModuleList(
  160. [
  161. EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop, layer_idx=i)
  162. for i in range(config.n_layer)
  163. ]
  164. )
  165. self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  166. self.register_buffer(
  167. "pos_encoding", positional_encoding(config.n_positions, self.d_model_size, torch.float), persistent=False
  168. )
  169. # Initialize weights and apply final processing
  170. self.post_init()
  171. def get_input_embeddings(self):
  172. return self.w
  173. def set_input_embeddings(self, new_embeddings):
  174. self.w = new_embeddings
  175. @auto_docstring
  176. def forward(
  177. self,
  178. input_ids: torch.LongTensor | None = None,
  179. past_key_values: Cache | None = None,
  180. attention_mask: torch.FloatTensor | None = None,
  181. token_type_ids: torch.LongTensor | None = None,
  182. position_ids: torch.LongTensor | None = None,
  183. inputs_embeds: torch.FloatTensor | None = None,
  184. use_cache: bool | None = None,
  185. output_attentions: bool | None = None,
  186. output_hidden_states: bool | None = None,
  187. return_dict: bool | None = None,
  188. **kwargs, # NOOP kwargs, for now
  189. ) -> tuple[torch.Tensor] | BaseModelOutputWithPast:
  190. r"""
  191. Example:
  192. ```python
  193. >>> from transformers import AutoTokenizer, CTRLModel
  194. >>> import torch
  195. >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
  196. >>> model = CTRLModel.from_pretrained("Salesforce/ctrl")
  197. >>> # CTRL was trained with control codes as the first token
  198. >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
  199. >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
  200. >>> outputs = model(**inputs)
  201. >>> last_hidden_states = outputs.last_hidden_state
  202. >>> list(last_hidden_states.shape)
  203. [1, 5, 1280]
  204. ```"""
  205. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  206. use_cache = use_cache if use_cache is not None else self.config.use_cache
  207. output_hidden_states = (
  208. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  209. )
  210. return_dict = return_dict if return_dict is not None else self.config.return_dict
  211. if input_ids is not None and inputs_embeds is not None:
  212. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  213. elif input_ids is not None:
  214. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  215. input_shape = input_ids.size()
  216. input_ids = input_ids.view(-1, input_shape[-1])
  217. batch_size = input_ids.shape[0]
  218. elif inputs_embeds is not None:
  219. input_shape = inputs_embeds.size()[:-1]
  220. batch_size = inputs_embeds.shape[0]
  221. else:
  222. raise ValueError("You have to specify either input_ids or inputs_embeds")
  223. device = input_ids.device if input_ids is not None else inputs_embeds.device
  224. if use_cache and past_key_values is None:
  225. past_key_values = DynamicCache(config=self.config)
  226. past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  227. if position_ids is None:
  228. position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
  229. position_ids = position_ids.unsqueeze(0)
  230. # Attention mask.
  231. if attention_mask is not None:
  232. if batch_size <= 0:
  233. raise ValueError("batch_size has to be defined and > 0")
  234. attention_mask = attention_mask.view(batch_size, -1)
  235. # We create a 3D attention mask from a 2D tensor mask.
  236. # Sizes are [batch_size, 1, 1, to_seq_length]
  237. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  238. # this attention mask is more simple than the triangular masking of causal attention
  239. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  240. attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  241. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  242. # masked positions, this operation will create a tensor which is 0.0 for
  243. # positions we want to attend and the dtype's smallest value for masked positions.
  244. # Since we are adding it to the raw scores before the softmax, this is
  245. # effectively the same as removing these entirely.
  246. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
  247. attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
  248. if token_type_ids is not None:
  249. token_type_ids = token_type_ids.view(-1, input_shape[-1])
  250. token_type_embeds = self.w(token_type_ids)
  251. token_type_embeds *= np.sqrt(self.d_model_size)
  252. else:
  253. token_type_embeds = 0
  254. if inputs_embeds is None:
  255. inputs_embeds = self.w(input_ids)
  256. # inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
  257. seq_len = input_shape[-1]
  258. mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(device)
  259. inputs_embeds *= np.sqrt(self.d_model_size)
  260. # `self.pos_encoding` won't be sent to the correct device along the model, so we do it manually.
  261. self.pos_encoding = self.pos_encoding.to(device)
  262. pos_embeds = self.pos_encoding[position_ids, :]
  263. hidden_states = inputs_embeds + pos_embeds + token_type_embeds
  264. hidden_states = self.dropout(hidden_states)
  265. all_hidden_states = () if output_hidden_states else None
  266. all_attentions = () if output_attentions else None
  267. for i, h in enumerate(self.h):
  268. if output_hidden_states:
  269. all_hidden_states = all_hidden_states + (hidden_states,)
  270. outputs = h(
  271. hidden_states,
  272. mask,
  273. layer_past=past_key_values,
  274. attention_mask=attention_mask,
  275. use_cache=use_cache,
  276. output_attentions=output_attentions,
  277. )
  278. hidden_states = outputs[0]
  279. if output_attentions:
  280. all_attentions += (outputs[1],)
  281. hidden_states = self.layernorm(hidden_states)
  282. if output_hidden_states:
  283. all_hidden_states = all_hidden_states + (hidden_states,)
  284. if not return_dict:
  285. return tuple(
  286. v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None
  287. )
  288. return BaseModelOutputWithPast(
  289. last_hidden_state=hidden_states,
  290. past_key_values=past_key_values,
  291. hidden_states=all_hidden_states,
  292. attentions=all_attentions,
  293. )
  294. @auto_docstring(
  295. custom_intro="""
  296. The CTRL Model transformer with a language modeling head on top (linear layer with weights tied to the input
  297. embeddings).
  298. """
  299. )
  300. class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin):
  301. _tied_weights_keys = {"lm_head.weight": "transformer.w.weight"}
  302. def __init__(self, config):
  303. super().__init__(config)
  304. self.transformer = CTRLModel(config)
  305. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)
  306. # Initialize weights and apply final processing
  307. self.post_init()
  308. @auto_docstring
  309. def forward(
  310. self,
  311. input_ids: torch.LongTensor | None = None,
  312. past_key_values: Cache | None = None,
  313. attention_mask: torch.FloatTensor | None = None,
  314. token_type_ids: torch.LongTensor | None = None,
  315. position_ids: torch.LongTensor | None = None,
  316. inputs_embeds: torch.FloatTensor | None = None,
  317. labels: torch.LongTensor | None = None,
  318. use_cache: bool | None = None,
  319. output_attentions: bool | None = None,
  320. output_hidden_states: bool | None = None,
  321. return_dict: bool | None = None,
  322. logits_to_keep: int | torch.Tensor = 0,
  323. **kwargs,
  324. ) -> tuple[torch.Tensor] | CausalLMOutputWithPast:
  325. r"""
  326. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  327. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  328. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  329. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  330. Example:
  331. ```python
  332. >>> import torch
  333. >>> from transformers import AutoTokenizer, CTRLLMHeadModel
  334. >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
  335. >>> model = CTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
  336. >>> # CTRL was trained with control codes as the first token
  337. >>> inputs = tokenizer("Wikipedia The llama is", return_tensors="pt")
  338. >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
  339. >>> sequence_ids = model.generate(inputs["input_ids"])
  340. >>> sequences = tokenizer.batch_decode(sequence_ids)
  341. >>> sequences
  342. ['Wikipedia The llama is a member of the family Bovidae. It is native to the Andes of Peru,']
  343. >>> outputs = model(**inputs, labels=inputs["input_ids"])
  344. >>> round(outputs.loss.item(), 2)
  345. 9.21
  346. >>> list(outputs.logits.shape)
  347. [1, 5, 246534]
  348. ```"""
  349. return_dict = return_dict if return_dict is not None else self.config.return_dict
  350. transformer_outputs = self.transformer(
  351. input_ids,
  352. past_key_values=past_key_values,
  353. attention_mask=attention_mask,
  354. token_type_ids=token_type_ids,
  355. position_ids=position_ids,
  356. inputs_embeds=inputs_embeds,
  357. use_cache=use_cache,
  358. output_attentions=output_attentions,
  359. output_hidden_states=output_hidden_states,
  360. return_dict=return_dict,
  361. )
  362. hidden_states = transformer_outputs[0]
  363. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  364. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  365. logits = self.lm_head(hidden_states[:, slice_indices, :])
  366. loss = None
  367. if labels is not None:
  368. loss = self.loss_function(
  369. logits,
  370. labels,
  371. vocab_size=self.config.vocab_size,
  372. **kwargs,
  373. )
  374. if not return_dict:
  375. output = (logits,) + transformer_outputs[1:]
  376. return ((loss,) + output) if loss is not None else output
  377. return CausalLMOutputWithPast(
  378. loss=loss,
  379. logits=logits,
  380. past_key_values=transformer_outputs.past_key_values,
  381. hidden_states=transformer_outputs.hidden_states,
  382. attentions=transformer_outputs.attentions,
  383. )
  384. def prepare_inputs_for_generation(
  385. self, input_ids, past_key_values=None, use_cache=None, is_first_iteration=False, **kwargs
  386. ):
  387. # Overwritten -- `token_type_ids` are created in custom way inside model`
  388. model_inputs = super().prepare_inputs_for_generation(
  389. input_ids,
  390. past_key_values=past_key_values,
  391. use_cache=use_cache,
  392. is_first_iteration=is_first_iteration,
  393. **kwargs,
  394. )
  395. # token_type_ids are computed on CTRLModel.forward()
  396. model_inputs.pop("token_type_ids", None)
  397. return model_inputs
  398. @auto_docstring(
  399. custom_intro="""
  400. The CTRL Model transformer with a sequence classification head on top (linear layer).
  401. [`CTRLForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  402. (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last
  403. token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in
  404. each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
  405. guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last
  406. value in each row of the batch).
  407. """
  408. )
  409. class CTRLForSequenceClassification(CTRLPreTrainedModel):
  410. def __init__(self, config):
  411. super().__init__(config)
  412. self.num_labels = config.num_labels
  413. self.transformer = CTRLModel(config)
  414. self.classifier = nn.Linear(config.n_embd, self.num_labels, bias=False)
  415. # Initialize weights and apply final processing
  416. self.post_init()
  417. @auto_docstring
  418. def forward(
  419. self,
  420. input_ids: torch.LongTensor | None = None,
  421. past_key_values: Cache | None = None,
  422. attention_mask: torch.FloatTensor | None = None,
  423. token_type_ids: torch.LongTensor | None = None,
  424. position_ids: torch.LongTensor | None = None,
  425. inputs_embeds: torch.FloatTensor | None = None,
  426. labels: torch.LongTensor | None = None,
  427. use_cache: bool | None = None,
  428. output_attentions: bool | None = None,
  429. output_hidden_states: bool | None = None,
  430. return_dict: bool | None = None,
  431. **kwargs,
  432. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  433. r"""
  434. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  435. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  436. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  437. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  438. Example of single-label classification:
  439. ```python
  440. >>> import torch
  441. >>> from transformers import AutoTokenizer, CTRLForSequenceClassification
  442. >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
  443. >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl")
  444. >>> # CTRL was trained with control codes as the first token
  445. >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
  446. >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
  447. >>> with torch.no_grad():
  448. ... logits = model(**inputs).logits
  449. >>> predicted_class_id = logits.argmax().item()
  450. >>> model.config.id2label[predicted_class_id]
  451. 'LABEL_0'
  452. ```
  453. ```python
  454. >>> import torch
  455. >>> torch.manual_seed(42) # doctest: +IGNORE_RESULT
  456. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  457. >>> num_labels = len(model.config.id2label)
  458. >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
  459. >>> labels = torch.tensor(1)
  460. >>> loss = model(**inputs, labels=labels).loss
  461. >>> round(loss.item(), 2)
  462. 0.93
  463. ```
  464. Example of multi-label classification:
  465. ```python
  466. >>> import torch
  467. >>> from transformers import AutoTokenizer, CTRLForSequenceClassification
  468. >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
  469. >>> model = CTRLForSequenceClassification.from_pretrained(
  470. ... "Salesforce/ctrl", problem_type="multi_label_classification"
  471. ... )
  472. >>> # CTRL was trained with control codes as the first token
  473. >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
  474. >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
  475. >>> with torch.no_grad():
  476. ... logits = model(**inputs).logits
  477. >>> predicted_class_id = logits.argmax().item()
  478. >>> model.config.id2label[predicted_class_id]
  479. 'LABEL_0'
  480. ```
  481. ```python
  482. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  483. >>> num_labels = len(model.config.id2label)
  484. >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
  485. >>> num_labels = len(model.config.id2label)
  486. >>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
  487. ... torch.float
  488. ... )
  489. >>> loss = model(**inputs, labels=labels).loss
  490. >>> loss.backward() # doctest: +IGNORE_RESULT
  491. ```"""
  492. return_dict = return_dict if return_dict is not None else self.config.return_dict
  493. transformer_outputs = self.transformer(
  494. input_ids,
  495. past_key_values=past_key_values,
  496. attention_mask=attention_mask,
  497. token_type_ids=token_type_ids,
  498. position_ids=position_ids,
  499. inputs_embeds=inputs_embeds,
  500. use_cache=use_cache,
  501. output_attentions=output_attentions,
  502. output_hidden_states=output_hidden_states,
  503. return_dict=return_dict,
  504. )
  505. hidden_states = transformer_outputs[0]
  506. logits = self.classifier(hidden_states)
  507. if input_ids is not None:
  508. batch_size, sequence_length = input_ids.shape[:2]
  509. else:
  510. batch_size, sequence_length = inputs_embeds.shape[:2]
  511. if self.config.pad_token_id is None and batch_size != 1:
  512. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  513. if self.config.pad_token_id is None:
  514. last_non_pad_token = -1
  515. elif input_ids is not None:
  516. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  517. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  518. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  519. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  520. else:
  521. last_non_pad_token = -1
  522. logger.warning_once(
  523. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  524. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  525. )
  526. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  527. loss = None
  528. if labels is not None:
  529. if self.config.problem_type is None:
  530. if self.num_labels == 1:
  531. self.config.problem_type = "regression"
  532. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  533. self.config.problem_type = "single_label_classification"
  534. else:
  535. self.config.problem_type = "multi_label_classification"
  536. if self.config.problem_type == "regression":
  537. loss_fct = MSELoss()
  538. if self.num_labels == 1:
  539. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  540. else:
  541. loss = loss_fct(pooled_logits, labels)
  542. elif self.config.problem_type == "single_label_classification":
  543. loss_fct = CrossEntropyLoss()
  544. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  545. elif self.config.problem_type == "multi_label_classification":
  546. loss_fct = BCEWithLogitsLoss()
  547. loss = loss_fct(pooled_logits, labels)
  548. if not return_dict:
  549. output = (pooled_logits,) + transformer_outputs[2:]
  550. return ((loss,) + output) if loss is not None else output
  551. return SequenceClassifierOutput(
  552. loss=loss,
  553. logits=pooled_logits,
  554. hidden_states=transformer_outputs.hidden_states,
  555. attentions=transformer_outputs.attentions,
  556. )
  557. __all__ = ["CTRLForSequenceClassification", "CTRLLMHeadModel", "CTRLModel", "CTRLPreTrainedModel"]