onnx_exporter.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import os
  8. from pathlib import Path
  9. import numpy
  10. import torch
  11. from affinity_helper import AffinitySetting
  12. from benchmark_helper import OptimizerInfo, Precision, create_onnxruntime_session
  13. from huggingface_models import MODEL_CLASSES
  14. from quantize_helper import QuantizeHelper
  15. from torch_onnx_export_helper import torch_onnx_export
  16. from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, LxmertConfig, TransfoXLConfig
  17. from onnxruntime.transformers.models.gpt2.gpt2_helper import (
  18. PRETRAINED_GPT2_MODELS,
  19. GPT2ModelNoPastState,
  20. TFGPT2ModelNoPastState,
  21. )
  22. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
  23. logger = logging.getLogger(__name__)
  24. # Workaround by replacing torch.triu using self-defined op
  25. # Since torch.triu cannot be exported to ONNX. See https://github.com/pytorch/pytorch/issues/32968
  26. torch_func = {"triu": torch.triu}
  27. def triu_onnx(x, diagonal=0, out=None):
  28. assert out is None
  29. assert len(x.shape) == 2 and x.size(0) == x.size(1)
  30. torch_triu = torch_func["triu"]
  31. template = torch_triu(torch.ones((1024, 1024), dtype=torch.uint8), diagonal)
  32. mask = template[: x.size(0), : x.size(1)]
  33. return torch.where(mask.bool(), x, torch.zeros_like(x))
  34. def replace_torch_functions():
  35. torch.triu = triu_onnx
  36. def restore_torch_functions():
  37. torch.triu = torch_func["triu"]
  38. def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, config, data_type=numpy.int64):
  39. if config.model_type in ["vit", "swin"]:
  40. input_ids = numpy.random.rand(batch_size, 3, config.image_size, config.image_size).astype(numpy.float32)
  41. inputs = {"pixel_values": input_ids}
  42. return inputs
  43. input_ids = numpy.random.randint(low=0, high=vocab_size - 1, size=(batch_size, sequence_length), dtype=data_type)
  44. inputs = {"input_ids": input_ids}
  45. if "attention_mask" in input_names:
  46. attention_mask = numpy.ones([batch_size, sequence_length], dtype=data_type)
  47. inputs["attention_mask"] = attention_mask
  48. if "token_type_ids" in input_names:
  49. segment_ids = numpy.zeros([batch_size, sequence_length], dtype=data_type)
  50. inputs["token_type_ids"] = segment_ids
  51. if config.is_encoder_decoder:
  52. inputs["decoder_input_ids"] = input_ids
  53. if isinstance(config, LxmertConfig):
  54. inputs["visual_feats"] = numpy.random.randn(1, 1, config.visual_feat_dim).astype(numpy.float32)
  55. inputs["visual_pos"] = numpy.random.randn(1, 1, config.visual_pos_dim).astype(numpy.float32)
  56. if isinstance(config, TransfoXLConfig):
  57. inputs["tf_transfo_xl_model/transformer/pos_emb/einsum/Einsum/inputs_1:0"] = numpy.zeros(
  58. [config.hidden_size], dtype=numpy.float32
  59. )
  60. return inputs
  61. def filter_inputs(inputs, input_names):
  62. remaining_model_inputs = {}
  63. for input_name in input_names:
  64. if input_name in inputs:
  65. remaining_model_inputs[input_name] = inputs[input_name]
  66. return remaining_model_inputs
  67. def flatten(inputs):
  68. return [[flatten(i) for i in inputs] if isinstance(inputs, (list, tuple)) else inputs]
  69. def update_flatten_list(inputs, res_list):
  70. for i in inputs:
  71. res_list.append(i) if not isinstance(i, (list, tuple)) else update_flatten_list(i, res_list)
  72. return res_list
  73. def build_dynamic_axes(example_inputs, outputs_flatten):
  74. sequence_length = example_inputs["input_ids"].shape[-1]
  75. dynamic_axes = {key: {0: "batch_size", 1: "seq_len"} for key in example_inputs}
  76. output_names = ["output_" + str(i + 1) for i in range(len(outputs_flatten))]
  77. for i, output_name in enumerate(output_names):
  78. dynamic_axes[output_name] = {0: "batch_size"}
  79. dims = outputs_flatten[i].shape
  80. for j, dim in enumerate(dims):
  81. if dim == sequence_length:
  82. dynamic_axes[output_name].update({j: "seq_len"})
  83. return dynamic_axes, output_names
  84. def validate_onnx_model(
  85. onnx_model_path,
  86. example_inputs,
  87. example_outputs_flatten,
  88. use_gpu,
  89. fp16,
  90. output_names=None,
  91. ):
  92. test_session = create_onnxruntime_session(onnx_model_path, use_gpu, enable_all_optimization=False)
  93. if test_session is None:
  94. logger.error(f"{onnx_model_path} is an invalid ONNX model")
  95. return False
  96. logger.info(f"{onnx_model_path} is a valid ONNX model")
  97. # Compare the inference result with PyTorch or Tensorflow
  98. example_ort_inputs = {k: t.numpy() for k, t in example_inputs.items()}
  99. example_ort_outputs = test_session.run(output_names, example_ort_inputs)
  100. if len(example_outputs_flatten) != len(example_ort_outputs):
  101. logger.error(
  102. f"Number of output tensors expected {len(example_outputs_flatten)}, got {len(example_ort_outputs)}"
  103. )
  104. return False
  105. for i in range(len(example_outputs_flatten)):
  106. abs_diff = numpy.amax(numpy.abs(example_ort_outputs[i] - example_outputs_flatten[i].cpu().numpy()))
  107. if abs_diff > 1e-4:
  108. logger.info(f"Max absolute diff={abs_diff} for output tensor {i}")
  109. rtol = 5e-02 if fp16 else 1e-4
  110. atol = 1e-01 if fp16 else 1e-4
  111. if not numpy.allclose(
  112. example_ort_outputs[i],
  113. example_outputs_flatten[i].cpu().numpy(),
  114. rtol=rtol,
  115. atol=atol,
  116. ):
  117. logger.error(f"Output tensor {i} is not close: rtol={rtol}, atol={atol}")
  118. return False
  119. logger.info(f"inference result of onnxruntime is validated on {onnx_model_path}")
  120. return True
  121. def get_onnx_file_path(
  122. onnx_dir: str,
  123. model_name: str,
  124. input_count: int,
  125. optimized_by_script: bool,
  126. use_gpu: bool,
  127. precision: Precision,
  128. optimized_by_onnxruntime: bool,
  129. use_external_data: bool,
  130. ):
  131. from re import sub # noqa: PLC0415
  132. normalized_model_name = sub(r"[^a-zA-Z0-9_]", "_", model_name)
  133. if not optimized_by_script:
  134. filename = f"{normalized_model_name}_{input_count}"
  135. else:
  136. device = "gpu" if use_gpu else "cpu"
  137. filename = f"{normalized_model_name}_{input_count}_{precision}_{device}"
  138. if optimized_by_onnxruntime:
  139. filename += "_ort"
  140. directory = onnx_dir
  141. # ONNXRuntime will not write external data so the raw and optimized models shall be in same directory.
  142. if use_external_data and not optimized_by_onnxruntime:
  143. directory = os.path.join(onnx_dir, filename)
  144. if not os.path.exists(directory):
  145. os.makedirs(directory)
  146. return os.path.join(directory, f"{filename}.onnx")
  147. def add_filename_suffix(file_path: str, suffix: str) -> str:
  148. """
  149. Append a suffix at the filename (before the extension).
  150. Args:
  151. path: pathlib.Path The actual path object we would like to add a suffix
  152. suffix: The suffix to add
  153. Returns: path with suffix appended at the end of the filename and before extension
  154. """
  155. path = Path(file_path)
  156. return str(path.parent.joinpath(path.stem + suffix).with_suffix(path.suffix))
  157. def optimize_onnx_model_by_ort(onnx_model_path, ort_model_path, use_gpu, overwrite, model_fusion_statistics):
  158. if overwrite or not os.path.exists(ort_model_path):
  159. Path(ort_model_path).parent.mkdir(parents=True, exist_ok=True)
  160. from optimizer import get_fusion_statistics, optimize_by_onnxruntime # noqa: PLC0415
  161. # Use onnxruntime to optimize model, which will be saved to *_ort.onnx
  162. _ = optimize_by_onnxruntime(
  163. onnx_model_path,
  164. use_gpu=use_gpu,
  165. optimized_model_path=ort_model_path,
  166. opt_level=99,
  167. )
  168. model_fusion_statistics[ort_model_path] = get_fusion_statistics(ort_model_path)
  169. else:
  170. logger.info(f"Skip optimization since model existed: {ort_model_path}")
  171. def optimize_onnx_model(
  172. onnx_model_path,
  173. optimized_model_path,
  174. model_type,
  175. num_attention_heads,
  176. hidden_size,
  177. use_gpu,
  178. precision,
  179. use_raw_attention_mask,
  180. overwrite,
  181. model_fusion_statistics,
  182. use_external_data_format,
  183. optimization_options=None,
  184. ):
  185. if overwrite or not os.path.exists(optimized_model_path):
  186. Path(optimized_model_path).parent.mkdir(parents=True, exist_ok=True)
  187. from fusion_options import FusionOptions # noqa: PLC0415
  188. from optimizer import optimize_model # noqa: PLC0415
  189. if optimization_options is None:
  190. optimization_options = FusionOptions(model_type)
  191. optimization_options.use_raw_attention_mask(use_raw_attention_mask)
  192. if precision == Precision.FLOAT16:
  193. optimization_options.enable_gelu_approximation = True
  194. if precision == Precision.INT8:
  195. optimization_options.enable_embed_layer_norm = False
  196. # For swin models, the num_attention_heads is a list, which isn't supported yet, so set to 0 for now
  197. if model_type == "swin":
  198. num_attention_heads = 0
  199. hidden_size = 0
  200. # Use script to optimize model.
  201. # Use opt_level <= 1 for models to be converted to fp16, because some fused op (like FusedGemm) has only fp32 and no fp16.
  202. # It is better to be conservative so we use opt_level=0 here, in case MemcpyFromHost is added to the graph by OnnxRuntime.
  203. opt_model = optimize_model(
  204. onnx_model_path,
  205. model_type,
  206. num_heads=num_attention_heads,
  207. hidden_size=hidden_size,
  208. opt_level=0,
  209. optimization_options=optimization_options,
  210. use_gpu=use_gpu,
  211. only_onnxruntime=False,
  212. )
  213. if model_type == "bert_keras" or model_type == "bert_tf":
  214. opt_model.use_dynamic_axes()
  215. model_fusion_statistics[optimized_model_path] = opt_model.get_fused_operator_statistics()
  216. if precision == Precision.FLOAT16:
  217. opt_model.convert_float_to_float16(keep_io_types=True)
  218. opt_model.save_model_to_file(optimized_model_path, use_external_data_format)
  219. else:
  220. logger.info(f"Skip optimization since model existed: {optimized_model_path}")
  221. def modelclass_dispatcher(model_name, custom_model_class):
  222. if custom_model_class is not None:
  223. if custom_model_class in MODEL_CLASSES:
  224. return custom_model_class
  225. else:
  226. raise Exception("Valid model class: " + " ".join(MODEL_CLASSES))
  227. if model_name in PRETRAINED_GPT2_MODELS:
  228. return "GPT2ModelNoPastState"
  229. import re # noqa: PLC0415
  230. if re.search("-squad$", model_name) is not None:
  231. return "AutoModelForQuestionAnswering"
  232. elif re.search("-mprc$", model_name) is not None:
  233. return "AutoModelForSequenceClassification"
  234. elif re.search("gpt2", model_name) is not None:
  235. return "AutoModelWithLMHead"
  236. return "AutoModel"
  237. def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_tf_model=False):
  238. model_class_name = modelclass_dispatcher(model_name, custom_model_class)
  239. if model_class_name == "GPT2ModelNoPastState":
  240. if is_tf_model:
  241. return TFGPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
  242. else:
  243. return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
  244. if is_tf_model:
  245. model_class_name = "TF" + model_class_name
  246. transformers_module = __import__("transformers", fromlist=[model_class_name])
  247. logger.info(f"Model class name: {model_class_name}")
  248. model_class = getattr(transformers_module, model_class_name)
  249. return model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir)
  250. def load_pt_model(model_name, model_class, cache_dir, config_modifier):
  251. config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
  252. if hasattr(config, "return_dict"):
  253. config.return_dict = False
  254. config_modifier.modify(config)
  255. model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class)
  256. return config, model
  257. def load_tf_model(model_name, model_class, cache_dir, config_modifier):
  258. config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
  259. config_modifier.modify(config)
  260. # Loading tf model from transformers limits the cpu affinity to {0} when KMP_AFFINITY is set
  261. # Restore the affinity after model loading for expected ORT performance
  262. affinity_setting = AffinitySetting()
  263. affinity_setting.get_affinity()
  264. model = load_pretrained_model(
  265. model_name,
  266. config=config,
  267. cache_dir=cache_dir,
  268. custom_model_class=model_class,
  269. is_tf_model=True,
  270. )
  271. affinity_setting.set_affinity()
  272. return config, model
  273. # For test only
  274. def load_pt_model_from_tf(model_name):
  275. # Note that we could get pt model from tf, but model source and its structure in this case is different from directly using
  276. # load_pt_model() and load_tf_model() even with the same name. Therefore it should not be used for comparing with them
  277. from convert_tf_models_to_pytorch import tf2pt_pipeline # noqa: PLC0415
  278. config, model = tf2pt_pipeline(model_name)
  279. return config, model
  280. def validate_and_optimize_onnx(
  281. model_name,
  282. use_external_data_format,
  283. model_type,
  284. onnx_dir,
  285. input_names,
  286. use_gpu,
  287. precision,
  288. optimize_info,
  289. validate_onnx,
  290. use_raw_attention_mask,
  291. overwrite,
  292. config,
  293. model_fusion_statistics,
  294. onnx_model_path,
  295. example_inputs,
  296. example_outputs_flatten,
  297. output_names,
  298. fusion_options,
  299. ):
  300. is_valid_onnx_model = True
  301. if validate_onnx:
  302. is_valid_onnx_model = validate_onnx_model(
  303. onnx_model_path,
  304. example_inputs,
  305. example_outputs_flatten,
  306. use_gpu,
  307. False,
  308. output_names,
  309. )
  310. if optimize_info.name == OptimizerInfo.NOOPT.name:
  311. return onnx_model_path, is_valid_onnx_model, config.vocab_size
  312. if (
  313. optimize_info.name == OptimizerInfo.BYSCRIPT.name
  314. or precision == Precision.FLOAT16
  315. or precision == Precision.INT8
  316. ): # Use script (optimizer.py) to optimize
  317. optimized_model_path = get_onnx_file_path(
  318. onnx_dir,
  319. model_name,
  320. len(input_names),
  321. True,
  322. use_gpu,
  323. precision,
  324. False,
  325. use_external_data_format,
  326. )
  327. optimize_onnx_model(
  328. onnx_model_path,
  329. optimized_model_path,
  330. model_type,
  331. config.num_attention_heads,
  332. config.hidden_size,
  333. use_gpu,
  334. precision,
  335. use_raw_attention_mask,
  336. overwrite,
  337. model_fusion_statistics,
  338. use_external_data_format,
  339. fusion_options,
  340. )
  341. onnx_model_path = optimized_model_path
  342. if validate_onnx:
  343. is_valid_onnx_model = validate_onnx_model(
  344. onnx_model_path,
  345. example_inputs,
  346. example_outputs_flatten,
  347. use_gpu,
  348. precision == Precision.FLOAT16,
  349. output_names,
  350. )
  351. if precision == Precision.INT8:
  352. logger.info(f"Quantizing model: {onnx_model_path}")
  353. QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_path, use_external_data_format)
  354. logger.info(f"Finished quantizing model: {onnx_model_path}")
  355. if optimize_info.name == OptimizerInfo.BYORT.name: # Use OnnxRuntime to optimize
  356. if is_valid_onnx_model:
  357. ort_model_path = add_filename_suffix(onnx_model_path, "_ort")
  358. optimize_onnx_model_by_ort(
  359. onnx_model_path,
  360. ort_model_path,
  361. use_gpu,
  362. overwrite,
  363. model_fusion_statistics,
  364. )
  365. return (
  366. onnx_model_path,
  367. is_valid_onnx_model,
  368. config.num_labels if model_type in ["vit", "swin"] else config.vocab_size,
  369. )
  370. def export_onnx_model_from_pt(
  371. model_name,
  372. opset_version,
  373. use_external_data_format,
  374. model_type,
  375. model_class,
  376. config_modifier,
  377. cache_dir,
  378. onnx_dir,
  379. input_names,
  380. use_gpu,
  381. precision,
  382. optimizer_info,
  383. validate_onnx,
  384. use_raw_attention_mask,
  385. overwrite,
  386. model_fusion_statistics,
  387. fusion_options,
  388. ):
  389. config, model = load_pt_model(model_name, model_class, cache_dir, config_modifier)
  390. # config, model = load_pt_model_from_tf(model_name)
  391. model.cpu()
  392. example_inputs = None
  393. max_input_size = None
  394. if model_type in ["vit", "swin"]:
  395. image_processor = AutoFeatureExtractor.from_pretrained(model_name, cache_dir=cache_dir)
  396. data = numpy.random.randint(
  397. low=0, high=256, size=config.image_size * config.image_size * 3, dtype=numpy.uint8
  398. ).reshape(config.image_size, config.image_size, 3)
  399. example_inputs = image_processor(data, return_tensors="pt")
  400. else:
  401. tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
  402. max_input_size = tokenizer.model_max_length
  403. example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")
  404. example_inputs = filter_inputs(example_inputs, input_names)
  405. example_outputs = model(**example_inputs)
  406. assert isinstance(example_outputs, (list, tuple)), f"type of output is not list or tuple: {type(example_outputs)}"
  407. # Flatten is needed for gpt2 and distilgpt2.
  408. example_outputs_flatten = flatten(example_outputs)
  409. example_outputs_flatten = update_flatten_list(example_outputs_flatten, [])
  410. onnx_model_path = get_onnx_file_path(
  411. onnx_dir,
  412. model_name,
  413. len(input_names),
  414. False,
  415. use_gpu,
  416. precision,
  417. False,
  418. use_external_data_format,
  419. )
  420. if overwrite or not os.path.exists(onnx_model_path):
  421. logger.info(f"Exporting ONNX model to {onnx_model_path}")
  422. Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  423. dynamic_axes = None
  424. output_names = None
  425. if model_type in ["vit", "swin"]:
  426. dynamic_axes, output_names = {key: {0: "pixel_values"} for key in example_inputs}, ["logits"]
  427. else:
  428. dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)
  429. replace_torch_functions()
  430. torch_onnx_export(
  431. model=model,
  432. args=tuple(example_inputs.values()),
  433. f=onnx_model_path,
  434. input_names=list(example_inputs.keys()),
  435. output_names=output_names,
  436. dynamic_axes=dynamic_axes,
  437. do_constant_folding=True,
  438. opset_version=opset_version,
  439. use_external_data_format=use_external_data_format,
  440. )
  441. restore_torch_functions()
  442. else:
  443. logger.info(f"Skip export since model existed: {onnx_model_path}")
  444. onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
  445. model_name,
  446. use_external_data_format,
  447. model_type,
  448. onnx_dir,
  449. input_names,
  450. use_gpu,
  451. precision,
  452. optimizer_info,
  453. validate_onnx,
  454. use_raw_attention_mask,
  455. overwrite,
  456. config,
  457. model_fusion_statistics,
  458. onnx_model_path,
  459. example_inputs,
  460. example_outputs_flatten,
  461. None,
  462. fusion_options,
  463. )
  464. return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size
  465. def export_onnx_model_from_tf(
  466. model_name,
  467. opset_version,
  468. use_external_data_format,
  469. model_type,
  470. model_class,
  471. config_modifier,
  472. cache_dir,
  473. onnx_dir,
  474. input_names,
  475. use_gpu,
  476. precision,
  477. optimizer_info,
  478. validate_onnx,
  479. use_raw_attention_mask,
  480. overwrite,
  481. model_fusion_statistics,
  482. fusion_options,
  483. ):
  484. # Use CPU to export
  485. import tensorflow as tf # noqa: PLC0415
  486. tf.config.set_visible_devices([], "GPU")
  487. tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
  488. # Fix "Using pad_token, but it is not set yet" error.
  489. if tokenizer.pad_token is None:
  490. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  491. max_input_size = tokenizer.model_max_length
  492. config, model = load_tf_model(model_name, model_class, cache_dir, config_modifier)
  493. model.resize_token_embeddings(len(tokenizer))
  494. example_inputs = tokenizer.encode_plus(
  495. "This is a sample input",
  496. return_tensors="tf",
  497. max_length=max_input_size,
  498. padding="max_length",
  499. truncation=True,
  500. )
  501. example_inputs = filter_inputs(example_inputs, input_names)
  502. if config.is_encoder_decoder:
  503. example_inputs["decoder_input_ids"] = tokenizer.encode_plus(
  504. "This is a sample input",
  505. return_tensors="tf",
  506. max_length=max_input_size,
  507. padding="max_length",
  508. truncation=True,
  509. ).input_ids
  510. if model_name == "unc-nlp/lxmert-base-uncased":
  511. example_inputs["visual_feats"] = tf.random.normal([1, 1, config.visual_feat_dim])
  512. example_inputs["visual_pos"] = tf.random.normal([1, 1, config.visual_pos_dim])
  513. try:
  514. # Use no past state for these models
  515. if config.use_cache:
  516. config.use_cache = False
  517. except Exception:
  518. pass
  519. example_outputs = model(example_inputs, training=False)
  520. output_names = None
  521. # For xlnet models, only compare the last_hidden_state output.
  522. if model_name == "xlnet-base-cased" or model_name == "xlnet-large-cased":
  523. output_names = ["last_hidden_state"]
  524. example_outputs = example_outputs["last_hidden_state"]
  525. # Flatten is needed for gpt2 and distilgpt2. Output name sorting is needed for tf2onnx outputs to match onnx outputs.
  526. from tensorflow.python.util import nest # noqa: PLC0415
  527. example_outputs_flatten = nest.flatten(example_outputs)
  528. onnx_model_path = get_onnx_file_path(
  529. onnx_dir,
  530. model_name,
  531. len(input_names),
  532. False,
  533. use_gpu,
  534. precision,
  535. False,
  536. use_external_data_format,
  537. )
  538. tf_internal_model_path = onnx_model_path[:-5] if use_external_data_format else onnx_model_path
  539. if overwrite or not os.path.exists(tf_internal_model_path):
  540. logger.info(f"Exporting ONNX model to {onnx_model_path}")
  541. if not use_external_data_format:
  542. Path(tf_internal_model_path).parent.mkdir(parents=True, exist_ok=True)
  543. import zipfile # noqa: PLC0415
  544. import tf2onnx # noqa: PLC0415
  545. tf2onnx.logging.set_level(tf2onnx.logging.ERROR)
  546. specs = []
  547. for name, value in example_inputs.items():
  548. dims = [None] * len(value.shape)
  549. specs.append(tf.TensorSpec(tuple(dims), value.dtype, name=name))
  550. _, _ = tf2onnx.convert.from_keras(
  551. model,
  552. input_signature=tuple(specs),
  553. opset=opset_version,
  554. large_model=use_external_data_format,
  555. output_path=tf_internal_model_path,
  556. )
  557. if use_external_data_format:
  558. # need to unpack the zip for run_onnxruntime()
  559. with zipfile.ZipFile(tf_internal_model_path, "r") as z:
  560. z.extractall(os.path.dirname(tf_internal_model_path))
  561. tf_internal_model_path = os.path.join(os.path.dirname(tf_internal_model_path), "__MODEL_PROTO.onnx")
  562. if os.path.exists(onnx_model_path):
  563. os.remove(onnx_model_path)
  564. os.rename(tf_internal_model_path, onnx_model_path)
  565. else:
  566. logger.info(f"Skip export since model existed: {onnx_model_path}")
  567. model_type = model_type + "_tf"
  568. optimized_onnx_path, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
  569. model_name,
  570. use_external_data_format,
  571. model_type,
  572. onnx_dir,
  573. input_names,
  574. use_gpu,
  575. precision,
  576. optimizer_info,
  577. validate_onnx,
  578. use_raw_attention_mask,
  579. overwrite,
  580. config,
  581. model_fusion_statistics,
  582. onnx_model_path,
  583. example_inputs,
  584. example_outputs_flatten,
  585. output_names,
  586. fusion_options,
  587. )
  588. return (
  589. optimized_onnx_path,
  590. is_valid_onnx_model,
  591. vocab_size,
  592. max_input_size,
  593. )