watermarking.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. # Copyright 2024 The HuggingFace Inc. team and Google DeepMind.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. from dataclasses import dataclass
  16. from functools import lru_cache
  17. from typing import TYPE_CHECKING, Any, Union
  18. import numpy as np
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCELoss
  22. from .. import initialization as init
  23. from ..configuration_utils import PreTrainedConfig
  24. from ..modeling_utils import PreTrainedModel
  25. from ..utils import ModelOutput, logging
  26. from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
  27. if TYPE_CHECKING:
  28. from .configuration_utils import WatermarkingConfig
  29. logger = logging.get_logger(__name__)
  30. @dataclass
  31. class WatermarkDetectorOutput:
  32. """
  33. Outputs of a watermark detector.
  34. Args:
  35. num_tokens_scored (np.ndarray of shape (batch_size)):
  36. Array containing the number of tokens scored for each element in the batch.
  37. num_green_tokens (np.ndarray of shape (batch_size)):
  38. Array containing the number of green tokens for each element in the batch.
  39. green_fraction (np.ndarray of shape (batch_size)):
  40. Array containing the fraction of green tokens for each element in the batch.
  41. z_score (np.ndarray of shape (batch_size)):
  42. Array containing the z-score for each element in the batch. Z-score here shows
  43. how many standard deviations away is the green token count in the input text
  44. from the expected green token count for machine-generated text.
  45. p_value (np.ndarray of shape (batch_size)):
  46. Array containing the p-value for each batch obtained from z-scores.
  47. prediction (np.ndarray of shape (batch_size)), *optional*:
  48. Array containing boolean predictions whether a text is machine-generated for each element in the batch.
  49. confidence (np.ndarray of shape (batch_size)), *optional*:
  50. Array containing confidence scores of a text being machine-generated for each element in the batch.
  51. """
  52. num_tokens_scored: np.ndarray | None = None
  53. num_green_tokens: np.ndarray | None = None
  54. green_fraction: np.ndarray | None = None
  55. z_score: np.ndarray | None = None
  56. p_value: np.ndarray | None = None
  57. prediction: np.ndarray | None = None
  58. confidence: np.ndarray | None = None
  59. class WatermarkDetector:
  60. r"""
  61. Detector for detection of watermark generated text. The detector needs to be given the exact same settings that were
  62. given during text generation to replicate the watermark greenlist generation and so detect the watermark. This includes
  63. the correct device that was used during text generation, the correct watermarking arguments and the correct tokenizer vocab size.
  64. The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main).
  65. See [the paper](https://huggingface.co/papers/2306.04634) for more information.
  66. Args:
  67. model_config (`PreTrainedConfig`):
  68. The model config that will be used to get model specific arguments used when generating.
  69. device (`str`):
  70. The device which was used during watermarked text generation.
  71. watermarking_config (Union[`WatermarkingConfig`, `Dict`]):
  72. The exact same watermarking config and arguments used when generating text.
  73. ignore_repeated_ngrams (`bool`, *optional*, defaults to `False`):
  74. Whether to count every unique ngram only once or not.
  75. max_cache_size (`int`, *optional*, defaults to 128):
  76. The max size to be used for LRU caching of seeding/sampling algorithms called for every token.
  77. Examples:
  78. ```python
  79. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig
  80. >>> model_id = "openai-community/gpt2"
  81. >>> model = AutoModelForCausalLM.from_pretrained(model_id)
  82. >>> tok = AutoTokenizer.from_pretrained(model_id)
  83. >>> tok.pad_token_id = tok.eos_token_id
  84. >>> tok.padding_side = "left"
  85. >>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
  86. >>> input_len = inputs["input_ids"].shape[-1]
  87. >>> # first generate text with watermark and without
  88. >>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
  89. >>> out_watermarked = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)
  90. >>> out = model.generate(**inputs, do_sample=False, max_length=20)
  91. >>> # now we can instantiate the detector and check the generated text
  92. >>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
  93. >>> detection_out_watermarked = detector(out_watermarked, return_dict=True)
  94. >>> detection_out = detector(out, return_dict=True)
  95. >>> detection_out_watermarked.prediction
  96. array([ True, True])
  97. >>> detection_out.prediction
  98. array([False, False])
  99. ```
  100. """
  101. def __init__(
  102. self,
  103. model_config: "PreTrainedConfig",
  104. device: str,
  105. watermarking_config: Union["WatermarkingConfig", dict],
  106. ignore_repeated_ngrams: bool = False,
  107. max_cache_size: int = 128,
  108. ):
  109. if not isinstance(watermarking_config, dict):
  110. watermarking_config = watermarking_config.to_dict()
  111. self.bos_token_id = (
  112. model_config.bos_token_id if not model_config.is_encoder_decoder else model_config.decoder_start_token_id
  113. )
  114. self.greenlist_ratio = watermarking_config["greenlist_ratio"]
  115. self.ignore_repeated_ngrams = ignore_repeated_ngrams
  116. self.processor = WatermarkLogitsProcessor(
  117. vocab_size=model_config.vocab_size, device=device, **watermarking_config
  118. )
  119. # Expensive re-seeding and sampling is cached.
  120. self._get_ngram_score_cached = lru_cache(maxsize=max_cache_size)(self._get_ngram_score)
  121. def _get_ngram_score(self, prefix: torch.LongTensor, target: int):
  122. greenlist_ids = self.processor._get_greenlist_ids(prefix)
  123. return target in greenlist_ids
  124. def _score_ngrams_in_passage(self, input_ids: torch.LongTensor):
  125. batch_size, seq_length = input_ids.shape
  126. selfhash = int(self.processor.seeding_scheme == "selfhash")
  127. n = self.processor.context_width + 1 - selfhash
  128. indices = torch.arange(n).unsqueeze(0) + torch.arange(seq_length - n + 1).unsqueeze(1)
  129. ngram_tensors = input_ids[:, indices]
  130. num_tokens_scored_batch = np.zeros(batch_size)
  131. green_token_count_batch = np.zeros(batch_size)
  132. for batch_idx in range(ngram_tensors.shape[0]):
  133. frequencies_table = collections.Counter(ngram_tensors[batch_idx])
  134. ngram_to_watermark_lookup = {}
  135. for ngram_example in frequencies_table:
  136. prefix = ngram_example if selfhash else ngram_example[:-1]
  137. target = ngram_example[-1]
  138. ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target)
  139. if self.ignore_repeated_ngrams:
  140. # counts a green/red hit once per unique ngram.
  141. # num total tokens scored becomes the number unique ngrams.
  142. num_tokens_scored_batch[batch_idx] = len(frequencies_table.keys())
  143. green_token_count_batch[batch_idx] = sum(ngram_to_watermark_lookup.values())
  144. else:
  145. num_tokens_scored_batch[batch_idx] = sum(frequencies_table.values())
  146. green_token_count_batch[batch_idx] = sum(
  147. freq * outcome
  148. for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values())
  149. )
  150. return num_tokens_scored_batch, green_token_count_batch
  151. def _compute_z_score(self, green_token_count: np.ndarray, total_num_tokens: np.ndarray) -> np.ndarray:
  152. expected_count = self.greenlist_ratio
  153. numer = green_token_count - expected_count * total_num_tokens
  154. denom = np.sqrt(total_num_tokens * expected_count * (1 - expected_count))
  155. z = numer / denom
  156. return z
  157. def _compute_pval(self, x, loc=0, scale=1):
  158. z = (x - loc) / scale
  159. return 1 - (0.5 * (1 + np.sign(z) * (1 - np.exp(-2 * z**2 / np.pi))))
  160. def __call__(
  161. self,
  162. input_ids: torch.LongTensor,
  163. z_threshold: float = 3.0,
  164. return_dict: bool = False,
  165. ) -> WatermarkDetectorOutput | np.ndarray:
  166. """
  167. Args:
  168. input_ids (`torch.LongTensor`):
  169. The watermark generated text. It is advised to remove the prompt, which can affect the detection.
  170. z_threshold (`Dict`, *optional*, defaults to `3.0`):
  171. Changing this threshold will change the sensitivity of the detector. Higher z threshold gives less
  172. sensitivity and vice versa for lower z threshold.
  173. return_dict (`bool`, *optional*, defaults to `False`):
  174. Whether to return `~generation.WatermarkDetectorOutput` or not. If not it will return boolean predictions,
  175. ma
  176. Return:
  177. [`~generation.WatermarkDetectorOutput`] or `np.ndarray`: A [`~generation.WatermarkDetectorOutput`]
  178. if `return_dict=True` otherwise a `np.ndarray`.
  179. """
  180. # Let's assume that if one batch start with `bos`, all batched also do
  181. if input_ids[0, 0] == self.bos_token_id:
  182. input_ids = input_ids[:, 1:]
  183. if input_ids.shape[-1] - self.processor.context_width < 1:
  184. raise ValueError(
  185. f"Must have at least `1` token to score after the first "
  186. f"min_prefix_len={self.processor.context_width} tokens required by the seeding scheme."
  187. )
  188. num_tokens_scored, green_token_count = self._score_ngrams_in_passage(input_ids)
  189. z_score = self._compute_z_score(green_token_count, num_tokens_scored)
  190. prediction = z_score > z_threshold
  191. if return_dict:
  192. p_value = self._compute_pval(z_score)
  193. confidence = 1 - p_value
  194. return WatermarkDetectorOutput(
  195. num_tokens_scored=num_tokens_scored,
  196. num_green_tokens=green_token_count,
  197. green_fraction=green_token_count / num_tokens_scored,
  198. z_score=z_score,
  199. p_value=p_value,
  200. prediction=prediction,
  201. confidence=confidence,
  202. )
  203. return prediction
  204. class BayesianDetectorConfig(PreTrainedConfig):
  205. """
  206. This is the configuration class to store the configuration of a [`BayesianDetectorModel`]. It is used to
  207. instantiate a Bayesian Detector model according to the specified arguments.
  208. Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
  209. documentation from [`PreTrainedConfig`] for more information.
  210. Args:
  211. watermarking_depth (`int`, *optional*):
  212. The number of tournament layers.
  213. base_rate (`float1`, *optional*, defaults to 0.5):
  214. Prior probability P(w) that a text is watermarked.
  215. """
  216. def __init__(self, watermarking_depth: int | None = None, base_rate: float = 0.5, **kwargs):
  217. self.watermarking_depth = watermarking_depth
  218. self.base_rate = base_rate
  219. # These can be set later to store information about this detector.
  220. self.model_name = None
  221. self.watermarking_config = None
  222. super().__init__(**kwargs)
  223. def set_detector_information(self, model_name, watermarking_config):
  224. self.model_name = model_name
  225. self.watermarking_config = watermarking_config
  226. @dataclass
  227. class BayesianWatermarkDetectorModelOutput(ModelOutput):
  228. """
  229. Base class for outputs of models predicting if the text is watermarked.
  230. Args:
  231. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  232. Language modeling loss.
  233. posterior_probabilities (`torch.FloatTensor` of shape `(1,)`):
  234. Multiple choice classification loss.
  235. """
  236. loss: torch.FloatTensor | None = None
  237. posterior_probabilities: torch.FloatTensor | None = None
  238. class BayesianDetectorWatermarkedLikelihood(nn.Module):
  239. """Watermarked likelihood model for binary-valued g-values.
  240. This takes in g-values and returns p(g_values|watermarked).
  241. """
  242. def __init__(self, watermarking_depth: int):
  243. """Initializes the model parameters."""
  244. super().__init__()
  245. self.watermarking_depth = watermarking_depth
  246. self.beta = torch.nn.Parameter(-2.5 + 0.001 * torch.randn(1, 1, watermarking_depth))
  247. self.delta = torch.nn.Parameter(0.001 * torch.randn(1, 1, self.watermarking_depth, watermarking_depth))
  248. def _compute_latents(self, g_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  249. """Computes the unique token probability distribution given g-values.
  250. Args:
  251. g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
  252. PRF values.
  253. Returns:
  254. p_one_unique_token and p_two_unique_tokens, both of shape
  255. [batch_size, seq_len, watermarking_depth]. p_one_unique_token[i,t,l]
  256. gives the probability of there being one unique token in a tournament
  257. match on layer l, on timestep t, for batch item i.
  258. p_one_unique_token[i,t,l] + p_two_unique_token[i,t,l] = 1.
  259. """
  260. # Tile g-values to produce feature vectors for predicting the latents
  261. # for each layer in the tournament; our model for the latents psi is a
  262. # logistic regression model psi = sigmoid(delta * x + beta).
  263. # [batch_size, seq_len, watermarking_depth, watermarking_depth]
  264. x = torch.repeat_interleave(torch.unsqueeze(g_values, dim=-2), self.watermarking_depth, axis=-2)
  265. # mask all elements above -1 diagonal for autoregressive factorization
  266. x = torch.tril(x, diagonal=-1)
  267. # [batch_size, seq_len, watermarking_depth]
  268. # (i, j, k, l) x (i, j, k, l) -> (i, j, k) einsum equivalent
  269. logits = (self.delta[..., None, :] @ x.type(self.delta.dtype)[..., None]).squeeze() + self.beta
  270. p_two_unique_tokens = torch.sigmoid(logits)
  271. p_one_unique_token = 1 - p_two_unique_tokens
  272. return p_one_unique_token, p_two_unique_tokens
  273. def forward(self, g_values: torch.Tensor) -> torch.Tensor:
  274. """Computes the likelihoods P(g_values|watermarked).
  275. Args:
  276. g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`):
  277. g-values (values 0 or 1)
  278. Returns:
  279. p(g_values|watermarked) of shape [batch_size, seq_len, watermarking_depth].
  280. """
  281. p_one_unique_token, p_two_unique_tokens = self._compute_latents(g_values)
  282. # P(g_tl | watermarked) is equal to
  283. # 0.5 * [ (g_tl+0.5) * p_two_unique_tokens + p_one_unique_token].
  284. return 0.5 * ((g_values + 0.5) * p_two_unique_tokens + p_one_unique_token)
  285. class BayesianDetectorModel(PreTrainedModel):
  286. r"""
  287. Bayesian classifier for watermark detection.
  288. This detector uses Bayes' rule to compute a watermarking score, which is the sigmoid of the log of ratio of the
  289. posterior probabilities P(watermarked|g_values) and P(unwatermarked|g_values). Please see the section on
  290. BayesianScore in the paper for further details.
  291. Paper URL: https://www.nature.com/articles/s41586-024-08025-4
  292. Note that this detector only works with non-distortionary Tournament-based watermarking using the Bernoulli(0.5)
  293. g-value distribution.
  294. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  295. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  296. etc.)
  297. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  298. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  299. and behavior.
  300. Parameters:
  301. config ([`BayesianDetectorConfig`]): Model configuration class with all the parameters of the model.
  302. Initializing with a config file does not load the weights associated with the model, only the
  303. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  304. """
  305. config: BayesianDetectorConfig
  306. base_model_prefix = "model"
  307. def __init__(self, config):
  308. super().__init__(config)
  309. self.watermarking_depth = config.watermarking_depth
  310. self.base_rate = config.base_rate
  311. self.likelihood_model_watermarked = BayesianDetectorWatermarkedLikelihood(
  312. watermarking_depth=self.watermarking_depth
  313. )
  314. self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
  315. @torch.no_grad()
  316. def _init_weights(self, module):
  317. """Initialize the weights."""
  318. if isinstance(module, nn.Parameter):
  319. init.normal_(module.weight, mean=0.0, std=0.02)
  320. def _compute_posterior(
  321. self,
  322. likelihoods_watermarked: torch.Tensor,
  323. likelihoods_unwatermarked: torch.Tensor,
  324. mask: torch.Tensor,
  325. prior: float,
  326. ) -> torch.Tensor:
  327. """
  328. Compute posterior P(w|g) given likelihoods, mask and prior.
  329. Args:
  330. likelihoods_watermarked (`torch.Tensor` of shape `(batch, length, depth)`):
  331. Likelihoods P(g_values|watermarked) of g-values under watermarked model.
  332. likelihoods_unwatermarked (`torch.Tensor` of shape `(batch, length, depth)`):
  333. Likelihoods P(g_values|unwatermarked) of g-values under unwatermarked model.
  334. mask (`torch.Tensor` of shape `(batch, length)`):
  335. A binary array indicating which g-values should be used. g-values with mask value 0 are discarded.
  336. prior (`float`):
  337. the prior probability P(w) that the text is watermarked.
  338. Returns:
  339. Posterior probability P(watermarked|g_values), shape [batch].
  340. """
  341. mask = torch.unsqueeze(mask, dim=-1)
  342. prior = torch.clamp(prior, min=1e-5, max=1 - 1e-5)
  343. log_likelihoods_watermarked = torch.log(torch.clamp(likelihoods_watermarked, min=1e-30, max=float("inf")))
  344. log_likelihoods_unwatermarked = torch.log(torch.clamp(likelihoods_unwatermarked, min=1e-30, max=float("inf")))
  345. log_odds = log_likelihoods_watermarked - log_likelihoods_unwatermarked
  346. # Sum relative surprisals (log odds) across all token positions and layers.
  347. relative_surprisal_likelihood = torch.einsum("i...->i", log_odds * mask)
  348. # Compute the relative surprisal prior
  349. relative_surprisal_prior = torch.log(prior) - torch.log(1 - prior)
  350. # Combine prior and likelihood.
  351. # [batch_size]
  352. relative_surprisal = relative_surprisal_prior + relative_surprisal_likelihood
  353. # Compute the posterior probability P(w|g) = sigmoid(relative_surprisal).
  354. return torch.sigmoid(relative_surprisal)
  355. def forward(
  356. self,
  357. g_values: torch.Tensor,
  358. mask: torch.Tensor,
  359. labels: torch.Tensor | None = None,
  360. loss_batch_weight=1,
  361. return_dict=False,
  362. ) -> BayesianWatermarkDetectorModelOutput:
  363. """
  364. Computes the watermarked posterior P(watermarked|g_values).
  365. Args:
  366. g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth, ...)`):
  367. g-values (with values 0 or 1)
  368. mask:
  369. A binary array shape [batch_size, seq_len] indicating which g-values should be used. g-values with mask
  370. value 0 are discarded.
  371. Returns:
  372. p(watermarked | g_values), of shape [batch_size].
  373. """
  374. likelihoods_watermarked = self.likelihood_model_watermarked(g_values)
  375. likelihoods_unwatermarked = 0.5 * torch.ones_like(g_values)
  376. out = self._compute_posterior(
  377. likelihoods_watermarked=likelihoods_watermarked,
  378. likelihoods_unwatermarked=likelihoods_unwatermarked,
  379. mask=mask,
  380. prior=self.prior,
  381. )
  382. loss = None
  383. if labels is not None:
  384. loss_fct = BCELoss()
  385. loss_unwweight = torch.sum(self.likelihood_model_watermarked.delta**2)
  386. loss_weight = loss_unwweight * loss_batch_weight
  387. loss = loss_fct(torch.clamp(out, 1e-5, 1 - 1e-5), labels) + loss_weight
  388. if not return_dict:
  389. return (out,) if loss is None else (out, loss)
  390. return BayesianWatermarkDetectorModelOutput(loss=loss, posterior_probabilities=out)
  391. class SynthIDTextWatermarkDetector:
  392. r"""
  393. SynthID text watermark detector class.
  394. This class has to be initialized with the trained bayesian detector module check script
  395. in examples/synthid_text/detector_training.py for example in training/saving/loading this
  396. detector module. The folder also showcases example use case of this detector.
  397. Parameters:
  398. detector_module ([`BayesianDetectorModel`]):
  399. Bayesian detector module object initialized with parameters.
  400. Check https://github.com/huggingface/transformers-research-projects/tree/main/synthid_text for usage.
  401. logits_processor (`SynthIDTextWatermarkLogitsProcessor`):
  402. The logits processor used for watermarking.
  403. tokenizer (`Any`):
  404. The tokenizer used for the model.
  405. Examples:
  406. ```python
  407. >>> from transformers import (
  408. ... AutoTokenizer, BayesianDetectorModel, SynthIDTextWatermarkLogitsProcessor, SynthIDTextWatermarkDetector
  409. ... )
  410. >>> # Load the detector. See https://github.com/huggingface/transformers-research-projects/tree/main/synthid_text for training a detector.
  411. >>> detector_model = BayesianDetectorModel.from_pretrained("joaogante/dummy_synthid_detector")
  412. >>> logits_processor = SynthIDTextWatermarkLogitsProcessor(
  413. ... **detector_model.config.watermarking_config, device="cpu"
  414. ... )
  415. >>> tokenizer = AutoTokenizer.from_pretrained(detector_model.config.model_name)
  416. >>> detector = SynthIDTextWatermarkDetector(detector_model, logits_processor, tokenizer)
  417. >>> # Test whether a certain string is watermarked
  418. >>> test_input = tokenizer(["This is a test input"], return_tensors="pt")
  419. >>> is_watermarked = detector(test_input.input_ids)
  420. ```
  421. """
  422. def __init__(
  423. self,
  424. detector_module: BayesianDetectorModel,
  425. logits_processor: SynthIDTextWatermarkLogitsProcessor,
  426. tokenizer: Any,
  427. ):
  428. self.detector_module = detector_module
  429. self.logits_processor = logits_processor
  430. self.tokenizer = tokenizer
  431. def __call__(self, tokenized_outputs: torch.Tensor):
  432. # eos mask is computed, skip first ngram_len - 1 tokens
  433. # eos_mask will be of shape [batch_size, output_len]
  434. eos_token_mask = self.logits_processor.compute_eos_token_mask(
  435. input_ids=tokenized_outputs,
  436. eos_token_id=self.tokenizer.eos_token_id,
  437. )[:, self.logits_processor.ngram_len - 1 :]
  438. # context repetition mask is computed
  439. context_repetition_mask = self.logits_processor.compute_context_repetition_mask(
  440. input_ids=tokenized_outputs,
  441. )
  442. # context repetition mask shape [batch_size, output_len - (ngram_len - 1)]
  443. combined_mask = context_repetition_mask * eos_token_mask
  444. g_values = self.logits_processor.compute_g_values(
  445. input_ids=tokenized_outputs,
  446. )
  447. # g values shape [batch_size, output_len - (ngram_len - 1), depth]
  448. return self.detector_module(g_values, combined_mask)