convert_generation.py 138 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # -------------------------------------------------------------------------
  5. """
  6. This converts GPT2 or T5 model to onnx with beam search operator.
  7. Example 1: convert gpt2 model with beam search:
  8. python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx
  9. Example 2: convert gpt2 model with beam search containing specific cuda optimizations:
  10. python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx --use_gpu \
  11. --past_present_share_buffer --use_decoder_masked_attention
  12. Example 3: convert gpt2 model with beam search with mixed precision and enable SkipLayerNorm strict mode:
  13. python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx --use_gpu -p fp16 --use_sln_strict_mode
  14. Example 4: convert T5 model with beam search in two steps:
  15. python -m models.t5.convert_to_onnx -m t5-small
  16. python convert_generation.py -m t5-small --model_type t5 \
  17. --decoder_onnx ./onnx_models/t5-small_decoder.onnx \
  18. --encoder_decoder_init_onnx ./onnx_models/t5-small_encoder.onnx \
  19. --output ./onnx_models/t5_small_beam_search.onnx
  20. Example 5: convert T5 model with beam search. All in one step:
  21. python convert_generation.py -m t5-small --model_type t5 --output t5_small_beam_search.onnx
  22. Example 6: convert T5 model with beam search containing specific cuda optimizations. All in one step:
  23. python convert_generation.py -m t5-small --model_type t5 --output t5_small_beam_search.onnx \
  24. --use_gpu --past_present_share_buffer --use_decoder_masked_attention
  25. Example 7: convert MT5 model with external data file like mt5-base-beamsearch.onnx.data in below example.
  26. python convert_generation.py -m google/mt5-base --model_type mt5 --output mt5-base-beamsearch.onnx -e
  27. Example 8: convert gpt2 model with greedy search:
  28. python convert_generation.py -m gpt2 --output gpt2_greedy_search.onnx --num_beams 1 --num_return_sequences 1
  29. Example 9: convert gpt2 model with sampling:
  30. python convert_generation.py -m gpt2 --output gpt2_sampling.onnx --num_beams 1 --num_return_sequences 1 --top_p 0.6
  31. """
  32. import argparse
  33. import logging
  34. import math
  35. import os
  36. import time
  37. from enum import Enum
  38. from pathlib import Path
  39. from typing import Any
  40. import numpy as np
  41. import onnx
  42. import torch
  43. from benchmark_helper import Precision, setup_logger
  44. from fusion_utils import NumpyHelper
  45. from onnx import GraphProto, ModelProto, TensorProto
  46. from onnx_model import OnnxModel
  47. from transformers import (
  48. GPT2Config,
  49. GPT2LMHeadModel,
  50. GPT2Tokenizer,
  51. MT5Config,
  52. MT5ForConditionalGeneration,
  53. T5Config,
  54. T5ForConditionalGeneration,
  55. T5Tokenizer,
  56. )
  57. from onnxruntime import (
  58. GraphOptimizationLevel,
  59. InferenceSession,
  60. SessionOptions,
  61. get_available_providers,
  62. )
  63. from onnxruntime.transformers.models.gpt2.convert_to_onnx import (
  64. main as convert_gpt2_to_onnx,
  65. )
  66. from onnxruntime.transformers.models.gpt2.gpt2_helper import PRETRAINED_GPT2_MODELS
  67. from onnxruntime.transformers.models.t5.convert_to_onnx import (
  68. export_onnx_models as export_t5_onnx_models,
  69. )
  70. from onnxruntime.transformers.models.t5.t5_helper import (
  71. PRETRAINED_MT5_MODELS,
  72. PRETRAINED_T5_MODELS,
  73. )
  74. logger = logging.getLogger("")
  75. class GenerationType(Enum):
  76. BEAMSEARCH = "beam_search"
  77. GREEDYSEARCH = "greedy_search"
  78. SAMPLING = "sampling"
  79. def __str__(self):
  80. return self.value
  81. def parse_arguments(argv: list[str] | None = None) -> argparse.Namespace:
  82. """Parse arguments
  83. Args:
  84. argv (Optional[List[str]], optional): _description_. Defaults to None.
  85. Returns:
  86. argparse.Namespace: Parsed arguments.
  87. """
  88. parser = argparse.ArgumentParser()
  89. input_group = parser.add_argument_group("Input options")
  90. input_group.add_argument(
  91. "-m",
  92. "--model_name_or_path",
  93. required=True,
  94. type=str,
  95. help="Pytorch model checkpoint path, or pretrained model name in the list: "
  96. + ", ".join(PRETRAINED_GPT2_MODELS + PRETRAINED_T5_MODELS + PRETRAINED_MT5_MODELS),
  97. )
  98. input_group.add_argument(
  99. "--model_type",
  100. required=False,
  101. type=str,
  102. default="gpt2",
  103. choices=["gpt2", "t5", "mt5"],
  104. help="Model type (default is gpt2) in the list: " + ", ".join(["gpt2", "t5", "mt5"]),
  105. )
  106. input_group.add_argument(
  107. "--cache_dir",
  108. required=False,
  109. type=str,
  110. default=os.path.join(".", "cache_models"),
  111. help="Directory to cache pre-trained models",
  112. )
  113. input_group.add_argument(
  114. "--decoder_onnx",
  115. required=False,
  116. type=str,
  117. default="",
  118. help="Path of onnx model for decoder. Specify it when you have exported the model.",
  119. )
  120. input_group.add_argument(
  121. "--encoder_decoder_init_onnx",
  122. required=False,
  123. type=str,
  124. default="",
  125. help="Path of ONNX model for encoder and decoder initialization. Specify it when you have exported the model.",
  126. )
  127. parser.add_argument(
  128. "--verbose",
  129. required=False,
  130. action="store_true",
  131. help="Print more information",
  132. )
  133. parser.set_defaults(verbose=False)
  134. output_group = parser.add_argument_group("Output options")
  135. output_group.add_argument(
  136. "--output",
  137. required=True,
  138. type=str,
  139. help="Output path for onnx model with beam search.",
  140. )
  141. output_group.add_argument(
  142. "-p",
  143. "--precision",
  144. required=False,
  145. type=str,
  146. default=Precision.FLOAT32.value,
  147. choices=[Precision.FLOAT32.value, Precision.FLOAT16.value],
  148. help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision",
  149. )
  150. output_group.add_argument(
  151. "-b",
  152. "--op_block_list",
  153. required=False,
  154. nargs="*",
  155. default=["auto"],
  156. help="Disable certain onnx operators when exporting model to onnx format. When using default"
  157. 'value for gpt2 type of model fp16 precision, it will be set to ["Add", "LayerNormalization",'
  158. ' "SkipLayerNormalization", "FastGelu"]. Other situation, it will be set to []',
  159. )
  160. output_group.add_argument(
  161. "-e",
  162. "--use_external_data_format",
  163. required=False,
  164. action="store_true",
  165. help="save external data for model > 2G",
  166. )
  167. output_group.set_defaults(use_external_data_format=False)
  168. output_group.add_argument(
  169. "-s",
  170. "--run_shape_inference",
  171. required=False,
  172. action="store_true",
  173. help="run shape inference",
  174. )
  175. output_group.set_defaults(run_shape_inference=False)
  176. output_group.add_argument(
  177. "-dpvs",
  178. "--disable_pad_vocab_size",
  179. required=False,
  180. action="store_true",
  181. help="Do not pad logits MatMul weight to be a multiple of 8 along the dimension where dim value is"
  182. " the vocab size. The logits MatMul may hence be of poor performance for fp16 precision.",
  183. )
  184. output_group.set_defaults(disable_pad_vocab_size=False)
  185. output_group.add_argument(
  186. "-dsgd",
  187. "--disable_separate_gpt2_decoder_for_init_run",
  188. required=False,
  189. action="store_true",
  190. help="Do not create separate decoder subgraphs for initial and remaining runs. This does not allow "
  191. "for optimizations based on sequence lengths in each subgraph",
  192. )
  193. output_group.set_defaults(disable_separate_gpt2_decoder_for_init_run=False)
  194. output_group.add_argument(
  195. "-i",
  196. "--disable_shared_initializers",
  197. required=False,
  198. action="store_true",
  199. help="do not share initializers in encoder and decoder for T5 or in the init decoder and decoder for "
  200. "GPT2. It will increase memory usage of t5/mt5/gpt2 models.",
  201. )
  202. output_group.set_defaults(disable_shared_initializers=False)
  203. output_group.add_argument(
  204. "--encoder_decoder_init",
  205. required=False,
  206. action="store_true",
  207. help="Add decoder initialization to encoder for T5 model. This is legacy format that will be deprecated.",
  208. )
  209. output_group.set_defaults(encoder_decoder_init=False)
  210. model_group = parser.add_argument_group("Beam search parameters that stored in the output model")
  211. model_group.add_argument(
  212. "--output_sequences_scores",
  213. required=False,
  214. action="store_true",
  215. help="output sequences scores",
  216. )
  217. model_group.set_defaults(output_sequences_scores=False)
  218. model_group.add_argument(
  219. "--output_token_scores",
  220. required=False,
  221. action="store_true",
  222. help="output token scores",
  223. )
  224. model_group.set_defaults(output_token_scores=False)
  225. model_group.add_argument("--early_stopping", required=False, action="store_true")
  226. model_group.set_defaults(early_stopping=False)
  227. model_group.add_argument(
  228. "--no_repeat_ngram_size",
  229. type=int,
  230. required=False,
  231. default=0,
  232. help="No repeat ngram size",
  233. )
  234. model_group.add_argument(
  235. "--vocab_mask",
  236. required=False,
  237. action="store_true",
  238. help="Enable vocab_mask. This mask applies only to every generated token to filter some bad words.",
  239. )
  240. model_group.set_defaults(vocab_mask=False)
  241. model_group.add_argument(
  242. "--past_present_share_buffer",
  243. required=False,
  244. action="store_true",
  245. help="Use shared buffer for past and present, currently work for gpt2 greedy/sampling search.",
  246. )
  247. model_group.set_defaults(past_present_share_buffer=False)
  248. model_group.add_argument(
  249. "--use_decoder_masked_attention",
  250. required=False,
  251. action="store_true",
  252. help="Uses `DecoderMaskedSelfAttention` or `DecoderMaskedMultiHeadAttention` to optimize the decoding Attention computation. "
  253. "Must be used with `past_present_share_buffer`. Currently, only Attention head sizes of 32, 64 and 128 are supported.",
  254. )
  255. model_group.set_defaults(use_decoder_masked_attention=False)
  256. model_group.add_argument(
  257. "--prefix_vocab_mask",
  258. required=False,
  259. action="store_true",
  260. help="Enable prefix_vocab_mask. This mask can be used to filter bad words in the first generated token only",
  261. )
  262. model_group.set_defaults(prefix_vocab_mask=False)
  263. model_group.add_argument(
  264. "--custom_attention_mask",
  265. required=False,
  266. action="store_true",
  267. help="Enable custom_attention_mask. This mask can be used to replace default encoder attention mask",
  268. )
  269. model_group.set_defaults(custom_attention_mask=False)
  270. model_group.add_argument(
  271. "--presence_mask",
  272. required=False,
  273. action="store_true",
  274. help="Presence mask for custom sampling",
  275. )
  276. model_group.set_defaults(presence_mask=False)
  277. model_group.add_argument(
  278. "--seed",
  279. required=False,
  280. action="store_true",
  281. help="Random seed for sampling op",
  282. )
  283. model_group.set_defaults(seed=False)
  284. beam_parameters_group = parser.add_argument_group(
  285. "Beam search parameters not stored in the output model, for testing parity and performance"
  286. )
  287. beam_parameters_group.add_argument("--min_length", type=int, required=False, default=1, help="Min sequence length")
  288. beam_parameters_group.add_argument("--max_length", type=int, required=False, default=50, help="Max sequence length")
  289. beam_parameters_group.add_argument("--num_beams", type=int, required=False, default=4, help="Beam size")
  290. beam_parameters_group.add_argument(
  291. "--num_return_sequences",
  292. type=int,
  293. required=False,
  294. default=1,
  295. help="Number of return sequence <= num_beams",
  296. )
  297. beam_parameters_group.add_argument(
  298. "--length_penalty",
  299. type=float,
  300. required=False,
  301. default=1,
  302. help="Positive. >1 to penalize and <1 to encourage short sentence.",
  303. )
  304. beam_parameters_group.add_argument(
  305. "--repetition_penalty",
  306. type=float,
  307. required=False,
  308. default=1,
  309. help="Positive. >1 to penalize and <1 to encourage.",
  310. )
  311. beam_parameters_group.add_argument(
  312. "--temperature",
  313. type=float,
  314. required=False,
  315. default=1.0,
  316. help="The value used to module the next token probabilities.",
  317. )
  318. beam_parameters_group.add_argument(
  319. "--top_p",
  320. type=float,
  321. required=False,
  322. default=1.0,
  323. help="Top P for sampling",
  324. )
  325. beam_parameters_group.add_argument(
  326. "--filter_value",
  327. type=float,
  328. required=False,
  329. default=-float("Inf"),
  330. help="Filter value for Top P sampling",
  331. )
  332. beam_parameters_group.add_argument(
  333. "--min_tokens_to_keep",
  334. type=int,
  335. required=False,
  336. default=1,
  337. help="Minimum number of tokens we keep per batch example in the output.",
  338. )
  339. beam_parameters_group.add_argument(
  340. "--presence_penalty",
  341. type=float,
  342. required=False,
  343. default=0.0,
  344. help="presence penalty for custom sampling.",
  345. )
  346. beam_parameters_group.add_argument(
  347. "--custom",
  348. type=int,
  349. required=False,
  350. default=0,
  351. help="If 1 customized top P logic is applied",
  352. )
  353. beam_parameters_group.add_argument(
  354. "--vocab_size",
  355. type=int,
  356. required=False,
  357. default=-1,
  358. help="Vocab_size of the underlying model used to decide the shape of vocab mask",
  359. )
  360. beam_parameters_group.add_argument(
  361. "--eos_token_id",
  362. type=int,
  363. required=False,
  364. default=-1,
  365. help="custom eos_token_id for generating model with existing onnx encoder/decoder",
  366. )
  367. beam_parameters_group.add_argument(
  368. "--pad_token_id",
  369. type=int,
  370. required=False,
  371. default=-1,
  372. help="custom pad_token_id for generating model with existing onnx encoder/decoder",
  373. )
  374. test_group = parser.add_argument_group("Other options for testing parity and performance")
  375. test_group.add_argument(
  376. "--use_sln_strict_mode",
  377. required=False,
  378. action="store_true",
  379. help="Enable strict mode for SLN in CUDA provider. This ensures a better accuracy but will be slower.",
  380. )
  381. test_group.set_defaults(use_sln_strict_mode=False)
  382. test_group.add_argument(
  383. "--use_gpu",
  384. required=False,
  385. action="store_true",
  386. help="use GPU for inference. Required for fp16.",
  387. )
  388. test_group.set_defaults(use_gpu=False)
  389. test_group.add_argument(
  390. "--disable_parity",
  391. required=False,
  392. action="store_true",
  393. help="do not run parity test",
  394. )
  395. test_group.set_defaults(disable_parity=False)
  396. test_group.add_argument(
  397. "--disable_perf_test",
  398. required=False,
  399. action="store_true",
  400. help="do not run perf test",
  401. )
  402. test_group.set_defaults(disable_perf_test=False)
  403. test_group.add_argument(
  404. "--torch_performance",
  405. required=False,
  406. action="store_true",
  407. help="test PyTorch performance",
  408. )
  409. test_group.set_defaults(torch_performance=False)
  410. test_group.add_argument(
  411. "--total_runs",
  412. required=False,
  413. type=int,
  414. default=1,
  415. help="Number of times of inference for latency measurement",
  416. )
  417. test_group.add_argument(
  418. "--save_test_data",
  419. required=False,
  420. action="store_true",
  421. help="save test data for onnxruntime_perf_test tool",
  422. )
  423. test_group.set_defaults(save_test_data=False)
  424. args = parser.parse_args(argv)
  425. return args
  426. def gpt2_to_onnx(args: argparse.Namespace):
  427. """Convert GPT-2 model to onnx
  428. Args:
  429. args (argparse.Namespace): arguments parsed from command line
  430. """
  431. model_name = args.model_name_or_path
  432. arguments = [
  433. "--model_name_or_path",
  434. model_name,
  435. "--output",
  436. args.decoder_onnx,
  437. "--optimize_onnx",
  438. "--precision",
  439. args.precision,
  440. "--test_runs",
  441. "1",
  442. "--test_cases",
  443. "10",
  444. "--overwrite", # Overwrite onnx file if existed
  445. ]
  446. if args.cache_dir:
  447. arguments.extend(["--cache_dir", args.cache_dir])
  448. if args.use_gpu:
  449. arguments.append("--use_gpu")
  450. if args.use_external_data_format:
  451. arguments.append("--use_external_data_format")
  452. if len(args.op_block_list):
  453. arguments.extend(["--op_block_list"])
  454. arguments.extend(args.op_block_list)
  455. if args.precision == Precision.FLOAT16.value:
  456. assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu"
  457. # TODO(tianleiwu): Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision')
  458. # Need change cuda kernel to support a combination of fp32 logits and fp16 past state.
  459. # Currently logits and past state shall be same data type.
  460. if args.verbose:
  461. logger.info(f"arguments for convert_to_onnx:{arguments}")
  462. convert_gpt2_to_onnx(argv=arguments)
  463. def t5_to_onnx(args: argparse.Namespace):
  464. """Convert T5 model to onnx
  465. Args:
  466. args (argparse.Namespace): arguments parsed from command line
  467. """
  468. paths = export_t5_onnx_models(
  469. model_name_or_path=args.model_name_or_path,
  470. cache_dir=args.cache_dir,
  471. output_dir=Path(args.output).parent,
  472. use_gpu=args.use_gpu,
  473. use_external_data_format=args.use_external_data_format,
  474. optimize_onnx=(args.precision != Precision.FLOAT16.value),
  475. precision=args.precision,
  476. verbose=False,
  477. use_decoder_start_token=False,
  478. overwrite=True,
  479. disable_auto_mixed_precision=False,
  480. use_int32_inputs=True,
  481. model_type=args.model_type,
  482. encoder_decoder_init=args.encoder_decoder_init,
  483. force_fp16_io=(args.precision == Precision.FLOAT16.value), # required by BeamSearch op implementation.
  484. )
  485. logger.debug(f"onnx model for encoder: {paths[0]}")
  486. logger.debug(f"onnx model for decoder: {paths[1]}")
  487. args.encoder_decoder_init_onnx = paths[0]
  488. args.decoder_onnx = paths[1]
  489. def shape_inference(onnx_path: str, use_external_data_format: bool = True):
  490. """Shape inference on an onnx file, which will be overwritten.
  491. Args:
  492. onnx_path (str): Path of onnx model
  493. use_external_data_format(bool): output tensors to external data or not.
  494. """
  495. # Run symbolic shape inference to walk around ORT shape inference issue for subgraph.
  496. from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference # noqa: PLC0415
  497. model = onnx.load_model(onnx_path, load_external_data=True)
  498. out = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=False)
  499. if out:
  500. OnnxModel.save(out, onnx_path, save_as_external_data=use_external_data_format)
  501. else:
  502. logger.warning("Failed to run symbolic shape inference on the model.")
  503. def pad_weights_of_logits_matmul(onnx_path: str, use_external_data_format: bool = True) -> bool:
  504. """Pad the logits MatMul weight in the provided decoder model, which will be overwritten.
  505. Args:
  506. onnx_path (str): Path of onnx model
  507. use_external_data_format(bool): output tensors to external data or not.
  508. """
  509. decoder_model_proto = onnx.load_model(onnx_path, load_external_data=True)
  510. logits_output_name = decoder_model_proto.graph.output[0].name
  511. decoder_model = OnnxModel(decoder_model_proto)
  512. output_name_to_node = decoder_model.output_name_to_node()
  513. assert logits_output_name in output_name_to_node
  514. matmul_node = output_name_to_node[logits_output_name]
  515. # Sanity check - the logits need to be produced by a MatMul node
  516. if matmul_node.op_type != "MatMul":
  517. return False
  518. # The logits MatMul weight MUST be an initializer (or)
  519. # it MUST be flowing through a Transpose whose input is
  520. # an initializer
  521. pad_along_axis_1 = True
  522. logits_weight = decoder_model.get_initializer(matmul_node.input[1])
  523. if logits_weight is None:
  524. transpose_before_matmul = decoder_model.match_parent(matmul_node, "Transpose", 1)
  525. if transpose_before_matmul is None:
  526. return False
  527. logits_weight = decoder_model.get_initializer(transpose_before_matmul.input[0])
  528. if logits_weight is None:
  529. return False
  530. pad_along_axis_1 = False
  531. # The logits MatMul weight MUST be fp16
  532. if logits_weight.data_type != TensorProto.DataType.FLOAT16:
  533. return False
  534. # The logits MatMul weight MUST be 2-dimensional
  535. if len(logits_weight.dims) != 2:
  536. return False
  537. # Pad and over-write the initializer (if needed)
  538. actual_vocab_size = logits_weight.dims[1]
  539. if (actual_vocab_size % 8) == 0:
  540. # Already "padded"
  541. return True
  542. padded_vocab_size = math.ceil(actual_vocab_size / 8) * 8
  543. padding = padded_vocab_size - actual_vocab_size
  544. # TODO(hasesh): Handle cases where the fp16 data is stored in the
  545. # non-raw data field
  546. if logits_weight.raw_data:
  547. if pad_along_axis_1:
  548. padding_data = np.zeros((logits_weight.dims[0], padding), dtype=np.float16)
  549. weight_with_padding = np.concatenate((NumpyHelper.to_array(logits_weight), padding_data), axis=1)
  550. logits_weight.dims[1] = padded_vocab_size
  551. else:
  552. padding_data = np.zeros((padding, logits_weight.dims[1]), dtype=np.float16)
  553. weight_with_padding = np.concatenate((NumpyHelper.to_array(logits_weight), padding_data), axis=0)
  554. logits_weight.dims[0] = padded_vocab_size
  555. logits_weight.raw_data = weight_with_padding.tobytes()
  556. else:
  557. return False
  558. # Save the model
  559. OnnxModel.save(decoder_model_proto, onnx_path, save_as_external_data=use_external_data_format)
  560. return True
  561. def create_ort_session(model_path: str, use_gpu: bool, use_sln_strict_mode: bool) -> InferenceSession:
  562. """Create OnnxRuntime session.
  563. Args:
  564. model_path (str): onnx model path
  565. use_gpu (bool): use GPU or not
  566. use_sln_strict_mode (bool): use strict mode for skip layer normalization or not
  567. Raises:
  568. RuntimeError: CUDAExecutionProvider is not available when --use_gpu is specified.
  569. Returns:
  570. onnxruntime.InferenceSession: The created session.
  571. """
  572. sess_options = SessionOptions()
  573. sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
  574. execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
  575. if use_gpu:
  576. if "CUDAExecutionProvider" not in get_available_providers():
  577. raise RuntimeError("CUDAExecutionProvider is not available for --use_gpu!")
  578. else:
  579. logger.info("use CUDAExecutionProvider")
  580. if use_sln_strict_mode:
  581. cuda_provider_options = {"enable_skip_layer_norm_strict_mode": True}
  582. provider_options = {"CUDAExecutionProvider": cuda_provider_options}
  583. execution_providers = [
  584. (name, provider_options[name]) if name in provider_options else name for name in execution_providers
  585. ]
  586. ort_session = InferenceSession(model_path, sess_options, providers=execution_providers)
  587. return ort_session
  588. def verify_gpt2_subgraph(graph: onnx.GraphProto, precision: Precision):
  589. """Verify GPT-2 subgraph
  590. Args:
  591. graph (onnx.GraphProto): onnx graph of GPT-2
  592. precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
  593. Raises:
  594. ValueError: Number of inputs not expected.
  595. ValueError: Input name is not expected.
  596. ValueError: Input data type is not expected.
  597. ValueError: Number of outputs not expected.
  598. ValueError: Output name is not expected.
  599. ValueError: Output data type is not expected.
  600. """
  601. is_float16 = precision == Precision.FLOAT16.value
  602. input_count = len(graph.input)
  603. layer_count = input_count - 3
  604. assert layer_count >= 1
  605. expected_inputs = ["input_ids", "position_ids", "attention_mask"] + [f"past_{i}" for i in range(layer_count)]
  606. if len(graph.input) != len(expected_inputs):
  607. raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
  608. for i, expected_input in enumerate(expected_inputs):
  609. if graph.input[i].name != expected_input:
  610. raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
  611. expected_type = TensorProto.INT32
  612. if i >= 3:
  613. expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
  614. input_type = graph.input[i].type.tensor_type.elem_type
  615. if input_type != expected_type:
  616. raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
  617. logger.info("Verifying GPT-2 graph inputs: name and data type are good.")
  618. expected_outputs = ["logits"] + [f"present_{i}" for i in range(layer_count)]
  619. if len(graph.output) != len(expected_outputs):
  620. raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
  621. for i, expected_output in enumerate(expected_outputs):
  622. if graph.output[i].name != expected_output:
  623. raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
  624. expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
  625. output_type = graph.output[i].type.tensor_type.elem_type
  626. if output_type != expected_type:
  627. raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {output_type}")
  628. logger.info("Verifying GPT-2 graph outputs: name and data type are good.")
  629. # TODO(tianleiwu): verify shapes of inputs and outputs.
  630. return
  631. def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision):
  632. """Verify T5 decoder subgraph
  633. Args:
  634. graph (onnx.GraphProto): onnx graph of T5 decoder
  635. precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
  636. Raises:
  637. ValueError: Number of inputs not expected.
  638. ValueError: Input name is not expected.
  639. ValueError: Input data type is not expected.
  640. ValueError: Number of outputs not expected.
  641. ValueError: Output name is not expected.
  642. ValueError: Output data type is not expected.
  643. """
  644. is_float16 = precision == Precision.FLOAT16.value
  645. float_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
  646. input_count = len(graph.input)
  647. layer_count = (input_count - 2) // 4
  648. assert layer_count >= 1
  649. # Expect inputs:
  650. # input_ids: int32 (B, 1)
  651. # encoder_attention_mask: int32 (B, encode_sequence_length)
  652. # past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
  653. # past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
  654. # ... (for each self attention layer)
  655. # past_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
  656. # past_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
  657. # ... (for each cross attention layer)
  658. # TODO: encoder_hidden_states is optional
  659. expected_inputs = ["input_ids", "encoder_attention_mask"]
  660. for i in range(layer_count):
  661. expected_inputs.append(f"past_key_self_{i}")
  662. expected_inputs.append(f"past_value_self_{i}")
  663. for i in range(layer_count):
  664. expected_inputs.append(f"past_key_cross_{i}")
  665. expected_inputs.append(f"past_value_cross_{i}")
  666. if len(graph.input) != len(expected_inputs):
  667. raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
  668. for i, expected_input in enumerate(expected_inputs):
  669. if graph.input[i].name != expected_input:
  670. raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
  671. expected_type = TensorProto.INT32 if i < 2 else float_type
  672. input_type = graph.input[i].type.tensor_type.elem_type
  673. if input_type != expected_type:
  674. raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
  675. # Expect outputs:
  676. # logits: (B, 1, vocab_size)
  677. # present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
  678. # present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
  679. # ... (for each self attention layer)
  680. expected_outputs = ["logits"]
  681. for i in range(layer_count):
  682. expected_outputs.append(f"present_key_self_{i}")
  683. expected_outputs.append(f"present_value_self_{i}")
  684. if len(graph.output) != len(expected_outputs):
  685. raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
  686. for i, expected_output in enumerate(expected_outputs):
  687. if graph.output[i].name != expected_output:
  688. raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
  689. output_type = graph.output[i].type.tensor_type.elem_type
  690. if output_type != float_type:
  691. raise ValueError(f"Output {i} is expected to have onnx data type {float_type}. Got {output_type}")
  692. def verify_t5_encoder_decoder_init_subgraph(graph: onnx.GraphProto, precision: Precision):
  693. """Verify T5 decoder subgraph
  694. Args:
  695. graph (onnx.GraphProto): onnx graph of T5 decoder
  696. precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
  697. Raises:
  698. ValueError: Number of inputs not expected.
  699. ValueError: Input name is not expected.
  700. ValueError: Input data type is not expected.
  701. ValueError: Number of outputs not expected.
  702. ValueError: Output name is not expected.
  703. ValueError: Output data type is not expected.
  704. """
  705. is_float16 = precision == Precision.FLOAT16.value
  706. new_format = "cross" in graph.output[0].name
  707. # Expect 3 inputs:
  708. # encoder_input_ids: int32 (B, encode_sequence_length)
  709. # encoder_attention_mask: int32 (B, encode_sequence_length)
  710. # decoder_input_ids: int32 (B, 1)
  711. expected_inputs = [
  712. "encoder_input_ids",
  713. "encoder_attention_mask",
  714. "decoder_input_ids",
  715. ]
  716. if new_format:
  717. expected_inputs = expected_inputs[:2]
  718. if len(graph.input) != len(expected_inputs):
  719. raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
  720. for i, expected_input in enumerate(expected_inputs):
  721. if graph.input[i].name != expected_input:
  722. raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
  723. expected_type = TensorProto.INT32
  724. input_type = graph.input[i].type.tensor_type.elem_type
  725. if input_type != expected_type:
  726. raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
  727. if new_format:
  728. assert len(graph.output) % 2 == 0
  729. layer_count = len(graph.output) // 2
  730. assert layer_count >= 1
  731. # Expected outputs:
  732. # present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
  733. # present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
  734. # ... (for each cross attention layer)
  735. expected_outputs = []
  736. for i in range(layer_count):
  737. expected_outputs.append(f"present_key_cross_{i}")
  738. expected_outputs.append(f"present_value_cross_{i}")
  739. else:
  740. logger.warning("This format is deprecated. Please export T5 encoder in new format with only cross outputs.")
  741. assert (len(graph.output) - 2) % 4 == 0
  742. layer_count = (len(graph.output) - 2) // 4
  743. assert layer_count >= 1
  744. # Expected outputs:
  745. # logits: (B, 1, vocab_size)
  746. # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
  747. # present_key_self_0: (B, num_heads, 1, head_size)
  748. # present_value_self_0: (B, num_heads, 1, head_size)
  749. # ... (for each self attention layer)
  750. # present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
  751. # present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
  752. # ... (for each cross attention layer)
  753. expected_outputs = ["logits", "encoder_hidden_states"]
  754. for i in range(layer_count):
  755. expected_outputs.append(f"present_key_self_{i}")
  756. expected_outputs.append(f"present_value_self_{i}")
  757. for i in range(layer_count):
  758. expected_outputs.append(f"present_key_cross_{i}")
  759. expected_outputs.append(f"present_value_cross_{i}")
  760. if len(graph.output) != len(expected_outputs):
  761. raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
  762. for i, expected_output in enumerate(expected_outputs):
  763. if graph.output[i].name != expected_output:
  764. raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
  765. expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
  766. output_type = graph.output[i].type.tensor_type.elem_type
  767. if output_type != expected_type:
  768. raise ValueError(f"Output {i} is expected to have onnx data type {expected_type}. Got {output_type}")
  769. logger.info("T5 encoder graph verified: name and data type of inputs and outputs are good.")
  770. def remove_shared_initializers(
  771. graph1: GraphProto,
  772. graph2: GraphProto,
  773. shared_prefix: str = "shared_",
  774. min_elements: int = 1024,
  775. signature_cache1: dict | None = None,
  776. signature_cache2: dict | None = None,
  777. ):
  778. """Remove initializers with same value from two graphs.
  779. Args:
  780. graph1 (GraphProto): the first graph to process
  781. graph2 (GraphProto): the second graph to process
  782. shared_prefix (str): add prefix to the shared initializers among two graphs
  783. min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024.
  784. signature_cache1 (dict): Optional dictionary to store data signatures of tensors in graph1 in order to speed up comparison
  785. signature_cache2 (dict): Optional dictionary to store data signatures of tensors in graph2 in order to speed up comparison
  786. """
  787. mapping_initializers_1 = {}
  788. mapping_initializers_2 = {}
  789. shared_initializers_1 = []
  790. shared_initializers_2 = []
  791. shared_initializers_names = []
  792. for initializer1 in graph1.initializer:
  793. if not (initializer1.dims and sum(initializer1.dims) >= min_elements):
  794. continue
  795. for initializer2 in graph2.initializer:
  796. if not (initializer2.dims and sum(initializer2.dims) >= min_elements):
  797. continue
  798. if OnnxModel.has_same_value(initializer1, initializer2, signature_cache1, signature_cache2):
  799. mapping_initializers_1[initializer1.name] = shared_prefix + initializer2.name
  800. shared_initializers_1.append(initializer1)
  801. if initializer2.name not in mapping_initializers_2:
  802. shared_name = shared_prefix + initializer2.name
  803. mapping_initializers_2[initializer2.name] = shared_name
  804. shared_initializers_2.append(initializer2)
  805. shared_initializers_names.append(shared_name)
  806. break
  807. logger.debug(f"shared initializers:{shared_initializers_names}")
  808. # Make sure new name does not exist in graph 1
  809. for node in graph1.node:
  810. for j in range(len(node.input)):
  811. if node.input[j] in shared_initializers_names:
  812. raise RuntimeError(f"name is found in graph 1: {node.input[j]}")
  813. # Make sure new name does not exist in graph 2
  814. for node in graph2.node:
  815. for j in range(len(node.input)):
  816. if node.input[j] in shared_initializers_names:
  817. raise RuntimeError(f"name is found in graph 2: {node.input[j]}")
  818. # Remove shared initializers from graph 2
  819. for initializer in shared_initializers_2:
  820. graph2.initializer.remove(initializer)
  821. # Rename value info for old names in graph 2
  822. for value_info in graph2.value_info:
  823. if value_info.name in mapping_initializers_2:
  824. value_info.name = mapping_initializers_2[value_info.name]
  825. # Rename nodes inputs in graph 2:
  826. for node in graph2.node:
  827. for j in range(len(node.input)):
  828. if node.input[j] in mapping_initializers_2:
  829. new_name = mapping_initializers_2[node.input[j]]
  830. logger.debug(f"graph 2 rename node {node.name} input {j} from {node.input[j]} to {new_name}")
  831. node.input[j] = new_name
  832. # Remove shared initializers from graph 1
  833. for initializer in shared_initializers_1:
  834. graph1.initializer.remove(initializer)
  835. # Rename value info for old names in graph 1
  836. for value_info in graph1.value_info:
  837. if value_info.name in mapping_initializers_1:
  838. value_info.name = mapping_initializers_1[value_info.name]
  839. # Rename nodes inputs in graph 1:
  840. for node in graph1.node:
  841. for j in range(len(node.input)):
  842. if node.input[j] in mapping_initializers_1:
  843. new_name = mapping_initializers_1[node.input[j]]
  844. logger.debug(f"graph 1 rename node {node.name} input {j} from {node.input[j]} to {new_name}")
  845. node.input[j] = new_name
  846. # Rename shared initializers in graph 2
  847. for initializer in shared_initializers_2:
  848. initializer.name = mapping_initializers_2[initializer.name]
  849. for initializer in shared_initializers_2:
  850. shape = onnx.numpy_helper.to_array(initializer).shape
  851. value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
  852. # Need add value_info for initializers moved to parent graph. Otherwise, ORT will fail.
  853. graph1.value_info.append(value_info)
  854. graph2.value_info.append(value_info)
  855. return shared_initializers_2
  856. def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto):
  857. encoder = OnnxModel(encoder_model)
  858. decoder = OnnxModel(decoder_model)
  859. encoder.add_prefix_to_names("e_")
  860. decoder.add_prefix_to_names("d_")
  861. signature_cache1, signature_cache2 = {}, {}
  862. encoder.remove_duplicated_initializer(signature_cache1)
  863. decoder.remove_duplicated_initializer(signature_cache2)
  864. initializers = remove_shared_initializers(
  865. decoder.model.graph,
  866. encoder.model.graph,
  867. shared_prefix="s_",
  868. signature_cache1=signature_cache1,
  869. signature_cache2=signature_cache2,
  870. )
  871. return initializers
  872. def move_initializers(
  873. graph: GraphProto,
  874. min_elements: int = 1024,
  875. ) -> list[TensorProto]:
  876. """Remove initializers of a graph, when they have number of elements larger than a threshold.
  877. Args:
  878. graph (GraphProto): the graph.
  879. min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024.
  880. Returns:
  881. List[TensorProto]: initializers that are removed from the graph.
  882. """
  883. moved_initializers = []
  884. for tensor in graph.initializer:
  885. if not (tensor.dims and sum(tensor.dims) >= min_elements):
  886. continue
  887. moved_initializers.append(tensor)
  888. for initializer in moved_initializers:
  889. graph.initializer.remove(initializer)
  890. # Add type info, otherwise ORT will raise error: "input arg (*) does not have type information set by parent node."
  891. for initializer in moved_initializers:
  892. shape = onnx.numpy_helper.to_array(initializer).shape
  893. value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
  894. graph.value_info.append(value_info)
  895. return moved_initializers
  896. def _attribute_to_pair(attribute):
  897. """
  898. Convert attribute to kwarg format for use with onnx.helper.make_node.
  899. :parameter attribute: attribute in AttributeProto format.
  900. :return: attribute in {key: value} format.
  901. """
  902. if attribute.type == 0:
  903. raise ValueError(f"attribute {attribute.name} does not have type specified.")
  904. # Based on attribute type definitions from AttributeProto
  905. # definition in https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
  906. if attribute.type == 1:
  907. value = attribute.f
  908. elif attribute.type == 2:
  909. value = attribute.i
  910. elif attribute.type == 3:
  911. value = attribute.s
  912. elif attribute.type == 4:
  913. value = attribute.t
  914. elif attribute.type == 5:
  915. value = attribute.g
  916. elif attribute.type == 6:
  917. value = attribute.floats
  918. elif attribute.type == 7:
  919. value = attribute.ints
  920. elif attribute.type == 8:
  921. value = attribute.strings
  922. elif attribute.type == 9:
  923. value = attribute.tensors
  924. elif attribute.type == 10:
  925. value = attribute.graphs
  926. else:
  927. raise ValueError(f"attribute {attribute.name} has unsupported type {attribute.type}.")
  928. return (attribute.name, value)
  929. def kwargs_of(node):
  930. kwargs = {}
  931. for attr in node.attribute:
  932. (key, value) = _attribute_to_pair(attr)
  933. kwargs.update({key: value})
  934. if node.domain:
  935. kwargs.update({"domain": node.domain})
  936. return kwargs
  937. def shape_of(vi):
  938. return tuple([d.dim_param if (d.dim_param) else d.dim_value for d in vi.type.tensor_type.shape.dim])
  939. def update_decoder_subgraph_past_present_share_buffer(subg: GraphProto):
  940. input_past_0 = 3
  941. output_past_0 = 1
  942. new_inputs = []
  943. for i, vi in enumerate(subg.input):
  944. if i >= input_past_0:
  945. shape = shape_of(vi)
  946. vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
  947. vi.name,
  948. elem_type=vi.type.tensor_type.elem_type,
  949. shape=[shape[0], shape[1], shape[2], "max_seq_len", shape[4]],
  950. )
  951. new_inputs.extend([vi])
  952. new_inputs.extend([onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])])
  953. subg.ClearField("input")
  954. subg.input.extend(new_inputs)
  955. new_outputs = []
  956. for i, vi in enumerate(subg.output):
  957. if i >= output_past_0:
  958. shape = shape_of(vi)
  959. vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
  960. vi.name,
  961. elem_type=vi.type.tensor_type.elem_type,
  962. shape=[shape[0], shape[1], shape[2], "max_seq_len", shape[4]],
  963. )
  964. new_outputs.extend([vi])
  965. subg.ClearField("output")
  966. subg.output.extend(new_outputs)
  967. new_nodes = []
  968. for node in subg.node:
  969. new_node = node
  970. if node.op_type == "Attention":
  971. kwargs = kwargs_of(node)
  972. kwargs.update({"past_present_share_buffer": 1})
  973. nis = []
  974. nis.extend(node.input)
  975. while len(nis) < 6:
  976. nis.extend([""])
  977. if len(nis) < 7:
  978. nis.extend(["past_sequence_length"])
  979. new_node = onnx.helper.make_node("Attention", nis, node.output, name=node.name, **kwargs)
  980. new_nodes.extend([new_node])
  981. subg.ClearField("node")
  982. subg.node.extend(new_nodes)
  983. return subg
  984. def update_decoder_subgraph_use_decoder_masked_attention(
  985. subg: GraphProto, is_beam_search: bool, switch_attention: bool
  986. ) -> bool:
  987. """Update the Attention nodes to DecoderMaskedSelfAttention.
  988. Args:
  989. subg (GraphProto): GraphProto of the decoder subgraph
  990. is_beam_search (bool): Boolean specifying if the sampling algo is BeamSearch
  991. switch_attention (bool): Boolean specifying if `Attention` is to be switched with `DecoderMaskedSelfAttention`
  992. """
  993. if is_beam_search:
  994. new_inputs = []
  995. for _i, vi in enumerate(subg.input):
  996. new_inputs.extend([vi])
  997. # Add 2 BeamSearch specific inputs
  998. new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
  999. new_inputs.extend(
  1000. [
  1001. onnx.helper.make_tensor_value_info(
  1002. "cache_indirection",
  1003. onnx.TensorProto.INT32,
  1004. shape=["batch_size", "beam_width", "max_seq_len"],
  1005. )
  1006. ]
  1007. )
  1008. subg.ClearField("input")
  1009. subg.input.extend(new_inputs)
  1010. if switch_attention:
  1011. decoder_masked_attention_supported_attr = [
  1012. "past_present_share_buffer",
  1013. "num_heads",
  1014. "scale",
  1015. "mask_filter_value",
  1016. "domain",
  1017. ]
  1018. new_nodes = []
  1019. for node in subg.node:
  1020. if node.op_type == "Attention":
  1021. kwargs = kwargs_of(node)
  1022. for k in kwargs.copy():
  1023. # The Attention operator does not support different qkv hidden sizes when past/present
  1024. # input/output exists (GPT2 model). Hence, we should never run into this.
  1025. # But, if we do, do not go ahead with the optimization.
  1026. if k == "qkv_hidden_sizes":
  1027. return False
  1028. if k not in decoder_masked_attention_supported_attr:
  1029. # Log the fact that we are removing certain attributes from the node
  1030. # We don't need to log it for "unidirectional" as we are aware that
  1031. # decoding attention kernels are unidirectional by definition.
  1032. if k != "unidirectional":
  1033. logger.warning(
  1034. f"Removing attribute: {k} from Attention node while switching to DecoderMaskedSelfAttention"
  1035. )
  1036. del kwargs[k]
  1037. nis = []
  1038. nis.extend(node.input)
  1039. # Add 2 BeamSearch specific inputs
  1040. if is_beam_search:
  1041. while len(nis) < 7:
  1042. nis.extend([""])
  1043. if len(nis) < 8:
  1044. nis.extend(["beam_width"])
  1045. if len(nis) < 9:
  1046. nis.extend(["cache_indirection"])
  1047. node = onnx.helper.make_node( # noqa: PLW2901
  1048. "DecoderMaskedSelfAttention",
  1049. nis,
  1050. node.output,
  1051. name=node.name,
  1052. **kwargs,
  1053. )
  1054. new_nodes.extend([node])
  1055. subg.ClearField("node")
  1056. subg.node.extend(new_nodes)
  1057. return True
  1058. def find_past_seq_len_usage(subg: GraphProto):
  1059. """Correct graph which originally use dim of past_seq_len from input_ids's shape which is fixed to max_seq_len after
  1060. shared past/present buffer
  1061. Args:
  1062. subg (GraphProto): GraphProto of the decoder subgraph
  1063. return:
  1064. tensor_names_to_rename : set of tensor names which is equal to past_sequence_length
  1065. nodes_to_remove : list of node to remove
  1066. """
  1067. tensor_names_to_rename = set()
  1068. nodes_to_remove = []
  1069. graph_input_names = {inp.name: index for index, inp in enumerate(subg.input)}
  1070. input_name_to_nodes = {}
  1071. output_name_to_node = {}
  1072. for node in subg.node:
  1073. for input_name in node.input:
  1074. if input_name:
  1075. if input_name not in input_name_to_nodes:
  1076. input_name_to_nodes[input_name] = [node]
  1077. else:
  1078. input_name_to_nodes[input_name].append(node)
  1079. for output_name in node.output:
  1080. if output_name:
  1081. output_name_to_node[output_name] = node
  1082. for node in subg.node:
  1083. # find "past_key_self_0 --> [Transpose(past_key_self_0) --> Reshape(past_key_self_0)] --> Shape(past_key_self_0) --> Gather(*, 2)"
  1084. # where [Transpose(past_key_self_0) --> Reshape(past_key_self_0)] may or may not exist
  1085. if node.op_type == "Gather":
  1086. if not node.input[1] or not node.input[0]:
  1087. continue
  1088. # Find Gather node's index value
  1089. shape_tensor_name, shape_index_name = (node.input[0], node.input[1])
  1090. ini_gather_indices = None
  1091. if "Constant_" in shape_index_name:
  1092. # If shape_index_name refers to a Constant node
  1093. for const_node in subg.node:
  1094. if const_node.op_type == "Constant" and const_node.output[0] == shape_index_name:
  1095. ini_gather_indices = const_node.attribute[0].t
  1096. break
  1097. else:
  1098. # If shape_index_name refers to an initializer
  1099. for tensor in subg.initializer:
  1100. if tensor.name == shape_index_name:
  1101. ini_gather_indices = tensor
  1102. break
  1103. if ini_gather_indices is None:
  1104. continue
  1105. gather_indices_arr = onnx.numpy_helper.to_array(ini_gather_indices)
  1106. if (
  1107. gather_indices_arr.size == 1
  1108. and gather_indices_arr.item() in {1, 2}
  1109. and node.input[0] in output_name_to_node
  1110. ):
  1111. shape_node = output_name_to_node[shape_tensor_name]
  1112. if not (shape_node.op_type == "Shape" and shape_node.input[0]):
  1113. continue
  1114. if (
  1115. shape_node.input[0] in graph_input_names
  1116. and (
  1117. shape_node.input[0].startswith("past_key_self_")
  1118. or shape_node.input[0].startswith("past_value_self_")
  1119. )
  1120. and gather_indices_arr.item() == 2
  1121. ):
  1122. # "past_key_self_0 --> Shape(past_key_self_0) --> Gather(*, 2)"
  1123. tensor_names_to_rename.add(node.output[0])
  1124. nodes_to_remove.append(node)
  1125. if len(input_name_to_nodes[shape_node.output[0]]) == 1:
  1126. nodes_to_remove.append(shape_node)
  1127. continue
  1128. if shape_node.input[0] not in output_name_to_node:
  1129. continue
  1130. reshape_node = output_name_to_node[shape_node.input[0]]
  1131. if not (reshape_node.op_type == "Reshape" and reshape_node.input[0]):
  1132. continue
  1133. transpose_node = output_name_to_node[reshape_node.input[0]]
  1134. if not (transpose_node.op_type == "Transpose" and transpose_node.input[0]):
  1135. continue
  1136. if (
  1137. transpose_node.input[0] in graph_input_names
  1138. and (
  1139. transpose_node.input[0].startswith("past_key_self_")
  1140. or transpose_node.input[0].startswith("past_value_self_")
  1141. )
  1142. and gather_indices_arr.item() == 1
  1143. ):
  1144. # "past_key_self_0 --> Transpose(past_key_self_0) --> Reshape(past_key_self_0) --> Shape(past_key_self_0) --> Gather(*, 2)"
  1145. tensor_names_to_rename.add(node.output[0])
  1146. nodes_to_remove.extend([node, shape_node, reshape_node])
  1147. if len(input_name_to_nodes[transpose_node.output[0]]) == 1:
  1148. nodes_to_remove.append(transpose_node)
  1149. continue
  1150. return tensor_names_to_rename, nodes_to_remove
  1151. def add_cache_indirection_to_mha(model: OnnxModel, past_seq_len_name: str):
  1152. # Add past_sequence_length and cache_indirection as inputs to all MultiHeadAttention ops and as inputs to model
  1153. cache_indirection_name = "cache_indirection"
  1154. mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
  1155. for node in mha_nodes:
  1156. # MHA op takes the following potential inputs:
  1157. # query, key, value, bias, key_padding_mask, add_qk, past_key, past_value
  1158. while len(node.input) < 8:
  1159. node.input.append("")
  1160. node.input.append(past_seq_len_name)
  1161. node.input.append(cache_indirection_name)
  1162. model.model.graph.input.append(
  1163. onnx.helper.make_tensor_value_info(
  1164. cache_indirection_name, TensorProto.INT32, shape=["batch_size", "beam_width", "max_sequence_length"]
  1165. ),
  1166. )
  1167. model.topological_sort()
  1168. return model
  1169. def add_output_qk_to_mha(model: OnnxModel, dtype: int = 0, skip_node_idxs: list[int] = []): # noqa: B006
  1170. # Add output_qk as output to MultiHeadAttention ops and as outputs to model
  1171. output_qk_basename = "output_cross_qk"
  1172. output_qks = []
  1173. mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
  1174. for idx, node in enumerate(mha_nodes):
  1175. # Skip MHA nodes where output_qk does not need to be added
  1176. if idx in skip_node_idxs:
  1177. continue
  1178. # Get `num_heads` attribute from MHA
  1179. num_heads = 0
  1180. for att in node.attribute:
  1181. if att.name == "num_heads":
  1182. num_heads = att.i
  1183. break
  1184. # Get dtype for `output_qk` based on MHA bias if not provided
  1185. output_qk_dtype = dtype
  1186. if output_qk_dtype == 0:
  1187. for i in model.model.graph.initializer:
  1188. if i.name == node.input[3]:
  1189. output_qk_dtype = i.data_type
  1190. break
  1191. # Get `target_sequence_length` attribute from 4D input for key if it's a constant
  1192. target_sequence_length = "target_sequence_length"
  1193. for i in model.model.graph.input:
  1194. if i.name == node.input[1]:
  1195. target_sequence_length = i.type.tensor_type.shape.dim[2].dim_value
  1196. break
  1197. # MHA op takes the following potential outputs:
  1198. # output, present_key, present_value
  1199. while len(node.output) < 3:
  1200. node.output.append("")
  1201. output_qk_name = f"{output_qk_basename}_{idx // 2}"
  1202. node.output.append(output_qk_name)
  1203. output_qks.append(
  1204. onnx.helper.make_tensor_value_info(
  1205. output_qk_name,
  1206. output_qk_dtype,
  1207. shape=["batch_size", num_heads, "sequence_length", target_sequence_length],
  1208. ),
  1209. )
  1210. model.model.graph.output.extend(output_qks)
  1211. model.topological_sort()
  1212. return model
  1213. def fix_past_sequence_length(model: OnnxModel):
  1214. # Modify total_sequence_length = past_sequence_length + curr_sequence_length subgraph to calculate
  1215. # past_sequence_length from the new `past_sequence_length` input of size 1D and type int32 instead of
  1216. # from `past_key_self_0` since DecoderMaskedMultiHeadAttention (DMMHA) uses buffer sharing and
  1217. # `past_key_self_0.shape[2] = max_sequence_length` instead of `past_key_self_0.shape[2] = past_sequence_length`
  1218. # when buffer sharing is enabled
  1219. #
  1220. # Before:
  1221. #
  1222. # input_ids past_key_self_0
  1223. # | |
  1224. # Shape Shape
  1225. # | |
  1226. # Gather Gather
  1227. # (idx=1) (idx=2)
  1228. # | | \
  1229. # +--------+--------+ Unsqueeze
  1230. # |
  1231. # Add
  1232. #
  1233. # After:
  1234. #
  1235. # input_ids past_sequence_length (1D)
  1236. # | |
  1237. # Shape Squeeze
  1238. # | |
  1239. # Gather Cast
  1240. # (idx=1) (int64)
  1241. # | | \
  1242. # +--------+--------+ Unsqueeze
  1243. # |
  1244. # Add
  1245. # Constant names to be used
  1246. past_seq_len_name = "past_sequence_length"
  1247. past_seq_len_int32 = "past_seq_len_int32"
  1248. past_seq_len_int64 = "past_seq_len_int64"
  1249. node = list(filter(lambda n: n.op_type == "LayerNormalization", model.model.graph.node))[0] # noqa: RUF015
  1250. base_path_hf = model.match_parent_path(
  1251. node,
  1252. ["Add", "Gather", "Tile", "Expand", "Unsqueeze", "Range"],
  1253. [0, 1, 1, 0, 0, 0],
  1254. )
  1255. base_path_oai = model.match_parent_path(
  1256. node,
  1257. ["Add", "Slice"],
  1258. [0, 1],
  1259. )
  1260. if base_path_hf is not None:
  1261. base_path = base_path_hf
  1262. elif base_path_oai is not None:
  1263. base_path = base_path_oai
  1264. else:
  1265. logger.info("Cannot identify base path for fixing past_sequence_length subgraph")
  1266. return
  1267. base_node = base_path[-1]
  1268. if base_node.op_type == "Range":
  1269. # Hugging Face implementation
  1270. range_node = base_path[-1]
  1271. gather_path = model.match_parent_path(
  1272. range_node,
  1273. ["Gather", "Shape"],
  1274. [0, 0],
  1275. )
  1276. if gather_path is None:
  1277. logger.info("Cannot identify gather path for fixing past_sequence_length subgraph")
  1278. return
  1279. add_path = model.match_parent_path(
  1280. range_node,
  1281. ["Add", "Gather", "Shape"],
  1282. [1, 0, 0],
  1283. )
  1284. if add_path is None:
  1285. logger.info("Cannot identify add path for fixing past_sequence_length subgraph")
  1286. return
  1287. add_node = add_path[0]
  1288. if gather_path != add_path[1:]:
  1289. logger.info("Gather path and add path do not share the same nodes for calculating the past_sequence_length")
  1290. return
  1291. # Remove `past_key_self_0 --> Shape --> Gather` connection
  1292. constant_in_gather = list(filter(lambda n: n.output[0] == gather_path[0].input[1], model.model.graph.node))[0] # noqa: RUF015
  1293. model.model.graph.node.remove(constant_in_gather)
  1294. model.model.graph.node.remove(gather_path[0])
  1295. model.model.graph.node.remove(gather_path[1])
  1296. # Add `past_seq_len_int64` as an input name to existing nodes
  1297. range_node.input[0] = past_seq_len_int64
  1298. add_node.input[0] = past_seq_len_int64
  1299. else:
  1300. # OpenAI implementation
  1301. input_ids_path = model.match_parent_path(
  1302. base_node,
  1303. ["Unsqueeze", "Add", "Gather", "Shape", "Reshape", "Transpose"],
  1304. [2, 0, 0, 0, 0, 0],
  1305. )
  1306. if input_ids_path is None:
  1307. logger.info("Cannot identify input_ids path for fixing past_sequence_length subgraph")
  1308. return
  1309. add_node = input_ids_path[1]
  1310. past_key_path = model.match_parent_path(
  1311. base_node,
  1312. ["Unsqueeze", "Gather", "Shape", "Reshape", "Transpose"],
  1313. [1, 0, 0, 0, 0],
  1314. )
  1315. if past_key_path is None:
  1316. logger.info("Cannot identify past_key path for fixing past_sequence_length subgraph")
  1317. return
  1318. unsqueeze_node = past_key_path[0]
  1319. if input_ids_path[2:] != past_key_path[1:]:
  1320. logger.info(
  1321. "The input_ids path and past_key path do not share the same nodes for calculating the past_sequence_length"
  1322. )
  1323. return
  1324. # Remove `past_key_self_0 --> Transpose --> Reshape --> Shape --> Gather` connection
  1325. constant_in_gather = list(filter(lambda n: n.output[0] == past_key_path[1].input[1], model.model.graph.node))[0] # noqa: RUF015
  1326. model.model.graph.node.remove(constant_in_gather)
  1327. constant_in_reshape = list(filter(lambda n: n.output[0] == past_key_path[-2].input[1], model.model.graph.node))[ # noqa: RUF015
  1328. 0
  1329. ]
  1330. model.model.graph.node.remove(constant_in_reshape)
  1331. model.model.graph.node.remove(past_key_path[1])
  1332. model.model.graph.node.remove(past_key_path[2])
  1333. model.model.graph.node.remove(past_key_path[3])
  1334. model.model.graph.node.remove(past_key_path[4])
  1335. # Add `past_seq_len_int64` as an input name to existing nodes
  1336. unsqueeze_node.input[0] = past_seq_len_int64
  1337. add_node.input[0] = past_seq_len_int64
  1338. # Add `past_sequence_length` as model input
  1339. model.model.graph.input.append(
  1340. onnx.helper.make_tensor_value_info(past_seq_len_name, TensorProto.INT32, shape=[1]),
  1341. )
  1342. # Add `past_sequence_length --> Squeeze --> Cast` connection
  1343. squeeze_node = onnx.helper.make_node(
  1344. "Squeeze",
  1345. inputs=[past_seq_len_name],
  1346. outputs=[past_seq_len_int32],
  1347. name=model.create_node_name("Squeeze"),
  1348. )
  1349. squeeze_output = onnx.helper.make_tensor_value_info(past_seq_len_int32, TensorProto.INT32, shape=[])
  1350. cast_node = onnx.helper.make_node(
  1351. "Cast",
  1352. inputs=[past_seq_len_int32],
  1353. outputs=[past_seq_len_int64],
  1354. name=model.create_node_name("Cast"),
  1355. to=TensorProto.INT64,
  1356. )
  1357. cast_output = onnx.helper.make_tensor_value_info(past_seq_len_int64, TensorProto.INT64, shape=[])
  1358. # Add new nodes to graph
  1359. model.model.graph.node.extend([squeeze_node, cast_node])
  1360. model.model.graph.value_info.extend([squeeze_output, cast_output])
  1361. model.topological_sort()
  1362. return model, past_seq_len_name
  1363. def replace_mha_with_dmmha(model: OnnxModel, past_seq_len_name: str):
  1364. # Add `beam_width` and `cache_indirection` as model inputs
  1365. beam_width = "beam_width"
  1366. cache_indirection = "cache_indirection"
  1367. model.model.graph.input.extend(
  1368. [
  1369. onnx.helper.make_tensor_value_info(beam_width, TensorProto.INT32, shape=[1]),
  1370. onnx.helper.make_tensor_value_info(
  1371. cache_indirection, TensorProto.INT32, shape=["batch_size", "beam_width", "max_sequence_length"]
  1372. ),
  1373. ]
  1374. )
  1375. # Replace all `MultiHeadAttention` nodes with `DecoderMaskedMultiHeadAttention` nodes
  1376. mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
  1377. for idx, node in enumerate(mha_nodes):
  1378. # Get `num_heads` attribute from MHA
  1379. num_heads = 0
  1380. for att in node.attribute:
  1381. if att.name == "num_heads":
  1382. num_heads = att.i
  1383. break
  1384. # Make Q*K outputs for cross-attention layers, which happen every alternative layer
  1385. qk_output_name = f"output_cross_qk_{idx // 2}"
  1386. qk_output = onnx.helper.make_tensor_value_info(
  1387. qk_output_name, TensorProto.FLOAT, shape=["batch_size", num_heads, 1, "encode_sequence_length / 2"]
  1388. )
  1389. if idx % 2 == 1:
  1390. model.model.graph.output.append(qk_output)
  1391. # Make DMMHA node
  1392. dmmha_node = onnx.helper.make_node(
  1393. "DecoderMaskedMultiHeadAttention",
  1394. inputs=[
  1395. node.input[0], # query
  1396. node.input[1], # key
  1397. node.input[2], # value
  1398. "", # mask_index
  1399. "", # relative_position_bias
  1400. node.input[6] if len(node.input) > 4 else "", # past_key
  1401. node.input[7] if len(node.input) > 4 else "", # past_value
  1402. past_seq_len_name, # past_sequence_length
  1403. beam_width, # beam_width
  1404. cache_indirection, # cache_indirection
  1405. node.input[3], # bias
  1406. ],
  1407. outputs=[
  1408. node.output[0], # output
  1409. node.output[1] if len(node.input) > 4 else "", # present_key
  1410. node.output[2] if len(node.input) > 4 else "", # present_value
  1411. qk_output_name if idx % 2 == 1 else "", # output_cross_qk
  1412. ],
  1413. name=node.name.replace("MultiHeadAttention", "DecoderMaskedMultiHeadAttention"),
  1414. domain="com.microsoft",
  1415. num_heads=num_heads,
  1416. output_qk=(idx % 2),
  1417. past_present_share_buffer=1,
  1418. )
  1419. if idx % 2 == 0:
  1420. # Remove empty string for output_cross_qk, which happens every alternative layer
  1421. dmmha_node.output.remove("")
  1422. model.model.graph.node.remove(node)
  1423. model.model.graph.node.extend([dmmha_node])
  1424. model.topological_sort()
  1425. return model
  1426. def replace_mha_with_gqa(
  1427. model: OnnxModel,
  1428. attn_mask: str,
  1429. kv_num_heads: int = 0,
  1430. world_size: int = 1,
  1431. window_size: int = -1,
  1432. ):
  1433. # Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
  1434. #
  1435. # attention_mask
  1436. # / \
  1437. # ReduceSum Shape
  1438. # | |
  1439. # Sub Gather
  1440. # | |
  1441. # seqlens_k total_sequence_length
  1442. # | |
  1443. # Cast to int32 Cast to int32
  1444. model.add_initializer(
  1445. onnx.helper.make_tensor(
  1446. name="one",
  1447. data_type=TensorProto.INT64,
  1448. dims=[1],
  1449. vals=[1],
  1450. )
  1451. )
  1452. reduce_sum_node = onnx.helper.make_node(
  1453. "ReduceSum",
  1454. inputs=[attn_mask, "one"],
  1455. outputs=[attn_mask + "_row_sums"],
  1456. name=model.create_node_name("ReduceSum"),
  1457. )
  1458. sub_node = onnx.helper.make_node(
  1459. "Sub",
  1460. inputs=[attn_mask + "_row_sums", "one"],
  1461. outputs=["seqlens_k_int64"],
  1462. name=model.create_node_name("Sub"),
  1463. )
  1464. seqlen_k_cast_node = onnx.helper.make_node(
  1465. "Cast",
  1466. inputs=["seqlens_k_int64"],
  1467. outputs=["seqlens_k"],
  1468. name=model.create_node_name("Cast"),
  1469. to=TensorProto.INT32,
  1470. )
  1471. shape_node = onnx.helper.make_node(
  1472. "Shape",
  1473. inputs=[attn_mask],
  1474. outputs=[attn_mask + "_shape"],
  1475. name=model.create_node_name("Shape"),
  1476. )
  1477. gather_node = onnx.helper.make_node(
  1478. "Gather",
  1479. inputs=[attn_mask + "_shape", "one"],
  1480. outputs=["total_seq_len_int64"],
  1481. name=model.create_node_name("Gather"),
  1482. axis=0,
  1483. )
  1484. total_seqlen_cast_node = onnx.helper.make_node(
  1485. "Cast",
  1486. inputs=["total_seq_len_int64"],
  1487. outputs=["total_seq_len"],
  1488. name=model.create_node_name("Cast"),
  1489. to=TensorProto.INT32,
  1490. )
  1491. model.model.graph.node.extend(
  1492. [
  1493. reduce_sum_node,
  1494. sub_node,
  1495. seqlen_k_cast_node,
  1496. shape_node,
  1497. gather_node,
  1498. total_seqlen_cast_node,
  1499. ]
  1500. )
  1501. # Replace MultiHeadAttention with GroupQueryAttention
  1502. #
  1503. # When replacing, fuse the following subgraph:
  1504. #
  1505. # root_input
  1506. # / | \
  1507. # MatMul MatMul MatMul
  1508. # | | |
  1509. # Add Add Add (optional Adds)
  1510. # | | |
  1511. # RotEmb RotEmb |
  1512. # \ | /
  1513. # MultiHeadAttention
  1514. #
  1515. # to this new subgraph:
  1516. #
  1517. # root_input
  1518. # |
  1519. # PackedMatMul (if possible)
  1520. # |
  1521. # PackedAdd (if possible)
  1522. # |
  1523. # GroupQueryAttention
  1524. #
  1525. mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
  1526. for idx, node in enumerate(mha_nodes):
  1527. # Detect Q path to MHA
  1528. q_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [0, 0, 0])
  1529. q_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [0, 0])
  1530. q_rotary, q_add, q_matmul = None, None, None
  1531. if q_path_1 is not None:
  1532. q_rotary, q_add, q_matmul = q_path_1
  1533. elif q_path_2 is not None:
  1534. q_rotary, q_matmul = q_path_2
  1535. # Detect K path to MHA
  1536. k_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [1, 0, 0])
  1537. k_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [1, 0])
  1538. k_rotary, k_add, k_matmul = None, None, None
  1539. if k_path_1 is not None:
  1540. k_rotary, k_add, k_matmul = k_path_1
  1541. elif k_path_2 is not None:
  1542. k_rotary, k_matmul = k_path_2
  1543. # Detect V path to MHA
  1544. v_path_1 = model.match_parent_path(node, ["Add", "MatMul"], [2, 0])
  1545. v_path_2 = model.match_parent_path(node, ["MatMul"], [2])
  1546. v_add, v_matmul = None, None
  1547. if v_path_1 is not None:
  1548. v_add, v_matmul = v_path_1
  1549. elif v_path_2 is not None:
  1550. v_matmul = v_path_2[0]
  1551. # Get `interleaved` attribute from RotaryEmbedding
  1552. interleaved = 0
  1553. if q_rotary is not None and k_rotary is not None:
  1554. for att in q_rotary.attribute:
  1555. if att.name == "interleaved":
  1556. interleaved = att.i
  1557. # Get `num_heads` attribute from MHA
  1558. num_heads = 0
  1559. for att in node.attribute:
  1560. if att.name == "num_heads":
  1561. num_heads = att.i
  1562. # Check if root_input to Q/K/V paths is the same
  1563. root_input_is_same = q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
  1564. # Check if Q/K/V paths all have bias or all don't have bias
  1565. all_paths_have_bias = q_add is not None and k_add is not None and v_add is not None
  1566. all_paths_have_no_bias = q_add is None and k_add is None and v_add is None
  1567. # Make PackedMatMul node if possible
  1568. q_input_to_attention, k_input_to_attention, v_input_to_attention = "", "", ""
  1569. if root_input_is_same and (all_paths_have_bias or all_paths_have_no_bias):
  1570. qw = NumpyHelper.to_array(model.get_initializer(q_matmul.input[1]))
  1571. kw = NumpyHelper.to_array(model.get_initializer(k_matmul.input[1]))
  1572. vw = NumpyHelper.to_array(model.get_initializer(v_matmul.input[1]))
  1573. dim = qw.shape[-1]
  1574. qkv_weight = np.stack((qw, kw, vw), axis=1).reshape(dim, 3 * dim)
  1575. qkv_weight = onnx.numpy_helper.from_array(qkv_weight, name=f"QKV_Weight_{idx}")
  1576. model.add_initializer(qkv_weight)
  1577. packed_matmul_node = onnx.helper.make_node(
  1578. "MatMul",
  1579. inputs=[q_matmul.input[0], qkv_weight.name],
  1580. outputs=[f"{qkv_weight.name}_output"],
  1581. name=model.create_node_name("MatMul"),
  1582. )
  1583. model.model.graph.node.extend([packed_matmul_node])
  1584. model.model.graph.node.remove(q_matmul)
  1585. model.model.graph.node.remove(k_matmul)
  1586. model.model.graph.node.remove(v_matmul)
  1587. q_input_to_attention = packed_matmul_node.output[0]
  1588. # Make PackedAdd node if possible
  1589. if all_paths_have_bias:
  1590. qb = NumpyHelper.to_array(model.get_initializer(q_add.input[1]))
  1591. kb = NumpyHelper.to_array(model.get_initializer(k_add.input[1]))
  1592. vb = NumpyHelper.to_array(model.get_initializer(v_add.input[1]))
  1593. dim = qb.shape[-1]
  1594. qkv_bias = np.stack((qb, kb, vb), axis=0).reshape(3 * dim)
  1595. qkv_bias = onnx.numpy_helper.from_array(qkv_bias, name=f"QKV_Bias_{idx}")
  1596. model.add_initializer(qkv_bias)
  1597. packed_add_node = onnx.helper.make_node(
  1598. "Add",
  1599. inputs=[packed_matmul_node.output[0], qkv_bias.name],
  1600. outputs=[f"{qkv_bias.name}_output"],
  1601. )
  1602. model.model.graph.node.extend([packed_add_node])
  1603. model.model.graph.node.remove(q_add)
  1604. model.model.graph.node.remove(k_add)
  1605. model.model.graph.node.remove(v_add)
  1606. q_input_to_attention = packed_add_node.output[0]
  1607. else:
  1608. q_input_to_attention = q_matmul.output[0]
  1609. k_input_to_attention = k_matmul.output[0]
  1610. v_input_to_attention = v_matmul.output[0]
  1611. # Make GQA node
  1612. gqa_node = onnx.helper.make_node(
  1613. "GroupQueryAttention",
  1614. inputs=[
  1615. q_input_to_attention, # query
  1616. k_input_to_attention, # key
  1617. v_input_to_attention, # value
  1618. node.input[6], # past_key
  1619. node.input[7], # past_value
  1620. seqlen_k_cast_node.output[0], # seqlens_k (for attention mask)
  1621. total_seqlen_cast_node.output[0], # total_seq_len (for attention mask)
  1622. (q_rotary.input[2] if q_rotary is not None else ""), # cos_cache (for rotary embeddings)
  1623. (q_rotary.input[3] if q_rotary is not None else ""), # sin_cache (for rotary embeddings)
  1624. ],
  1625. outputs=node.output,
  1626. name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
  1627. domain="com.microsoft",
  1628. num_heads=num_heads // world_size,
  1629. kv_num_heads=(num_heads // world_size if kv_num_heads == 0 else kv_num_heads // world_size),
  1630. local_window_size=window_size,
  1631. do_rotary=int(q_rotary is not None and k_rotary is not None),
  1632. rotary_interleaved=interleaved,
  1633. )
  1634. model.model.graph.node.remove(node)
  1635. model.model.graph.node.extend([gqa_node])
  1636. if q_rotary is not None:
  1637. model.model.graph.node.remove(q_rotary)
  1638. if k_rotary is not None:
  1639. model.model.graph.node.remove(k_rotary)
  1640. return model
  1641. def update_decoder_subgraph_output_cross_attention(subg: GraphProto):
  1642. input_self_past_0 = 1
  1643. # w/wo attention mask, w/wo hidden_state
  1644. graph_input_names = [gi.name for gi in subg.input]
  1645. while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"):
  1646. input_self_past_0 += 1
  1647. output_self_present_0 = 1
  1648. num_layers = (len(subg.output) - output_self_present_0) // 2
  1649. input_cross_past_0 = 2 * num_layers + input_self_past_0
  1650. past_key_cross_inputs = {subg.input[layer * 2 + input_cross_past_0].name: layer for layer in range(num_layers)}
  1651. print(f" -- past_key_cross_inputs = {past_key_cross_inputs}")
  1652. input_past_key_cross_0_shape = shape_of(subg.input[input_cross_past_0])
  1653. print(f"past_key_cross_0_shape is {input_past_key_cross_0_shape}")
  1654. batch_size_dim = input_past_key_cross_0_shape[0]
  1655. num_heads_dim = input_past_key_cross_0_shape[1]
  1656. cross_seq_len_dim = input_past_key_cross_0_shape[2]
  1657. num_layer_output_qk = 0
  1658. for node in subg.node:
  1659. if (node.op_type == "DecoderMaskedMultiHeadAttention") and (node.input[1] in past_key_cross_inputs):
  1660. print(f" -- add cross QK output from: node: {node.name} with output: {node.output}")
  1661. num_layer_output_qk += 1
  1662. layer = past_key_cross_inputs[node.input[1]]
  1663. cross_attention_out_name = f"output_cross_qk_{layer}"
  1664. appended_names = [""] * (3 - len(node.output))
  1665. appended_names.append(cross_attention_out_name)
  1666. node.output.extend(appended_names)
  1667. node.attribute.extend([onnx.helper.make_attribute("output_qk", 1)])
  1668. cross_attention = onnx.helper.make_tensor_value_info(
  1669. cross_attention_out_name,
  1670. TensorProto.FLOAT,
  1671. [batch_size_dim, num_heads_dim, 1, cross_seq_len_dim],
  1672. )
  1673. subg.output.extend([cross_attention])
  1674. if num_layer_output_qk != num_layers:
  1675. raise ValueError(f"Did not add cross QK for all layers{num_layers} vs {num_layer_output_qk}")
  1676. def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelProto):
  1677. input_self_past_0 = 1
  1678. # w/wo attention mask, w/wo hidden_state
  1679. graph_input_names = [gi.name for gi in subg.input]
  1680. while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"):
  1681. input_self_past_0 += 1
  1682. output_self_past_0 = 1
  1683. num_layers = int((len(subg.input) - input_self_past_0) / 4)
  1684. input_cross_past_0 = 2 * num_layers + input_self_past_0
  1685. new_nodes = []
  1686. old_nodes = []
  1687. for node in subg.node:
  1688. if node.op_type == "MultiHeadAttention":
  1689. old_nodes.extend([node])
  1690. # If not all the MultiHeadAttention nodes are fused, this optimization is not applicable
  1691. if len(old_nodes) < num_layers:
  1692. return False
  1693. # Redirect the RelativePositionBias node's input from past_key_self_0.shape[2] to past_sequence_length.
  1694. # There is only one RelativePositionBias node in T5 decoder subgraph.
  1695. rel_pos_bias_node = None
  1696. for node in subg.node:
  1697. if node.op_type == "RelativePositionBias":
  1698. rel_pos_bias_node = node
  1699. break
  1700. decoder_masked_attention_supported_attr = [
  1701. "past_present_share_buffer",
  1702. "num_heads",
  1703. "scale",
  1704. "mask_filter_value",
  1705. "domain",
  1706. ]
  1707. target_squeezed_past_seq_name = "past_sequence_length_squeezed_int64"
  1708. tensor_names_to_rename, nodes_to_remove = find_past_seq_len_usage(subg)
  1709. if len(tensor_names_to_rename) > 0:
  1710. for name_to_rename in tensor_names_to_rename:
  1711. print(f"Found tensor name `{name_to_rename}` to be renamed to `{target_squeezed_past_seq_name}`")
  1712. for nr in nodes_to_remove:
  1713. print(f"Found node to remove: type = {nr.op_type}, name = {nr.name}")
  1714. squeeze_node = onnx.helper.make_node(
  1715. "Squeeze",
  1716. ["past_sequence_length"],
  1717. ["past_sequence_length_squeezed"],
  1718. name="node_past_sequence_length_squeeze",
  1719. )
  1720. cast_node = onnx.helper.make_node(
  1721. "Cast",
  1722. ["past_sequence_length_squeezed"],
  1723. [target_squeezed_past_seq_name],
  1724. name="node_past_sequence_length_squeeze_cast",
  1725. to=TensorProto.INT64,
  1726. )
  1727. new_nodes.extend([squeeze_node, cast_node])
  1728. for node in subg.node:
  1729. if len(node.output) > 0 and rel_pos_bias_node is not None and node.output[0] == rel_pos_bias_node.input[1]:
  1730. cast_node = onnx.helper.make_node(
  1731. "Cast",
  1732. ["past_sequence_length"],
  1733. ["past_sequence_length_int64"],
  1734. name="past_sequence_length_cast",
  1735. to=TensorProto.INT64,
  1736. )
  1737. node.input[1] = cast_node.output[0]
  1738. new_nodes.extend([cast_node])
  1739. if node.op_type == "MultiHeadAttention":
  1740. kwargs = kwargs_of(node)
  1741. for k in kwargs.copy():
  1742. if k not in decoder_masked_attention_supported_attr:
  1743. del kwargs[k]
  1744. # note: This logic only apply to T5 model where there is no bias in Attention node.
  1745. nis = [
  1746. node.input[0], # query
  1747. node.input[1], # key
  1748. node.input[2], # value
  1749. ]
  1750. nis.extend([node.input[4] if len(node.input) > 4 else ""]) # 2D mask
  1751. nis.extend([node.input[5] if len(node.input) > 5 else ""]) # attention_bias
  1752. nis.extend([node.input[6] if len(node.input) > 6 else ""]) # past_key
  1753. nis.extend([node.input[7] if len(node.input) > 7 else ""]) # past_value
  1754. nis.extend(["past_sequence_length"]) # past_sequence_length
  1755. nis.extend(["beam_width"]) # beam_width
  1756. nis.extend(["cache_indirection"]) # cache_indirection
  1757. nis.extend([node.input[3] if len(node.input) > 3 else ""]) # bias
  1758. kwargs["past_present_share_buffer"] = 1
  1759. node = onnx.helper.make_node( # noqa: PLW2901
  1760. "DecoderMaskedMultiHeadAttention",
  1761. nis,
  1762. node.output,
  1763. name=node.name,
  1764. **kwargs,
  1765. )
  1766. if node not in nodes_to_remove:
  1767. for index, name in enumerate(node.input):
  1768. if name in tensor_names_to_rename:
  1769. node.input[index] = target_squeezed_past_seq_name
  1770. new_nodes.extend([node])
  1771. subg.ClearField("node")
  1772. subg.node.extend(new_nodes)
  1773. orig_input_names = [inp.name for inp in subg.input]
  1774. new_inputs = []
  1775. for i, vi in enumerate(subg.input):
  1776. if i >= input_self_past_0 and i < input_cross_past_0:
  1777. shape = shape_of(vi)
  1778. vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
  1779. vi.name,
  1780. elem_type=vi.type.tensor_type.elem_type,
  1781. shape=[shape[0], shape[1], "max_seq_len", shape[3]],
  1782. )
  1783. new_inputs.extend([vi])
  1784. if "past_sequence_length" not in orig_input_names:
  1785. new_inputs.extend(
  1786. [onnx.helper.make_tensor_value_info("past_sequence_length", onnx.TensorProto.INT32, shape=[1])]
  1787. )
  1788. if "beam_width" not in orig_input_names:
  1789. new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
  1790. if "cache_indirection" not in orig_input_names:
  1791. new_inputs.extend(
  1792. [
  1793. onnx.helper.make_tensor_value_info(
  1794. "cache_indirection",
  1795. onnx.TensorProto.INT32,
  1796. shape=["batch_size", "beam_width", "max_seq_len"],
  1797. )
  1798. ]
  1799. )
  1800. subg.ClearField("input")
  1801. subg.input.extend(new_inputs)
  1802. new_outputs = []
  1803. for i, vi in enumerate(subg.output):
  1804. if i >= output_self_past_0:
  1805. shape = shape_of(vi)
  1806. vi = onnx.helper.make_tensor_value_info( # noqa: PLW2901
  1807. vi.name,
  1808. elem_type=vi.type.tensor_type.elem_type,
  1809. shape=[shape[0], shape[1], "max_seq_len", shape[3]],
  1810. )
  1811. new_outputs.extend([vi])
  1812. subg.ClearField("output")
  1813. subg.output.extend(new_outputs)
  1814. return True
  1815. def pack_qkv_for_decoder_masked_mha(model_proto: ModelProto):
  1816. onnx_model = OnnxModel(model_proto)
  1817. output_name_to_node = onnx_model.output_name_to_node()
  1818. nodes_to_add = []
  1819. nodes_to_remove = []
  1820. for node in onnx_model.nodes():
  1821. if node.op_type == "DecoderMaskedMultiHeadAttention":
  1822. if "past_key_cross" in node.input[1] and "past_value_cross" in node.input[2]:
  1823. continue
  1824. q_matmul = output_name_to_node[node.input[0]]
  1825. k_matmul = output_name_to_node[node.input[1]]
  1826. v_matmul = output_name_to_node[node.input[2]]
  1827. q_weight = onnx_model.get_initializer(q_matmul.input[1])
  1828. k_weight = onnx_model.get_initializer(k_matmul.input[1])
  1829. v_weight = onnx_model.get_initializer(v_matmul.input[1])
  1830. if not (q_weight and k_weight and v_weight):
  1831. return False
  1832. qw = NumpyHelper.to_array(q_weight)
  1833. kw = NumpyHelper.to_array(k_weight)
  1834. vw = NumpyHelper.to_array(v_weight)
  1835. qkv_weight = np.concatenate([qw, kw, vw], axis=1)
  1836. matmul_node_name = onnx_model.create_node_name("MatMul", name_prefix="MatMul_QKV")
  1837. weight = onnx.helper.make_tensor(
  1838. name=matmul_node_name + "_weight",
  1839. data_type=(TensorProto.FLOAT if q_weight.data_type == 1 else TensorProto.FLOAT16),
  1840. dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
  1841. vals=qkv_weight.flatten().tolist(),
  1842. )
  1843. model_proto.graph.initializer.extend([weight])
  1844. matmul_node = onnx.helper.make_node(
  1845. "MatMul",
  1846. inputs=[q_matmul.input[0], matmul_node_name + "_weight"],
  1847. outputs=[matmul_node_name + "_out"],
  1848. name=matmul_node_name,
  1849. )
  1850. node.input[0] = matmul_node.output[0]
  1851. node.input[1] = ""
  1852. node.input[2] = ""
  1853. nodes_to_add.extend([matmul_node])
  1854. nodes_to_remove.extend([q_matmul, k_matmul, v_matmul])
  1855. onnx_model.add_nodes(nodes_to_add)
  1856. onnx_model.remove_nodes(nodes_to_remove)
  1857. onnx_model.update_graph()
  1858. onnx_model.topological_sort()
  1859. return True
  1860. def update_input_shapes_for_gpt2_decoder_model(decoder_onnx_path: str, use_external_data_format: bool = True):
  1861. """Update the input shapes for the inputs "input_ids" and "position_ids" and make the sequence length dim value 1 for each of them.
  1862. The decoder model will be over-written.
  1863. Args:
  1864. decoder_onnx_path (str): Path of GPT-2 decoder onnx model
  1865. use_external_data_format(bool): output tensors to external data or not.
  1866. """
  1867. decoder_model_proto = onnx.load_model(decoder_onnx_path, load_external_data=True)
  1868. for i in range(len(decoder_model_proto.graph.input)):
  1869. if (
  1870. decoder_model_proto.graph.input[i].name == "input_ids"
  1871. or decoder_model_proto.graph.input[i].name == "position_ids"
  1872. ):
  1873. shape_dim_proto = decoder_model_proto.graph.input[i].type.tensor_type.shape.dim[1]
  1874. # Clear any existing dim_param first
  1875. if shape_dim_proto.HasField("dim_param"):
  1876. shape_dim_proto.Clear()
  1877. # Update dim_value to be 1
  1878. shape_dim_proto.dim_value = 1
  1879. OnnxModel.save(
  1880. decoder_model_proto,
  1881. decoder_onnx_path,
  1882. save_as_external_data=use_external_data_format,
  1883. )
  1884. return True
  1885. def generate_gpt2_init_decoder(
  1886. decoder_onnx_path: str,
  1887. init_decoder_onnx_path: str,
  1888. use_external_data_format: bool = True,
  1889. ) -> bool:
  1890. """Generates the initial decoder GPT2 subgraph and saves it for downstream use.
  1891. The initial decoder model will be saved to init_decoder_onnx_path.
  1892. Args:
  1893. decoder_onnx_path (str): Path of GPT-2 decoder onnx model
  1894. init_decoder_onnx_path (str): Path of GPT-2 init decoder onnx model
  1895. use_external_data_format(bool): output tensors to external data or not.
  1896. """
  1897. init_decoder_model_proto = onnx.load_model(decoder_onnx_path, load_external_data=True)
  1898. logits_output_name = init_decoder_model_proto.graph.output[0].name
  1899. gpt2_init_decoder_model = OnnxModel(init_decoder_model_proto)
  1900. output_name_to_node = gpt2_init_decoder_model.output_name_to_node()
  1901. assert logits_output_name in output_name_to_node
  1902. logits_matmul_node = output_name_to_node[logits_output_name]
  1903. # Sanity check - the logits need to be produced by a MatMul node
  1904. if logits_matmul_node.op_type != "MatMul":
  1905. return False
  1906. # Try to find the last residual Add
  1907. # For fp16, there are Casts along the way
  1908. # Normalization Node is : LayerNormalization
  1909. logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
  1910. logits_matmul_node,
  1911. [
  1912. "Cast",
  1913. "LayerNormalization",
  1914. "Add",
  1915. "Add",
  1916. "Cast",
  1917. "MatMul",
  1918. "Cast",
  1919. "FastGelu",
  1920. "Cast",
  1921. "MatMul",
  1922. "Cast",
  1923. "LayerNormalization",
  1924. "Add",
  1925. ],
  1926. [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  1927. )
  1928. # Normalization Node is : SkipLayerNormalization
  1929. if logits_matmul_to_residual_add_path is None:
  1930. logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
  1931. logits_matmul_node,
  1932. [
  1933. "Cast",
  1934. "SkipLayerNormalization",
  1935. "Cast",
  1936. "MatMul",
  1937. "Cast",
  1938. "FastGelu",
  1939. "Cast",
  1940. "MatMul",
  1941. "Cast",
  1942. "SkipLayerNormalization",
  1943. ],
  1944. [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
  1945. )
  1946. # Try without the Casts before and after the MatMuls
  1947. if logits_matmul_to_residual_add_path is None:
  1948. # Normalization Node is : LayerNormalization
  1949. logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
  1950. logits_matmul_node,
  1951. [
  1952. "LayerNormalization",
  1953. "Add",
  1954. "Add",
  1955. "MatMul",
  1956. "FastGelu",
  1957. "MatMul",
  1958. "LayerNormalization",
  1959. "Add",
  1960. ],
  1961. [0, 0, 1, 0, 0, 0, 0, 0],
  1962. )
  1963. # Normalization Node is : SkipLayerNormalization
  1964. if logits_matmul_to_residual_add_path is None:
  1965. logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path(
  1966. logits_matmul_node,
  1967. [
  1968. "SkipLayerNormalization",
  1969. "MatMul",
  1970. "FastGelu",
  1971. "MatMul",
  1972. "SkipLayerNormalization",
  1973. ],
  1974. [0, 1, 0, 0, 0],
  1975. )
  1976. # TODO(hasesh): Are there more permutations to try before returning ?
  1977. if logits_matmul_to_residual_add_path is None:
  1978. return False
  1979. residual_add_node = logits_matmul_to_residual_add_path[-1]
  1980. # If the last node in the pattern is SkipLayerNormalization, we need to adjust our pattern searches accordingly
  1981. is_skiplayernorm_path = residual_add_node.op_type == "SkipLayerNormalization"
  1982. # Regular LayerNormalization path
  1983. if not is_skiplayernorm_path:
  1984. residual_add_to_attention_parent_index = 0
  1985. residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
  1986. residual_add_node,
  1987. ["Add", "Cast", "MatMul", "Attention"],
  1988. [residual_add_to_attention_parent_index, 0, 0, 0],
  1989. )
  1990. # Try other parent index of the residual Add node
  1991. if residual_add_to_attention_path is None:
  1992. residual_add_to_attention_parent_index = 1
  1993. residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
  1994. residual_add_node,
  1995. ["Add", "Cast", "MatMul", "Attention"],
  1996. [residual_add_to_attention_parent_index, 0, 0, 0],
  1997. )
  1998. # Try without the Casts before and after the MatMuls
  1999. if residual_add_to_attention_path is None:
  2000. residual_add_to_attention_parent_index = 0
  2001. residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
  2002. residual_add_node,
  2003. ["Add", "MatMul", "Attention"],
  2004. [residual_add_to_attention_parent_index, 0, 0],
  2005. )
  2006. # Try without the Casts before and after the MatMuls and other parent index of the residual Add node
  2007. if residual_add_to_attention_path is None:
  2008. residual_add_to_attention_parent_index = 1
  2009. residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
  2010. residual_add_node,
  2011. ["Add", "MatMul", "Attention"],
  2012. [residual_add_to_attention_parent_index, 0, 0],
  2013. )
  2014. # SkipLayerNormalization path
  2015. else:
  2016. residual_add_to_attention_parent_index = 0
  2017. residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
  2018. residual_add_node,
  2019. ["Cast", "MatMul", "Attention"],
  2020. [residual_add_to_attention_parent_index, 0, 0],
  2021. )
  2022. # Try other parent index of the residual Add node
  2023. if residual_add_to_attention_path is None:
  2024. residual_add_to_attention_parent_index = 1
  2025. residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
  2026. residual_add_node,
  2027. ["Cast", "MatMul", "Attention"],
  2028. [residual_add_to_attention_parent_index, 0, 0],
  2029. )
  2030. # Try without the Casts before and after the MatMuls
  2031. if residual_add_to_attention_path is None:
  2032. residual_add_to_attention_parent_index = 0
  2033. residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
  2034. residual_add_node,
  2035. ["MatMul", "Attention"],
  2036. [residual_add_to_attention_parent_index, 0],
  2037. )
  2038. # Try without the Casts before and after the MatMuls and other parent index of the residual Add node
  2039. if residual_add_to_attention_path is None:
  2040. residual_add_to_attention_parent_index = 1
  2041. residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path(
  2042. residual_add_node,
  2043. ["MatMul", "Attention"],
  2044. [residual_add_to_attention_parent_index, 0],
  2045. )
  2046. # TODO(hasesh): Are there more permutations to try before returning ?
  2047. if residual_add_to_attention_path is None:
  2048. return False
  2049. residual_add_to_add_parent_index = 0 if residual_add_to_attention_parent_index == 1 else 1
  2050. # Regular LayerNormalization path
  2051. if not is_skiplayernorm_path:
  2052. add_before_residual_add = gpt2_init_decoder_model.match_parent(
  2053. residual_add_node, "Add", residual_add_to_add_parent_index
  2054. )
  2055. # SkipLayerNormalization path
  2056. else:
  2057. add_before_residual_add = gpt2_init_decoder_model.match_parent(
  2058. residual_add_node,
  2059. "SkipLayerNormalization",
  2060. residual_add_to_add_parent_index,
  2061. )
  2062. if add_before_residual_add is None:
  2063. return False
  2064. attention = residual_add_to_attention_path[-1]
  2065. matmul_after_attention = residual_add_to_attention_path[-2]
  2066. slice_starts = onnx.helper.make_tensor(
  2067. name="SliceLastTokenStarts",
  2068. data_type=TensorProto.INT32,
  2069. dims=[1],
  2070. vals=[-1],
  2071. )
  2072. slice_ends = onnx.helper.make_tensor(
  2073. name="SliceLastTokenEnds",
  2074. data_type=TensorProto.INT32,
  2075. dims=[1],
  2076. vals=[-2],
  2077. )
  2078. slice_axes = onnx.helper.make_tensor(
  2079. name="SliceLastTokenAxes",
  2080. data_type=TensorProto.INT32,
  2081. dims=[1],
  2082. vals=[1],
  2083. )
  2084. slice_steps = onnx.helper.make_tensor(
  2085. name="SliceLastTokenSteps",
  2086. data_type=TensorProto.INT32,
  2087. dims=[1],
  2088. vals=[-1],
  2089. )
  2090. gpt2_init_decoder_model.add_initializer(slice_starts)
  2091. gpt2_init_decoder_model.add_initializer(slice_ends)
  2092. gpt2_init_decoder_model.add_initializer(slice_axes)
  2093. gpt2_init_decoder_model.add_initializer(slice_steps)
  2094. # Add Slice node to the graph such that it consumes the output of Attention
  2095. slice_0_output_name = "edge_modified_" + attention.output[0]
  2096. slice_node_0 = onnx.helper.make_node(
  2097. "Slice",
  2098. inputs=[
  2099. attention.output[0],
  2100. "SliceLastTokenStarts",
  2101. "SliceLastTokenEnds",
  2102. "SliceLastTokenAxes",
  2103. "SliceLastTokenSteps",
  2104. ],
  2105. outputs=[slice_0_output_name],
  2106. name=gpt2_init_decoder_model.create_node_name("Slice", "GatherLastToken_0_"),
  2107. )
  2108. # Add Slice node to the graph such that it consumes the output of Add before the residual Add
  2109. # If the 'Add' output is produced by a SkipLayerNormalization node, then adjust its output
  2110. # index appropriately
  2111. add_before_residual_add_output = (
  2112. add_before_residual_add.output[0] if not is_skiplayernorm_path else add_before_residual_add.output[3]
  2113. )
  2114. slice_1_output_name = "edge_modified_" + add_before_residual_add.output[0]
  2115. slice_node_1 = onnx.helper.make_node(
  2116. "Slice",
  2117. inputs=[
  2118. add_before_residual_add_output,
  2119. "SliceLastTokenStarts",
  2120. "SliceLastTokenEnds",
  2121. "SliceLastTokenAxes",
  2122. "SliceLastTokenSteps",
  2123. ],
  2124. outputs=[slice_1_output_name],
  2125. name=gpt2_init_decoder_model.create_node_name("Slice", "GatherLastToken_1_"),
  2126. )
  2127. # Add the 2 Slice nodes
  2128. gpt2_init_decoder_model.add_node(slice_node_0)
  2129. gpt2_init_decoder_model.add_node(slice_node_1)
  2130. # Adjust the input(s) to the nodes consuming the outputs of the added Slice nodes
  2131. gpt2_init_decoder_model.replace_node_input(matmul_after_attention, attention.output[0], slice_0_output_name)
  2132. gpt2_init_decoder_model.replace_node_input(residual_add_node, add_before_residual_add_output, slice_1_output_name)
  2133. # Topologically sort the updated graph
  2134. gpt2_init_decoder_model.topological_sort()
  2135. # Save the init decoder model
  2136. OnnxModel.save(
  2137. init_decoder_model_proto,
  2138. init_decoder_onnx_path,
  2139. save_as_external_data=use_external_data_format,
  2140. )
  2141. return True
  2142. def make_dim_proto_numeric_t5(model, config):
  2143. """Make dim_proto numeric.
  2144. Args:
  2145. model: T5 encoder and decoder model.
  2146. config: T5 config.
  2147. """
  2148. sequence_length = str(1)
  2149. num_heads = str(config.num_heads)
  2150. hidden_size = str(config.d_model)
  2151. head_size = str(config.d_kv)
  2152. for tensor in model.graph.output:
  2153. for dim_proto in tensor.type.tensor_type.shape.dim:
  2154. if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
  2155. sequence_length,
  2156. num_heads,
  2157. hidden_size,
  2158. head_size,
  2159. ]:
  2160. dim_value = int(dim_proto.dim_param)
  2161. dim_proto.Clear()
  2162. dim_proto.dim_value = dim_value
  2163. for tensor in model.graph.input:
  2164. for dim_proto in tensor.type.tensor_type.shape.dim:
  2165. if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
  2166. sequence_length,
  2167. num_heads,
  2168. hidden_size,
  2169. head_size,
  2170. ]:
  2171. dim_value = int(dim_proto.dim_param)
  2172. dim_proto.Clear()
  2173. dim_proto.dim_value = dim_value
  2174. def convert_generation_model(
  2175. args: argparse.Namespace,
  2176. generation_type: GenerationType = GenerationType.BEAMSEARCH,
  2177. ):
  2178. """Convert model according to command line arguments.
  2179. Args:
  2180. args (argparse.Namespace): arguments parsed from command line
  2181. """
  2182. is_gpt2: bool = args.model_type == "gpt2"
  2183. is_beamsearch: bool = generation_type == GenerationType.BEAMSEARCH
  2184. is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH
  2185. is_sampling: bool = generation_type == GenerationType.SAMPLING
  2186. past_present_share_buffer: bool = args.past_present_share_buffer
  2187. logger.info(f"**** past_present_share_buffer={past_present_share_buffer}")
  2188. if len(args.op_block_list) == 1 and args.op_block_list[0] == "auto":
  2189. if is_gpt2 and args.precision == Precision.FLOAT16.value:
  2190. args.op_block_list = [
  2191. "Add",
  2192. "LayerNormalization",
  2193. "SkipLayerNormalization",
  2194. "FastGelu",
  2195. ]
  2196. logger.info(f"**** Setting op_block_list to {args.op_block_list}")
  2197. logger.info("**** use --op_block_list if you want to override the block operator list.")
  2198. else:
  2199. args.op_block_list = []
  2200. if is_greedysearch or is_sampling:
  2201. if not is_gpt2:
  2202. raise NotImplementedError("Currently only gpt2 with greedy search/sampling is supported")
  2203. if args.output_sequences_scores:
  2204. raise NotImplementedError("output_sequences_scores currently is not supported in greedy search/sampling")
  2205. if args.output_token_scores:
  2206. raise NotImplementedError("output_token_scores currently is not supported in greedy search/sampling")
  2207. # For BeamSearch, sharing buffers for past and present states is only supported
  2208. # when using `use_decoder_masked_attention`
  2209. if past_present_share_buffer and is_beamsearch and not args.use_decoder_masked_attention:
  2210. raise ValueError(
  2211. "`use_decoder_masked_attention` MUST be turned on to use `past_present_share_buffer` in case of BeamSearch"
  2212. )
  2213. # For any kind of sampling, using decoder masked multihead attention is only supported
  2214. # when using `past_present_share_buffer`
  2215. if args.use_decoder_masked_attention and not past_present_share_buffer:
  2216. raise ValueError("`past_present_share_buffer` MUST be turned on to use `use_decoder_masked_attention`")
  2217. # For any kind of sampling, using decoder masked multihead attention is only supported
  2218. # on GPUs
  2219. if args.use_decoder_masked_attention and not args.use_gpu:
  2220. raise ValueError("`use_decoder_masked_attention` option is only supported on GPUs")
  2221. if is_gpt2:
  2222. if args.decoder_onnx and os.path.exists(args.decoder_onnx):
  2223. logger.info(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
  2224. else:
  2225. if not args.decoder_onnx:
  2226. onnx_filename = f"{args.model_name_or_path}_past_{args.precision}.onnx"
  2227. args.decoder_onnx = Path(Path(args.output).parent, onnx_filename).as_posix()
  2228. logger.info(f"Convert GPT model {args.model_name_or_path} to onnx {args.decoder_onnx} ...")
  2229. gpt2_to_onnx(args)
  2230. else: # t5 or mt5
  2231. if args.decoder_onnx and args.encoder_decoder_init_onnx:
  2232. logger.info(
  2233. f"skip convert_to_onnx since paths specified: {args.decoder_onnx} and {args.encoder_decoder_init_onnx}"
  2234. )
  2235. else:
  2236. logger.info(f"Convert model {args.model_name_or_path} to onnx ...")
  2237. t5_to_onnx(args)
  2238. # We only want to pad the logits MatMul weight in the decoder for fp16 models.
  2239. # The inherent assumption is that fp16 models run on GPU for which all
  2240. # dims need to be a multiple of 8 to leverage tensor cores.
  2241. # NOTE: We currently only support padding the MatMul logits weight for GPT2 GreedySearch/BeamSearch.
  2242. # This can be expanded to other models/decoding strategies later
  2243. logits_matmul_weight_padded = False
  2244. if (
  2245. not args.disable_pad_vocab_size
  2246. and args.precision == Precision.FLOAT16.value
  2247. and is_gpt2
  2248. and (is_beamsearch or is_greedysearch or is_sampling)
  2249. ):
  2250. logger.info(
  2251. f"Pad logits MatMul weights for optimal MatMul perf in fp16 on {args.decoder_onnx}. "
  2252. "The file will be overwritten."
  2253. )
  2254. logits_matmul_weight_padded = pad_weights_of_logits_matmul(args.decoder_onnx, args.use_external_data_format)
  2255. if not logits_matmul_weight_padded:
  2256. logger.warning(
  2257. "Tried and failed to pad logits MatMul weights. Performance may be sub-optimal for this MatMul"
  2258. )
  2259. gpt2_init_decoder_generated = False
  2260. gpt2_init_decoder_onnx_path = None
  2261. if (
  2262. not args.disable_separate_gpt2_decoder_for_init_run
  2263. and is_gpt2
  2264. and (is_beamsearch or is_greedysearch or is_sampling)
  2265. ):
  2266. logger.info(f"Creating an initial run GPT2 decoder from {args.decoder_onnx}. ")
  2267. gpt2_init_decoder_onnx_filename = f"gpt2_init_past_{args.precision}.onnx"
  2268. gpt2_init_decoder_onnx_path = Path(Path(args.output).parent, gpt2_init_decoder_onnx_filename).as_posix()
  2269. gpt2_init_decoder_generated = generate_gpt2_init_decoder(
  2270. args.decoder_onnx,
  2271. gpt2_init_decoder_onnx_path,
  2272. args.use_external_data_format,
  2273. )
  2274. if not gpt2_init_decoder_generated:
  2275. logger.warning(
  2276. "Tried and failed to generate the init decoder GPT2 model. "
  2277. "Performance may be sub-optimal for the initial decoding run"
  2278. )
  2279. # Update the graph input shapes for the non-initial decoder model to account
  2280. # for the fact that the sequence length will always be 1
  2281. if gpt2_init_decoder_generated and not update_input_shapes_for_gpt2_decoder_model(
  2282. args.decoder_onnx, args.use_external_data_format
  2283. ):
  2284. # Can't proceed further - better to raise an exception
  2285. raise ValueError("Could not update the input shapes for the non-initial decoder subgraph.")
  2286. # If the user explicitly requests running shape inference or if we padded/mutated
  2287. # weight(s)/input shape(s) in the decoder, we want to run shape inference to capture the new
  2288. # shapes
  2289. if logits_matmul_weight_padded or args.run_shape_inference or gpt2_init_decoder_generated:
  2290. logger.info(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.")
  2291. shape_inference(args.decoder_onnx, args.use_external_data_format)
  2292. if gpt2_init_decoder_generated:
  2293. logger.info(f"Run symbolic shape inference on {gpt2_init_decoder_onnx_path}. The file will be overwritten.")
  2294. shape_inference(gpt2_init_decoder_onnx_path, args.use_external_data_format)
  2295. if is_gpt2:
  2296. config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
  2297. elif args.model_type == "t5":
  2298. config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
  2299. else:
  2300. config = MT5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
  2301. if args.verbose:
  2302. logger.info(f"Config={config}")
  2303. eos_token_id = config.eos_token_id
  2304. pad_token_id = config.eos_token_id if is_gpt2 else config.pad_token_id
  2305. vocab_size = config.vocab_size
  2306. # if vocab_size is given in parameters use that.
  2307. if args.vocab_size != -1:
  2308. vocab_size = args.vocab_size
  2309. if args.eos_token_id != -1:
  2310. eos_token_id = args.eos_token_id
  2311. if args.pad_token_id != -1:
  2312. pad_token_id = args.pad_token_id
  2313. decoder_model = onnx.load_model(args.decoder_onnx, load_external_data=True)
  2314. decoder_model.graph.name = f"{args.model_type} decoder"
  2315. gpt2_init_decoder_model = None
  2316. if args.model_type == "gpt2":
  2317. verify_gpt2_subgraph(decoder_model.graph, args.precision)
  2318. # If we generated the init decoder model, verify that as well
  2319. if gpt2_init_decoder_generated:
  2320. gpt2_init_decoder_model = onnx.load_model(gpt2_init_decoder_onnx_path, load_external_data=True)
  2321. gpt2_init_decoder_model.graph.name = f"{args.model_type} init decoder"
  2322. verify_gpt2_subgraph(gpt2_init_decoder_model.graph, args.precision)
  2323. else:
  2324. verify_t5_decoder_subgraph(decoder_model.graph, args.precision)
  2325. inputs = None
  2326. if is_beamsearch:
  2327. inputs = [
  2328. "input_ids",
  2329. "max_length",
  2330. "min_length",
  2331. "num_beams",
  2332. "num_return_sequences",
  2333. "length_penalty",
  2334. "repetition_penalty",
  2335. ]
  2336. elif is_greedysearch or is_sampling:
  2337. inputs = [
  2338. "input_ids",
  2339. "max_length",
  2340. "min_length",
  2341. "repetition_penalty",
  2342. ]
  2343. if args.vocab_mask:
  2344. inputs.append("vocab_mask")
  2345. else:
  2346. inputs.append("")
  2347. if args.prefix_vocab_mask:
  2348. inputs.append("prefix_vocab_mask")
  2349. else:
  2350. inputs.append("")
  2351. if args.custom_attention_mask:
  2352. inputs.append("attention_mask")
  2353. else:
  2354. inputs.append("")
  2355. if is_sampling:
  2356. if args.custom and args.presence_mask:
  2357. inputs.append("presence_mask")
  2358. else:
  2359. inputs.append("")
  2360. if args.seed:
  2361. inputs.append("seed")
  2362. outputs = ["sequences"]
  2363. if args.output_sequences_scores:
  2364. outputs.append("sequences_scores")
  2365. if args.output_token_scores:
  2366. assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
  2367. outputs.append("scores")
  2368. node = None
  2369. if is_beamsearch:
  2370. node = onnx.helper.make_node(
  2371. "BeamSearch",
  2372. inputs=inputs,
  2373. outputs=outputs,
  2374. name=f"BeamSearch_{args.model_type}",
  2375. )
  2376. elif is_greedysearch:
  2377. node = onnx.helper.make_node(
  2378. "GreedySearch",
  2379. inputs=inputs,
  2380. outputs=outputs,
  2381. name=f"GreedySearch_{args.model_type}",
  2382. )
  2383. elif is_sampling:
  2384. node = onnx.helper.make_node(
  2385. "Sampling",
  2386. inputs=inputs,
  2387. outputs=outputs,
  2388. name=f"Sampling_{args.model_type}",
  2389. )
  2390. node.domain = "com.microsoft"
  2391. attr_to_extend = None
  2392. if is_beamsearch:
  2393. attr_to_extend = [
  2394. onnx.helper.make_attribute("eos_token_id", eos_token_id),
  2395. onnx.helper.make_attribute("pad_token_id", pad_token_id),
  2396. onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
  2397. onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
  2398. onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
  2399. ]
  2400. elif is_greedysearch:
  2401. attr_to_extend = [
  2402. onnx.helper.make_attribute("eos_token_id", eos_token_id),
  2403. onnx.helper.make_attribute("pad_token_id", pad_token_id),
  2404. onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
  2405. onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
  2406. ]
  2407. elif is_sampling:
  2408. attr_to_extend = [
  2409. onnx.helper.make_attribute("eos_token_id", eos_token_id),
  2410. onnx.helper.make_attribute("pad_token_id", pad_token_id),
  2411. onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
  2412. onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
  2413. onnx.helper.make_attribute("temperature", args.temperature),
  2414. onnx.helper.make_attribute("top_p", args.top_p),
  2415. onnx.helper.make_attribute("filter_value", args.filter_value),
  2416. onnx.helper.make_attribute("min_tokens_to_keep", args.min_tokens_to_keep),
  2417. onnx.helper.make_attribute("custom", args.custom),
  2418. onnx.helper.make_attribute("presence_penalty", args.presence_penalty),
  2419. ]
  2420. # Explicitly pass in the vocab size via an attribute
  2421. if logits_matmul_weight_padded:
  2422. attr_to_extend.extend([onnx.helper.make_attribute("vocab_size", vocab_size)])
  2423. node.attribute.extend(attr_to_extend)
  2424. initializers = []
  2425. if args.model_type in ["t5", "mt5"]:
  2426. if args.run_shape_inference:
  2427. logger.info(f"Symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.")
  2428. shape_inference(args.encoder_decoder_init_onnx, args.use_external_data_format)
  2429. encoder_model = onnx.load_model(args.encoder_decoder_init_onnx, load_external_data=True)
  2430. suffix = "encoder" if len(encoder_model.graph.input) == 2 else "encoder and decoder init"
  2431. encoder_model.graph.name = f"{args.model_type} {suffix}"
  2432. verify_t5_encoder_decoder_init_subgraph(encoder_model.graph, args.precision)
  2433. make_dim_proto_numeric_t5(encoder_model, config)
  2434. make_dim_proto_numeric_t5(decoder_model, config)
  2435. # Update decoder subgraph in preparation to use past present share buffer
  2436. if past_present_share_buffer:
  2437. if not args.use_decoder_masked_attention:
  2438. raise ValueError("past_present_share_buffer is only supported with use_decoder_masked_attention")
  2439. logger.info(
  2440. "*****update t5 decoder subgraph to share past/present buffer and use decoder_masked_multihead_attention*****"
  2441. )
  2442. if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
  2443. logger.info("*****update t5 decoder subgraph successfully!!!*****")
  2444. else:
  2445. logger.info("*****DecoderMaskedMultiHeadAttention is not applied to T5 decoder*****")
  2446. if pack_qkv_for_decoder_masked_mha(decoder_model):
  2447. logger.info("*****pack qkv for decoder masked mha successfully!!!*****")
  2448. else:
  2449. logger.info("*****pack qkv for decoder masked mha failed!!!*****")
  2450. if not args.disable_shared_initializers:
  2451. # Unique shared initializers from the decoder and decoder_init could reduce memory usage in inference.
  2452. initializers = get_shared_initializers(encoder_model, decoder_model)
  2453. logger.info(
  2454. f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in encoder and decoder subgraphs are moved to the main graph"
  2455. )
  2456. # TODO(tianleiwu): investigate the following which causes error in inference
  2457. # Move initializer from subgraph to main graph could reduce memory usage in inference.
  2458. # moved_initializers = move_initializers(encoder_model.graph)
  2459. # logger.info(
  2460. # f"{len(moved_initializers)} initializers ({[i.name for i in moved_initializers]}) from the encoder are moved to the main graph"
  2461. # )
  2462. # initializers.extend(moved_initializers)
  2463. assert config.decoder_start_token_id >= 0, "decoder_start_token_id should be >= 0"
  2464. node.attribute.extend(
  2465. [
  2466. onnx.helper.make_attribute("encoder", encoder_model.graph),
  2467. onnx.helper.make_attribute("decoder", decoder_model.graph),
  2468. onnx.helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id),
  2469. ]
  2470. )
  2471. else:
  2472. if gpt2_init_decoder_generated:
  2473. # Move shared initializers (shared between init decoder and decoder models) to the main
  2474. # graph and remove them from these models
  2475. if not args.disable_shared_initializers:
  2476. # Unique shared initializers from the decoder and decoder_init could reduce memory usage in inference.
  2477. initializers = get_shared_initializers(gpt2_init_decoder_model, decoder_model)
  2478. logger.info(
  2479. f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in decoder and init decoder subgraphs are moved to the main graph"
  2480. )
  2481. # Update init decoder subgraph in preparation to use past present share buffer
  2482. if past_present_share_buffer:
  2483. logger.info("*****update init decoder subgraph to make past and present share buffer******************")
  2484. update_decoder_subgraph_past_present_share_buffer(gpt2_init_decoder_model.graph)
  2485. # Update init decoder subgraph in preparation to use DecoderMaskedSelfAttention
  2486. # NOTE: Even if we will not use DecoderMaskedSelfAttention in the init decoder subgraph
  2487. # it makes the runtime changes cleaner if we keep both the init decoder and decoder subgraphs
  2488. # same in terms of the subgraph inputs.
  2489. if args.use_decoder_masked_attention and not update_decoder_subgraph_use_decoder_masked_attention(
  2490. gpt2_init_decoder_model.graph, is_beamsearch, False
  2491. ):
  2492. raise ValueError("Could not update the init decoder subgraph to use DecoderMaskedSelfAttention")
  2493. node.attribute.append(onnx.helper.make_attribute("init_decoder", gpt2_init_decoder_model.graph))
  2494. else:
  2495. # Move initializer from subgraph to main graph could reduce memory usage in inference.
  2496. initializers = move_initializers(decoder_model.graph)
  2497. logger.info(f"{len(initializers)} initializers from the decoder are moved to the main graph")
  2498. # Update decoder subgraph in preparation to use past present share buffer
  2499. if past_present_share_buffer:
  2500. logger.info("*****update decoder subgraph to make past and present share buffer******************")
  2501. update_decoder_subgraph_past_present_share_buffer(decoder_model.graph)
  2502. # Update decoder subgraph in preparation to use DecoderMaskedSelfAttention
  2503. if args.use_decoder_masked_attention and not update_decoder_subgraph_use_decoder_masked_attention(
  2504. decoder_model.graph, is_beamsearch, True
  2505. ):
  2506. raise ValueError("Could not update the decoder subgraph to use DecoderMaskedSelfAttention")
  2507. node.attribute.append(onnx.helper.make_attribute("decoder", decoder_model.graph))
  2508. # graph inputs
  2509. input_ids = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])
  2510. max_length = onnx.helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
  2511. min_length = onnx.helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
  2512. num_beams = onnx.helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
  2513. num_return_sequences = onnx.helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
  2514. length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
  2515. repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
  2516. graph_inputs = None
  2517. if is_beamsearch:
  2518. graph_inputs = [
  2519. input_ids,
  2520. max_length,
  2521. min_length,
  2522. num_beams,
  2523. num_return_sequences,
  2524. length_penalty,
  2525. repetition_penalty,
  2526. ]
  2527. elif is_greedysearch or is_sampling:
  2528. graph_inputs = [
  2529. input_ids,
  2530. max_length,
  2531. min_length,
  2532. repetition_penalty,
  2533. ]
  2534. if args.vocab_mask:
  2535. vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size])
  2536. graph_inputs.append(vocab_mask)
  2537. if args.prefix_vocab_mask:
  2538. prefix_vocab_mask = onnx.helper.make_tensor_value_info(
  2539. "prefix_vocab_mask", TensorProto.INT32, ["batch_size", vocab_size]
  2540. )
  2541. graph_inputs.append(prefix_vocab_mask)
  2542. if args.custom_attention_mask:
  2543. attention_mask = onnx.helper.make_tensor_value_info(
  2544. "attention_mask", TensorProto.INT32, ["batch_size", "sequence_length"]
  2545. )
  2546. graph_inputs.append(attention_mask)
  2547. if args.custom and args.presence_mask:
  2548. presence_mask = onnx.helper.make_tensor_value_info(
  2549. "presence_mask", TensorProto.INT32, ["batch_size", vocab_size]
  2550. )
  2551. graph_inputs.append(presence_mask)
  2552. if is_sampling and args.seed:
  2553. seed = onnx.helper.make_tensor_value_info("seed", TensorProto.INT32, [1])
  2554. graph_inputs.append(seed)
  2555. # graph outputs
  2556. sequences = None
  2557. if is_beamsearch:
  2558. sequences = onnx.helper.make_tensor_value_info(
  2559. "sequences",
  2560. TensorProto.INT32,
  2561. ["batch_size", "num_return_sequences", "max_length"],
  2562. )
  2563. elif is_greedysearch or is_sampling:
  2564. sequences = onnx.helper.make_tensor_value_info(
  2565. "sequences",
  2566. TensorProto.INT32,
  2567. ["batch_size", "max_length"],
  2568. )
  2569. graph_outputs = [sequences]
  2570. if args.output_sequences_scores:
  2571. sequences_scores = onnx.helper.make_tensor_value_info(
  2572. "sequences_scores",
  2573. TensorProto.FLOAT,
  2574. ["batch_size", "num_return_sequences"],
  2575. )
  2576. graph_outputs.append(sequences_scores)
  2577. if args.output_token_scores:
  2578. scores = onnx.helper.make_tensor_value_info(
  2579. "scores",
  2580. TensorProto.FLOAT,
  2581. ["max_length - sequence_length", "batch_size", "num_beams", vocab_size],
  2582. )
  2583. graph_outputs.append(scores)
  2584. new_graph = onnx.helper.make_graph(
  2585. [node],
  2586. (f"{args.model_type} beam search" if not is_greedysearch else f"{args.model_type} greedy search"),
  2587. graph_inputs,
  2588. graph_outputs,
  2589. initializers,
  2590. )
  2591. # Create the model
  2592. new_model = onnx.helper.make_model(
  2593. new_graph,
  2594. producer_name="onnxruntime.transformers",
  2595. opset_imports=decoder_model.opset_import,
  2596. )
  2597. # TODO(tianleiwu): move shared initializers from T5 encoder and decoder subgraphs to parent graph to save memory.
  2598. if args.use_external_data_format:
  2599. from packaging import version # noqa: PLC0415
  2600. if version.parse(onnx.__version__) < version.parse("1.12.0"):
  2601. logger.warning("Require onnx >= 1.12 to save large (>2GB) model!")
  2602. OnnxModel.save(
  2603. new_model,
  2604. args.output,
  2605. save_as_external_data=True,
  2606. all_tensors_to_one_file=True,
  2607. )
  2608. else:
  2609. onnx.save(new_model, args.output)
  2610. logger.info(f"model save to {args.output}")
  2611. def test_torch_performance(
  2612. args: argparse.Namespace,
  2613. model: GPT2LMHeadModel | T5ForConditionalGeneration,
  2614. input_ids: torch.Tensor,
  2615. attention_mask: torch.Tensor,
  2616. eos_token_id: int,
  2617. pad_token_id: int,
  2618. bad_words_ids: list[list[int]],
  2619. ) -> dict[str, Any]:
  2620. """Test PyTorch performance of text generation.
  2621. Args:
  2622. args (argparse.Namespace): arguments parsed from command line
  2623. model (Union[GPT2LMHeadModel, T5ForConditionalGeneration]): PyTorch model
  2624. input_ids (torch.Tensor): input_ids
  2625. attention_mask (torch.Tensor): Attention mask
  2626. eos_token_id (int): EOS token ID
  2627. pad_token_id (int): Padding token ID
  2628. bad_words_ids (List[List[int]]): Words shall not be generated.
  2629. Raises:
  2630. RuntimeError: PyTorch with CUDA is not available for --use_gpu
  2631. Returns:
  2632. Dict[str, Any]: A dictionary with string with metric name, and value can be integer or string.
  2633. """
  2634. if args.use_gpu and not torch.cuda.is_available():
  2635. raise RuntimeError("Please install PyTorch with Cuda for testing gpu performance.")
  2636. if args.precision == Precision.FLOAT16.value:
  2637. model.half()
  2638. device = torch.device("cuda:0" if args.use_gpu else "cpu")
  2639. model.to(device)
  2640. torch.set_grad_enabled(False)
  2641. input_ids = input_ids.to(device)
  2642. attention_mask = attention_mask.to(device)
  2643. torch_latency = []
  2644. for _ in range(args.total_runs):
  2645. start = time.time()
  2646. _ = model.generate(
  2647. input_ids=input_ids,
  2648. attention_mask=attention_mask,
  2649. max_length=args.max_length,
  2650. min_length=args.min_length,
  2651. num_beams=args.num_beams,
  2652. early_stopping=args.early_stopping,
  2653. no_repeat_ngram_size=args.no_repeat_ngram_size,
  2654. eos_token_id=eos_token_id,
  2655. pad_token_id=pad_token_id,
  2656. num_return_sequences=args.num_return_sequences,
  2657. length_penalty=args.length_penalty,
  2658. repetition_penalty=args.repetition_penalty,
  2659. bad_words_ids=bad_words_ids if bad_words_ids else None,
  2660. return_dict_in_generate=True,
  2661. output_scores=args.output_sequences_scores or args.output_token_scores,
  2662. )
  2663. torch_latency.append(time.time() - start)
  2664. batch_size = input_ids.shape[0]
  2665. from benchmark_helper import get_latency_result # noqa: PLC0415
  2666. return get_latency_result(torch_latency, batch_size)
  2667. def create_attention_mask(input_ids, pad_token_id):
  2668. attention_mask = np.ones(input_ids.shape, dtype=np.int32)
  2669. for i in range(input_ids.shape[0]):
  2670. abs_pos = 0
  2671. for j in range(input_ids.shape[1]):
  2672. if input_ids[i][j] == pad_token_id and abs_pos == 0:
  2673. attention_mask[i][j] = 0
  2674. else:
  2675. abs_pos += 1
  2676. return attention_mask
  2677. def test_gpt_model(
  2678. args: argparse.Namespace,
  2679. sentences: list[str] | None = None,
  2680. is_greedy: bool = False,
  2681. ):
  2682. """Test GPT-2 model
  2683. Args:
  2684. args (argparse.Namespace): arguments parsed from command line
  2685. sentences (Optional[List[str]], optional): input text. Defaults to None.
  2686. Returns:
  2687. Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
  2688. """
  2689. assert args.model_type == "gpt2"
  2690. tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
  2691. tokenizer.padding_side = "left"
  2692. tokenizer.pad_token = tokenizer.eos_token
  2693. model = GPT2LMHeadModel.from_pretrained(
  2694. args.model_name_or_path,
  2695. cache_dir=args.cache_dir,
  2696. pad_token_id=tokenizer.eos_token_id,
  2697. )
  2698. # Use different length sentences to test batching
  2699. if sentences is None:
  2700. sentences = [
  2701. "The product is released",
  2702. "I enjoy walking in the park",
  2703. "Test best way to invest",
  2704. ]
  2705. inputs = tokenizer(sentences, return_tensors="pt", padding=True)
  2706. input_ids = inputs["input_ids"]
  2707. attention_mask = inputs["attention_mask"]
  2708. bad_words = "walk in park"
  2709. bad_words_ids = tokenizer.encode(bad_words, add_prefix_space=True)
  2710. bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list
  2711. if args.vocab_mask:
  2712. logger.debug("bad_words_ids", bad_words_ids) # noqa: PLE1205
  2713. else:
  2714. bad_words_ids = []
  2715. config = model.config
  2716. eos_token_id = config.eos_token_id
  2717. pad_token_id = config.eos_token_id
  2718. vocab_size = config.vocab_size
  2719. torch_decoded_sequences = []
  2720. beam_outputs = None
  2721. if not args.disable_parity:
  2722. print("-" * 50)
  2723. print("Test PyTorch model and beam search with huggingface transformers...")
  2724. beam_outputs = model.generate(
  2725. input_ids=input_ids,
  2726. attention_mask=attention_mask,
  2727. max_length=args.max_length,
  2728. min_length=args.min_length,
  2729. num_beams=args.num_beams,
  2730. early_stopping=args.early_stopping,
  2731. no_repeat_ngram_size=args.no_repeat_ngram_size,
  2732. eos_token_id=eos_token_id,
  2733. pad_token_id=pad_token_id,
  2734. num_return_sequences=args.num_return_sequences,
  2735. length_penalty=args.length_penalty,
  2736. repetition_penalty=args.repetition_penalty,
  2737. bad_words_ids=bad_words_ids if bad_words_ids else None,
  2738. return_dict_in_generate=True,
  2739. output_scores=args.output_sequences_scores or args.output_token_scores,
  2740. )
  2741. print("input_ids", input_ids)
  2742. print("huggingface transformers outputs:")
  2743. print("sequences", beam_outputs.sequences)
  2744. if args.output_sequences_scores:
  2745. print("sequences_scores", beam_outputs.sequences_scores)
  2746. if args.output_token_scores:
  2747. print("scores", beam_outputs.scores)
  2748. for i, sequence in enumerate(beam_outputs.sequences):
  2749. decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True)
  2750. torch_decoded_sequences.append(decoded_sequence)
  2751. print(f"{i}: {decoded_sequence}")
  2752. print("-" * 50)
  2753. print("Testing beam search with onnxruntime...")
  2754. if is_greedy:
  2755. inputs = {
  2756. "input_ids": input_ids.cpu().numpy().astype(np.int32),
  2757. "max_length": np.array([args.max_length], dtype=np.int32),
  2758. "min_length": np.array([args.min_length], dtype=np.int32),
  2759. "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
  2760. }
  2761. else:
  2762. inputs = {
  2763. "input_ids": input_ids.cpu().numpy().astype(np.int32),
  2764. "max_length": np.array([args.max_length], dtype=np.int32),
  2765. "min_length": np.array([args.min_length], dtype=np.int32),
  2766. "num_beams": np.array([args.num_beams], dtype=np.int32),
  2767. "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32),
  2768. "length_penalty": np.array([args.length_penalty], dtype=np.float32),
  2769. "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
  2770. }
  2771. if args.vocab_mask:
  2772. vocab_mask = np.ones((vocab_size), dtype=np.int32)
  2773. if args.vocab_mask:
  2774. for bad_word_id in bad_words_ids:
  2775. vocab_mask[bad_word_id] = 0
  2776. inputs["vocab_mask"] = vocab_mask
  2777. if args.custom_attention_mask:
  2778. inputs["attention_mask"] = create_attention_mask(input_ids, pad_token_id)
  2779. batch_size = input_ids.shape[0]
  2780. if args.prefix_vocab_mask:
  2781. logger.info("Use prefix vocab mask with all ones in ORT, but no corresponding setting for Torch model.")
  2782. prefix_vocab_mask = np.ones((batch_size, vocab_size), dtype=np.int32)
  2783. inputs["prefix_vocab_mask"] = prefix_vocab_mask
  2784. if args.save_test_data:
  2785. test_data_dir = Path(args.output).parent.as_posix()
  2786. logger.debug("test_data_dir", test_data_dir) # noqa: PLE1205
  2787. from bert_test_data import output_test_data # noqa: PLC0415
  2788. logger.info(f"Saving test_data to {test_data_dir}/test_data_set_* ...")
  2789. all_inputs = [inputs]
  2790. for i, inputs in enumerate(all_inputs):
  2791. dir = os.path.join(test_data_dir, "test_data_set_" + str(i))
  2792. output_test_data(dir, inputs)
  2793. logger.debug("ORT inputs", inputs) # noqa: PLE1205
  2794. if args.disable_perf_test:
  2795. return
  2796. logger.debug("Creating ort session......")
  2797. ort_session = create_ort_session(args.output, args.use_gpu, args.use_sln_strict_mode)
  2798. logger.debug("Run ort session......")
  2799. result = ort_session.run(None, inputs)
  2800. # Test performance
  2801. latency = []
  2802. for _ in range(args.total_runs):
  2803. start = time.time()
  2804. _ = ort_session.run(None, inputs)
  2805. latency.append(time.time() - start)
  2806. from benchmark_helper import get_latency_result # noqa: PLC0415
  2807. batch_size = input_ids.shape[0]
  2808. output = get_latency_result(latency, batch_size)
  2809. print("ORT outputs:")
  2810. sequences = result[0]
  2811. print("sequences", sequences)
  2812. if args.output_sequences_scores:
  2813. print("sequences_scores", result[1])
  2814. if args.output_token_scores:
  2815. print("scores", result[2])
  2816. if is_greedy:
  2817. (batch_size, max_length) = sequences.shape
  2818. ort_decoded_sequences = []
  2819. for i in range(batch_size):
  2820. decoded_sequence = tokenizer.decode(sequences[i], skip_special_tokens=True)
  2821. ort_decoded_sequences.append(decoded_sequence)
  2822. print(f"batch {i} sequence: {decoded_sequence}")
  2823. else:
  2824. (batch_size, num_sequences, max_length) = sequences.shape
  2825. ort_decoded_sequences = []
  2826. for i in range(batch_size):
  2827. for j in range(num_sequences):
  2828. decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True)
  2829. ort_decoded_sequences.append(decoded_sequence)
  2830. print(f"batch {i} sequence {j}: {decoded_sequence}")
  2831. if beam_outputs:
  2832. torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1)
  2833. ort_sequences = torch.LongTensor(sequences)
  2834. print("-" * 50)
  2835. print("Torch Sequences:")
  2836. print(torch_sequences)
  2837. print(torch_decoded_sequences)
  2838. print("-" * 50)
  2839. print("ORT Sequences:")
  2840. print(ort_sequences)
  2841. print(ort_decoded_sequences)
  2842. print("-" * 50)
  2843. # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not.
  2844. is_same = torch_decoded_sequences == ort_decoded_sequences
  2845. print("Torch and ORT result is", "same" if is_same else "different")
  2846. output["parity"] = is_same
  2847. if args.torch_performance:
  2848. torch_latency_output = test_torch_performance(
  2849. args,
  2850. model,
  2851. input_ids,
  2852. attention_mask,
  2853. eos_token_id,
  2854. pad_token_id,
  2855. bad_words_ids,
  2856. )
  2857. print("Torch Latency", torch_latency_output)
  2858. print("ORT", output)
  2859. return output
  2860. def test_t5_model(args: argparse.Namespace, sentences: list[str] | None = None):
  2861. """Test T5 or MT5 model
  2862. Args:
  2863. args (argparse.Namespace): arguments parsed from command line
  2864. sentences (Optional[List[str]], optional): input text. Defaults to None.
  2865. Returns:
  2866. Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
  2867. """
  2868. assert args.model_type in ["t5", "mt5"]
  2869. if args.prefix_vocab_mask:
  2870. logger.debug("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
  2871. return None
  2872. tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
  2873. tokenizer.padding_side = "left"
  2874. if args.model_type == "t5":
  2875. model = T5ForConditionalGeneration.from_pretrained(
  2876. args.model_name_or_path,
  2877. cache_dir=args.cache_dir,
  2878. )
  2879. else:
  2880. model = MT5ForConditionalGeneration.from_pretrained(
  2881. args.model_name_or_path,
  2882. cache_dir=args.cache_dir,
  2883. )
  2884. # Use different length sentences to test batching
  2885. if sentences is None:
  2886. sentences = [
  2887. "translate English to French: The product is released",
  2888. "summarize: research continues to show that pets bring real health benefits to their owners. Having a dog around can lead to lower levels of stress for both adults and kids.",
  2889. # "summarize: I enjoy walking in the park. It makes my mind feel calm and refreshed. "
  2890. # + "I enjoy looking at the trees, flowers, and wildlife around me, and listening to sound from natural.",
  2891. ]
  2892. inputs = tokenizer(sentences, return_tensors="pt", padding=True)
  2893. input_ids = inputs["input_ids"]
  2894. attention_mask = inputs["attention_mask"]
  2895. bad_words = "walk in park"
  2896. bad_words_ids = tokenizer.encode(bad_words)[:-1] # exclude the last token (EOS)
  2897. bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list
  2898. if args.vocab_mask:
  2899. logger.debug("bad_words_ids", bad_words_ids) # noqa: PLE1205
  2900. else:
  2901. bad_words_ids = []
  2902. config = model.config
  2903. eos_token_id = config.eos_token_id
  2904. pad_token_id = config.pad_token_id
  2905. vocab_size = config.vocab_size
  2906. logger.debug(f"eos_token_id:{eos_token_id}, pad_token_id:{pad_token_id}, vocab_size:{vocab_size}")
  2907. torch_decoded_sequences = []
  2908. if not args.disable_parity:
  2909. print("-" * 50)
  2910. print("Test PyTorch model and beam search with huggingface transformers...")
  2911. beam_outputs = model.generate(
  2912. input_ids=input_ids,
  2913. attention_mask=attention_mask,
  2914. max_length=args.max_length,
  2915. min_length=args.min_length,
  2916. num_beams=args.num_beams,
  2917. early_stopping=args.early_stopping,
  2918. no_repeat_ngram_size=args.no_repeat_ngram_size,
  2919. eos_token_id=eos_token_id,
  2920. pad_token_id=pad_token_id,
  2921. num_return_sequences=args.num_return_sequences,
  2922. length_penalty=args.length_penalty,
  2923. repetition_penalty=args.repetition_penalty,
  2924. bad_words_ids=bad_words_ids if bad_words_ids else None,
  2925. return_dict_in_generate=True,
  2926. output_scores=args.output_sequences_scores or args.output_token_scores,
  2927. )
  2928. print("input_ids", input_ids)
  2929. print("huggingface transformers outputs:")
  2930. print("sequences", beam_outputs.sequences)
  2931. if args.output_sequences_scores:
  2932. print("sequences_scores", beam_outputs.sequences_scores)
  2933. if args.output_token_scores:
  2934. print("scores", beam_outputs.scores)
  2935. for i, sequence in enumerate(beam_outputs.sequences):
  2936. decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True)
  2937. torch_decoded_sequences.append(decoded_sequence)
  2938. print(f"{i}: {decoded_sequence}")
  2939. print("-" * 50)
  2940. print("Testing beam search with onnxruntime...")
  2941. vocab_mask = np.ones((vocab_size), dtype=np.int32)
  2942. if args.vocab_mask:
  2943. for bad_word_id in bad_words_ids:
  2944. vocab_mask[bad_word_id] = 0
  2945. inputs = {
  2946. "input_ids": input_ids.cpu().numpy().astype(np.int32),
  2947. "max_length": np.array([args.max_length], dtype=np.int32),
  2948. "min_length": np.array([args.min_length], dtype=np.int32),
  2949. "num_beams": np.array([args.num_beams], dtype=np.int32),
  2950. "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32),
  2951. "length_penalty": np.array([args.length_penalty], dtype=np.float32),
  2952. "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
  2953. }
  2954. if args.vocab_mask:
  2955. inputs["vocab_mask"] = vocab_mask
  2956. if args.custom_attention_mask:
  2957. inputs["attention_mask"] = create_attention_mask(input_ids, pad_token_id)
  2958. if args.save_test_data:
  2959. test_data_dir = Path(args.output).parent.as_posix()
  2960. logger.debug("test_data_dir", test_data_dir) # noqa: PLE1205
  2961. from bert_test_data import output_test_data # noqa: PLC0415
  2962. all_inputs = [inputs]
  2963. for i, inputs in enumerate(all_inputs):
  2964. dir = os.path.join(test_data_dir, "test_data_set_" + str(i))
  2965. output_test_data(dir, inputs)
  2966. logger.debug("ORT inputs", inputs) # noqa: PLE1205
  2967. ort_session = create_ort_session(args.output, args.use_gpu, args.use_sln_strict_mode)
  2968. # Test performance
  2969. latency = []
  2970. for _ in range(args.total_runs):
  2971. start = time.time()
  2972. result = ort_session.run(None, inputs)
  2973. latency.append(time.time() - start)
  2974. batch_size = input_ids.shape[0]
  2975. from benchmark_helper import get_latency_result # noqa: PLC0415
  2976. output = get_latency_result(latency, batch_size)
  2977. print("ORT outputs:")
  2978. sequences = result[0]
  2979. print("sequences", sequences)
  2980. if args.output_sequences_scores:
  2981. print("sequences_scores", result[1])
  2982. if args.output_token_scores:
  2983. print("scores", result[2])
  2984. (batch_size, num_sequences, max_length) = sequences.shape
  2985. ort_decoded_sequences = []
  2986. for i in range(batch_size):
  2987. for j in range(num_sequences):
  2988. decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True)
  2989. ort_decoded_sequences.append(decoded_sequence)
  2990. print(f"batch {i} sequence {j}: {decoded_sequence}")
  2991. if not args.disable_parity:
  2992. torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1)
  2993. ort_sequences = torch.LongTensor(sequences)
  2994. print("-" * 50)
  2995. print("Torch Sequences:")
  2996. print(torch_sequences)
  2997. print(torch_decoded_sequences)
  2998. print("-" * 50)
  2999. print("ORT Sequences:")
  3000. print(ort_sequences)
  3001. print(ort_decoded_sequences)
  3002. print("-" * 50)
  3003. # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not.
  3004. is_same = torch_decoded_sequences == ort_decoded_sequences
  3005. print("Torch and ORT result is ", "same" if is_same else "different")
  3006. output["parity"] = is_same
  3007. if args.torch_performance:
  3008. torch_latency_output = test_torch_performance(
  3009. args,
  3010. model,
  3011. input_ids,
  3012. attention_mask,
  3013. eos_token_id,
  3014. pad_token_id,
  3015. bad_words_ids,
  3016. )
  3017. print("Torch Latency", torch_latency_output)
  3018. print("ORT", output)
  3019. return output
  3020. def main(argv: list[str] | None = None, sentences: list[str] | None = None):
  3021. """Main entry function
  3022. Args:
  3023. argv (Optional[List[str]], optional): _description_. Defaults to None.
  3024. sentences (Optional[List[str]], optional): input text. Defaults to None.
  3025. Raises:
  3026. ValueError: Path does not exist: --encoder_decoder_init_onnx
  3027. ValueError: Path does not exist: --decoder_onnx
  3028. ValueError: --decoder_onnx and --encoder_decoder_init_onnx are not used together for T5
  3029. Returns:
  3030. Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
  3031. """
  3032. args = parse_arguments(argv)
  3033. setup_logger(args.verbose)
  3034. if args.model_type in ["t5", "mt5"]:
  3035. if args.encoder_decoder_init_onnx and not os.path.exists(args.encoder_decoder_init_onnx):
  3036. raise ValueError(f"Path does not exist: --encoder_decoder_init_onnx {args.encoder_decoder_init_onnx}")
  3037. if args.decoder_onnx and not os.path.exists(args.decoder_onnx):
  3038. raise ValueError(f"Path does not exist: --decoder_onnx {args.decoder_onnx}")
  3039. if (args.encoder_decoder_init_onnx and not args.decoder_onnx) or (
  3040. args.decoder_onnx and not args.encoder_decoder_init_onnx
  3041. ):
  3042. raise ValueError("--decoder_onnx shall use together with --encoder_decoder_init_onnx")
  3043. is_greedy = args.num_beams == 1 and args.num_return_sequences == 1
  3044. if args.model_type == "gpt2" and is_greedy:
  3045. if args.top_p > 0.0 and args.top_p < 1.0:
  3046. convert_generation_model(args, GenerationType.SAMPLING)
  3047. logger.info(
  3048. "The test for gpt2_sampling onnx model is limited to non-custom model with small top_p(e.g <=0.01) value. The result should be the same as gpt2 greedy search."
  3049. )
  3050. if args.top_p > 0.01 or args.custom or args.seed:
  3051. return
  3052. else:
  3053. convert_generation_model(args, GenerationType.GREEDYSEARCH)
  3054. else:
  3055. convert_generation_model(args)
  3056. logger.info("start testing model...")
  3057. if args.model_type in ["t5", "mt5"]:
  3058. result = test_t5_model(args, sentences=sentences)
  3059. else:
  3060. result = test_gpt_model(args, sentences=sentences, is_greedy=is_greedy)
  3061. if result:
  3062. if args.use_external_data_format:
  3063. logger.info(f"Output files: {args.output}, {args.output}.data")
  3064. else:
  3065. logger.info(f"Output file: {args.output}")
  3066. return result
  3067. if __name__ == "__main__":
  3068. main()