profiler.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. import argparse
  2. import os
  3. import numpy
  4. import psutil
  5. from onnx import TensorProto
  6. """
  7. This profiler tool could run a transformer model and print out the kernel time spent on each Node of the model.
  8. Example of profiling of longformer model:
  9. python profiler.py --model longformer-base-4096_fp32.onnx --batch_size 1 --sequence_length 4096 --global_length 8 --samples 1000 --thread_num 8 --dummy_inputs longformer --use_gpu
  10. Example of importing profile result file from onnxruntime_perf_test:
  11. python profiler.py --input profile_2021-10-25_12-02-41.json
  12. """
  13. def parse_arguments(argv=None):
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument(
  16. "-i",
  17. "--input",
  18. required=False,
  19. type=str,
  20. help="Set the input file for reading the profile results",
  21. )
  22. parser.add_argument(
  23. "-m",
  24. "--model",
  25. required=False,
  26. type=str,
  27. help="onnx model path to run profiling. Required when --input is not specified.",
  28. )
  29. parser.add_argument(
  30. "-b",
  31. "--batch_size",
  32. required=False,
  33. type=int,
  34. default=1,
  35. help="batch size of input",
  36. )
  37. parser.add_argument(
  38. "-s",
  39. "--sequence_length",
  40. required=False,
  41. type=int,
  42. default=32,
  43. help="sequence length of input",
  44. )
  45. parser.add_argument(
  46. "--past_sequence_length",
  47. required=False,
  48. type=int,
  49. default=1,
  50. help="past sequence length for gpt2",
  51. )
  52. parser.add_argument(
  53. "--global_length",
  54. required=False,
  55. type=int,
  56. default=1,
  57. help="number of global tokens for longformer",
  58. )
  59. parser.add_argument(
  60. "--samples",
  61. required=False,
  62. type=int,
  63. default=1000,
  64. help="number of samples to test. Set it large enough to reduce the variance of performance result.",
  65. )
  66. parser.add_argument(
  67. "--threshold",
  68. required=False,
  69. type=float,
  70. default=0.01,
  71. help="Threshold of run time ratio among all nodes. Nodes with larger ratio will show in top expensive nodes.",
  72. )
  73. parser.add_argument(
  74. "--thread_num",
  75. required=False,
  76. type=int,
  77. default=-1,
  78. help="number of threads to use",
  79. )
  80. parser.add_argument(
  81. "--input_ids_name",
  82. required=False,
  83. type=str,
  84. default=None,
  85. help="input name for input IDs, for bert",
  86. )
  87. parser.add_argument(
  88. "--segment_ids_name",
  89. required=False,
  90. type=str,
  91. default=None,
  92. help="input name for segment IDs, for bert",
  93. )
  94. parser.add_argument(
  95. "--input_mask_name",
  96. required=False,
  97. type=str,
  98. default=None,
  99. help="input name for attention mask, for bert",
  100. )
  101. parser.add_argument(
  102. "--dummy_inputs",
  103. required=False,
  104. default="default",
  105. choices=["bert", "gpt2", "longformer", "default"],
  106. help="Type of model inputs. The default will create dummy inputs with ones.",
  107. )
  108. parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="use GPU")
  109. parser.set_defaults(use_gpu=False)
  110. parser.add_argument(
  111. "--provider",
  112. required=False,
  113. type=str,
  114. default="cuda",
  115. help="Execution provider to use",
  116. )
  117. parser.add_argument(
  118. "--basic_optimization",
  119. required=False,
  120. action="store_true",
  121. help="Enable only basic graph optimizations. By default, all optimizations are enabled in OnnxRuntime",
  122. )
  123. parser.set_defaults(basic_optimization=False)
  124. parser.add_argument(
  125. "--kernel_time_only",
  126. required=False,
  127. action="store_true",
  128. help="Only include the kernel time and no fence time",
  129. )
  130. parser.set_defaults(kernel_time_only=False)
  131. parser.add_argument("-v", "--verbose", required=False, action="store_true")
  132. parser.set_defaults(verbose=False)
  133. return parser.parse_args(argv)
  134. def run_profile(onnx_model_path, use_gpu, provider, basic_optimization, thread_num, all_inputs):
  135. from benchmark_helper import create_onnxruntime_session # noqa: PLC0415
  136. session = create_onnxruntime_session(
  137. onnx_model_path,
  138. use_gpu,
  139. provider,
  140. enable_all_optimization=not basic_optimization,
  141. num_threads=thread_num,
  142. enable_profiling=True,
  143. )
  144. for inputs in all_inputs:
  145. _ = session.run(None, inputs)
  146. profile_file = session.end_profiling()
  147. return profile_file
  148. def get_dim_from_type_proto(dim):
  149. return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) == str else None # noqa: E721
  150. def get_shape_from_type_proto(type_proto):
  151. return [get_dim_from_type_proto(d) for d in type_proto.tensor_type.shape.dim]
  152. def create_dummy_inputs(onnx_model, batch_size, sequence_length, samples):
  153. """Create dummy inputs for ONNX model.
  154. Args:
  155. onnx_model (OnnxModel): ONNX model
  156. batch_size (int): batch size
  157. sequence_length (int): sequence length
  158. samples (int): number of samples
  159. Returns:
  160. List[Dict]: list of inputs
  161. """
  162. dummy_inputs = {}
  163. for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
  164. shape = get_shape_from_type_proto(graph_input.type)
  165. symbol_dims = []
  166. for i, dim in enumerate(shape):
  167. if isinstance(dim, str):
  168. symbol_dims.append(i)
  169. # allowed symbolic dimensions: batch_size and sequence_length
  170. if len(symbol_dims) > 2:
  171. return None
  172. if len(symbol_dims) > 0:
  173. shape[symbol_dims[0]] = batch_size
  174. if len(symbol_dims) > 1:
  175. shape[symbol_dims[1]] = sequence_length
  176. elem_type = graph_input.type.tensor_type.elem_type
  177. assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
  178. data_type = (
  179. numpy.float32
  180. if elem_type == TensorProto.FLOAT
  181. else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
  182. )
  183. data = numpy.ones(shape, dtype=data_type)
  184. dummy_inputs[graph_input.name] = data
  185. all_inputs = [dummy_inputs for _ in range(samples)]
  186. return all_inputs
  187. def create_bert_inputs(
  188. onnx_model,
  189. batch_size,
  190. sequence_length,
  191. samples,
  192. input_ids_name=None,
  193. segment_ids_name=None,
  194. input_mask_name=None,
  195. ):
  196. """Create dummy inputs for BERT model.
  197. Args:
  198. onnx_model (OnnxModel): ONNX model
  199. batch_size (int): batch size
  200. sequence_length (int): sequence length
  201. samples (int): number of samples
  202. input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
  203. segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
  204. input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
  205. Returns:
  206. List[Dict]: list of inputs
  207. """
  208. from bert_test_data import find_bert_inputs, generate_test_data # noqa: PLC0415
  209. input_ids, segment_ids, input_mask = find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
  210. all_inputs = generate_test_data(
  211. batch_size,
  212. sequence_length,
  213. test_cases=samples,
  214. seed=123,
  215. verbose=False,
  216. input_ids=input_ids,
  217. segment_ids=segment_ids,
  218. input_mask=input_mask,
  219. random_mask_length=False,
  220. )
  221. return all_inputs
  222. def create_gpt2_inputs(onnx_model, batch_size, sequence_length, past_sequence_length, samples):
  223. """Create dummy inputs for GPT-2 model.
  224. Args:
  225. onnx_model (OnnxModel): ONNX model
  226. batch_size (int): batch size
  227. sequence_length (int): sequence length
  228. past_sequence_length (int): past sequence length
  229. samples (int): number of samples
  230. Raises:
  231. RuntimeError: symbolic is not supported. Use the tool convert_to_onnx.py to export ONNX model instead.
  232. Returns:
  233. List[Dict]: list of inputs
  234. """
  235. # The symbolic names shall be same as those used in Gpt2Helper.export_onnx(...) function.
  236. symbols = {
  237. "batch_size": batch_size,
  238. "seq_len": sequence_length,
  239. "past_seq_len": past_sequence_length,
  240. "total_seq_len": sequence_length + past_sequence_length,
  241. }
  242. dummy_inputs = {}
  243. for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
  244. shape = get_shape_from_type_proto(graph_input.type)
  245. for i, dim in enumerate(shape):
  246. if isinstance(dim, str):
  247. if dim not in symbols:
  248. raise RuntimeError(f"symbol is not supported: {dim}")
  249. else:
  250. shape[i] = symbols[dim]
  251. elem_type = graph_input.type.tensor_type.elem_type
  252. assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
  253. data_type = (
  254. numpy.float32
  255. if elem_type == TensorProto.FLOAT
  256. else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
  257. )
  258. data = numpy.ones(shape, dtype=data_type)
  259. dummy_inputs[graph_input.name] = data
  260. all_inputs = [dummy_inputs for _ in range(samples)]
  261. return all_inputs
  262. def create_longformer_inputs(onnx_model, batch_size, sequence_length, global_length, samples):
  263. """Create dummy inputs for Longformer model.
  264. Args:
  265. onnx_model (OnnxModel): ONNX model
  266. batch_size (int): batch size
  267. sequence_length (int): sequence length
  268. global_length (int): number of global tokens
  269. samples (int): number of samples
  270. Raises:
  271. RuntimeError: symbolic is not supported. Use the tool convert_longformer_to_onnx.py to export ONNX model instead.
  272. Returns:
  273. List[Dict]: list of inputs
  274. """
  275. symbols = {"batch_size": batch_size, "sequence_length": sequence_length}
  276. dummy_inputs = {}
  277. for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
  278. shape = get_shape_from_type_proto(graph_input.type)
  279. for i, dim in enumerate(shape):
  280. if isinstance(dim, str):
  281. if dim not in symbols:
  282. raise RuntimeError(f"symbol is not supported: {dim}")
  283. else:
  284. shape[i] = symbols[dim]
  285. elem_type = graph_input.type.tensor_type.elem_type
  286. assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
  287. data_type = (
  288. numpy.float32
  289. if elem_type == TensorProto.FLOAT
  290. else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
  291. )
  292. if "global" in graph_input.name:
  293. data = numpy.zeros(shape, dtype=data_type)
  294. data[:, :global_length] = 1
  295. else:
  296. data = numpy.ones(shape, dtype=data_type)
  297. dummy_inputs[graph_input.name] = data
  298. all_inputs = [dummy_inputs for _ in range(samples)]
  299. return all_inputs
  300. def run(args):
  301. num_threads = args.thread_num if args.thread_num > 0 else psutil.cpu_count(logical=False)
  302. # Set OMP environment variable before importing onnxruntime. Needed for cpu only, and no impact for onnxruntime-gpu package.
  303. if "OMP_NUM_THREADS" not in os.environ:
  304. os.environ["OMP_NUM_THREADS"] = str(num_threads)
  305. from onnx import load # noqa: PLC0415
  306. from onnx_model import OnnxModel # noqa: PLC0415
  307. onnx_model = OnnxModel(load(args.model))
  308. all_inputs = None
  309. if args.dummy_inputs == "bert":
  310. all_inputs = create_bert_inputs(
  311. onnx_model,
  312. args.batch_size,
  313. args.sequence_length,
  314. args.samples,
  315. args.input_ids_name,
  316. args.segment_ids_name,
  317. args.input_mask_name,
  318. )
  319. elif args.dummy_inputs == "gpt2":
  320. all_inputs = create_gpt2_inputs(
  321. onnx_model,
  322. args.batch_size,
  323. args.sequence_length,
  324. args.past_sequence_length,
  325. args.samples,
  326. )
  327. elif args.dummy_inputs == "longformer":
  328. all_inputs = create_longformer_inputs(
  329. onnx_model,
  330. args.batch_size,
  331. args.sequence_length,
  332. args.global_length,
  333. args.samples,
  334. )
  335. else: # default
  336. all_inputs = create_dummy_inputs(onnx_model, args.batch_size, args.sequence_length, args.samples)
  337. profile_file = run_profile(
  338. args.model,
  339. args.use_gpu,
  340. args.provider,
  341. args.basic_optimization,
  342. args.thread_num,
  343. all_inputs,
  344. )
  345. return profile_file
  346. if __name__ == "__main__":
  347. arguments = parse_arguments()
  348. print("Arguments", arguments)
  349. from benchmark_helper import setup_logger
  350. setup_logger(arguments.verbose)
  351. if not arguments.input:
  352. assert arguments.model, "requires either --model to run profiling or --input to read profiling results"
  353. profile_file = run(arguments)
  354. else:
  355. profile_file = arguments.input
  356. from profile_result_processor import process_results
  357. results = process_results(profile_file, arguments)
  358. for line in results:
  359. print(line)