compare_bert_results.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. # It is a tool to compare the inference results of the original model and optimized model.
  6. import argparse
  7. import statistics
  8. from pathlib import Path
  9. import numpy as np
  10. import psutil
  11. from bert_perf_test import create_session, onnxruntime_inference
  12. from bert_test_data import generate_test_data, get_bert_inputs, output_test_data
  13. def run_model(model_path, all_inputs, use_gpu, disable_optimization):
  14. import onnxruntime # noqa: PLC0415
  15. graph_optimization_level = None
  16. if disable_optimization:
  17. graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
  18. intra_op_num_threads = psutil.cpu_count(logical=False)
  19. session = create_session(
  20. model_path, use_gpu, "cuda" if use_gpu else "cpu", intra_op_num_threads, graph_optimization_level
  21. )
  22. output_names = [output.name for output in session.get_outputs()]
  23. results, latency_list = onnxruntime_inference(session, all_inputs, output_names)
  24. return results, latency_list, output_names
  25. def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3):
  26. # Validate the output of baseline and treatment, to make sure the results are similar.
  27. diff_count = 0
  28. max_abs_diff = 0
  29. max_diff_percentage = 0
  30. case_passed = True
  31. for test_case_id, results in enumerate(baseline_results):
  32. for i in range(len(results)):
  33. treatment_output = treatment_results[test_case_id][i]
  34. abs_diff_tensor = np.abs(treatment_output - results[i])
  35. abs_diff = np.amax(abs_diff_tensor)
  36. if verbose and abs_diff > atol:
  37. print("abs_diff", abs_diff)
  38. print("treatment", treatment_output)
  39. print("baseline", results[i])
  40. count_exceeding = np.sum(abs_diff_tensor > atol)
  41. total_elements = abs_diff_tensor.size
  42. percentage_exceeding = (count_exceeding / total_elements) * 100
  43. max_diff_percentage = max(max_diff_percentage, percentage_exceeding)
  44. max_abs_diff = max(max_abs_diff, abs_diff)
  45. if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol):
  46. if case_passed:
  47. case_passed = False
  48. diff_count += 1
  49. if verbose:
  50. print(f"case {test_case_id} output {i}")
  51. print(f"baseline={results[i].tolist()}\ntreatment={treatment_output}")
  52. print(f"abs_diff={abs_diff}")
  53. if diff_count == 0:
  54. print(f"100% passed for {len(baseline_results)} random inputs given thresholds (rtol={rtol}, atol={atol}).")
  55. else:
  56. print(
  57. f"WARNING: {diff_count} out of {len(baseline_results)} results NOT passed for thresholds (rtol={rtol}, atol={atol})."
  58. )
  59. print(f"maximum absolute difference={max_abs_diff}")
  60. print(f"maximum percentage of elements that exceeds atol={atol} is {max_diff_percentage:.3f}%")
  61. return max_abs_diff, case_passed
  62. def run_test(
  63. baseline_model,
  64. optimized_model,
  65. output_dir,
  66. batch_size,
  67. sequence_length,
  68. use_gpu,
  69. test_cases,
  70. seed,
  71. verbose,
  72. rtol,
  73. atol,
  74. input_ids_name,
  75. segment_ids_name,
  76. input_mask_name,
  77. mask_type,
  78. dictionary_size: int = 1024,
  79. ):
  80. # Try deduce input names from optimized model.
  81. input_ids, segment_ids, input_mask = get_bert_inputs(
  82. optimized_model, input_ids_name, segment_ids_name, input_mask_name
  83. )
  84. # Use random mask length for accuracy test. It might introduce slight inflation in latency reported in this script.
  85. average_sequence_length = int(sequence_length / 2) if sequence_length >= 2 else sequence_length
  86. all_inputs = generate_test_data(
  87. batch_size,
  88. sequence_length,
  89. test_cases,
  90. seed,
  91. verbose,
  92. input_ids,
  93. segment_ids,
  94. input_mask,
  95. average_sequence_length,
  96. True, # random sequence length
  97. mask_type,
  98. dictionary_size=dictionary_size,
  99. )
  100. baseline_results, baseline_latency, output_names = run_model(
  101. baseline_model, all_inputs, use_gpu, disable_optimization=True
  102. )
  103. if verbose:
  104. print(f"baseline average latency (all optimizations disabled): {statistics.mean(baseline_latency) * 1000} ms")
  105. if output_dir is not None:
  106. for i, inputs in enumerate(all_inputs):
  107. output_test_data(output_dir, i, inputs)
  108. treatment_results, treatment_latency, treatment_output_names = run_model(
  109. optimized_model, all_inputs, use_gpu, disable_optimization=False
  110. )
  111. if verbose:
  112. print(f"treatment average latency: {statistics.mean(treatment_latency) * 1000} ms")
  113. # Validate the output of baseline and treatment, to make sure the results are similar.
  114. return compare(baseline_results, treatment_results, verbose, rtol, atol)
  115. def parse_arguments():
  116. parser = argparse.ArgumentParser()
  117. parser.add_argument("--baseline_model", required=True, type=str, help="baseline onnx model path.")
  118. parser.add_argument(
  119. "--optimized_model",
  120. required=True,
  121. type=str,
  122. default=None,
  123. help="path of the optimized model. It shall have same inputs as the baseline model.",
  124. )
  125. parser.add_argument(
  126. "--output_dir",
  127. required=False,
  128. type=str,
  129. default=None,
  130. help="output test data path. If not specified, test data will not be saved.",
  131. )
  132. parser.add_argument("--batch_size", required=True, type=int, help="batch size of input")
  133. parser.add_argument(
  134. "--sequence_length",
  135. required=True,
  136. type=int,
  137. help="maximum sequence length of input",
  138. )
  139. parser.add_argument("--rtol", required=False, type=float, default=1e-3, help="relative tolerance")
  140. parser.add_argument("--atol", required=False, type=float, default=1e-4, help="absolute tolerance")
  141. parser.add_argument(
  142. "--samples",
  143. required=False,
  144. type=int,
  145. default=100,
  146. help="number of test cases to be generated",
  147. )
  148. parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
  149. parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU")
  150. parser.set_defaults(use_gpu=False)
  151. parser.add_argument(
  152. "--verbose",
  153. required=False,
  154. action="store_true",
  155. help="print verbose information",
  156. )
  157. parser.set_defaults(verbose=False)
  158. parser.add_argument(
  159. "--input_ids",
  160. required=False,
  161. type=str,
  162. default=None,
  163. help="input name for input ids",
  164. )
  165. parser.add_argument(
  166. "--segment_ids",
  167. required=False,
  168. type=str,
  169. default=None,
  170. help="input name for segment ids",
  171. )
  172. parser.add_argument(
  173. "--input_mask",
  174. required=False,
  175. type=str,
  176. default=None,
  177. help="input name for attention mask",
  178. )
  179. parser.add_argument(
  180. "--mask_type",
  181. required=False,
  182. type=int,
  183. default=2,
  184. help="mask type: (1: mask index or sequence length, 2: raw 2D mask, 3: key len, cumulated lengths of query and key)",
  185. )
  186. args = parser.parse_args()
  187. return args
  188. def main():
  189. args = parse_arguments()
  190. if args.output_dir is not None:
  191. # create the output directory if not existed
  192. path = Path(args.output_dir)
  193. path.mkdir(parents=True, exist_ok=True)
  194. run_test(
  195. args.baseline_model,
  196. args.optimized_model,
  197. args.output_dir,
  198. args.batch_size,
  199. args.sequence_length,
  200. args.use_gpu,
  201. args.samples,
  202. args.seed,
  203. args.verbose,
  204. args.rtol,
  205. args.atol,
  206. args.input_ids,
  207. args.segment_ids,
  208. args.input_mask,
  209. args.mask_type,
  210. )
  211. if __name__ == "__main__":
  212. main()