bert_test_data.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. # It is a tool to generate test data for a bert model.
  6. # The test data can be used by onnxruntime_perf_test tool to evaluate the inference latency.
  7. import argparse
  8. import os
  9. import random
  10. from pathlib import Path
  11. import numpy as np
  12. from onnx import ModelProto, TensorProto, numpy_helper
  13. from onnx_model import OnnxModel
  14. def fake_input_ids_data(
  15. input_ids: TensorProto, batch_size: int, sequence_length: int, dictionary_size: int
  16. ) -> np.ndarray:
  17. """Create input tensor based on the graph input of input_ids
  18. Args:
  19. input_ids (TensorProto): graph input of the input_ids input tensor
  20. batch_size (int): batch size
  21. sequence_length (int): sequence length
  22. dictionary_size (int): vocabulary size of dictionary
  23. Returns:
  24. np.ndarray: the input tensor created
  25. """
  26. assert input_ids.type.tensor_type.elem_type in [
  27. TensorProto.FLOAT,
  28. TensorProto.INT32,
  29. TensorProto.INT64,
  30. ]
  31. data = np.random.randint(dictionary_size, size=(batch_size, sequence_length), dtype=np.int32)
  32. if input_ids.type.tensor_type.elem_type == TensorProto.FLOAT:
  33. data = np.float32(data)
  34. elif input_ids.type.tensor_type.elem_type == TensorProto.INT64:
  35. data = np.int64(data)
  36. return data
  37. def fake_segment_ids_data(segment_ids: TensorProto, batch_size: int, sequence_length: int) -> np.ndarray:
  38. """Create input tensor based on the graph input of segment_ids
  39. Args:
  40. segment_ids (TensorProto): graph input of the token_type_ids input tensor
  41. batch_size (int): batch size
  42. sequence_length (int): sequence length
  43. Returns:
  44. np.ndarray: the input tensor created
  45. """
  46. assert segment_ids.type.tensor_type.elem_type in [
  47. TensorProto.FLOAT,
  48. TensorProto.INT32,
  49. TensorProto.INT64,
  50. ]
  51. data = np.zeros((batch_size, sequence_length), dtype=np.int32)
  52. if segment_ids.type.tensor_type.elem_type == TensorProto.FLOAT:
  53. data = np.float32(data)
  54. elif segment_ids.type.tensor_type.elem_type == TensorProto.INT64:
  55. data = np.int64(data)
  56. return data
  57. def get_random_length(max_sequence_length: int, average_sequence_length: int):
  58. assert average_sequence_length >= 1 and average_sequence_length <= max_sequence_length
  59. # For uniform distribution, we find proper lower and upper bounds so that the average is in the middle.
  60. if 2 * average_sequence_length > max_sequence_length:
  61. return random.randint(2 * average_sequence_length - max_sequence_length, max_sequence_length)
  62. else:
  63. return random.randint(1, 2 * average_sequence_length - 1)
  64. def fake_input_mask_data(
  65. input_mask: TensorProto,
  66. batch_size: int,
  67. sequence_length: int,
  68. average_sequence_length: int,
  69. random_sequence_length: bool,
  70. mask_type: int = 2,
  71. ) -> np.ndarray:
  72. """Create input tensor based on the graph input of segment_ids.
  73. Args:
  74. input_mask (TensorProto): graph input of the attention mask input tensor
  75. batch_size (int): batch size
  76. sequence_length (int): sequence length
  77. average_sequence_length (int): average sequence length excluding paddings
  78. random_sequence_length (bool): whether use uniform random number for sequence length
  79. mask_type (int): mask type - 1: mask index (sequence length excluding paddings). Shape is (batch_size).
  80. 2: 2D attention mask. Shape is (batch_size, sequence_length).
  81. 3: key len, cumulated lengths of query and key. Shape is (3 * batch_size + 2).
  82. Returns:
  83. np.ndarray: the input tensor created
  84. """
  85. assert input_mask.type.tensor_type.elem_type in [
  86. TensorProto.FLOAT,
  87. TensorProto.INT32,
  88. TensorProto.INT64,
  89. ]
  90. if mask_type == 1: # sequence length excluding paddings
  91. data = np.ones((batch_size), dtype=np.int32)
  92. if random_sequence_length:
  93. for i in range(batch_size):
  94. data[i] = get_random_length(sequence_length, average_sequence_length)
  95. else:
  96. for i in range(batch_size):
  97. data[i] = average_sequence_length
  98. elif mask_type == 2: # 2D attention mask
  99. data = np.zeros((batch_size, sequence_length), dtype=np.int32)
  100. if random_sequence_length:
  101. for i in range(batch_size):
  102. actual_seq_len = get_random_length(sequence_length, average_sequence_length)
  103. for j in range(actual_seq_len):
  104. data[i, j] = 1
  105. else:
  106. temp = np.ones((batch_size, average_sequence_length), dtype=np.int32)
  107. data[: temp.shape[0], : temp.shape[1]] = temp
  108. else:
  109. assert mask_type == 3
  110. data = np.zeros((batch_size * 3 + 2), dtype=np.int32)
  111. if random_sequence_length:
  112. for i in range(batch_size):
  113. data[i] = get_random_length(sequence_length, average_sequence_length)
  114. for i in range(batch_size + 1):
  115. data[batch_size + i] = data[batch_size + i - 1] + data[i - 1] if i > 0 else 0
  116. data[2 * batch_size + 1 + i] = data[batch_size + i - 1] + data[i - 1] if i > 0 else 0
  117. else:
  118. for i in range(batch_size):
  119. data[i] = average_sequence_length
  120. for i in range(batch_size + 1):
  121. data[batch_size + i] = i * average_sequence_length
  122. data[2 * batch_size + 1 + i] = i * average_sequence_length
  123. if input_mask.type.tensor_type.elem_type == TensorProto.FLOAT:
  124. data = np.float32(data)
  125. elif input_mask.type.tensor_type.elem_type == TensorProto.INT64:
  126. data = np.int64(data)
  127. return data
  128. def output_test_data(directory: str, inputs: dict[str, np.ndarray]):
  129. """Output input tensors of test data to a directory
  130. Args:
  131. directory (str): path of a directory
  132. inputs (Dict[str, np.ndarray]): map from input name to value
  133. """
  134. if not os.path.exists(directory):
  135. try:
  136. os.mkdir(directory)
  137. except OSError:
  138. print(f"Creation of the directory {directory} failed")
  139. else:
  140. print(f"Successfully created the directory {directory} ")
  141. else:
  142. print(f"Warning: directory {directory} existed. Files will be overwritten.")
  143. for index, (name, data) in enumerate(inputs.items()):
  144. tensor = numpy_helper.from_array(data, name)
  145. with open(os.path.join(directory, f"input_{index}.pb"), "wb") as file:
  146. file.write(tensor.SerializeToString())
  147. def fake_test_data(
  148. batch_size: int,
  149. sequence_length: int,
  150. test_cases: int,
  151. dictionary_size: int,
  152. verbose: bool,
  153. random_seed: int,
  154. input_ids: TensorProto,
  155. segment_ids: TensorProto,
  156. input_mask: TensorProto,
  157. average_sequence_length: int,
  158. random_sequence_length: bool,
  159. mask_type: int,
  160. ):
  161. """Create given number of input data for testing
  162. Args:
  163. batch_size (int): batch size
  164. sequence_length (int): sequence length
  165. test_cases (int): number of test cases
  166. dictionary_size (int): vocabulary size of dictionary for input_ids
  167. verbose (bool): print more information or not
  168. random_seed (int): random seed
  169. input_ids (TensorProto): graph input of input IDs
  170. segment_ids (TensorProto): graph input of token type IDs
  171. input_mask (TensorProto): graph input of attention mask
  172. average_sequence_length (int): average sequence length excluding paddings
  173. random_sequence_length (bool): whether use uniform random number for sequence length
  174. mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key
  175. Returns:
  176. List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
  177. with input name as key and a tensor as value
  178. """
  179. assert input_ids is not None
  180. np.random.seed(random_seed)
  181. random.seed(random_seed)
  182. all_inputs = []
  183. for _test_case in range(test_cases):
  184. input_1 = fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size)
  185. inputs = {input_ids.name: input_1}
  186. if segment_ids:
  187. inputs[segment_ids.name] = fake_segment_ids_data(segment_ids, batch_size, sequence_length)
  188. if input_mask:
  189. inputs[input_mask.name] = fake_input_mask_data(
  190. input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length, mask_type
  191. )
  192. if verbose and len(all_inputs) == 0:
  193. print("Example inputs", inputs)
  194. all_inputs.append(inputs)
  195. return all_inputs
  196. def generate_test_data(
  197. batch_size: int,
  198. sequence_length: int,
  199. test_cases: int,
  200. seed: int,
  201. verbose: bool,
  202. input_ids: TensorProto,
  203. segment_ids: TensorProto,
  204. input_mask: TensorProto,
  205. average_sequence_length: int,
  206. random_sequence_length: bool,
  207. mask_type: int,
  208. dictionary_size: int = 10000,
  209. ):
  210. """Create given number of input data for testing
  211. Args:
  212. batch_size (int): batch size
  213. sequence_length (int): sequence length
  214. test_cases (int): number of test cases
  215. seed (int): random seed
  216. verbose (bool): print more information or not
  217. input_ids (TensorProto): graph input of input IDs
  218. segment_ids (TensorProto): graph input of token type IDs
  219. input_mask (TensorProto): graph input of attention mask
  220. average_sequence_length (int): average sequence length excluding paddings
  221. random_sequence_length (bool): whether use uniform random number for sequence length
  222. mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key
  223. Returns:
  224. List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
  225. with input name as key and a tensor as value
  226. """
  227. all_inputs = fake_test_data(
  228. batch_size,
  229. sequence_length,
  230. test_cases,
  231. dictionary_size,
  232. verbose,
  233. seed,
  234. input_ids,
  235. segment_ids,
  236. input_mask,
  237. average_sequence_length,
  238. random_sequence_length,
  239. mask_type,
  240. )
  241. if len(all_inputs) != test_cases:
  242. print("Failed to create test data for test.")
  243. return all_inputs
  244. def get_graph_input_from_embed_node(onnx_model, embed_node, input_index):
  245. if input_index >= len(embed_node.input):
  246. return None
  247. input = embed_node.input[input_index]
  248. graph_input = onnx_model.find_graph_input(input)
  249. if graph_input is None:
  250. parent_node = onnx_model.get_parent(embed_node, input_index)
  251. if parent_node is not None and parent_node.op_type == "Cast":
  252. graph_input = onnx_model.find_graph_input(parent_node.input[0])
  253. return graph_input
  254. def find_bert_inputs(
  255. onnx_model: OnnxModel,
  256. input_ids_name: str | None = None,
  257. segment_ids_name: str | None = None,
  258. input_mask_name: str | None = None,
  259. ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
  260. """Find graph inputs for BERT model.
  261. First, we will deduce inputs from EmbedLayerNormalization node.
  262. If not found, we will guess the meaning of graph inputs based on naming.
  263. Args:
  264. onnx_model (OnnxModel): onnx model object
  265. input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
  266. segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
  267. input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
  268. Raises:
  269. ValueError: Graph does not have input named of input_ids_name or segment_ids_name or input_mask_name
  270. ValueError: Expected graph input number does not match with specified input_ids_name, segment_ids_name
  271. and input_mask_name
  272. Returns:
  273. Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
  274. segment_ids and input_mask
  275. """
  276. graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
  277. if input_ids_name is not None:
  278. input_ids = onnx_model.find_graph_input(input_ids_name)
  279. if input_ids is None:
  280. raise ValueError(f"Graph does not have input named {input_ids_name}")
  281. segment_ids = None
  282. if segment_ids_name:
  283. segment_ids = onnx_model.find_graph_input(segment_ids_name)
  284. if segment_ids is None:
  285. raise ValueError(f"Graph does not have input named {segment_ids_name}")
  286. input_mask = None
  287. if input_mask_name:
  288. input_mask = onnx_model.find_graph_input(input_mask_name)
  289. if input_mask is None:
  290. raise ValueError(f"Graph does not have input named {input_mask_name}")
  291. expected_inputs = 1 + (1 if segment_ids else 0) + (1 if input_mask else 0)
  292. if len(graph_inputs) != expected_inputs:
  293. raise ValueError(f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}")
  294. return input_ids, segment_ids, input_mask
  295. if len(graph_inputs) != 3:
  296. raise ValueError(f"Expect the graph to have 3 inputs. Got {len(graph_inputs)}")
  297. embed_nodes = onnx_model.get_nodes_by_op_type("EmbedLayerNormalization")
  298. if len(embed_nodes) == 1:
  299. embed_node = embed_nodes[0]
  300. input_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 0)
  301. segment_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 1)
  302. input_mask = get_graph_input_from_embed_node(onnx_model, embed_node, 7)
  303. if input_mask is None:
  304. for input in graph_inputs:
  305. input_name_lower = input.name.lower()
  306. if "mask" in input_name_lower:
  307. input_mask = input
  308. if input_mask is None:
  309. raise ValueError("Failed to find attention mask input")
  310. return input_ids, segment_ids, input_mask
  311. # Try guess the inputs based on naming.
  312. input_ids = None
  313. segment_ids = None
  314. input_mask = None
  315. for input in graph_inputs:
  316. input_name_lower = input.name.lower()
  317. if "mask" in input_name_lower: # matches input with name like "attention_mask" or "input_mask"
  318. input_mask = input
  319. elif (
  320. "token" in input_name_lower or "segment" in input_name_lower
  321. ): # matches input with name like "segment_ids" or "token_type_ids"
  322. segment_ids = input
  323. else:
  324. input_ids = input
  325. if input_ids and segment_ids and input_mask:
  326. return input_ids, segment_ids, input_mask
  327. raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.")
  328. def get_bert_inputs(
  329. onnx_file: str,
  330. input_ids_name: str | None = None,
  331. segment_ids_name: str | None = None,
  332. input_mask_name: str | None = None,
  333. ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
  334. """Find graph inputs for BERT model.
  335. First, we will deduce inputs from EmbedLayerNormalization node.
  336. If not found, we will guess the meaning of graph inputs based on naming.
  337. Args:
  338. onnx_file (str): onnx model path
  339. input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
  340. segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
  341. input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
  342. Returns:
  343. Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
  344. segment_ids and input_mask
  345. """
  346. model = ModelProto()
  347. with open(onnx_file, "rb") as file:
  348. model.ParseFromString(file.read())
  349. onnx_model = OnnxModel(model)
  350. return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
  351. def parse_arguments():
  352. parser = argparse.ArgumentParser()
  353. parser.add_argument("--model", required=True, type=str, help="bert onnx model path.")
  354. parser.add_argument(
  355. "--output_dir",
  356. required=False,
  357. type=str,
  358. default=None,
  359. help="output test data path. Default is current directory.",
  360. )
  361. parser.add_argument("--batch_size", required=False, type=int, default=1, help="batch size of input")
  362. parser.add_argument(
  363. "--sequence_length",
  364. required=False,
  365. type=int,
  366. default=128,
  367. help="maximum sequence length of input",
  368. )
  369. parser.add_argument(
  370. "--input_ids_name",
  371. required=False,
  372. type=str,
  373. default=None,
  374. help="input name for input ids",
  375. )
  376. parser.add_argument(
  377. "--segment_ids_name",
  378. required=False,
  379. type=str,
  380. default=None,
  381. help="input name for segment ids",
  382. )
  383. parser.add_argument(
  384. "--input_mask_name",
  385. required=False,
  386. type=str,
  387. default=None,
  388. help="input name for attention mask",
  389. )
  390. parser.add_argument(
  391. "--samples",
  392. required=False,
  393. type=int,
  394. default=1,
  395. help="number of test cases to be generated",
  396. )
  397. parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
  398. parser.add_argument(
  399. "--verbose",
  400. required=False,
  401. action="store_true",
  402. help="print verbose information",
  403. )
  404. parser.set_defaults(verbose=False)
  405. parser.add_argument(
  406. "--only_input_tensors",
  407. required=False,
  408. action="store_true",
  409. help="only save input tensors and no output tensors",
  410. )
  411. parser.set_defaults(only_input_tensors=False)
  412. parser.add_argument(
  413. "-a",
  414. "--average_sequence_length",
  415. default=-1,
  416. type=int,
  417. help="average sequence length excluding padding",
  418. )
  419. parser.add_argument(
  420. "-r",
  421. "--random_sequence_length",
  422. required=False,
  423. action="store_true",
  424. help="use uniform random instead of fixed sequence length",
  425. )
  426. parser.set_defaults(random_sequence_length=False)
  427. parser.add_argument(
  428. "--mask_type",
  429. required=False,
  430. type=int,
  431. default=2,
  432. help="mask type: (1: mask index, 2: raw 2D mask, 3: key lengths, cumulated lengths of query and key)",
  433. )
  434. args = parser.parse_args()
  435. return args
  436. def create_and_save_test_data(
  437. model: str,
  438. output_dir: str,
  439. batch_size: int,
  440. sequence_length: int,
  441. test_cases: int,
  442. seed: int,
  443. verbose: bool,
  444. input_ids_name: str | None,
  445. segment_ids_name: str | None,
  446. input_mask_name: str | None,
  447. only_input_tensors: bool,
  448. average_sequence_length: int,
  449. random_sequence_length: bool,
  450. mask_type: int,
  451. ):
  452. """Create test data for a model, and save test data to a directory.
  453. Args:
  454. model (str): path of ONNX bert model
  455. output_dir (str): output directory
  456. batch_size (int): batch size
  457. sequence_length (int): sequence length
  458. test_cases (int): number of test cases
  459. seed (int): random seed
  460. verbose (bool): whether print more information
  461. input_ids_name (str): graph input name of input_ids
  462. segment_ids_name (str): graph input name of segment_ids
  463. input_mask_name (str): graph input name of input_mask
  464. only_input_tensors (bool): only save input tensors,
  465. average_sequence_length (int): average sequence length excluding paddings
  466. random_sequence_length (bool): whether use uniform random number for sequence length
  467. mask_type(int): mask type
  468. """
  469. input_ids, segment_ids, input_mask = get_bert_inputs(model, input_ids_name, segment_ids_name, input_mask_name)
  470. all_inputs = generate_test_data(
  471. batch_size,
  472. sequence_length,
  473. test_cases,
  474. seed,
  475. verbose,
  476. input_ids,
  477. segment_ids,
  478. input_mask,
  479. average_sequence_length,
  480. random_sequence_length,
  481. mask_type,
  482. )
  483. for i, inputs in enumerate(all_inputs):
  484. directory = os.path.join(output_dir, "test_data_set_" + str(i))
  485. output_test_data(directory, inputs)
  486. if only_input_tensors:
  487. return
  488. import onnxruntime # noqa: PLC0415
  489. providers = (
  490. ["CUDAExecutionProvider", "CPUExecutionProvider"]
  491. if "CUDAExecutionProvider" in onnxruntime.get_available_providers()
  492. else ["CPUExecutionProvider"]
  493. )
  494. session = onnxruntime.InferenceSession(model, providers=providers)
  495. output_names = [output.name for output in session.get_outputs()]
  496. for i, inputs in enumerate(all_inputs):
  497. directory = os.path.join(output_dir, "test_data_set_" + str(i))
  498. result = session.run(output_names, inputs)
  499. for i, output_name in enumerate(output_names): # noqa: PLW2901
  500. tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_name)
  501. with open(os.path.join(directory, f"output_{i}.pb"), "wb") as file:
  502. file.write(tensor_result.SerializeToString())
  503. def main():
  504. args = parse_arguments()
  505. if args.average_sequence_length <= 0:
  506. args.average_sequence_length = args.sequence_length
  507. output_dir = args.output_dir
  508. if output_dir is None:
  509. # Default output directory is a sub-directory under the directory of model.
  510. p = Path(args.model)
  511. output_dir = os.path.join(p.parent, f"batch_{args.batch_size}_seq_{args.sequence_length}")
  512. if output_dir is not None:
  513. # create the output directory if not existed
  514. path = Path(output_dir)
  515. path.mkdir(parents=True, exist_ok=True)
  516. else:
  517. print("Directory existed. test data files will be overwritten.")
  518. create_and_save_test_data(
  519. args.model,
  520. output_dir,
  521. args.batch_size,
  522. args.sequence_length,
  523. args.samples,
  524. args.seed,
  525. args.verbose,
  526. args.input_ids_name,
  527. args.segment_ids_name,
  528. args.input_mask_name,
  529. args.only_input_tensors,
  530. args.average_sequence_length,
  531. args.random_sequence_length,
  532. args.mask_type,
  533. )
  534. print("Test data is saved to directory:", output_dir)
  535. if __name__ == "__main__":
  536. main()