benchmark.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import argparse
  7. import datetime
  8. import gc
  9. import itertools
  10. import logging
  11. import os
  12. import sys
  13. import time
  14. import numpy as np
  15. import onnx
  16. import psutil
  17. import torch
  18. from benchmark_helper import measure_memory, setup_logger
  19. from dist_settings import get_rank, get_size
  20. from llama_inputs import (
  21. add_io_bindings_as_ortvalues,
  22. get_merged_sample_with_past_kv_inputs,
  23. get_msft_sample_inputs,
  24. get_sample_inputs,
  25. get_sample_with_past_kv_inputs,
  26. verify_ort_inputs,
  27. )
  28. from optimum.onnxruntime import ORTModelForCausalLM
  29. from torch.profiler import ProfilerActivity, profile, record_function
  30. from tqdm import trange
  31. from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
  32. import onnxruntime as ort
  33. logger = logging.getLogger(__name__)
  34. # For determining whether the ONNX model can do both prompt generation and token generation or only one of the two
  35. def get_ort_model_inputs_len(args, model):
  36. if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
  37. return 0
  38. if args.benchmark_type == "hf-ort":
  39. try:
  40. # New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268)
  41. return len(model.inputs_names)
  42. except Exception:
  43. # Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54)
  44. return len(model.decoder.input_names)
  45. return len(model.get_inputs())
  46. def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
  47. init_inputs, iter_inputs = None, None
  48. # For past_present_share_buffer:
  49. # Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
  50. # Set max_seq_len to config value for other models
  51. max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings
  52. if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
  53. init_inputs = get_sample_inputs(
  54. args.config,
  55. args.target_device,
  56. args.batch_size,
  57. args.sequence_length,
  58. return_dict=True,
  59. )
  60. iter_inputs = get_sample_with_past_kv_inputs(
  61. args.config,
  62. args.target_device,
  63. args.batch_size,
  64. args.sequence_length,
  65. use_fp16=args.use_fp16,
  66. return_dict=True,
  67. )
  68. elif args.benchmark_type in {"hf-ort"}:
  69. if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
  70. # Using split models in Optimum (e.g. created by Optimum export)
  71. init_inputs = get_sample_inputs(
  72. args.config,
  73. args.target_device,
  74. args.batch_size,
  75. args.sequence_length,
  76. return_dict=True,
  77. )
  78. iter_inputs = get_sample_with_past_kv_inputs(
  79. args.config,
  80. args.target_device,
  81. args.batch_size,
  82. args.sequence_length,
  83. use_fp16=args.use_fp16,
  84. return_dict=True,
  85. )
  86. else:
  87. # Using merged model in Optimum (e.g. created by convert_to_onnx export)
  88. init_inputs = get_merged_sample_with_past_kv_inputs(
  89. args.config,
  90. args.target_device,
  91. args.batch_size,
  92. seq_len=args.sequence_length,
  93. past_seq_len=0,
  94. max_seq_len=max_seq_len,
  95. use_fp16=args.use_fp16,
  96. use_buffer_share=args.use_buffer_share,
  97. engine="pt",
  98. return_dict=True,
  99. )
  100. iter_inputs = get_merged_sample_with_past_kv_inputs(
  101. args.config,
  102. args.target_device,
  103. args.batch_size,
  104. seq_len=1,
  105. past_seq_len=args.sequence_length,
  106. max_seq_len=max_seq_len,
  107. use_fp16=args.use_fp16,
  108. use_buffer_share=args.use_buffer_share,
  109. engine="pt",
  110. return_dict=True,
  111. )
  112. elif args.benchmark_type == "ort-convert-to-onnx":
  113. # Microsoft export from convert_to_onnx
  114. init_inputs = get_merged_sample_with_past_kv_inputs(
  115. args.config,
  116. args.target_device,
  117. args.batch_size,
  118. seq_len=args.sequence_length,
  119. past_seq_len=0,
  120. max_seq_len=max_seq_len,
  121. use_fp16=args.use_fp16,
  122. use_buffer_share=args.use_buffer_share,
  123. engine="ort",
  124. return_dict=True,
  125. world_size=args.world_size,
  126. )
  127. iter_inputs = get_merged_sample_with_past_kv_inputs(
  128. args.config,
  129. args.target_device,
  130. args.batch_size,
  131. seq_len=1,
  132. past_seq_len=args.sequence_length,
  133. max_seq_len=max_seq_len,
  134. use_fp16=args.use_fp16,
  135. use_buffer_share=args.use_buffer_share,
  136. engine="ort",
  137. return_dict=True,
  138. world_size=args.world_size,
  139. )
  140. elif args.benchmark_type == "ort-msft":
  141. # Microsoft export from https://github.com/microsoft/Llama-2-Onnx
  142. split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos]
  143. init_inputs = get_msft_sample_inputs(
  144. args.config,
  145. args.batch_size,
  146. past_seq_len=0,
  147. seq_len=args.sequence_length,
  148. max_seq_len=max_seq_len,
  149. use_fp16=args.use_fp16,
  150. use_buffer_share=args.use_buffer_share,
  151. split_kv=split_kv,
  152. )
  153. iter_inputs = get_msft_sample_inputs(
  154. args.config,
  155. args.batch_size,
  156. past_seq_len=args.sequence_length,
  157. seq_len=1,
  158. max_seq_len=max_seq_len,
  159. use_fp16=args.use_fp16,
  160. use_buffer_share=args.use_buffer_share,
  161. split_kv=split_kv,
  162. )
  163. else:
  164. raise Exception("Unable to auto-detect inputs for provided model")
  165. return init_inputs, iter_inputs
  166. def get_model(args: argparse.Namespace):
  167. model, sess_options = None, None
  168. start_time, end_time = None, None
  169. # There are multiple sources that the model could come from:
  170. # 1) Benchmark LLaMA-2 from unofficial source on Hugging Face
  171. # 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token
  172. # 3) Benchmark LLaMA-2 from local download of model
  173. # 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx)
  174. # 5) Benchmark LLaMA-2 from convert_to_onnx
  175. if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
  176. source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name
  177. start_time = time.time()
  178. model = AutoModelForCausalLM.from_pretrained(
  179. source,
  180. torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
  181. use_auth_token=args.auth,
  182. trust_remote_code=args.auth,
  183. use_cache=True,
  184. cache_dir=args.cache_dir,
  185. ).to(args.target_device)
  186. end_time = time.time()
  187. if args.benchmark_type == "hf-pt-compile":
  188. model = torch.compile(model)
  189. elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}:
  190. sess_options = ort.SessionOptions()
  191. sess_options.enable_profiling = args.profile
  192. if args.verbose:
  193. sess_options.log_verbosity_level = 1
  194. sess_options.log_severity_level = 1
  195. else:
  196. raise Exception(f"Cannot recognize {args.benchmark_type}")
  197. if args.benchmark_type == "hf-ort":
  198. # Optimum export or convert_to_onnx.py export
  199. provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
  200. provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
  201. decoder_file_name = None
  202. decoder_with_past_file_name = None
  203. for filename in os.listdir(args.hf_ort_dir_path):
  204. if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
  205. continue
  206. if "decoder_model" in filename or filename == "model.onnx":
  207. decoder_file_name = filename
  208. if "decoder_with_past_model" in filename:
  209. decoder_with_past_file_name = filename
  210. if "decoder_merged_model" in filename:
  211. decoder_file_name = filename
  212. decoder_with_past_file_name = filename
  213. start_time = time.time()
  214. model = ORTModelForCausalLM.from_pretrained(
  215. args.hf_ort_dir_path,
  216. decoder_file_name=decoder_file_name,
  217. decoder_with_past_file_name=decoder_with_past_file_name,
  218. use_auth_token=args.auth,
  219. trust_remote_code=args.auth,
  220. use_io_binding=True, # Large perf gain even for cpu due to avoiding output copy.
  221. use_merged=(True if decoder_file_name == "model.onnx" else None),
  222. provider=provider,
  223. provider_options=provider_options,
  224. session_options=sess_options,
  225. )
  226. end_time = time.time()
  227. if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
  228. # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
  229. logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}")
  230. start_time = time.time()
  231. model = ort.InferenceSession(
  232. args.ort_model_path.format(args.rank),
  233. sess_options,
  234. providers=[args.execution_provider],
  235. )
  236. end_time = time.time()
  237. logger.info(f"Loaded model in {end_time - start_time} s")
  238. return model
  239. def time_fn(args, fn, inputs):
  240. # Warm up
  241. warmup_range = (
  242. range(args.warmup_runs)
  243. if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
  244. else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
  245. )
  246. if args.verbose:
  247. outputs = fn(inputs)
  248. logger.info(outputs)
  249. input_sync = lambda *kwargs: ( # noqa: E731
  250. args.io_binding.synchronize_inputs()
  251. if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
  252. else lambda *kwargs: (
  253. torch.cuda.synchronize()
  254. if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
  255. else lambda *kwargs: None
  256. )
  257. ) # no-op function
  258. output_sync = lambda *kwargs: ( # noqa: E731
  259. args.io_binding.synchronize_outputs()
  260. if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
  261. else lambda *kwargs: (
  262. torch.cuda.synchronize()
  263. if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
  264. else lambda *kwargs: None
  265. )
  266. ) # no-op function
  267. for _ in warmup_range:
  268. input_sync()
  269. fn(inputs)
  270. output_sync()
  271. # Benchmark
  272. total_time = 0
  273. bench_range = (
  274. range(args.num_runs)
  275. if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
  276. else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
  277. )
  278. for _ in bench_range:
  279. input_sync()
  280. start_time = time.time()
  281. fn(inputs)
  282. output_sync()
  283. end_time = time.time()
  284. total_time += end_time - start_time
  285. # Newline print after trange in order to print metrics on new lines without progress bar on same line
  286. if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}:
  287. logger.info("")
  288. latency = total_time / args.num_runs
  289. throughput = args.batch_size / latency
  290. if args.rank == 0:
  291. logger.info(f"Batch Size: {args.batch_size}")
  292. logger.info(f"Sequence Length: {args.sequence_length}")
  293. logger.info(f"Latency: {latency} s")
  294. logger.info(f"Throughput: {throughput} tps")
  295. return
  296. def profile_fn(args, fn, inputs, inputs_type):
  297. # Filename prefix format:
  298. # "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
  299. prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
  300. filename = None
  301. if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
  302. # Profile PyTorch kernels
  303. with profile( # noqa: SIM117
  304. activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
  305. ) as prof:
  306. with record_function("model_inference"):
  307. fn(inputs)
  308. prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
  309. filename = os.path.join(args.log_folder, f"{prefix}.log")
  310. with open(filename, "w") as f:
  311. f.write(prof_data)
  312. else:
  313. # Profile ORT kernels
  314. fn(inputs)
  315. # Set new log name for ORT profile log generated
  316. filename = f"{prefix}.json"
  317. return filename
  318. def measure_fn(args, fn, inputs):
  319. # Measure CPU usage
  320. pid = os.getpid()
  321. process = psutil.Process(pid)
  322. process.cpu_percent(interval=0.1)
  323. fn(inputs)
  324. if args.rank == 0:
  325. logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%")
  326. # Measure memory usage
  327. gc.collect()
  328. torch.cuda.empty_cache()
  329. measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
  330. # Flush output so memory usage is printed
  331. sys.stdout.flush()
  332. def run_hf_inference(args, init_inputs, iter_inputs, model):
  333. # Inference steps to measure
  334. def get_logits(inputs):
  335. # Inference pass without decoding
  336. outputs = model(**inputs)
  337. return outputs
  338. # Examples of other inference steps that can be measured:
  339. # To use, uncomment the function and assign it to `generate_fn`
  340. # def get_pred_ids(inputs):
  341. # # Inference pass with predicted token ids generation
  342. # predicted_ids = model.generate(**inputs)
  343. # return predicted_ids
  344. # def gen_and_dec(inputs):
  345. # # Inference pass with generation and decoding
  346. # predicted_ids = get_pred_ids(inputs)
  347. # transcription = []
  348. # for bs in range(args.batch_size):
  349. # for rs in range(args.num_return_sequences):
  350. # transcription.append(
  351. # args.tokenizer.batch_decode(
  352. # predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True
  353. # )[0]
  354. # )
  355. # return transcription
  356. generate_fn = get_logits
  357. if args.benchmark_type == "hf-pt-compile":
  358. # Run forward pass once with each set of inputs to process through Dynamo
  359. generate_fn(init_inputs)
  360. generate_fn(iter_inputs)
  361. if args.profile:
  362. new_logname = profile_fn(args, generate_fn, init_inputs, "prompt")
  363. if args.benchmark_type == "hf-ort":
  364. # Turn profiling off to stop appending to log
  365. old_logname = model.decoder.session.end_profiling()
  366. logger.warning(f"Renaming {old_logname} to {new_logname}")
  367. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  368. new_logname = profile_fn(args, generate_fn, iter_inputs, "token")
  369. if args.benchmark_type == "hf-ort":
  370. # Turn profiling off to stop appending to log
  371. old_logname = model.decoder_with_past.session.end_profiling()
  372. logger.warning(f"Renaming {old_logname} to {new_logname}")
  373. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  374. return
  375. # PyTorch evaluations
  376. logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
  377. time_fn(args, generate_fn, init_inputs)
  378. measure_fn(args, generate_fn, init_inputs)
  379. logger.info("\nEvaluating `model(inputs)` step with past_key_values")
  380. time_fn(args, generate_fn, iter_inputs)
  381. measure_fn(args, generate_fn, iter_inputs)
  382. def run_ort_inference(args, init_inputs, iter_inputs, model):
  383. def prepare_ort_inputs(inputs, kv_cache_ortvalues):
  384. # Verify model inputs
  385. inputs = verify_ort_inputs(model, inputs)
  386. # Add IO bindings for non-CPU execution providers
  387. if args.device != "cpu":
  388. io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
  389. model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues
  390. )
  391. setattr(args, "io_binding", io_binding) # noqa: B010
  392. return io_binding, kv_cache_ortvalues
  393. return inputs, kv_cache_ortvalues
  394. def with_io_binding(io_binding):
  395. # Inference pass with IO binding
  396. model.run_with_iobinding(io_binding)
  397. def without_io_binding(inputs):
  398. # Inference pass without IO binding
  399. outputs = model.run(None, inputs)
  400. return outputs
  401. generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
  402. kv_cache_ortvalues = {}
  403. if args.profile:
  404. ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
  405. new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
  406. # Turn profiling off to stop appending to log file
  407. old_logname = model.end_profiling()
  408. logger.warning(f"Renaming {old_logname} to {new_logname}")
  409. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  410. # Re-initialize model for new log file instead of appending to old log file
  411. model = get_model(args)
  412. ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
  413. new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
  414. # Turn profiling off to stop appending to log
  415. old_logname = model.end_profiling()
  416. logger.warning(f"Renaming {old_logname} to {new_logname}")
  417. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  418. return
  419. # ORT evaluations
  420. logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
  421. ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
  422. time_fn(args, generate_fn, ort_init_inputs)
  423. measure_fn(args, generate_fn, ort_init_inputs)
  424. logger.info("\nEvaluating `model(inputs)` step with past_key_values")
  425. ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
  426. time_fn(args, generate_fn, ort_iter_inputs)
  427. measure_fn(args, generate_fn, ort_iter_inputs)
  428. def run_inference(args, init_inputs, iter_inputs, model):
  429. if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
  430. run_hf_inference(args, init_inputs, iter_inputs, model)
  431. elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
  432. run_ort_inference(args, init_inputs, iter_inputs, model)
  433. else:
  434. raise Exception(f"Cannot recognize {args.benchmark_type}")
  435. def get_args(rank=0):
  436. parser = argparse.ArgumentParser()
  437. parser.add_argument(
  438. "-bt",
  439. "--benchmark-type",
  440. type=str,
  441. required=True,
  442. choices=[
  443. "hf-pt-eager",
  444. "hf-pt-compile",
  445. "hf-ort",
  446. "ort-msft",
  447. "ort-convert-to-onnx",
  448. ],
  449. )
  450. parser.add_argument(
  451. "-m",
  452. "--model-name",
  453. type=str,
  454. required=True,
  455. help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
  456. )
  457. parser.add_argument(
  458. "-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model"
  459. )
  460. # Args for choosing the model
  461. parser.add_argument(
  462. "-p",
  463. "--precision",
  464. required=True,
  465. type=str,
  466. default="fp32",
  467. choices=["int4", "int8", "fp16", "fp32"],
  468. help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
  469. )
  470. parser.add_argument(
  471. "--hf-pt-dir-path",
  472. type=str,
  473. default="",
  474. help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
  475. )
  476. parser.add_argument(
  477. "--hf-ort-dir-path",
  478. type=str,
  479. default="",
  480. help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)",
  481. )
  482. parser.add_argument(
  483. "--ort-model-path",
  484. type=str,
  485. default="",
  486. help="Path to ONNX model",
  487. )
  488. # Args for running and evaluating the model
  489. parser.add_argument(
  490. "-b",
  491. "--batch-sizes",
  492. default="1 2",
  493. )
  494. parser.add_argument(
  495. "-s",
  496. "--sequence-lengths",
  497. default="32 64 128 256 512",
  498. )
  499. parser.add_argument(
  500. "-d",
  501. "--device",
  502. type=str,
  503. default="cuda" if torch.cuda.is_available() else "cpu",
  504. choices=["cpu", "cuda"],
  505. )
  506. parser.add_argument("-id", "--device-id", type=int, default=0)
  507. parser.add_argument("-w", "--warmup-runs", type=int, default=5)
  508. parser.add_argument("-n", "--num-runs", type=int, default=10)
  509. parser.add_argument("--seed", type=int, default=2)
  510. # Args for decoding logic
  511. parser.add_argument("--max-length", type=int, default=32)
  512. parser.add_argument("--num-return-sequences", type=int, default=1)
  513. # Args for accessing detailed info
  514. parser.add_argument("--profile", default=False, action="store_true")
  515. parser.add_argument(
  516. "--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
  517. )
  518. parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
  519. parser.add_argument("--verbose", default=False, action="store_true")
  520. parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
  521. parser.add_argument(
  522. "--cache-dir",
  523. type=str,
  524. required=True,
  525. default="./model_cache",
  526. help="Cache dir where Hugging Face files are stored",
  527. )
  528. args = parser.parse_args()
  529. # Set seed properties
  530. np.random.seed(args.seed)
  531. torch.manual_seed(args.seed)
  532. # Set runtime properties
  533. if "ort" in args.benchmark_type:
  534. setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
  535. if args.execution_provider == "CUDAExecutionProvider":
  536. args.execution_provider = (args.execution_provider, {"device_id": rank})
  537. # Check that paths have been specified for any benchmarking with ORT
  538. if args.benchmark_type == "hf-ort":
  539. assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
  540. if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
  541. assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
  542. args.batch_sizes = args.batch_sizes.split(" ")
  543. args.sequence_lengths = args.sequence_lengths.split(" ")
  544. # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
  545. args.precision = (
  546. "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
  547. )
  548. # Check that only one (batch_size, sequence_length) combination is set for profiling
  549. if args.profile:
  550. assert len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1, (
  551. "Please provide only one (batch_size, sequence_length) combination for profiling"
  552. )
  553. return args
  554. def main():
  555. rank = get_rank()
  556. world_size = get_size()
  557. args = get_args(rank)
  558. setup_logger(args.verbose)
  559. logger.info(args.__dict__)
  560. torch.backends.cudnn.benchmark = True
  561. args.rank = rank
  562. args.world_size = world_size
  563. tokenizer = AutoTokenizer.from_pretrained(
  564. args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
  565. )
  566. config = AutoConfig.from_pretrained(
  567. args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
  568. )
  569. target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
  570. use_fp16 = args.precision == "fp16"
  571. setattr(args, "tokenizer", tokenizer) # noqa: B010
  572. setattr(args, "config", config) # noqa: B010
  573. setattr(args, "target_device", target_device) # noqa: B010
  574. setattr(args, "use_fp16", use_fp16) # noqa: B010
  575. # Get model and model info
  576. model = get_model(args)
  577. ort_model_inputs_len = get_ort_model_inputs_len(args, model)
  578. # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
  579. if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
  580. onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False)
  581. gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
  582. use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
  583. setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010
  584. else:
  585. setattr(args, "use_buffer_share", False) # noqa: B010
  586. # Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
  587. for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
  588. if args.rank == 0:
  589. logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
  590. setattr(args, "batch_size", int(batch_size)) # noqa: B010
  591. setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
  592. init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len)
  593. run_inference(args, init_inputs, iter_inputs, model)
  594. if __name__ == "__main__":
  595. main()