logits_process.py 147 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218
  1. # Copyright 2024 The HuggingFace Inc. team and Google DeepMind.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import math
  16. from collections.abc import Callable, Iterable
  17. from typing import TYPE_CHECKING, Any, cast
  18. import numpy as np
  19. import torch
  20. from .._typing import WhisperGenerationConfigLike
  21. from ..utils import add_start_docstrings
  22. from ..utils.logging import get_logger
  23. # TODO (joao): We shouldn't need this, but there would be a circular import
  24. if TYPE_CHECKING:
  25. from ..generation.configuration_utils import GenerationConfig
  26. logger = get_logger(__name__)
  27. LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
  28. Args:
  29. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  30. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  31. scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
  32. Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
  33. search or log softmax for each vocabulary token when using beam search
  34. Return:
  35. `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
  36. """
  37. class LogitsProcessor:
  38. """Abstract base class for all logit processors that can be applied during generation."""
  39. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  40. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  41. raise NotImplementedError(
  42. f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
  43. )
  44. class LogitsProcessorList(list):
  45. """
  46. This class can be used to create a list of [`LogitsProcessor`] to subsequently process a `scores` input tensor.
  47. This class inherits from list and adds a specific *__call__* method to apply each [`LogitsProcessor`] to the
  48. inputs.
  49. """
  50. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
  51. r"""
  52. Args:
  53. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  54. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  55. scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
  56. Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
  57. beam search or log softmax for each vocabulary token when using beam search
  58. kwargs (`dict[str, Any]`, *optional*):
  59. Additional kwargs that are specific to a logits processor.
  60. Return:
  61. `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
  62. The processed prediction scores.
  63. """
  64. for processor in self:
  65. function_args = inspect.signature(processor.__call__).parameters
  66. if len(function_args) > 2:
  67. if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
  68. raise ValueError(
  69. f"Make sure that all the required parameters: {list(function_args.keys())} for "
  70. f"{processor.__class__} are passed to the logits processor."
  71. )
  72. scores = processor(input_ids, scores, **kwargs)
  73. else:
  74. scores = processor(input_ids, scores)
  75. return scores
  76. def set_continuous_batching_context(self, logits_indices: torch.Tensor, cu_seq_lens_q: torch.Tensor) -> None:
  77. """Forwards the continuous batching metadata to all logit processors that need it."""
  78. for processor in self:
  79. if hasattr(processor, "set_continuous_batching_context"):
  80. processor.set_continuous_batching_context(logits_indices, cu_seq_lens_q)
  81. class MinLengthLogitsProcessor(LogitsProcessor):
  82. r"""
  83. [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. Note that, for decoder-only models
  84. like most LLMs, the length includes the prompt.
  85. Args:
  86. min_length (`int`):
  87. The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
  88. eos_token_id (`Union[int, list[int], torch.Tensor]`):
  89. The id(s) of the *end-of-sequence* token.
  90. device (`str`, *optional*, defaults to `"cpu"`):
  91. The device to allocate the tensors.
  92. Examples:
  93. ```python
  94. >>> from transformers import AutoModelForCausalLM, AutoTokenizer
  95. >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
  96. >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
  97. >>> inputs = tokenizer("A number:", return_tensors="pt")
  98. >>> gen_out = model.generate(**inputs)
  99. >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
  100. A number: one
  101. >>> # setting `min_length` to a value smaller than the uncontrolled output length has no impact
  102. >>> gen_out = model.generate(**inputs, min_length=3)
  103. >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
  104. A number: one
  105. >>> # setting a larger `min_length` will force the model to generate beyond its natural ending point, which is not
  106. >>> # necessarily incorrect
  107. >>> gen_out = model.generate(**inputs, min_length=10)
  108. >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
  109. A number: one thousand, nine hundred and ninety-four
  110. ```
  111. """
  112. def __init__(self, min_length: int, eos_token_id: int | list[int] | torch.Tensor, device: str = "cpu"):
  113. if not isinstance(min_length, int) or min_length < 0:
  114. raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
  115. if not isinstance(eos_token_id, torch.Tensor):
  116. if isinstance(eos_token_id, int):
  117. eos_token_id = [eos_token_id]
  118. eos_token_id = torch.tensor(eos_token_id, device=device)
  119. self.min_length = min_length
  120. self.eos_token_id = eos_token_id
  121. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  122. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  123. vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
  124. eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
  125. scores_processed = scores.clone()
  126. if input_ids.shape[-1] < self.min_length:
  127. scores_processed = torch.where(eos_token_mask, -math.inf, scores)
  128. return scores_processed
  129. class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
  130. r"""
  131. [`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
  132. Contrarily to [`MinLengthLogitsProcessor`], this processor ignores the prompt.
  133. Args:
  134. prompt_length_to_skip (`int`):
  135. The input tokens length. Not a valid argument when used with `generate` as it will automatically assign the
  136. input length.
  137. min_new_tokens (`int`):
  138. The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
  139. eos_token_id (`Union[int, list[int], torch.Tensor]`):
  140. The id(s) of the *end-of-sequence* token.
  141. device (`str`, *optional*, defaults to `"cpu"`):
  142. The device to allocate the tensors.
  143. Examples:
  144. ```python
  145. >>> from transformers import AutoModelForCausalLM, AutoTokenizer
  146. >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
  147. >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
  148. >>> inputs = tokenizer(["A number:"], return_tensors="pt")
  149. >>> gen_out = model.generate(**inputs)
  150. >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
  151. A number: one
  152. >>> # setting `min_new_tokens` will force the model to generate beyond its natural ending point, which is not
  153. >>> # necessarily incorrect
  154. >>> gen_out = model.generate(**inputs, min_new_tokens=2)
  155. >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
  156. A number: one thousand
  157. ```
  158. """
  159. def __init__(
  160. self,
  161. prompt_length_to_skip: int,
  162. min_new_tokens: int,
  163. eos_token_id: int | list[int] | torch.Tensor,
  164. device: str = "cpu",
  165. ):
  166. for arg_name, arg_value in [
  167. ("prompt_length_to_skip", prompt_length_to_skip),
  168. ("min_new_tokens", min_new_tokens),
  169. ]:
  170. if not isinstance(arg_value, int) or arg_value < 0:
  171. raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
  172. if not isinstance(eos_token_id, torch.Tensor):
  173. if isinstance(eos_token_id, int):
  174. eos_token_id = [eos_token_id]
  175. eos_token_id = torch.tensor(eos_token_id, device=device)
  176. self.prompt_length_to_skip = prompt_length_to_skip
  177. self.min_new_tokens = min_new_tokens
  178. self.eos_token_id = eos_token_id
  179. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  180. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  181. new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
  182. scores_processed = scores.clone()
  183. vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
  184. eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
  185. if new_tokens_length < self.min_new_tokens:
  186. scores_processed = torch.where(eos_token_mask, -math.inf, scores)
  187. return scores_processed
  188. class TemperatureLogitsWarper(LogitsProcessor):
  189. r"""
  190. [`LogitsProcessor`] for temperature (exponential scaling output probability distribution), which effectively means
  191. that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and
  192. [`TopKLogitsWarper`].
  193. <Tip>
  194. Make sure that `do_sample=True` is included in the `generate` arguments otherwise the temperature value won't have
  195. any effect.
  196. </Tip>
  197. Args:
  198. temperature (`float`):
  199. Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases
  200. randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely
  201. token.
  202. Examples:
  203. ```python
  204. >>> import torch
  205. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
  206. >>> set_seed(0) # for reproducibility
  207. >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
  208. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  209. >>> model.config.pad_token_id = model.config.eos_token_id
  210. >>> inputs = tokenizer(["Hugging Face Company is"], return_tensors="pt")
  211. >>> # With temperature=1.0, the default, we consistently get random outputs due to random sampling.
  212. >>> generate_kwargs = {"max_new_tokens": 10, "do_sample": True, "temperature": 1.0, "num_return_sequences": 2}
  213. >>> outputs = model.generate(**inputs, **generate_kwargs)
  214. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
  215. ['Hugging Face Company is one of these companies that is going to take a',
  216. "Hugging Face Company is a brand created by Brian A. O'Neil"]
  217. >>> # However, with temperature close to 0, it approximates greedy decoding strategies (invariant)
  218. >>> generate_kwargs["temperature"] = 0.0001
  219. >>> outputs = model.generate(**inputs, **generate_kwargs)
  220. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
  221. ['Hugging Face Company is a company that has been around for over 20 years',
  222. 'Hugging Face Company is a company that has been around for over 20 years']
  223. ```
  224. """
  225. def __init__(self, temperature: float):
  226. if not isinstance(temperature, float) or not (temperature > 0):
  227. except_msg = (
  228. f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
  229. "scores will be invalid."
  230. )
  231. if isinstance(temperature, float) and temperature == 0.0:
  232. except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
  233. raise ValueError(except_msg)
  234. self.temperature = temperature
  235. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  236. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  237. scores_processed = scores / self.temperature
  238. return scores_processed
  239. class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
  240. r"""
  241. [`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
  242. most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt
  243. by default.
  244. In the original [paper](https://huggingface.co/papers/1909.05858), the authors suggest the use of a penalty of around
  245. 1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
  246. repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
  247. repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
  248. Args:
  249. penalty (`float`):
  250. The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
  251. tokens. Between 0.0 and 1.0 rewards previously generated tokens.
  252. prompt_ignore_length (`int`, *optional*):
  253. The original input ids sequence length, which if provided, will not be used in the penalty calculation.
  254. Examples:
  255. ```py
  256. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, RepetitionPenaltyLogitsProcessor
  257. >>> # Initializing the model and tokenizer for it
  258. >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
  259. >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
  260. >>> inputs = tokenizer(["I'm not going to"], return_tensors="pt")
  261. >>> # This shows a normal generate without any specific parameters
  262. >>> summary_ids = model.generate(**inputs)
  263. >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
  264. I'm not going to be able to do that. I'm going to be able to do that
  265. >>> # This generates a penalty for repeated tokens
  266. >>> penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
  267. >>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
  268. I'm not going to be able to do that. I'll just have to go out and play
  269. >>> # We can also exclude the input prompt by creating an instance of this class
  270. >>> # with a `prompt_ignore_length` and passing it as a custom logit processor
  271. >>> rep_pen_processor = RepetitionPenaltyLogitsProcessor(
  272. ... penalty=1.1,
  273. ... prompt_ignore_length=inputs["input_ids"].shape[-1]
  274. ... )
  275. >>> penalized_ids = model.generate(**inputs, logits_processor=[rep_pen_processor])
  276. >>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
  277. I'm not going to be able to do that. I'm going to have to go through a lot of things, and
  278. ```
  279. """
  280. def __init__(self, penalty: float, prompt_ignore_length: int | None = None):
  281. if not isinstance(penalty, float) or not (penalty > 0):
  282. raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
  283. if prompt_ignore_length is not None and (
  284. not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0
  285. ):
  286. raise ValueError(f"`prompt_ignore_length` has to be a positive integer, but is {prompt_ignore_length}")
  287. self.penalty = penalty
  288. self.prompt_ignore_length = prompt_ignore_length
  289. self.logits_indices = None
  290. self.cu_seq_lens_q = None
  291. def set_continuous_batching_context(self, logits_indices: torch.Tensor, cu_seq_lens_q: torch.Tensor):
  292. self.logits_indices = logits_indices
  293. self.cu_seq_lens_q = cu_seq_lens_q
  294. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  295. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  296. if self.prompt_ignore_length:
  297. input_ids = input_ids[:, self.prompt_ignore_length :]
  298. if scores.dim() == 3:
  299. if self.logits_indices is not None and self.cu_seq_lens_q is not None:
  300. last_positions = self.logits_indices
  301. last_scores = scores[0, last_positions, :]
  302. # Prepare token mask
  303. token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
  304. cu_seq_lens = self.cu_seq_lens_q
  305. lengths = cu_seq_lens[1:] - cu_seq_lens[:-1]
  306. seq_indices = torch.repeat_interleave(torch.arange(len(lengths), device=input_ids.device), lengths)
  307. token_mask[seq_indices, input_ids] = True
  308. # Apply penalty
  309. penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
  310. scores[0, last_positions, :] = torch.where(token_mask, penalty_scores, last_scores)
  311. else:
  312. batch_size, seq_len, vocab_size = scores.shape
  313. last_scores = scores[:, -1, :]
  314. token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
  315. if input_ids.dim() == 1:
  316. unique_tokens = torch.unique(input_ids)
  317. token_mask.scatter_(1, unique_tokens.unsqueeze(0), True)
  318. else:
  319. token_mask.scatter_(1, input_ids, True)
  320. # if last_scores < 0 then repetition penalty has to be multiplied to reduce the token probabilities
  321. penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
  322. scores[:, -1, :] = torch.where(token_mask, penalty_scores, last_scores)
  323. return scores
  324. if input_ids.dim() == 1:
  325. input_ids = input_ids.unsqueeze(1)
  326. score = torch.gather(scores, 1, input_ids)
  327. # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
  328. score = torch.where(score < 0, score * self.penalty, score / self.penalty)
  329. scores_processed = scores.scatter(1, input_ids, score)
  330. return scores_processed
  331. class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
  332. r"""
  333. [`LogitsProcessor`] that works similarly to [`RepetitionPenaltyLogitsProcessor`], but with an *inverse* penalty
  334. that is applied to the tokens present in the prompt. In other words, a penalty above 1.0 increases the odds of
  335. selecting tokens that were present in the prompt.
  336. It was designed to avoid hallucination in input-grounded tasks, like summarization. Although originally intended
  337. for encoder-decoder models, it can also be used with decoder-only models like LLMs.
  338. Args:
  339. penalty (`float`):
  340. The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 rewards prompt tokens. Between 0.0
  341. and 1.0 penalizes prompt tokens.
  342. encoder_input_ids (`torch.LongTensor`):
  343. The encoder_input_ids that should be repeated within the decoder ids.
  344. Examples:
  345. ```python
  346. >>> from transformers import AutoModelForCausalLM, AutoTokenizer
  347. >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
  348. >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
  349. >>> inputs = tokenizer(["Alice and Bob. The third member's name was"], return_tensors="pt")
  350. >>> gen_out = model.generate(**inputs)
  351. >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
  352. Alice and Bob. The third member's name was not mentioned.
  353. >>> # With the `encoder_repetition_penalty` argument we can trigger this logits processor in `generate`, which can
  354. >>> # promote the use of prompt tokens ("Bob" in this example)
  355. >>> gen_out = model.generate(**inputs, encoder_repetition_penalty=1.2)
  356. >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
  357. Alice and Bob. The third member's name was Bob. The third member's name was Bob.
  358. ```
  359. """
  360. def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
  361. if not isinstance(penalty, float) or not (penalty > 0):
  362. raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
  363. self.penalty = 1 / penalty
  364. self.encoder_input_ids = encoder_input_ids
  365. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  366. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  367. score = torch.gather(scores, 1, self.encoder_input_ids)
  368. # if score < 0 then hallucination penalty has to be multiplied to increase the token probabilities
  369. score = torch.where(score < 0, score * self.penalty, score / self.penalty)
  370. scores_processed = scores.scatter(1, self.encoder_input_ids, score)
  371. return scores_processed
  372. class TopPLogitsWarper(LogitsProcessor):
  373. """
  374. [`LogitsProcessor`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
  375. Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
  376. Args:
  377. top_p (`float`):
  378. If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
  379. higher are kept for generation.
  380. filter_value (`float`, *optional*, defaults to -inf):
  381. All filtered values will be set to this float value.
  382. min_tokens_to_keep (`int`, *optional*, defaults to 1):
  383. Minimum number of tokens that cannot be filtered.
  384. Examples:
  385. ```python
  386. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
  387. >>> set_seed(1)
  388. >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
  389. >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
  390. >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")
  391. >>> # With sampling, the output is unexpected -- sometimes too unexpected.
  392. >>> outputs = model.generate(**inputs, do_sample=True)
  393. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  394. A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
  395. <BLANKLINE>
  396. <BLANKLINE>
  397. >>> # With `top_p` sampling, the output gets restricted to high-probability tokens.
  398. >>> # Pro tip: In practice, LLMs use `top_p` in the 0.9-0.95 range.
  399. >>> outputs = model.generate(**inputs, do_sample=True, top_p=0.1)
  400. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  401. A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
  402. ```
  403. """
  404. def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
  405. top_p = float(top_p)
  406. if top_p < 0 or top_p > 1.0:
  407. raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
  408. if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
  409. raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
  410. self.top_p = top_p
  411. self.filter_value = filter_value
  412. self.min_tokens_to_keep = min_tokens_to_keep
  413. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  414. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  415. sorted_logits, sorted_indices = torch.sort(scores, descending=False)
  416. cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
  417. # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
  418. sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
  419. # Keep at least min_tokens_to_keep
  420. sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
  421. # scatter sorted tensors to original indexing
  422. indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
  423. scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
  424. return scores_processed
  425. class TopKLogitsWarper(LogitsProcessor):
  426. r"""
  427. [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements. Often used
  428. together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
  429. Args:
  430. top_k (`int`):
  431. The number of highest probability vocabulary tokens to keep for top-k-filtering.
  432. filter_value (`float`, *optional*, defaults to -inf):
  433. All filtered values will be set to this float value.
  434. min_tokens_to_keep (`int`, *optional*, defaults to 1):
  435. Minimum number of tokens that cannot be filtered.
  436. Examples:
  437. ```python
  438. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
  439. >>> set_seed(1)
  440. >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
  441. >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
  442. >>> inputs = tokenizer("A sequence: A, B, C, D", return_tensors="pt")
  443. >>> # With sampling, the output is unexpected -- sometimes too unexpected.
  444. >>> outputs = model.generate(**inputs, do_sample=True)
  445. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  446. A sequence: A, B, C, D, E — S — O, P — R
  447. >>> # With `top_k` sampling, the output gets restricted the k most likely tokens.
  448. >>> # Pro tip: In practice, LLMs use `top_k` in the 5-50 range.
  449. >>> outputs = model.generate(**inputs, do_sample=True, top_k=2)
  450. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  451. A sequence: A, B, C, D, E, F, G, H, I
  452. ```
  453. """
  454. def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
  455. if not isinstance(top_k, int) or top_k <= 0:
  456. raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
  457. self.top_k = max(top_k, min_tokens_to_keep)
  458. self.filter_value = filter_value
  459. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  460. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  461. top_k = min(self.top_k, scores.size(-1)) # Safety check
  462. # Remove all tokens with a probability less than the last token of the top-k
  463. indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
  464. scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
  465. return scores_processed
  466. class TopHLogitsWarper(LogitsProcessor):
  467. """
  468. [`LogitsProcessor`] that implements Top-H sampling, a decoding method which adaptively selects a subset of
  469. high-probability tokens based on entropy and cumulative probability constraints.
  470. This method dynamically determines how many tokens to keep by analyzing the entropy difference of the selected
  471. distribution, thereby balancing exploration and exploitation. It ensures that generated text maintains both
  472. diversity and coherence.
  473. Reference:
  474. For details, see *Top-H Decoding: Adapting the Creativity and Coherence with Bounded Entropy in Text Generation*
  475. (NeurIPS 2025): https://arxiv.org/abs/2509.02510
  476. Args:
  477. top_h (`float`):
  478. Scaling coefficient for the entropy-based threshold (`tau`). Must be in the range `(0, 1]`.
  479. filter_value (`float`, *optional*, defaults to -inf):
  480. All filtered values will be set to this float value.
  481. Example:
  482. ```python
  483. >>> from transformers import AutoTokenizer, AutoModelForCausalLM
  484. >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
  485. >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
  486. >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")
  487. >>> outputs = model.generate(**inputs, do_sample=True, top_h=0.4)
  488. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  489. A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
  490. ```
  491. """
  492. def __init__(self, top_h: float, filter_value: float = -float("Inf")):
  493. super().__init__()
  494. # input checks
  495. if not (0 < top_h <= 1):
  496. raise ValueError("`top_h` must be in the range (0, 1].")
  497. # Maximum number of top tokens to consider before applying the entropy-based filter.
  498. # Acts as a cap for efficiency and numerical stability — increasing this allows more
  499. # tokens to be evaluated but may slow down generation. Default is 100.
  500. self.top_n = 100
  501. self.top_h = top_h
  502. self.filter_value = filter_value
  503. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  504. """
  505. Filters logits using Top-H sampling.
  506. Args:
  507. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  508. Input token IDs.
  509. scores (`torch.FloatTensor` of shape `(batch_size, vocab_size)`):
  510. Raw logits from the model.
  511. Return:
  512. `torch.FloatTensor` of shape `(batch_size, vocab_size)`:
  513. Processed logits where invalid tokens are masked with `-inf`.
  514. """
  515. batch_size, vocab_size = scores.shape
  516. device = scores.device
  517. keep_mask = torch.zeros((batch_size, vocab_size), dtype=torch.bool, device=device)
  518. top_n = min(self.top_n, vocab_size)
  519. # 1. Get top-k logits and indices for the whole batch
  520. top_logits, top_idx = torch.topk(scores, top_n, dim=-1, largest=True, sorted=True)
  521. # 2. Create a batch of categorical distributions
  522. dist = torch.distributions.Categorical(logits=top_logits)
  523. probs = dist.probs
  524. log_probs = torch.log(probs) # dist.log_prob(idx)
  525. # 3. Calculate the entropy-based threshold tau for the whole batch
  526. # We unsqueeze tau to enable broadcasting against the cumulative entropy tensor.
  527. tau = (dist.entropy() * self.top_h).unsqueeze(-1)
  528. # 4. Calculate cumulative entropy using torch.cumsum
  529. # The individual entropy terms (-p * log(p)) are calculated for all top_n tokens at once.
  530. entropy_terms = -probs * log_probs
  531. cumulative_entropy = torch.cumsum(entropy_terms, dim=-1)
  532. # 5. Determine which tokens to keep based on the stopping condition
  533. # Create a boolean mask for the top_n tokens.
  534. # Stopping rule: keep adding tokens in order of probability until the cumulative entropy
  535. # exceeds the threshold τ = H(p) * top_h. This ensures diversity (via entropy) while
  536. # guaranteeing at least the most probable token is always included.
  537. selection_mask = cumulative_entropy <= tau
  538. selection_mask[:, 0] = True
  539. # 6. Update the final keep_mask for the entire batch in one operation
  540. # The scatter_ operation efficiently updates the keep_mask at the indices
  541. # specified by top_idx with the boolean values from selection_mask.
  542. keep_mask.scatter_(dim=1, index=top_idx, src=selection_mask)
  543. # apply filtering
  544. scores_processed = scores.clone()
  545. scores_processed[~keep_mask] = self.filter_value
  546. return scores_processed
  547. class MinPLogitsWarper(LogitsProcessor):
  548. """
  549. [`LogitsProcessor`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
  550. probability of the most likely token. As a result, the filter becomes more aggressive in the presence of
  551. high-probability tokens, which is a sign of a confident output that we shouldn't deviate from.
  552. Often used together with [`TemperatureLogitsWarper`]. Used as an alternative to [`TopPLogitsWarper`] and
  553. [`TopKLogitsWarper`].
  554. Created by @menhguin and @kalomaze (github handles). Code adapted from [this external PR](https://github.com/oobabooga/text-generation-webui/pull/4449/files)
  555. Args:
  556. min_p (`float`):
  557. Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
  558. value between 0 and 1. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in
  559. the 0.99-0.8 range (use the opposite of normal `top_p` values).
  560. filter_value (`float`, *optional*, defaults to -inf):
  561. All filtered values will be set to this float value.
  562. min_tokens_to_keep (`int`, *optional*, defaults to 1):
  563. Minimum number of tokens that cannot be filtered.
  564. Examples:
  565. ```python
  566. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
  567. >>> set_seed(1)
  568. >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
  569. >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
  570. >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")
  571. >>> # With sampling, the output is unexpected -- sometimes too unexpected.
  572. >>> outputs = model.generate(**inputs, do_sample=True)
  573. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  574. A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
  575. <BLANKLINE>
  576. <BLANKLINE>
  577. >>> # With `min_p` sampling, the output gets restricted to high-probability tokens.
  578. >>> # Pro tip: In practice, LLMs use `min_p` in the 0.01-0.2 range.
  579. >>> outputs = model.generate(**inputs, do_sample=True, min_p=0.1)
  580. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  581. A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
  582. ```
  583. """
  584. def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
  585. if not (0 <= min_p <= 1.0):
  586. raise ValueError(f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}")
  587. if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
  588. raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
  589. self.min_p = min_p
  590. self.filter_value = filter_value
  591. self.min_tokens_to_keep = min_tokens_to_keep
  592. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  593. # Convert logits to probabilities
  594. probs = torch.softmax(scores, dim=-1)
  595. # Get the probability of the top token for each sequence in the batch
  596. top_probs = probs.amax(dim=-1, keepdim=True)
  597. # Calculate the actual min_p threshold by scaling min_p with the top token's probability
  598. scaled_min_p = self.min_p * top_probs
  599. # Create a mask for tokens that have a probability less than the scaled min_p
  600. tokens_to_remove = probs < scaled_min_p
  601. # Keep at least min_tokens_to_keep tokens (clip k to vocab size if needed, avoids index out of range)
  602. k = min(self.min_tokens_to_keep, probs.shape[-1])
  603. sorted_indices = torch.topk(probs, k, dim=-1).indices
  604. tokens_to_remove.scatter_(-1, sorted_indices, False)
  605. scores_processed = scores.masked_fill(tokens_to_remove, self.filter_value)
  606. return scores_processed
  607. class TypicalLogitsWarper(LogitsProcessor):
  608. r"""
  609. [`LogitsProcessor`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens
  610. whose log probability is close to the entropy of the token probability distribution. This means that the most
  611. likely tokens may be discarded in the process.
  612. See [Typical Decoding for Natural Language Generation](https://huggingface.co/papers/2202.00666) for more information.
  613. Args:
  614. mass (`float`, *optional*, defaults to 0.9):
  615. Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
  616. filter_value (`float`, *optional*, defaults to -inf):
  617. All filtered values will be set to this float value.
  618. min_tokens_to_keep (`int`, *optional*, defaults to 1):
  619. Minimum number of tokens that cannot be filtered.
  620. Examples:
  621. ```python
  622. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
  623. >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
  624. >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
  625. >>> inputs = tokenizer("1, 2, 3", return_tensors="pt")
  626. >>> # We can see that greedy decoding produces a sequence of numbers
  627. >>> outputs = model.generate(**inputs)
  628. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  629. 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
  630. >>> # For this particular seed, we can see that sampling produces nearly the same low-information (= low entropy)
  631. >>> # sequence
  632. >>> set_seed(18)
  633. >>> outputs = model.generate(**inputs, do_sample=True)
  634. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  635. 1, 2, 3, 4, 5, 6, 7, 8, 9 and 10
  636. >>> # With `typical_p` set, the most obvious sequence is no longer produced, which may be good for your problem
  637. >>> set_seed(18)
  638. >>> outputs = model.generate(
  639. ... **inputs, do_sample=True, typical_p=0.1, return_dict_in_generate=True, output_scores=True
  640. ... )
  641. >>> print(tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0])
  642. 1, 2, 3 and 5
  643. >>> # We can see that the token corresponding to "4" (token 934) in the second position, the most likely token
  644. >>> # as seen with greedy decoding, was entirely blocked out
  645. >>> print(outputs.scores[1][0, 934])
  646. tensor(-inf)
  647. ```
  648. """
  649. def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
  650. mass = float(mass)
  651. if not (mass > 0 and mass < 1):
  652. raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}")
  653. if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
  654. raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
  655. self.filter_value = filter_value
  656. self.mass = mass
  657. self.min_tokens_to_keep = min_tokens_to_keep
  658. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  659. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  660. # calculate entropy
  661. normalized = torch.nn.functional.log_softmax(scores, dim=-1)
  662. p = torch.exp(normalized)
  663. ent = -(normalized * p).nansum(-1, keepdim=True)
  664. # shift and sort
  665. shifted_scores = torch.abs((-normalized) - ent)
  666. sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
  667. sorted_logits = scores.gather(-1, sorted_indices)
  668. cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
  669. # Remove tokens with cumulative mass above the threshold
  670. last_ind = (cumulative_probs < self.mass).sum(dim=1)
  671. last_ind.clamp_(max=sorted_scores.shape[-1] - 1)
  672. sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
  673. sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
  674. indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
  675. scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
  676. return scores_processed
  677. class EpsilonLogitsWarper(LogitsProcessor):
  678. r"""
  679. [`LogitsProcessor`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
  680. largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
  681. Desmoothing](https://huggingface.co/papers/2210.15191) for more information.
  682. Args:
  683. epsilon (`float`):
  684. If set to > 0, only the most tokens with probabilities `epsilon` or higher are kept for generation.
  685. filter_value (`float`, *optional*, defaults to -inf):
  686. All filtered values will be set to this float value.
  687. min_tokens_to_keep (`int`, *optional*, defaults to 1):
  688. Minimum number of tokens that cannot be filtered.
  689. Examples:
  690. ```python
  691. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
  692. >>> set_seed(1)
  693. >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
  694. >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
  695. >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")
  696. >>> # With sampling, the output is unexpected -- sometimes too unexpected.
  697. >>> outputs = model.generate(**inputs, do_sample=True)
  698. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  699. A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
  700. <BLANKLINE>
  701. <BLANKLINE>
  702. >>> # With epsilon sampling, the output gets restricted to high-probability tokens. Note that this is similar to
  703. >>> # Top P sampling, which restricts tokens based on their cumulative probability.
  704. >>> # Pro tip: The paper recommends using `epsilon_cutoff` values between 3e-4 and 9e-4
  705. >>> outputs = model.generate(**inputs, do_sample=True, epsilon_cutoff=0.1)
  706. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  707. A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
  708. ```
  709. """
  710. def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
  711. epsilon = float(epsilon)
  712. if epsilon <= 0 or epsilon >= 1:
  713. raise ValueError(f"`epsilon_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
  714. min_tokens_to_keep = int(min_tokens_to_keep)
  715. if min_tokens_to_keep < 1:
  716. raise ValueError(
  717. f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
  718. )
  719. self.epsilon = epsilon
  720. self.filter_value = filter_value
  721. self.min_tokens_to_keep = min_tokens_to_keep
  722. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  723. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  724. # Determine which indices to remove
  725. probabilities = scores.softmax(dim=-1)
  726. indices_to_remove = probabilities < self.epsilon
  727. # Keep the words with the 'min_tokens_to_keep'-highest probabilities
  728. top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check
  729. indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None])
  730. scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
  731. return scores_processed
  732. class EtaLogitsWarper(LogitsProcessor):
  733. r"""
  734. [`LogitsProcessor`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
  735. cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of
  736. the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest
  737. min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
  738. samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
  739. Sampling as Language Model Desmoothing](https://huggingface.co/papers/2210.15191) for more information. Note: `do_sample`
  740. must be set to `True` for this `LogitsProcessor` to work.
  741. Args:
  742. epsilon (`float`):
  743. A float value in the range (0, 1). Hyperparameter used to calculate the dynamic cutoff value, `eta`. The
  744. suggested values from the paper ranges from 3e-4 to 4e-3 depending on the size of the model.
  745. filter_value (`float`, *optional*, defaults to -inf):
  746. All values that are found to be below the dynamic cutoff value, `eta`, are set to this float value. This
  747. parameter is useful when logits need to be modified for very low probability tokens that should be excluded
  748. from generation entirely.
  749. min_tokens_to_keep (`int`, *optional*, defaults to 1):
  750. Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
  751. For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
  752. even if all tokens have probabilities below the cutoff `eta`.
  753. device (`str`, *optional*, defaults to `"cpu"`):
  754. The device to allocate the tensors.
  755. Examples:
  756. ```python
  757. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
  758. >>> set_seed(1)
  759. >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
  760. >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
  761. >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")
  762. >>> # With sampling, the output is unexpected -- sometimes too unexpected.
  763. >>> outputs = model.generate(**inputs, do_sample=True)
  764. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  765. A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
  766. <BLANKLINE>
  767. <BLANKLINE>
  768. >>> # With eta sampling, the output gets restricted to high-probability tokens. You can see it as a dynamic form of
  769. >>> # epsilon sampling that adapts its cutoff probability based on the entropy (high entropy = lower cutoff).
  770. >>> # Pro tip: The paper recommends using `eta_cutoff` values between 3e-4 to 4e-3
  771. >>> outputs = model.generate(**inputs, do_sample=True, eta_cutoff=0.1)
  772. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  773. A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
  774. ```
  775. """
  776. def __init__(
  777. self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, device: str = "cpu"
  778. ):
  779. epsilon = float(epsilon)
  780. if epsilon <= 0 or epsilon >= 1:
  781. raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
  782. min_tokens_to_keep = int(min_tokens_to_keep)
  783. if min_tokens_to_keep < 1:
  784. raise ValueError(
  785. f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
  786. )
  787. self.epsilon = torch.tensor(epsilon, device=device)
  788. self.filter_value = filter_value
  789. self.min_tokens_to_keep = min_tokens_to_keep
  790. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  791. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  792. probabilities = scores.softmax(dim=-1)
  793. entropy = torch.distributions.Categorical(logits=scores).entropy()
  794. eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
  795. indices_to_remove = probabilities < eta
  796. # Keep the words with the 'min_tokens_to_keep'-highest probabilities
  797. top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check
  798. indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None])
  799. scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
  800. return scores_processed
  801. def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
  802. """
  803. Assume ngram_size=2 and prev_input_ids=tensor([[40, 2883, 2712, 4346]]). The output of generated ngrams look like
  804. this {(40,): [2883], (2883,): [2712], (2712,): [4346]}.
  805. Args:
  806. ngram_size (`int`):
  807. The number sequential tokens taken as a group which may only occur once before being banned.
  808. prev_input_ids (`torch.Tensor`):
  809. Generated token ids for the current hypothesis.
  810. num_hypos (`int`):
  811. The number of hypotheses for which n-grams need to be generated.
  812. Returns:
  813. generated_ngrams (`dict`):
  814. Dictionary of generated ngrams.
  815. """
  816. # Initialize an empty list of dictionaries, one for each hypothesis (index) in the range of num_hypos
  817. generated_ngrams = [{} for _ in range(num_hypos)]
  818. for idx in range(num_hypos):
  819. gen_tokens = prev_input_ids[idx].tolist()
  820. generated_ngram = generated_ngrams[idx]
  821. # Loop through each n-gram of size ngram_size in the list of tokens (gen_tokens)
  822. for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
  823. prev_ngram_tuple = tuple(ngram[:-1])
  824. generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
  825. return generated_ngrams
  826. def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
  827. """
  828. Determines the banned tokens for the current hypothesis based on previously generated n-grams.
  829. Args:
  830. banned_ngrams (`dict`):
  831. A dictionary containing previously generated n-grams for each hypothesis.
  832. prev_input_ids (`torch.Tensor`):
  833. Generated token ids for the current hypothesis.
  834. ngram_size (`int`):
  835. The number sequential tokens taken as a group which may only occur once before being banned.
  836. cur_len (`int`):
  837. The current length of the token sequences for which the n-grams are being checked.
  838. Returns:
  839. List of tokens that are banned.
  840. """
  841. # Before decoding the next token, prevent decoding of ngrams that have already appeared
  842. start_idx = cur_len + 1 - ngram_size
  843. ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
  844. return banned_ngrams.get(ngram_idx, [])
  845. def _calc_banned_ngram_tokens(
  846. ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
  847. ) -> list[Iterable[int]]:
  848. """Copied from fairseq for no_repeat_ngram in beam_search"""
  849. if cur_len + 1 < ngram_size:
  850. # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
  851. return [[] for _ in range(num_hypos)]
  852. generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
  853. banned_tokens = [
  854. _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
  855. for hypo_idx in range(num_hypos)
  856. ]
  857. return banned_tokens
  858. class NoRepeatNGramLogitsProcessor(LogitsProcessor):
  859. r"""
  860. N-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
  861. sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). In text generation,
  862. avoiding repetitions of word sequences provides a more diverse output. This [`LogitsProcessor`] enforces no
  863. repetition of n-grams by setting the scores of banned tokens to negative infinity which eliminates those tokens
  864. from consideration when further processing the scores. Note that, for decoder-only models like most LLMs, the
  865. prompt is also considered to obtain the n-grams.
  866. [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
  867. <Tip>
  868. Use n-gram penalties with care. For instance, penalizing 2-grams (bigrams) in an article about the city of New York
  869. might lead to undesirable outcomes where the city's name appears only once in the entire text.
  870. [Reference](https://huggingface.co/blog/how-to-generate)
  871. </Tip>
  872. Args:
  873. ngram_size (`int`):
  874. All ngrams of size `ngram_size` can only occur once.
  875. Examples:
  876. ```py
  877. >>> from transformers import AutoTokenizer, AutoModelForCausalLM
  878. >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
  879. >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
  880. >>> inputs = tokenizer(["Today I"], return_tensors="pt")
  881. >>> output = model.generate(**inputs)
  882. >>> print(tokenizer.decode(output[0], skip_special_tokens=True))
  883. Today I'm not sure if I'm going to be able to do it.
  884. >>> # Now let's add ngram size using `no_repeat_ngram_size`. This stops the repetitions ("I'm") in the output.
  885. >>> output = model.generate(**inputs, no_repeat_ngram_size=2)
  886. >>> print(tokenizer.decode(output[0], skip_special_tokens=True))
  887. Today I'm not sure if I can get a better understanding of the nature of this issue
  888. ```
  889. """
  890. def __init__(self, ngram_size: int):
  891. if not isinstance(ngram_size, int) or ngram_size <= 0:
  892. raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
  893. self.ngram_size = ngram_size
  894. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  895. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  896. num_batch_hypotheses = scores.shape[0]
  897. cur_len = input_ids.shape[-1]
  898. scores_processed = scores.clone()
  899. banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
  900. for i, banned_tokens in enumerate(banned_batch_tokens):
  901. scores_processed[i, banned_tokens] = -float("inf")
  902. return scores_processed
  903. class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
  904. r"""
  905. [`LogitsProcessor`] that works similarly to [`NoRepeatNGramLogitsProcessor`], but applied exclusively to prevent
  906. the repetition of n-grams present in the prompt.
  907. It was designed to promote chattiness in a language model, by preventing the generation of n-grams present in
  908. previous conversation rounds.
  909. Args:
  910. encoder_ngram_size (`int`):
  911. All ngrams of size `ngram_size` can only occur within the encoder input ids.
  912. encoder_input_ids (`int`):
  913. The encoder_input_ids that should not be repeated within the decoder ids.
  914. Examples:
  915. ```py
  916. >>> from transformers import AutoTokenizer, AutoModelForCausalLM
  917. >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
  918. >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
  919. >>> inputs = tokenizer("Alice: I love cats. What do you love?\nBob:", return_tensors="pt")
  920. >>> # With greedy decoding, we see Bob repeating Alice's opinion. If Bob was a chatbot, it would be a poor one.
  921. >>> outputs = model.generate(**inputs)
  922. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  923. Alice: I love cats. What do you love?
  924. Bob: I love cats. What do you
  925. >>> # With this logits processor, we can prevent Bob from repeating Alice's opinion.
  926. >>> outputs = model.generate(**inputs, encoder_no_repeat_ngram_size=2)
  927. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  928. Alice: I love cats. What do you love?
  929. Bob: My cats are very cute.
  930. ```
  931. """
  932. def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor):
  933. if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0:
  934. raise ValueError(
  935. f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}"
  936. )
  937. self.ngram_size = encoder_ngram_size
  938. if len(encoder_input_ids.shape) == 1:
  939. encoder_input_ids = encoder_input_ids.unsqueeze(0)
  940. self.batch_size = encoder_input_ids.shape[0]
  941. self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
  942. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  943. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  944. # B x num_beams
  945. num_hypos = scores.shape[0]
  946. num_beams = num_hypos // self.batch_size
  947. cur_len = input_ids.shape[-1]
  948. scores_processed = scores.clone()
  949. banned_batch_tokens = [
  950. _get_generated_ngrams(
  951. self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len
  952. )
  953. for hypo_idx in range(num_hypos)
  954. ]
  955. for i, banned_tokens in enumerate(banned_batch_tokens):
  956. scores_processed[i, banned_tokens] = -float("inf")
  957. return scores_processed
  958. class SequenceBiasLogitsProcessor(LogitsProcessor):
  959. """
  960. [`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence
  961. when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than
  962. one token, consider using beam methods (to gracefully work around partially completed sequences that have a
  963. negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier).
  964. <Tip>
  965. At a token-level, biasing a word is different from biasing a word with a space before it. If you want to bias
  966. "foo" mid-sentence, you'll likely want to add a prefix space and bias " foo" instead. Check the tokenizer section
  967. of our NLP course to find out why: https://huggingface.co/learn/nlp-course/chapter2/4?fw=pt
  968. </Tip>
  969. Args:
  970. sequence_bias (`list[list[Union[list[int], float]]]`):
  971. List of lists that maps a sequence of tokens to its bias term (e.g. `[[[10, 45], -2.0],
  972. [[64], -7.5]]`). Positive biases increase the odds of the
  973. sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
  974. will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
  975. completed (in the token selection step after this processor is applied).
  976. Examples:
  977. ```python
  978. >>> from transformers import AutoTokenizer, AutoModelForCausalLM
  979. >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
  980. >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
  981. >>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
  982. >>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4, do_sample=False)
  983. >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
  984. The full name of Donald is Donald John Trump Sr.
  985. >>> def get_tokens(word):
  986. ... return tokenizer([word], add_special_tokens=False).input_ids[0]
  987. >>> # IMPORTANT: Remember our tip about adding spaces before words to bias them correctly.
  988. >>> sequence_bias = [[get_tokens("Trump"), -10.0],] # will fail to apply bias
  989. >>> biased_ids = model.generate(
  990. ... inputs["input_ids"], max_new_tokens=4, do_sample=False, sequence_bias=sequence_bias
  991. ... )
  992. >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
  993. The full name of Donald is Donald John Trump Sr.
  994. >>> sequence_bias = [[get_tokens(" Trump"), -10.0],] # will work
  995. >>> biased_ids = model.generate(
  996. ... inputs["input_ids"], max_new_tokens=4, do_sample=False, sequence_bias=sequence_bias
  997. ... )
  998. >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
  999. The full name of Donald is Donald John Harper. He
  1000. >>> # We can also add a positive bias to nudge the model towards specific tokens or continuations. This technique
  1001. >>> # is also more effective when paired up with beam search.
  1002. >>> sequence_bias = [[get_tokens(" Donald Duck"), 10.0],]
  1003. >>> biased_ids = model.generate(
  1004. ... inputs["input_ids"], max_new_tokens=4, num_beams=4, do_sample=False, sequence_bias=sequence_bias
  1005. ... )
  1006. >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
  1007. The full name of Donald is Donald Duck. He is
  1008. ```
  1009. """
  1010. def __init__(self, sequence_bias: list[list[list[int] | float]]):
  1011. # After _convert_list_arguments_into_dict(), becomes dict[tuple[int, ...], float]
  1012. self.sequence_bias: Any = sequence_bias
  1013. self._validate_arguments()
  1014. self._convert_list_arguments_into_dict()
  1015. # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
  1016. # is inferred in the first usage, which inhibits initializing here)
  1017. self.length_1_bias = None
  1018. self.prepared_bias_variables = False
  1019. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1020. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1021. # 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
  1022. if not self.prepared_bias_variables:
  1023. self._prepare_bias_variables(scores)
  1024. # 2 - prepares an empty bias to add
  1025. bias = torch.zeros_like(scores)
  1026. # 3 - include the bias from length = 1
  1027. bias += self.length_1_bias
  1028. # 4 - include the bias from length > 1, after determining which biased sequences may be completed.
  1029. for sequence_ids, sequence_bias in self.sequence_bias.items():
  1030. if len(sequence_ids) == 1: # the sequence is of length 1, already applied
  1031. continue
  1032. if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore
  1033. continue
  1034. prefix_length = len(sequence_ids) - 1
  1035. last_token = sequence_ids[-1]
  1036. matching_rows = torch.eq(
  1037. input_ids[:, -prefix_length:],
  1038. torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device),
  1039. ).prod(dim=1)
  1040. bias[:, last_token] += torch.where(
  1041. matching_rows.bool(),
  1042. torch.tensor(sequence_bias, device=input_ids.device),
  1043. torch.tensor(0.0, device=input_ids.device),
  1044. )
  1045. # 5 - apply the bias to the scores
  1046. scores_processed = scores + bias
  1047. return scores_processed
  1048. def _prepare_bias_variables(self, scores: torch.FloatTensor):
  1049. vocabulary_size = scores.shape[-1]
  1050. # Check biased tokens out of bounds
  1051. invalid_biases = []
  1052. for sequence_ids in self.sequence_bias:
  1053. for token_id in sequence_ids:
  1054. if token_id >= vocabulary_size:
  1055. invalid_biases.append(token_id)
  1056. if len(invalid_biases) > 0:
  1057. raise ValueError(
  1058. f"The model vocabulary size is {vocabulary_size}, but the following tokens were being biased: "
  1059. f"{invalid_biases}"
  1060. )
  1061. # Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
  1062. # with simpler logic.
  1063. self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float, device=scores.device)
  1064. # Extract single-token sequences and their biases
  1065. single_token_ids = []
  1066. single_token_biases = []
  1067. for sequence_ids, bias in self.sequence_bias.items():
  1068. if len(sequence_ids) == 1:
  1069. single_token_ids.append(sequence_ids[0])
  1070. single_token_biases.append(bias)
  1071. if single_token_ids: # Only if we have any single-token sequences
  1072. self.length_1_bias[single_token_ids] = torch.tensor(single_token_biases, device=scores.device)
  1073. self.prepared_bias_variables = True
  1074. def _validate_arguments(self):
  1075. sequence_bias = self.sequence_bias
  1076. if not isinstance(sequence_bias, dict) and not isinstance(sequence_bias, list) or len(sequence_bias) == 0:
  1077. raise ValueError(
  1078. f"`sequence_bias` has to be a non-empty dictionary, or non-empty list of lists but is {sequence_bias}."
  1079. )
  1080. if isinstance(sequence_bias, dict) and any(
  1081. not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias
  1082. ):
  1083. raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.")
  1084. if isinstance(sequence_bias, dict) and any(
  1085. any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids)
  1086. or len(sequence_ids) == 0
  1087. for sequence_ids in sequence_bias
  1088. ):
  1089. raise ValueError(
  1090. f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
  1091. f"{sequence_bias}."
  1092. )
  1093. def all_token_bias_pairs_are_valid(sequence):
  1094. return (
  1095. isinstance(sequence[0], list)
  1096. and all(isinstance(token_id, (int, np.integer)) and token_id > 0 for token_id in sequence[0])
  1097. and isinstance(sequence[1], float)
  1098. )
  1099. if isinstance(sequence_bias, list) and any(
  1100. (not all_token_bias_pairs_are_valid(sequence)) or len(sequence) == 0 for sequence in sequence_bias
  1101. ):
  1102. raise ValueError(
  1103. f"Each element in `sequence_bias` has to be a non-empty list of lists of positive integers and float, but is "
  1104. f"{sequence_bias}."
  1105. )
  1106. if isinstance(sequence_bias, dict) and any(not isinstance(bias, float) for bias in sequence_bias.values()):
  1107. raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.")
  1108. def _convert_list_arguments_into_dict(self):
  1109. """BC: we used to accept `dict{tuple of tokens: float}` directly, now we expect a list"""
  1110. if isinstance(self.sequence_bias, list):
  1111. temp_sequence = self.sequence_bias
  1112. self.sequence_bias = {tuple(sublist[0]): sublist[1] for sublist in temp_sequence}
  1113. class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
  1114. """
  1115. [`LogitsProcessor`] that enforces that specified sequences will never be selected.
  1116. <Tip>
  1117. In order to get the token ids of the words that should not appear in the generated text, make sure to set
  1118. `add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words,
  1119. add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers,
  1120. as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more
  1121. [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
  1122. </Tip>
  1123. Args:
  1124. bad_words_ids (`list[list[int]]`):
  1125. List of list of token ids that are not allowed to be generated.
  1126. eos_token_id (`Union[int, list[int], torch.Tensor]`, *optional*):
  1127. The id(s) of the *end-of-sequence* token.
  1128. Examples:
  1129. ```python
  1130. >>> from transformers import AutoTokenizer, AutoModelForCausalLM
  1131. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  1132. >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
  1133. >>> inputs = tokenizer(["In a word, the cake is a"], return_tensors="pt")
  1134. >>> output_ids = model.generate(inputs["input_ids"], max_new_tokens=5, pad_token_id=tokenizer.eos_token_id)
  1135. >>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
  1136. In a word, the cake is a bit of a mess.
  1137. >>> # Now let's take the bad words out. Please note that the tokenizer is initialized differently
  1138. >>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)
  1139. >>> def get_tokens_as_list(word_list):
  1140. ... "Converts a sequence of words into a list of tokens"
  1141. ... tokens_list = []
  1142. ... for word in word_list:
  1143. ... tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
  1144. ... tokens_list.append(tokenized_word)
  1145. ... return tokens_list
  1146. >>> bad_words_ids = get_tokens_as_list(word_list=["mess"])
  1147. >>> output_ids = model.generate(
  1148. ... inputs["input_ids"], max_new_tokens=5, bad_words_ids=bad_words_ids, pad_token_id=tokenizer.eos_token_id
  1149. ... )
  1150. >>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
  1151. In a word, the cake is a bit of a surprise.
  1152. ```
  1153. """
  1154. def __init__(self, bad_words_ids: list[list[int]], eos_token_id: int | list[int] | torch.Tensor | None = None):
  1155. self.bad_word_ids = bad_words_ids
  1156. self._validate_arguments()
  1157. # Filter EOS token from bad_words_ids
  1158. if eos_token_id is not None:
  1159. if not isinstance(eos_token_id, torch.Tensor):
  1160. if isinstance(eos_token_id, int):
  1161. eos_token_id = [eos_token_id]
  1162. eos_token_id = torch.tensor(eos_token_id)
  1163. eos_token_id_list = eos_token_id.tolist() # convert to python list before
  1164. bad_words_ids = list(
  1165. filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id_list), bad_words_ids)
  1166. )
  1167. # Forbidding a sequence is equivalent to setting its bias to -inf
  1168. sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
  1169. super().__init__(sequence_bias=sequence_bias)
  1170. def _validate_arguments(self):
  1171. bad_words_ids = self.bad_word_ids
  1172. if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0:
  1173. raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
  1174. if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
  1175. raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
  1176. if any(
  1177. any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
  1178. for bad_word_ids in bad_words_ids
  1179. ):
  1180. raise ValueError(
  1181. f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
  1182. )
  1183. class PrefixConstrainedLogitsProcessor(LogitsProcessor):
  1184. r"""
  1185. [`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained
  1186. generation. See [Autoregressive Entity Retrieval](https://huggingface.co/papers/2010.00904) for more information.
  1187. Args:
  1188. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`):
  1189. This function constraints the beam search to allowed tokens only at each step. This function takes 2
  1190. arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the
  1191. next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID
  1192. `batch_id`.
  1193. Examples:
  1194. ```py
  1195. >>> from transformers import AutoTokenizer, AutoModelForCausalLM
  1196. >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
  1197. >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
  1198. >>> inputs = tokenizer("Alice and Bob", return_tensors="pt")
  1199. >>> # By default, it continues generating according to the model's logits
  1200. >>> outputs = model.generate(**inputs, max_new_tokens=5)
  1201. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  1202. Alice and Bob are friends
  1203. >>> # We can constrain it with `prefix_allowed_tokens_fn` to force a certain behavior based on a prefix.
  1204. >>> # For instance, we can force an entire entity to be generated when its beginning is detected.
  1205. >>> entity = tokenizer(" Bob Marley", return_tensors="pt").input_ids[0] # 3 tokens
  1206. >>> def prefix_allowed_tokens_fn(batch_id, input_ids):
  1207. ... '''
  1208. ... Attempts to generate 'Bob Marley' when 'Bob' is detected.
  1209. ... In this case, `batch_id` is not used, but you can set rules for each batch member.
  1210. ... '''
  1211. ... if input_ids[-1] == entity[0]:
  1212. ... return [entity[1].item()]
  1213. ... elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]:
  1214. ... return [entity[2].item()]
  1215. ... return list(range(tokenizer.vocab_size)) # If no match, allow all tokens
  1216. >>> outputs = model.generate(**inputs, max_new_tokens=5, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn)
  1217. >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
  1218. Alice and Bob Marley
  1219. ```
  1220. """
  1221. def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]], num_beams: int):
  1222. self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
  1223. self._num_beams = num_beams
  1224. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1225. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1226. mask = torch.full_like(scores, -math.inf)
  1227. batch_size = input_ids.shape[0] // self._num_beams
  1228. for batch_id in range(batch_size):
  1229. for beam_id in range(self._num_beams):
  1230. sent = input_ids[batch_id * self._num_beams + beam_id]
  1231. prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
  1232. if len(prefix_allowed_tokens) == 0:
  1233. raise ValueError(
  1234. f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
  1235. f"This means that the constraint is unsatisfiable. Please check your implementation"
  1236. f"of `prefix_allowed_tokens_fn` "
  1237. )
  1238. mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0
  1239. scores_processed = scores + mask
  1240. return scores_processed
  1241. class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
  1242. r"""
  1243. [`LogitsProcessor`] that enforces the specified token as the first generated token. Used with encoder-decoder
  1244. models.
  1245. Args:
  1246. bos_token_id (`int`):
  1247. The id of the token to force as the first generated token.
  1248. Examples:
  1249. ```python
  1250. >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
  1251. >>> model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
  1252. >>> tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
  1253. >>> inputs = tokenizer("Translate from English to German: I love cats.", return_tensors="pt")
  1254. >>> # By default, it continues generating according to the model's logits
  1255. >>> outputs = model.generate(**inputs, max_new_tokens=10)
  1256. >>> print(tokenizer.batch_decode(outputs)[0])
  1257. <pad> Ich liebe Kitty.</s>
  1258. >>> # We can use `forced_bos_token_id` to force the start of generation with an encoder-decoder model
  1259. >>> # (including forcing it to end straight away with an EOS token)
  1260. >>> outputs = model.generate(**inputs, max_new_tokens=10, forced_bos_token_id=tokenizer.eos_token_id)
  1261. >>> print(tokenizer.batch_decode(outputs)[0])
  1262. <pad></s>
  1263. ```
  1264. """
  1265. def __init__(self, bos_token_id: int):
  1266. self.bos_token_id = bos_token_id
  1267. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1268. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1269. cur_len = input_ids.shape[-1]
  1270. scores_processed = scores
  1271. if cur_len == 1:
  1272. scores_processed = torch.full_like(scores, -math.inf)
  1273. scores_processed[:, self.bos_token_id] = 0
  1274. return scores_processed
  1275. class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
  1276. r"""
  1277. [`LogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
  1278. Args:
  1279. max_length (`int`):
  1280. The maximum length of the sequence to be generated.
  1281. eos_token_id (`Union[int, list[int], torch.Tensor]`):
  1282. The id(s) of the *end-of-sequence* token.
  1283. device (`str`, *optional*, defaults to `"cpu"`):
  1284. The device to allocate the tensors.
  1285. Examples:
  1286. ```python
  1287. >>> from transformers import AutoTokenizer, AutoModelForCausalLM
  1288. >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
  1289. >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
  1290. >>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt")
  1291. >>> # By default, it continues generating according to the model's logits
  1292. >>> outputs = model.generate(**inputs, max_new_tokens=10)
  1293. >>> print(tokenizer.batch_decode(outputs)[0])
  1294. A sequence: 1, 2, 3, 4, 5, 6, 7, 8
  1295. >>> # `forced_eos_token_id` ensures the generation ends with a EOS token
  1296. >>> outputs = model.generate(**inputs, max_new_tokens=10, forced_eos_token_id=tokenizer.eos_token_id)
  1297. >>> print(tokenizer.batch_decode(outputs)[0])
  1298. A sequence: 1, 2, 3, 4, 5, 6, 7,<|endoftext|>
  1299. ```
  1300. """
  1301. def __init__(self, max_length: int, eos_token_id: int | list[int] | torch.Tensor, device: str = "cpu"):
  1302. self.max_length = max_length
  1303. if not isinstance(eos_token_id, torch.Tensor):
  1304. if isinstance(eos_token_id, int):
  1305. eos_token_id = [eos_token_id]
  1306. eos_token_id = torch.tensor(eos_token_id, device=device)
  1307. self.eos_token_id = eos_token_id
  1308. if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
  1309. raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
  1310. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1311. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1312. cur_len = input_ids.shape[-1]
  1313. scores_processed = scores
  1314. if cur_len == self.max_length - 1:
  1315. scores_processed = torch.full_like(scores, -math.inf)
  1316. scores_processed[:, self.eos_token_id] = 0
  1317. return scores_processed
  1318. class InfNanRemoveLogitsProcessor(LogitsProcessor):
  1319. r"""
  1320. [`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using
  1321. the logits processor should only be used if necessary since it can slow down the generation method.
  1322. This logits processor has no `generate` example, as there shouldn't be a correct combination of flags that warrants
  1323. its use.
  1324. """
  1325. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1326. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1327. # set all nan values to 0.0
  1328. scores_processed = torch.where(scores != scores, 0.0, scores)
  1329. # set all +/-inf values to max/min possible value
  1330. scores_processed = torch.where(scores == float("inf"), torch.finfo(scores.dtype).max, scores_processed)
  1331. scores_processed = torch.where(scores == -float("inf"), torch.finfo(scores.dtype).min, scores_processed)
  1332. return scores_processed
  1333. class ExponentialDecayLengthPenalty(LogitsProcessor):
  1334. r"""
  1335. [`LogitsProcessor`] that exponentially increases the score of the `eos_token_id` after `start_index` has been
  1336. reached. This allows generating shorter sequences without having a hard cutoff, allowing the `eos_token` to be
  1337. predicted in a meaningful position.
  1338. Args:
  1339. exponential_decay_length_penalty (`tuple(int, float)`):
  1340. This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
  1341. starts and `decay_factor` represents the factor of exponential decay
  1342. eos_token_id (`Union[int, list[int], torch.Tensor]`):
  1343. The id(s) of the *end-of-sequence* token.
  1344. input_ids_seq_length (`int`):
  1345. The length of the input sequence.
  1346. Examples:
  1347. ```python
  1348. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
  1349. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  1350. >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
  1351. >>> text = "Just wanted to let you know, I"
  1352. >>> inputs = tokenizer(text, return_tensors="pt")
  1353. >>> # Let's consider that we want short sentences, so we limit `max_length=30`. However, we observe that the answer
  1354. >>> # tends to end abruptly.
  1355. >>> set_seed(1)
  1356. >>> outputs = model.generate(**inputs, do_sample=True, temperature=0.9, max_length=30, pad_token_id=50256)
  1357. >>> print(tokenizer.batch_decode(outputs)[0])
  1358. Just wanted to let you know, I received a link to an ebook, the book How To Start A Social Network which was
  1359. published in 2010. Although
  1360. >>> # To promote the appearance of the EOS token at the right time, we add the `exponential_decay_length_penalty =
  1361. >>> # (start_index, decay_factor)`. Instead of cutting at max_tokens, the output comes to an end before and usually
  1362. >>> # with more meaning. What happens is that starting from `start_index` the EOS token score will be increased
  1363. >>> # by `decay_factor` exponentially. However, if you set a high decay factor, you may also end up with abruptly
  1364. >>> # ending sequences.
  1365. >>> set_seed(1)
  1366. >>> outputs = model.generate(
  1367. ... **inputs,
  1368. ... do_sample=True,
  1369. ... temperature=0.9,
  1370. ... max_length=30,
  1371. ... pad_token_id=50256,
  1372. ... exponential_decay_length_penalty=(15, 1.6),
  1373. ... )
  1374. >>> print(tokenizer.batch_decode(outputs)[0])
  1375. Just wanted to let you know, I received a link to an ebook, the book How To Start A Social Network
  1376. which<|endoftext|>
  1377. >>> # With a small decay factor, you will have a higher chance of getting a meaningful sequence.
  1378. >>> set_seed(1)
  1379. >>> outputs = model.generate(
  1380. ... **inputs,
  1381. ... do_sample=True,
  1382. ... temperature=0.9,
  1383. ... max_length=30,
  1384. ... pad_token_id=50256,
  1385. ... exponential_decay_length_penalty=(15, 1.01),
  1386. ... )
  1387. >>> print(tokenizer.batch_decode(outputs)[0])
  1388. Just wanted to let you know, I received a link to an ebook, the book How To Start A Social Network which was
  1389. published in 2010.<|endoftext|>
  1390. ```
  1391. """
  1392. def __init__(
  1393. self,
  1394. exponential_decay_length_penalty: tuple[int, float],
  1395. eos_token_id: int | list[int] | torch.Tensor,
  1396. input_ids_seq_length: int,
  1397. ):
  1398. self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
  1399. self.regulation_factor = exponential_decay_length_penalty[1]
  1400. if not isinstance(eos_token_id, torch.Tensor):
  1401. if isinstance(eos_token_id, int):
  1402. eos_token_id = [eos_token_id]
  1403. eos_token_id = torch.tensor(eos_token_id)
  1404. self.eos_token_id = eos_token_id
  1405. if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
  1406. raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
  1407. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1408. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1409. cur_len = input_ids.shape[-1]
  1410. self.eos_token_id = self.eos_token_id.to(scores.device)
  1411. penalties = torch.zeros_like(scores)
  1412. scores_processed = scores
  1413. if cur_len > self.regulation_start:
  1414. penalty_idx = cur_len - self.regulation_start
  1415. # To support negative logits we compute the penalty of the absolute value and add to the original logit
  1416. penalty = torch.abs(scores[:, self.eos_token_id]) * (pow(self.regulation_factor, penalty_idx) - 1)
  1417. penalties[:, self.eos_token_id] = penalty
  1418. scores_processed = scores + penalties
  1419. return scores_processed
  1420. class LogitNormalization(LogitsProcessor):
  1421. r"""
  1422. [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
  1423. the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
  1424. this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
  1425. the scores are normalized when comparing the hypotheses.
  1426. Examples:
  1427. ```python
  1428. >>> from transformers import AutoTokenizer, AutoModelForCausalLM
  1429. >>> import torch
  1430. >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
  1431. >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
  1432. >>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt")
  1433. >>> # By default, the scores are not normalized -- the sum of their exponentials is NOT a normalized probability
  1434. >>> # distribution, summing to 1
  1435. >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
  1436. >>> print(torch.allclose(torch.sum(torch.exp(outputs.scores[-1])), torch.Tensor((1.000,)), rtol=1e-4))
  1437. False
  1438. >>> # Normalizing them may have a positive impact on beam methods, or when using the scores on your application
  1439. >>> outputs = model.generate(**inputs, renormalize_logits=True, return_dict_in_generate=True, output_scores=True)
  1440. >>> print(torch.allclose(torch.sum(torch.exp(outputs.scores[-1])), torch.Tensor((1.000,)), rtol=1e-4))
  1441. True
  1442. ```
  1443. """
  1444. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1445. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1446. scores_processed = scores.log_softmax(dim=-1)
  1447. return scores_processed
  1448. class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
  1449. r"""
  1450. [`SuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts
  1451. generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are
  1452. not generated at the beginning. Originally created for
  1453. [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
  1454. Examples:
  1455. ```python
  1456. >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
  1457. >>> from datasets import load_dataset
  1458. >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
  1459. >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
  1460. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1461. >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
  1462. >>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means
  1463. >>> # it can't generate and EOS token in the first iteration, but it can in the others.
  1464. >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
  1465. >>> print(outputs.scores[0][0, 50256])
  1466. tensor(-inf)
  1467. >>> print(outputs.scores[-1][0, 50256]) # in other places we can see some probability mass for EOS
  1468. tensor(29.9010)
  1469. >>> # If we disable `begin_suppress_tokens`, we can generate EOS in the first iteration.
  1470. >>> outputs = model.generate(
  1471. ... **inputs, return_dict_in_generate=True, output_scores=True, begin_suppress_tokens=None
  1472. ... )
  1473. >>> print(outputs.scores[0][0, 50256])
  1474. tensor(11.2027)
  1475. ```
  1476. """
  1477. def __init__(self, begin_suppress_tokens, begin_index, device: str = "cpu"):
  1478. self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens), device=device)
  1479. self.begin_index = begin_index
  1480. def set_begin_index(self, begin_index):
  1481. self.begin_index = begin_index
  1482. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1483. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1484. vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
  1485. suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
  1486. scores_processed = scores
  1487. if input_ids.shape[-1] == self.begin_index:
  1488. scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
  1489. return scores_processed
  1490. class SuppressTokensLogitsProcessor(LogitsProcessor):
  1491. r"""
  1492. This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so
  1493. that they are not generated. Originally created for
  1494. [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
  1495. Examples:
  1496. ```python
  1497. >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
  1498. >>> from datasets import load_dataset
  1499. >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
  1500. >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
  1501. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1502. >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
  1503. >>> # Whisper has a long list of suppressed tokens. For instance, in this case, the token 1 is suppressed by default.
  1504. >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
  1505. >>> print(outputs.scores[1][0, 1]) # 1 (and not 0) is the first freely generated token
  1506. tensor(-inf)
  1507. >>> # If we disable `suppress_tokens`, we can generate it.
  1508. >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, suppress_tokens=None)
  1509. >>> print(outputs.scores[1][0, 1])
  1510. tensor(6.0678)
  1511. ```
  1512. """
  1513. def __init__(self, suppress_tokens, device: str = "cpu"):
  1514. self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device)
  1515. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1516. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1517. vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
  1518. suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens.to(scores.device))
  1519. scores = torch.where(suppress_token_mask, -float("inf"), scores)
  1520. return scores
  1521. class WhisperTimeStampLogitsProcessor(LogitsProcessor):
  1522. r"""
  1523. [`LogitsProcessor`] that modifies the logits for the generation of timestamps in the transcription. When the input
  1524. tokens are at a specific threshold, the processor sets the scores to negative infinity. The processor makes sure
  1525. that timestamp tokens appear in pairs, by masking out the logits that would break this pairing pattern. This is
  1526. done to maintain the consistency and structure of generated timestamps. It also ensures that when the predicted
  1527. probability of sampling any of the timestamp token is greater than any individual non-timestamp token, those
  1528. non-timestamp logits are set to negative infinity. This is done to ensure the generation of timestamps over other
  1529. potential tokens.
  1530. See [the paper](https://huggingface.co/papers/2212.04356) for more information.
  1531. Args:
  1532. generate_config (`GenerateConfig`):
  1533. The generate config used to generate the output. The following parameters are required:
  1534. eos_token_id (`int`, *optional*, defaults to 50257):
  1535. The id of the *end-of-sequence* token.
  1536. no_timestamps_token_id (`int`, *optional*, defaults to 50363):
  1537. The id of the `"<|notimestamps|>"` token.
  1538. max_initial_timestamp_index (`int`, *optional*, defaults to 1):
  1539. Used to set the maximum value of the initial timestamp. This is used to prevent the model from
  1540. predicting timestamps that are too far in the future.
  1541. begin_index (`int`):
  1542. Token index of the first token that is generated by the model.
  1543. _detect_timestamp_from_logprob (`bool`, *optional*):
  1544. Whether timestamps can be predicted from logprobs over all timestamps.
  1545. Examples:
  1546. ``` python
  1547. >>> import torch
  1548. >>> from transformers import AutoProcessor, WhisperForConditionalGeneration, GenerationConfig
  1549. >>> from datasets import load_dataset
  1550. >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
  1551. >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
  1552. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1553. >>> inputs = processor(ds[3]["audio"]["array"], return_tensors="pt")
  1554. >>> input_features = inputs.input_features
  1555. >>> #Displaying timestamps
  1556. >>> generated_ids = model.generate(inputs=input_features, return_timestamps=True)
  1557. >>> transcription = processor.batch_decode(generated_ids, decode_with_timestamps=True)[0]
  1558. >>> print("Transcription:", transcription)
  1559. Transcription: <|startoftranscript|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can<|6.44|><|6.44|> discover in it but little of rocky Ithaca.<|9.44|><|endoftext|>
  1560. >>> #No timestamps & change EOS:
  1561. >>> #This allows the user to select a specific token to terminate the sequence on, in this case it's the word "can"(460)
  1562. >>> model.generation_config.eos_token_id = 460
  1563. >>> generated_ids = model.generate(inputs=input_features,return_timestamps=False)
  1564. >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1565. >>> print("Transcription:", transcription)
  1566. Transcription: He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can
  1567. ```
  1568. """
  1569. def __init__(
  1570. self,
  1571. generate_config: "GenerationConfig",
  1572. begin_index: int,
  1573. _detect_timestamp_from_logprob: bool | None = None,
  1574. ): # support for the kwargs
  1575. whisper_generate_config = cast(WhisperGenerationConfigLike, generate_config)
  1576. self.no_timestamps_token_id = whisper_generate_config.no_timestamps_token_id
  1577. self.timestamp_begin = whisper_generate_config.no_timestamps_token_id + 1
  1578. self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id
  1579. # this variable is mostly just used for testing
  1580. self._detect_timestamp_from_logprob = (
  1581. _detect_timestamp_from_logprob
  1582. if _detect_timestamp_from_logprob is not None
  1583. else getattr(generate_config, "_detect_timestamp_from_logprob", True)
  1584. )
  1585. self.begin_index = begin_index
  1586. if begin_index is None:
  1587. raise ValueError(
  1588. "`forced_decoder_ids` is deprecated in favor of `task` and `language` and, as such, `begin_index` "
  1589. "must be provided to `WhisperTimeStampLogitsProcessor`. The previous default value of `begin_index` "
  1590. "was `len(generate_config.forced_decoder_ids)`"
  1591. )
  1592. self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
  1593. # TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
  1594. # self.max_initial_timestamp_index = 50
  1595. def set_begin_index(self, begin_index):
  1596. self.begin_index = begin_index
  1597. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1598. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1599. # suppress <|notimestamps|> which is handled by without_timestamps
  1600. scores_processed = scores.clone()
  1601. scores_processed[:, self.no_timestamps_token_id] = -float("inf")
  1602. # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
  1603. for k in range(input_ids.shape[0]):
  1604. sampled_tokens = input_ids[k, self.begin_index :]
  1605. seq = list(sampled_tokens.tolist())
  1606. last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
  1607. penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
  1608. if last_was_timestamp:
  1609. if penultimate_was_timestamp: # has to be non-timestamp
  1610. scores_processed[k, self.timestamp_begin :] = -float("inf")
  1611. else: # cannot be normal text tokens
  1612. scores_processed[k, : self.eos_token_id] = -float("inf")
  1613. timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
  1614. if timestamps.numel() > 0:
  1615. # `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
  1616. # The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
  1617. if last_was_timestamp and not penultimate_was_timestamp:
  1618. timestamp_last = timestamps[-1]
  1619. else:
  1620. # Avoid to emit <|0.00|> again
  1621. timestamp_last = timestamps[-1] + 1
  1622. scores_processed[k, self.timestamp_begin : timestamp_last] = -float("inf")
  1623. # apply the `max_initial_timestamp` option
  1624. if input_ids.shape[1] == self.begin_index:
  1625. scores_processed[:, : self.timestamp_begin] = -float("inf")
  1626. if self.max_initial_timestamp_index is not None:
  1627. last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
  1628. scores_processed[:, last_allowed + 1 :] = -float("inf")
  1629. # if sum of probability over timestamps is above any other token, sample timestamp
  1630. logprobs = torch.nn.functional.log_softmax(scores_processed.float(), dim=-1)
  1631. for k in range(input_ids.shape[0]):
  1632. timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
  1633. max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
  1634. if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
  1635. scores_processed[k, : self.timestamp_begin] = -float("inf")
  1636. return scores_processed
  1637. class WhisperNoSpeechDetection(LogitsProcessor):
  1638. """
  1639. This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits
  1640. to follow the original implementation
  1641. """
  1642. def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False):
  1643. self.no_speech_token = no_speech_token
  1644. # offset between <start-of-transcription> token, <SOT>, in paper and first generated token
  1645. # is equal to the position of the first generated token index
  1646. self.start_of_trans_offset = begin_index
  1647. # `self.begin_index` is a running value that is changed on the fly
  1648. self.begin_index = begin_index
  1649. self._no_speech_prob = [0.0]
  1650. self.is_scores_logprobs = scores_is_logprobs
  1651. # overwritten dynamically via set_model()
  1652. self.model: Any = None
  1653. self.inputs: dict[str, Any] | None = None
  1654. def set_model(self, model):
  1655. self.model = model
  1656. def set_inputs(self, inputs):
  1657. # prepare other inputs
  1658. self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs}
  1659. self.inputs["input_features"] = self.inputs.pop("inputs")
  1660. # Whisper encoder-decoder does not accept the input_ids as input
  1661. if "input_ids" not in inspect.signature(self.model.forward).parameters:
  1662. self.inputs.pop("input_ids", None)
  1663. @property
  1664. def no_speech_prob(self):
  1665. return self._no_speech_prob
  1666. def set_begin_index(self, begin_index):
  1667. self.begin_index = begin_index
  1668. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1669. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1670. is_scores_logprobs = self.is_scores_logprobs
  1671. if input_ids.shape[1] == self.begin_index:
  1672. if self.start_of_trans_offset > 1:
  1673. with torch.no_grad():
  1674. logits = self.model(**self.inputs).logits
  1675. no_speech_index = self.begin_index - self.start_of_trans_offset
  1676. no_speech_scores = logits[:, no_speech_index]
  1677. is_scores_logprobs = False
  1678. else:
  1679. no_speech_scores = scores
  1680. if is_scores_logprobs:
  1681. probs = no_speech_scores.exp()
  1682. else:
  1683. probs = no_speech_scores.float().softmax(dim=-1)
  1684. self._no_speech_prob = probs[:, self.no_speech_token]
  1685. return scores
  1686. class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
  1687. r"""
  1688. [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
  1689. where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
  1690. correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
  1691. weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
  1692. See [the paper](https://huggingface.co/papers/2306.05284) for more information.
  1693. <Tip warning={true}>
  1694. This logits processor is exclusively compatible with
  1695. [MusicGen](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen)
  1696. </Tip>
  1697. Args:
  1698. guidance_scale (float):
  1699. The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
  1700. Higher guidance scale encourages the model to generate samples that are more closely linked to the input
  1701. prompt, usually at the expense of poorer quality.
  1702. Examples:
  1703. ```python
  1704. >>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
  1705. >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
  1706. >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
  1707. >>> inputs = processor(
  1708. ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
  1709. ... padding=True,
  1710. ... return_tensors="pt",
  1711. ... )
  1712. >>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)
  1713. ```
  1714. """
  1715. def __init__(self, guidance_scale):
  1716. if guidance_scale > 1:
  1717. self.guidance_scale = guidance_scale
  1718. else:
  1719. raise ValueError(
  1720. "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale "
  1721. f"{guidance_scale}."
  1722. )
  1723. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1724. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1725. # simple check to make sure we have compatible batch sizes between our
  1726. # logits scores (cond + uncond) and input ids (cond only)
  1727. if scores.shape[0] != 2 * input_ids.shape[0]:
  1728. raise ValueError(
  1729. f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to "
  1730. f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got "
  1731. f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids."
  1732. )
  1733. unguided_bsz = scores.shape[0] // 2
  1734. cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0)
  1735. scores_processed = uncond_logits + (cond_logits - uncond_logits) * self.guidance_scale
  1736. return scores_processed
  1737. class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
  1738. r"""
  1739. [`LogitsProcessor`] enforcing alternated generation between the two codebooks of Bark.
  1740. <Tip warning={true}>
  1741. This logits processor is exclusively compatible with
  1742. [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark)'s fine submodel. See the model documentation
  1743. for examples.
  1744. </Tip>
  1745. Args:
  1746. input_start_len (`int`):
  1747. The length of the initial input sequence.
  1748. semantic_vocab_size (`int`):
  1749. Vocabulary size of the semantic part, i.e number of tokens associated to the semantic vocabulary.
  1750. codebook_size (`int`):
  1751. Number of tokens associated to the codebook.
  1752. """
  1753. def __init__(self, input_start_len: int, semantic_vocab_size: int, codebook_size: int):
  1754. if not isinstance(input_start_len, int) or input_start_len < 0:
  1755. raise ValueError(f"`input_starting_length` has to be a non-negative integer, but is {input_start_len}")
  1756. self.input_start_len = input_start_len
  1757. self.semantic_vocab_size = semantic_vocab_size
  1758. self.codebook_size = codebook_size
  1759. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1760. curr_len = input_ids.shape[-1]
  1761. # even -> first codebook, odd -> second codebook
  1762. is_first_codebook = ((curr_len - self.input_start_len) % 2) == 0
  1763. scores_processed = scores.clone()
  1764. if is_first_codebook:
  1765. scores_processed[:, : self.semantic_vocab_size] = -float("inf")
  1766. scores_processed[:, self.semantic_vocab_size + self.codebook_size :] = -float("inf")
  1767. else:
  1768. scores_processed[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf")
  1769. return scores_processed
  1770. class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
  1771. r"""
  1772. Logits processor for Classifier-Free Guidance (CFG). The processors computes a weighted average across scores
  1773. from prompt conditional and prompt unconditional (or negative) logits, parameterized by the `guidance_scale`.
  1774. The unconditional scores are computed internally by prompting `model` with the `unconditional_ids` branch.
  1775. See [the paper](https://huggingface.co/papers/2306.17806) for more information.
  1776. Args:
  1777. guidance_scale (`float`):
  1778. The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale != 1`.
  1779. Higher guidance scale encourages the model to generate samples that are more closely linked to the input
  1780. prompt, usually at the expense of poorer quality. A value smaller than 1 has the opposite effect, while
  1781. making the negative prompt provided with negative_prompt_ids (if any) act as a positive prompt.
  1782. model (`PreTrainedModel`):
  1783. The model computing the unconditional scores. Supposedly the same as the one computing the conditional
  1784. scores. Both models must use the same tokenizer.
  1785. unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1786. Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
  1787. the last token of the prompt.
  1788. unconditional_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1789. Attention mask for unconditional_ids.
  1790. use_cache (`bool`, *optional*, defaults to `True`):
  1791. Whether to cache key/values during the negative prompt forward pass.
  1792. Examples:
  1793. ```python
  1794. >>> from transformers import AutoTokenizer, AutoModelForCausalLM
  1795. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  1796. >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
  1797. >>> inputs = tokenizer(["Today, a dragon flew over Paris, France,"], return_tensors="pt")
  1798. >>> out = model.generate(inputs["input_ids"], guidance_scale=1.5)
  1799. >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
  1800. 'Today, a dragon flew over Paris, France, killing at least 50 people and injuring more than 100'
  1801. >>> # with a negative prompt
  1802. >>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
  1803. >>> out = model.generate(inputs["input_ids"], guidance_scale=2, negative_prompt_ids=neg_inputs["input_ids"])
  1804. >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
  1805. 'Today, a dragon flew over Paris, France, killing at least 130 people. French media reported that'
  1806. >>> # with a positive prompt
  1807. >>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
  1808. >>> out = model.generate(inputs["input_ids"], guidance_scale=0, negative_prompt_ids=neg_inputs["input_ids"])
  1809. >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
  1810. "Today, a dragon flew over Paris, France, and I'm very happy to be here. I"
  1811. ```
  1812. """
  1813. def __init__(
  1814. self,
  1815. guidance_scale: float,
  1816. model,
  1817. unconditional_ids: torch.LongTensor | None = None,
  1818. unconditional_attention_mask: torch.LongTensor | None = None,
  1819. use_cache: bool = True,
  1820. ):
  1821. self.guidance_scale = guidance_scale
  1822. self.model = model
  1823. self.unconditional_context = {
  1824. "input_ids": unconditional_ids,
  1825. "attention_mask": unconditional_attention_mask,
  1826. "use_cache": use_cache,
  1827. "past_key_values": None,
  1828. "first_pass": True,
  1829. }
  1830. def get_unconditional_logits(self, input_ids):
  1831. if self.unconditional_context["first_pass"]:
  1832. if self.unconditional_context["input_ids"] is None:
  1833. self.unconditional_context["input_ids"] = input_ids[:, -1:]
  1834. if self.unconditional_context["attention_mask"] is None:
  1835. self.unconditional_context["attention_mask"] = torch.ones_like(
  1836. self.unconditional_context["input_ids"], dtype=torch.long
  1837. )
  1838. input_ids = self.unconditional_context["input_ids"]
  1839. attention_mask = self.unconditional_context["attention_mask"]
  1840. self.unconditional_context["first_pass"] = False
  1841. else:
  1842. attention_mask = torch.cat(
  1843. [
  1844. self.unconditional_context["attention_mask"],
  1845. torch.ones_like(input_ids[:, -1:], dtype=torch.long),
  1846. ],
  1847. dim=1,
  1848. )
  1849. if not self.unconditional_context["use_cache"]:
  1850. input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1)
  1851. else:
  1852. input_ids = input_ids[:, -1:]
  1853. self.unconditional_context["input_ids"] = input_ids
  1854. self.unconditional_context["attention_mask"] = attention_mask
  1855. out = self.model(
  1856. input_ids,
  1857. attention_mask=attention_mask,
  1858. use_cache=self.unconditional_context["use_cache"],
  1859. past_key_values=self.unconditional_context["past_key_values"],
  1860. )
  1861. self.unconditional_context["past_key_values"] = out.get("past_key_values", None)
  1862. return out.logits
  1863. def __call__(self, input_ids, scores):
  1864. scores = torch.nn.functional.log_softmax(scores, dim=-1)
  1865. if self.guidance_scale == 1:
  1866. return scores
  1867. logits = self.get_unconditional_logits(input_ids)
  1868. unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
  1869. scores_processed = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
  1870. return scores_processed
  1871. class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
  1872. r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`.
  1873. <Tip warning={true}>
  1874. This logits processor is exclusively compatible with
  1875. [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark). See the model documentation for examples.
  1876. </Tip>
  1877. Args:
  1878. eos_token_id (`Union[int, list[int], torch.Tensor]`):
  1879. The id(s) of the *end-of-sequence* token.
  1880. min_eos_p (`float`, *optional*):
  1881. Minimum end of speech threshold.
  1882. """
  1883. def __init__(self, eos_token_id: int | list[int] | torch.Tensor, min_eos_p: float, device: str = "cpu"):
  1884. if not isinstance(eos_token_id, torch.Tensor):
  1885. if isinstance(eos_token_id, int):
  1886. eos_token_id = [eos_token_id]
  1887. eos_token_id = torch.tensor(eos_token_id, device=device)
  1888. self.eos_token_id = eos_token_id
  1889. if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
  1890. raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
  1891. if min_eos_p is not None and min_eos_p <= 0:
  1892. raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
  1893. self.min_eos_p = min_eos_p
  1894. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  1895. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  1896. scores_processed = scores
  1897. if self.min_eos_p:
  1898. probs = torch.nn.functional.softmax(scores.float(), dim=-1)
  1899. # create scores full of -inf except for the eos_token_id
  1900. early_stop_scores = torch.ones_like(scores) * -float("inf")
  1901. early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
  1902. do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
  1903. do_early_stop = torch.any(do_early_stop, dim=1, keepdim=True)
  1904. scores_processed = torch.where(do_early_stop, early_stop_scores, scores)
  1905. return scores_processed
  1906. class WatermarkLogitsProcessor(LogitsProcessor):
  1907. r"""
  1908. Logits processor for watermarking generated text. The processor modifies model output scores by adding a small bias to
  1909. randomized set of "green" tokens before generating the next token. "Green" tokens selection process depends on the
  1910. `seeding_scheme` used. The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main).
  1911. The text generated by this `LogitsProcessor` can be detected using `WatermarkDetector`. See [`~WatermarkDetector.__call__`] for details,
  1912. See [the paper](https://huggingface.co/papers/2306.04634) for more information.
  1913. Args:
  1914. vocab_size (`int`):
  1915. The model tokenizer's vocab_size. Used to calculate "green" tokens ratio.
  1916. device (`str`):
  1917. The device where model is allocated.
  1918. greenlist_ratio (`float`, optional, *optional*, defaults to 0.25):
  1919. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
  1920. bias (`float`, optional, *optional*, defaults to 2.0):
  1921. The bias added to the selected "green" tokens' logits. Consider lowering the
  1922. `bias` if the text generation quality degrades. Recommended values are in the
  1923. range of [0.5, 2.0]. Defaults to 2.0.
  1924. hashing_key (`int`, optional, *optional*, defaults to 15485863):
  1925. Key used for hashing. If you deploy this watermark, we advise using another private key.
  1926. Defaults to 15485863 (the millionth prime).
  1927. seeding_scheme (`str`, optional, *optional*, defaults to `"lefthash"`):
  1928. The seeding scheme used for selecting "green" tokens. Accepts values:
  1929. - "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from paper)
  1930. - "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from paper)
  1931. The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
  1932. The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
  1933. context_width (`int`, *optional*, defaults to 1):
  1934. The number of previous tokens to use when setting the seed.
  1935. Examples:
  1936. ```python
  1937. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkingConfig
  1938. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  1939. >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
  1940. >>> inputs = tokenizer(["Alice and Bob are"], return_tensors="pt")
  1941. >>> # normal generation
  1942. >>> out = model.generate(inputs["input_ids"], max_length=20, do_sample=False)
  1943. >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
  1944. 'Alice and Bob are both in the same room.\n\n"I\'m not sure if you\'re'
  1945. >>> # watermarked generation
  1946. >>> watermarking_config = WatermarkingConfig(bias=2.5, context_width=2, seeding_scheme="selfhash")
  1947. >>> out = model.generate(inputs["input_ids"], watermarking_config=watermarking_config, max_length=20, do_sample=False)
  1948. >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
  1949. 'Alice and Bob are both still alive and well and the story is pretty much a one-hour adventure'
  1950. >>> # to detect watermarked text use the WatermarkDetector class
  1951. >>> from transformers import WatermarkDetector
  1952. >>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config= watermarking_config)
  1953. >>> detection_preds = detector(out)
  1954. >>> detection_preds
  1955. array([ True])
  1956. ```
  1957. """
  1958. def __init__(
  1959. self,
  1960. vocab_size,
  1961. device,
  1962. greenlist_ratio: float = 0.25,
  1963. bias: float = 2.0,
  1964. hashing_key: int = 15485863,
  1965. seeding_scheme: str = "lefthash",
  1966. context_width: int = 1,
  1967. ):
  1968. if seeding_scheme not in ["selfhash", "lefthash"]:
  1969. raise ValueError(f"seeding_scheme has to be one of [`selfhash`, `lefthash`], but found {seeding_scheme}")
  1970. if greenlist_ratio >= 1.0 or greenlist_ratio <= 0.0:
  1971. raise ValueError(
  1972. f"greenlist_ratio has be in range between 0.0 and 1.0, exclusively. but found {greenlist_ratio}"
  1973. )
  1974. self.vocab_size = vocab_size
  1975. self.greenlist_size = int(self.vocab_size * greenlist_ratio)
  1976. self.bias = bias
  1977. self.seeding_scheme = seeding_scheme
  1978. self.rng = torch.Generator(device=device)
  1979. self.hash_key = hashing_key
  1980. self.context_width = context_width
  1981. self.rng.manual_seed(hashing_key)
  1982. self.table_size = 1_000_003
  1983. self.fixed_table = torch.randperm(self.table_size, generator=self.rng, device=device)
  1984. def set_seed(self, input_seq: torch.LongTensor):
  1985. input_seq = input_seq[-self.context_width :]
  1986. if self.seeding_scheme == "selfhash":
  1987. a = self.fixed_table[input_seq % self.table_size] + 1
  1988. b = self.fixed_table[input_seq[-1] % self.table_size] + 1
  1989. seed = (self.hash_key * a * b).min().item()
  1990. else:
  1991. seed = self.hash_key * input_seq[-1].item()
  1992. self.rng.manual_seed(seed % (2**64 - 1))
  1993. def _get_greenlist_ids(self, input_seq: torch.LongTensor) -> torch.LongTensor:
  1994. self.set_seed(input_seq)
  1995. vocab_permutation = torch.randperm(self.vocab_size, device=input_seq.device, generator=self.rng)
  1996. greenlist_ids = vocab_permutation[: self.greenlist_size]
  1997. return greenlist_ids
  1998. def _score_rejection_sampling(self, input_seq: torch.LongTensor, scores: torch.FloatTensor) -> torch.LongTensor:
  1999. """
  2000. Generate greenlist based on current candidate next token. Reject and move on if necessary.
  2001. Runs for a fixed number of steps only for efficiency, since the methods is not batched.
  2002. """
  2003. final_greenlist = []
  2004. _, greedy_predictions = scores.sort(dim=-1, descending=True)
  2005. # 40 is an arbitrary number chosen to save compute and not run for long (taken from orig repo)
  2006. for i in range(40):
  2007. greenlist_ids = self._get_greenlist_ids(torch.cat([input_seq, greedy_predictions[i, None]], dim=-1))
  2008. if greedy_predictions[i] in greenlist_ids:
  2009. final_greenlist.append(greedy_predictions[i])
  2010. return torch.tensor(final_greenlist, device=input_seq.device)
  2011. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  2012. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  2013. if input_ids.shape[-1] < self.context_width:
  2014. logger.warning(
  2015. f"`input_ids` should have at least `{self.context_width}` tokens but has {input_ids.shape[-1]}. "
  2016. "The seeding will be skipped for this generation step!"
  2017. )
  2018. return scores
  2019. scores_processed = scores.clone()
  2020. for b_idx, input_seq in enumerate(input_ids):
  2021. if self.seeding_scheme == "selfhash":
  2022. greenlist_ids = self._score_rejection_sampling(input_seq, scores[b_idx])
  2023. else:
  2024. greenlist_ids = self._get_greenlist_ids(input_seq)
  2025. scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias
  2026. return scores_processed
  2027. class SynthIDTextWatermarkState:
  2028. """SynthID watermarking state."""
  2029. def __init__(
  2030. self,
  2031. batch_size: int,
  2032. ngram_len: int,
  2033. context_history_size: int,
  2034. device: torch.device,
  2035. ):
  2036. """Initializes the state.
  2037. Args:
  2038. batch_size (`int`): Batch size.
  2039. ngram_len (`int`): Ngram length.
  2040. context_history_size (`int`): Size of the tensor to keep track of seen contexts.
  2041. device (`int`): Device to use.
  2042. """
  2043. self.context = torch.zeros(
  2044. (batch_size, ngram_len - 1),
  2045. dtype=torch.int64,
  2046. device=device,
  2047. )
  2048. self.context_history = torch.zeros(
  2049. (batch_size, context_history_size),
  2050. dtype=torch.int64,
  2051. device=device,
  2052. )
  2053. self.num_calls = 0
  2054. class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor):
  2055. r"""
  2056. Logits processor that implements watermarking techniques for text generation models.
  2057. This class facilitates the application of SynthID text watermarking, a method for embedding imperceptible signals
  2058. into generated text to aid in detecting synthetic content. It operates by subtly manipulating the probabilities of
  2059. token selection during text generation in a manner that can be reliably recovered later for verification.
  2060. Key Features:
  2061. * **State Management:** Maintains internal state to track token sequences and generate watermarking keys
  2062. dynamically.
  2063. * **Key Generation:** Computes hashes based on token sequences and watermarking parameters to create unique keys
  2064. for each position.
  2065. * **G-Value Sampling:** Employs a pre-computed sampling table to sample watermarking values (g-values) based on
  2066. the generated keys.
  2067. * **Score Adjustment:** Applies calculated g-values to modify token probabilities during generation, embedding the
  2068. watermark.
  2069. * **Context Repetition Handling:** Incorporates logic to avoid watermarking tokens in repeated contexts,
  2070. preserving naturalness.
  2071. * **EOS Token Masking:** Supports masking end-of-sentence tokens to prevent their inclusion in watermarking
  2072. calculations.
  2073. * **Utility Functions:** Provides functions to compute g-values directly, check for context repetition, create
  2074. EOS token masks, and estimate expected mean g-values.
  2075. Refer to paper url: https://www.nature.com/articles/s41586-024-08025-4 for more details around this.
  2076. Args:
  2077. ngram_len (`int`):
  2078. Ngram length.
  2079. keys (`list[int]`):
  2080. A sequence of watermarking keys, one for each depth.
  2081. sampling_table_size (`int`):
  2082. Size of the sampling table.
  2083. sampling_table_seed (`int`):
  2084. Random seed to generate the sampling table.
  2085. context_history_size (`int`):
  2086. Size of the tensor to keep track of seen contexts.
  2087. device (`torch.device`):
  2088. Device to use.
  2089. skip_first_ngram_calls (`bool`, *optional*, defaults to `False`):
  2090. Whether to skip first ngram calls.
  2091. debug_mode (`bool`, optional, *optional*, defaults to `False`):
  2092. Logits are modified to uniform one got before watermarking modification is applied. This is to test the
  2093. implementation.
  2094. Examples:
  2095. ```python
  2096. >>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig
  2097. >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b', padding_side="left")
  2098. >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b')
  2099. >>> # SynthID Text configuration
  2100. >>> watermarking_config = SynthIDTextWatermarkingConfig(
  2101. ... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57],
  2102. ... ngram_len=5,
  2103. ... )
  2104. >>> # Generation with watermarking
  2105. >>> tokenized_prompts = tokenizer(["Once upon a time, "], return_tensors="pt", padding=True)
  2106. >>> output_sequences = model.generate(
  2107. ... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, max_new_tokens=10
  2108. ... )
  2109. >>> watermarked_text = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
  2110. ```
  2111. """
  2112. def __init__(
  2113. self,
  2114. ngram_len: int,
  2115. keys: list[int],
  2116. sampling_table_size: int,
  2117. sampling_table_seed: int,
  2118. context_history_size: int,
  2119. device: torch.device,
  2120. skip_first_ngram_calls: bool = False,
  2121. debug_mode: bool = False,
  2122. ):
  2123. self.ngram_len = ngram_len
  2124. self.keys = torch.tensor(keys, device=device)
  2125. generator = torch.Generator(device=device).manual_seed(sampling_table_seed)
  2126. # A random sampling table is pre-computed and modulo table size is applied to map from a hash of ngram keys to
  2127. # g values, this is similar to the hashtable implementation used in
  2128. # https://github.com/facebookresearch/three_bricks. We note that the hashing employed in this repository is
  2129. # different from that used to watermark the Gemini App, and hence the detectors trained based on the
  2130. # hashing in this repository will not transfer to text generated by the Gemini App.
  2131. self.sampling_table = torch.randint(
  2132. low=0,
  2133. high=2,
  2134. size=(sampling_table_size,),
  2135. generator=generator,
  2136. device=device,
  2137. )
  2138. self.context_history_size = context_history_size
  2139. self.device = device
  2140. self.state = None
  2141. self.skip_first_ngram_calls = skip_first_ngram_calls
  2142. self.debug_mode = debug_mode
  2143. def _init_state(self, batch_size: int):
  2144. """Initializes the state."""
  2145. self.state = SynthIDTextWatermarkState(
  2146. batch_size=batch_size,
  2147. ngram_len=self.ngram_len,
  2148. context_history_size=self.context_history_size,
  2149. device=self.device,
  2150. )
  2151. def update_scores(self, scores: torch.FloatTensor, g_values: torch.FloatTensor) -> torch.FloatTensor:
  2152. """Updates scores using the g values.
  2153. We assume that the scores are in the log space.
  2154. Args:
  2155. scores (`torch.FloatTensor`): Scores (batch_size, vocab_size).
  2156. g_values (`torch.FloatTensor`): G values (batch_size, vocab_size, depth).
  2157. Returns:
  2158. Updated scores (batch_size, vocab_size).
  2159. """
  2160. _, _, depth = g_values.shape
  2161. probs = torch.softmax(scores, dim=1)
  2162. for i in range(depth):
  2163. g_values_at_depth = g_values[:, :, i]
  2164. g_mass_at_depth = (g_values_at_depth * probs).sum(axis=1, keepdims=True)
  2165. probs = probs * (1 + g_values_at_depth - g_mass_at_depth)
  2166. log_probs = torch.log(probs)
  2167. log_probs = torch.where(torch.isfinite(log_probs), log_probs, torch.finfo(log_probs.dtype).min)
  2168. return log_probs
  2169. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  2170. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  2171. self._check_input_ids_shape(input_ids)
  2172. batch_size, vocab_size = scores.shape
  2173. if self.debug_mode:
  2174. scores = torch.ones_like(scores)
  2175. # Build continuation indices once and broadcast across batch instead of creating one arange per row.
  2176. all_indices = torch.arange(vocab_size, device=self.device).unsqueeze(0).expand(batch_size, -1)
  2177. if self.state is None:
  2178. # Initialize watermarking state if it does not exist.
  2179. self._init_state(batch_size)
  2180. else:
  2181. # Append last input id (which is the input id added in last call) to the
  2182. # previous context so we have the context to be used for current
  2183. # watermarking.
  2184. self.state.context = torch.concat(
  2185. (self.state.context, input_ids[:, -1:]),
  2186. dim=1,
  2187. )
  2188. self.state.context = self.state.context[:, 1:]
  2189. if self.state is None:
  2190. raise ValueError("self.state can't be None! Call `self._init_state` to initialize the state.")
  2191. self.state.num_calls += 1
  2192. # Don't watermark the first ngram_len - 1 tokens if set.
  2193. if self.skip_first_ngram_calls and self.state.num_calls < self.ngram_len:
  2194. return scores
  2195. # 2. Generate random keys for each ngram key combination.
  2196. ngram_keys, hash_result_with_just_context = self._compute_keys(self.state.context, all_indices)
  2197. # ngram_keys shape [batch_size, top_k, depth]
  2198. # 3. Sample g values.
  2199. g_values = self.sample_g_values(ngram_keys)
  2200. # g_values shape [batch_size, top_k, depth]
  2201. # 4. Modify scores.
  2202. updated_scores = self.update_scores(scores, g_values)
  2203. # updated scores shape [batch_size, top_k]
  2204. # 5. Check if the current watermarking context was previously used, if yes skip watermarking.
  2205. hash_result_with_just_context = hash_result_with_just_context[:, None]
  2206. is_repeated_context = (self.state.context_history == hash_result_with_just_context).any(
  2207. dim=1,
  2208. keepdim=True,
  2209. )
  2210. self.state.context_history = torch.concat(
  2211. (hash_result_with_just_context, self.state.context_history),
  2212. dim=1,
  2213. )[:, :-1]
  2214. updated_watermarked_scores = torch.where(
  2215. is_repeated_context,
  2216. input=scores,
  2217. other=updated_scores,
  2218. )
  2219. return updated_watermarked_scores
  2220. def accumulate_hash(
  2221. self,
  2222. current_hash: torch.LongTensor,
  2223. data: torch.LongTensor,
  2224. multiplier: int = 6364136223846793005,
  2225. increment: int = 1,
  2226. ) -> torch.LongTensor:
  2227. """
  2228. Accumulate hash of data on current hash.
  2229. Method uses adapted linear congruential generator with newlib/musl parameters.
  2230. This function has following property -
  2231. f(x, data[T]) = f(f(x, data[:T - 1]), data[T])
  2232. This function expects current_hash.shape and data.shape[:-1] to
  2233. match/broadcastable.
  2234. Args:
  2235. current_hash (`torch.LongTensor`):
  2236. (shape,)
  2237. data (`torch.LongTensor`):
  2238. (shape, tensor_len)
  2239. multiplier (`int`, optional, *optional*, defaults to 6364136223846793005):
  2240. multiplier of linear congruential generator
  2241. increment (`int`, optional, *optional*, defaults to 1):
  2242. increment of linear congruential generator
  2243. Returns:
  2244. updated hash (shape,)
  2245. """
  2246. for i in range(data.shape[-1]):
  2247. current_hash = torch.add(current_hash, data[..., i])
  2248. current_hash = torch.mul(current_hash, multiplier)
  2249. current_hash = torch.add(current_hash, increment)
  2250. return current_hash
  2251. def compute_ngram_keys(self, ngrams: torch.LongTensor) -> torch.LongTensor:
  2252. """Computes random keys for each ngram and depth.
  2253. Args:
  2254. ngrams (`torch.LongTensor`):
  2255. Ngrams (batch_size, num_ngrams, ngram_len).
  2256. Returns:
  2257. ngram keys (batch_size, num_ngrams, depth).
  2258. """
  2259. if len(ngrams.shape) != 3:
  2260. raise ValueError(f"Ngrams should be of shape (batch_size, num_ngrams, ngram_len), but is {ngrams.shape}")
  2261. if ngrams.shape[2] != self.ngram_len:
  2262. raise ValueError(
  2263. "Ngrams should be of shape (batch_size, num_ngrams, ngram_len),"
  2264. f" where ngram_len is {self.ngram_len}, but is {ngrams.shape}"
  2265. )
  2266. batch_size, _, _ = ngrams.shape
  2267. hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
  2268. # hash_result shape [batch_size,]
  2269. # ngrams shape [batch_size, num_ngrams, ngram_len]
  2270. hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 1), out_dims=1)(hash_result, ngrams)
  2271. # hash_result shape [batch_size, num_ngrams]
  2272. keys = self.keys[None, None, :, None]
  2273. # hash_result shape [batch_size, num_ngrams]
  2274. # keys shape [1, 1, depth, 1]
  2275. hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 2), out_dims=2)(hash_result, keys)
  2276. # hash_result shape [batch_size, num_ngrams, depth]
  2277. return hash_result
  2278. def _compute_keys(
  2279. self, n_minus_1_grams: torch.LongTensor, indices: torch.LongTensor
  2280. ) -> tuple[torch.LongTensor, torch.LongTensor]:
  2281. """Computes random keys for each ngram and depth.
  2282. Args:
  2283. n_minus_1_grams (`torch.LongTensor`):
  2284. Ngrams (batch_size, ngram_len - 1).
  2285. indices (`torch.LongTensor`):
  2286. indices of the continuations (batch_size, num_indices)
  2287. Returns:
  2288. Ngram keys (batch_size, num_indices, depth).
  2289. """
  2290. batch_size, _ = n_minus_1_grams.shape
  2291. hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
  2292. # First hash n_minus_1 gram, for each batch entry we have a single
  2293. # n_minus_1 gram context.
  2294. # hash_result shape [batch_size]
  2295. # n_minus_1_gram shape [batch_size, ngram_len - 1]
  2296. hash_result_with_just_context = self.accumulate_hash(hash_result, n_minus_1_grams)
  2297. # hash_result shape [batch_size,]
  2298. # Indices is of shape [batch_size, num_indices], so we make it
  2299. # [batch_size, num_indices, 1] so we can vmap over num_indices dim.
  2300. hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 1), out_dims=1)(
  2301. hash_result_with_just_context, indices[:, :, None]
  2302. )
  2303. # hash_result shape [batch_size, num_indices]
  2304. # Basically we have a hash for each batch entry and each indices
  2305. # Now we add watermarking keys to this hash.
  2306. # keys are of shape [depth,]
  2307. # We add batch, num_indices and data dimension to this making it
  2308. # [1, 1, depth, 1].
  2309. # So we can vmap over the depth dimension for compute_hash
  2310. keys = self.keys[None, None, :, None]
  2311. hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 2), out_dims=2)(hash_result, keys)
  2312. # hash_result shape should be [batch_size, num_indices, depth]
  2313. return hash_result, hash_result_with_just_context
  2314. def sample_g_values(self, ngram_keys: torch.LongTensor) -> torch.LongTensor:
  2315. """
  2316. Samples g values from Bernoulli distribution.
  2317. It is not possible to pass random keys in a vectorized way in torch. Instead
  2318. we pre-compute a random sampling table, and use apply modulo table size to
  2319. map from ngram keys (int64) to g values.
  2320. Args:
  2321. ngram_keys (`torch.LongTensor`):
  2322. Random keys (batch_size, num_ngrams, depth).
  2323. Returns:
  2324. G values (batch_size, num_ngrams, depth).
  2325. """
  2326. (sampling_table_size,) = self.sampling_table.shape
  2327. sampling_table = self.sampling_table.reshape((1, 1, sampling_table_size))
  2328. ngram_keys = ngram_keys % sampling_table_size
  2329. return torch.take_along_dim(sampling_table, indices=ngram_keys, dim=2)
  2330. def _check_input_ids_shape(self, input_ids: torch.LongTensor):
  2331. """Checks the shape of input ids."""
  2332. if len(input_ids.shape) != 2:
  2333. raise ValueError(f"Input ids should be of shape (batch_size, input_len), but is {input_ids.shape}")
  2334. def compute_g_values(self, input_ids: torch.LongTensor) -> torch.LongTensor:
  2335. """
  2336. Computes g values for each ngram from the given sequence of tokens.
  2337. Args:
  2338. input_ids (`torch.LongTensor`):
  2339. Input token ids (batch_size, input_len).
  2340. Returns:
  2341. G values (batch_size, input_len - (ngram_len - 1), depth).
  2342. """
  2343. self._check_input_ids_shape(input_ids)
  2344. ngrams = input_ids.unfold(dimension=1, size=self.ngram_len, step=1)
  2345. ngram_keys = self.compute_ngram_keys(ngrams)
  2346. return self.sample_g_values(ngram_keys)
  2347. def compute_context_repetition_mask(self, input_ids: torch.LongTensor) -> torch.LongTensor:
  2348. """
  2349. Computes repetition mask.
  2350. 0 and 1 stand for repeated and not repeated context n-1 grams respectively.
  2351. Args:
  2352. input_ids (`torch.LongTensor`):
  2353. Input token ids (batch_size, input_len).
  2354. Returns:
  2355. Repetitions mask (batch_size, input_len - (ngram_len - 1)).
  2356. """
  2357. self._check_input_ids_shape(input_ids)
  2358. batch_size, _ = input_ids.shape
  2359. state = SynthIDTextWatermarkState(
  2360. batch_size=batch_size,
  2361. ngram_len=self.ngram_len,
  2362. context_history_size=self.context_history_size,
  2363. device=self.device,
  2364. )
  2365. contexts = input_ids[:, :-1].unfold(
  2366. dimension=1,
  2367. size=self.ngram_len - 1,
  2368. step=1,
  2369. )
  2370. _, num_contexts, _ = contexts.shape
  2371. are_repeated_contexts = []
  2372. for i in range(num_contexts):
  2373. context = contexts[:, i, :]
  2374. hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
  2375. context_hash = self.accumulate_hash(hash_result, context)[:, None]
  2376. is_repeated_context = (state.context_history == context_hash).any(
  2377. dim=1,
  2378. keepdim=True,
  2379. )
  2380. are_repeated_contexts.append(is_repeated_context)
  2381. state.context_history = torch.concat(
  2382. (context_hash, state.context_history),
  2383. dim=1,
  2384. )[:, :-1]
  2385. are_repeated_contexts = torch.concat(are_repeated_contexts, dim=1)
  2386. return torch.logical_not(are_repeated_contexts)
  2387. def compute_eos_token_mask(self, input_ids: torch.LongTensor, eos_token_id: int) -> torch.LongTensor:
  2388. """
  2389. Computes repetitions mask.
  2390. 1 stands for ngrams that don't contain EOS tokens and vice versa.
  2391. Args:
  2392. input_ids (`torch.LongTensor`):
  2393. Input token ids (batch_size, input_len).
  2394. eos_token_id (`int`):
  2395. EOS token ID.
  2396. Returns:
  2397. EOS token mask (batch_size, input_len).
  2398. """
  2399. self._check_input_ids_shape(input_ids)
  2400. noneos_masks = []
  2401. all_eos_equated = input_ids == eos_token_id
  2402. for eos_equated in all_eos_equated:
  2403. nonzero_idx = torch.nonzero(eos_equated)
  2404. noneos_mask = torch.ones_like(eos_equated)
  2405. if nonzero_idx.shape[0] != 0:
  2406. noneos_mask[nonzero_idx[0][0] :] = 0
  2407. noneos_masks.append(noneos_mask)
  2408. return torch.stack(noneos_masks, dim=0)
  2409. def expected_mean_g_value(self, vocab_size: int, coinflip_prob: float = 0.5) -> float:
  2410. """
  2411. Compute expected mean g-value after watermarking, assuming uniform LM dist.
  2412. This is the theoretical expected value for single-layer watermarking.
  2413. Args:
  2414. vocab_size (`int`):
  2415. The size of the vocabulary.
  2416. coinflip_prob arg_name (`float`, *optional*, defaults to 0.5):
  2417. Probability of 1 in boolean prf.
  2418. Returns:
  2419. The expected mean g-value for watermarked text.
  2420. """
  2421. return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size))
  2422. class DiaClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
  2423. r"""
  2424. [`LogitsProcessor`] for classifier free guidance (CFG). Similar to the original
  2425. `ClassifierFreeGuidanceLogitsProcessor` with some modifications on the overall
  2426. calculation, e.g. conditioned logits centered, and an additional top k selection
  2427. option.
  2428. <Tip warning={true}>
  2429. This logits processor is exclusively compatible with
  2430. [Dia](https://huggingface.co/docs/transformers/main/en/model_doc/dia)
  2431. </Tip>
  2432. Args:
  2433. guidance_scale (float):
  2434. The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
  2435. Higher guidance scale encourages the model to generate samples that are more closely linked to the input
  2436. prompt, usually at the expense of poorer quality.
  2437. guidance_top_k (int, *optional*):
  2438. The number of highest probability vocabulary tokens to keep for top-k-filtering. However, we do not keep
  2439. the logits of the combined CFG output, but the conditioned output only.
  2440. """
  2441. def __init__(self, guidance_scale: float, guidance_top_k: int | None = None):
  2442. if guidance_scale > 1:
  2443. self.guidance_scale = guidance_scale
  2444. else:
  2445. raise ValueError(
  2446. "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale "
  2447. f"{guidance_scale}."
  2448. )
  2449. self.guidance_top_k = guidance_top_k
  2450. if self.guidance_top_k is not None and self.guidance_top_k < 1:
  2451. raise ValueError(
  2452. f"`guidance_top_k` has to be a strictly positive integer if given, but is {self.guidance_top_k}"
  2453. )
  2454. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  2455. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  2456. # simple check to make sure we have compatible batch sizes between our
  2457. # logits scores (cond + uncond) and input ids (cond only)
  2458. if scores.shape[0] != 2 * input_ids.shape[0]:
  2459. raise ValueError(
  2460. f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to "
  2461. f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got "
  2462. f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids."
  2463. )
  2464. # Base CFG with center on cond_logits
  2465. unguided_bsz = scores.shape[0] // 2
  2466. cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0)
  2467. scores_processed = cond_logits + (cond_logits - uncond_logits) * self.guidance_scale
  2468. # Optional CFG top k filtering
  2469. if self.guidance_top_k is not None:
  2470. # Create top k based on the combined CFG output
  2471. _, top_k_indices = torch.topk(scores_processed, k=self.guidance_top_k, dim=-1)
  2472. top_k_mask = torch.ones_like(scores_processed, dtype=torch.bool)
  2473. top_k_mask = top_k_mask.scatter(dim=-1, index=top_k_indices, value=False)
  2474. # Only return conditioned logits with top k
  2475. scores_processed = cond_logits.masked_fill(top_k_mask, -float("inf"))
  2476. return scores_processed
  2477. class DiaEOSChannelFilterLogitsProcessor(LogitsProcessor):
  2478. r"""Specialized processor that ensures certain properties around EOS sampling:
  2479. 1. Only channel 0 can generate EOS
  2480. 2. If channel 0 has EOS with highest logit, it will be the only candidate
  2481. 3. If channel 0 has EOS not with highest logit, it will be suppressed
  2482. 2. and 3. are especially important in contexts where we allow sampling to guarantee the
  2483. respective tokens to be (not) sampled.
  2484. <Tip warning={true}>
  2485. This logits processor is exclusively compatible with
  2486. [Dia](https://huggingface.co/docs/transformers/en/model_doc/dia).
  2487. </Tip>
  2488. Args:
  2489. num_channels (`int`):
  2490. Number of audio codebooks. Simplifies access to the first channel on the logits.
  2491. eos_token_id (`int`):
  2492. The id of *end-of-sequence* token.
  2493. """
  2494. def __init__(self, num_channels: int, eos_token_id: int):
  2495. if num_channels < 1:
  2496. raise ValueError(f"Audio codebooks need at least one channel, but found {num_channels} channels.")
  2497. if eos_token_id < 1:
  2498. raise ValueError(f"Expected `eos_token_id` to be a positive integer, found {eos_token_id} instead.")
  2499. self.num_channels = num_channels
  2500. self.eos_id = eos_token_id
  2501. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  2502. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  2503. # Reshape for easier channel indexing [B, C, V]
  2504. scores = scores.reshape(-1, self.num_channels, scores.shape[-1])
  2505. # EOS filter
  2506. # 1. Condition: Only the first channel can generate the EOS token
  2507. # Side condition of disabling generation of special tokens (e.g. audio pad, bos, ...)
  2508. # (Assumes them to be greater than audio eos token position)
  2509. scores[:, 1:, self.eos_id :] = torch.full_like(
  2510. scores[:, 1:, self.eos_id :],
  2511. fill_value=-float("inf"),
  2512. )
  2513. scores[:, 0, self.eos_id + 1 :] = torch.full_like(
  2514. scores[:, 0, self.eos_id + 1 :],
  2515. fill_value=-float("inf"),
  2516. )
  2517. # 2+3 Conditions: Force/Suppress EOS if (not) highest logit
  2518. # Reshape back to original shape
  2519. scores = scores.view(-1, scores.shape[-1])
  2520. # Sample highest tokens
  2521. top_logit_indices = torch.argmax(scores, dim=-1)
  2522. # 2. Force EOS
  2523. eos_highest_mask = top_logit_indices == self.eos_id
  2524. mask_eos_highest = torch.zeros_like(scores, dtype=torch.bool)
  2525. mask_eos_highest[eos_highest_mask, : self.eos_id] = True
  2526. scores = scores.masked_fill(mask_eos_highest, -float("inf"))
  2527. # 3. Suppress EOS
  2528. eos_not_highest_mask = top_logit_indices != self.eos_id
  2529. mask_eos_unless_highest = torch.zeros_like(scores, dtype=torch.bool)
  2530. mask_eos_unless_highest[eos_not_highest_mask, self.eos_id] = True
  2531. scores = scores.masked_fill(mask_eos_unless_highest, -float("inf"))
  2532. return scores
  2533. class DiaEOSDelayPatternLogitsProcessor(LogitsProcessor):
  2534. r"""Special logits processor to handle the generation of the EOS token in Dia.
  2535. This is due to the fact that Dia does not allow the generation of EOS in all
  2536. channels except the first channel (C0).
  2537. Hence, based on the delay pattern, an EOS is forced after the respective delays
  2538. in the channels. For example, if the delay pattern is [0, 2, 3, 4]:
  2539. s s+1 s+2 s+3 s+4 s+5 ...
  2540. | | | | | |
  2541. C0: EOS PAD PAD PAD PAD PAD ...
  2542. C1: x x EOS PAD PAD PAD ...
  2543. C2: x x x EOS PAD PAD ...
  2544. C3: x x x x EOS PAD ...
  2545. If the first channel generated EOS at step s, channels Cx are forced to generate
  2546. theirs at the respective delays (s+2, s+3, s+4). Subsequent padding tokens are
  2547. handled by the `EosTokenCriteria` when an EOS has been detected.
  2548. <Tip warning={true}>
  2549. This logits processor is exclusively compatible with
  2550. [Dia](https://huggingface.co/docs/transformers/en/model_doc/dia).
  2551. </Tip>
  2552. Args:
  2553. delay_pattern (`List[int]`):
  2554. The delays per channel in the audio codebooks.
  2555. eos_token_id (`int`):
  2556. The id of *end-of-sequence* token.
  2557. max_generation_len (`int`):
  2558. The max sequence length that can be generated.
  2559. device (`str`, *optional*, defaults to `"cpu"`):
  2560. The device to allocate the tensors on.
  2561. """
  2562. def __init__(self, delay_pattern: list[int], eos_token_id: int, max_generation_len: int, device: str = "cpu"):
  2563. self.num_channels = len(delay_pattern)
  2564. # Update during first iteration
  2565. self.active_batches = None
  2566. self.delay_pattern = torch.tensor(delay_pattern, device=device, dtype=torch.int)[None, :]
  2567. self.eos_token_id = eos_token_id
  2568. self.max_generation_len = max_generation_len - max(delay_pattern) - 1
  2569. self.device = device
  2570. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  2571. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  2572. # Reshape for easier channel indexing [B, C, V]
  2573. scores = scores.reshape(-1, self.num_channels, scores.shape[-1])
  2574. # Initialize / expand values on first iteration
  2575. if self.active_batches is None:
  2576. self.delay_pattern = self.delay_pattern.repeat(scores.shape[0], 1)
  2577. self.active_batches = torch.zeros(size=(scores.shape[0],), device=self.device, dtype=torch.bool)
  2578. # Check if eos has been generated in any batch
  2579. channel_generated_eos = torch.argmax(scores, dim=-1)[:, 0] == self.eos_token_id
  2580. # Check if max len has been reached
  2581. reached_max_len = input_ids.shape[1] == self.max_generation_len
  2582. # Update active batches
  2583. self.active_batches |= channel_generated_eos
  2584. self.active_batches |= reached_max_len
  2585. # Find channels that need to force eos
  2586. forced_eos_channels = self.active_batches[:, None] & (self.delay_pattern == 0)
  2587. # Use indexing to avoid issues on all `False` by having empty tensors in that case
  2588. idx_bsz, idx_channel = forced_eos_channels.nonzero(as_tuple=True)
  2589. # Force eos if delay is kicking in
  2590. scores[idx_bsz, idx_channel, :] = -float("inf")
  2591. scores[idx_bsz, idx_channel, self.eos_token_id] = 0.0
  2592. # Reshape back to [B * C, V]
  2593. scores = scores.reshape(-1, scores.shape[-1])
  2594. # Update amount of delay left for each channel
  2595. self.delay_pattern -= self.active_batches[:, None].int()
  2596. return scores