benchmark.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import argparse
  7. import ast
  8. import datetime
  9. import gc
  10. import logging
  11. import os
  12. import sys
  13. import time
  14. import numpy as np
  15. import psutil
  16. import torch
  17. import whisper
  18. from benchmark_helper import measure_memory, setup_logger
  19. from onnxruntime_extensions import get_library_path
  20. from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
  21. from torch.profiler import ProfilerActivity, profile, record_function
  22. from tqdm import trange
  23. from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
  24. import onnxruntime as ort
  25. logger = logging.getLogger(__name__)
  26. def get_inputs(args: argparse.Namespace):
  27. if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}:
  28. raise Exception("Unable to auto-detect inputs for provided model")
  29. def load_via_ffmpeg():
  30. audio = whisper.load_audio(args.audio_path)
  31. audio = whisper.pad_or_trim(audio)
  32. return audio
  33. def load_via_numpy():
  34. with open(args.audio_path, "rb") as f:
  35. audio = np.asarray(list(f.read()), dtype=np.uint8)
  36. audio = np.array([audio])
  37. return audio
  38. inputs = {
  39. "max_length": args.max_length,
  40. "min_length": args.min_length,
  41. "num_beams": args.num_beams,
  42. "num_return_sequences": args.num_return_sequences,
  43. "length_penalty": args.length_penalty,
  44. "repetition_penalty": args.repetition_penalty,
  45. }
  46. if args.benchmark_type == "ort":
  47. # convert_to_onnx export or ONNX E2E solution created by Olive
  48. for k, v in inputs.items():
  49. inputs[k] = np.array([v], dtype=np.float32 if "penalty" in k else np.int32)
  50. if args.has_decoder_input_ids:
  51. inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32)
  52. if args.has_logits_processor:
  53. inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32)
  54. if args.has_temperature:
  55. inputs["temperature"] = np.array([args.temperature], dtype=np.float32)
  56. # Measure time taken to load audio file
  57. logger.info(f"Load audio: {args.audio_path}")
  58. load_audio_fn = lambda onnx_e2e: load_via_numpy() if onnx_e2e else load_via_ffmpeg() # noqa: E731
  59. time_fn(args, load_audio_fn, args.has_audio_stream)
  60. audio_data = load_audio_fn(args.has_audio_stream)
  61. if args.has_audio_stream:
  62. # ONNX E2E solution created by Olive
  63. inputs["audio_stream"] = audio_data
  64. return inputs
  65. # Measure time taken to get input features
  66. logger.info("Feature extraction: ")
  67. return_type = "np" if args.benchmark_type == "ort" else "pt"
  68. processor_fn = lambda audio: args.processor.feature_extractor( # noqa: E731
  69. [audio], return_tensors=return_type, sampling_rate=args.sampling_rate
  70. ).input_features
  71. time_fn(args, processor_fn, audio_data)
  72. input_features = processor_fn(audio_data)
  73. if args.benchmark_type == "ort":
  74. # convert_to_onnx export
  75. inputs["input_features"] = input_features
  76. return inputs
  77. inputs["inputs"] = input_features.to(
  78. dtype=torch.float16 if args.use_fp16 else torch.float32, device=args.target_device
  79. )
  80. inputs["no_repeat_ngram_size"] = args.no_repeat_ngram_size
  81. inputs["early_stopping"] = True
  82. inputs["use_cache"] = True
  83. if args.decoder_input_ids:
  84. inputs["forced_decoder_ids"] = args.decoder_input_ids
  85. return inputs
  86. def get_model(args: argparse.Namespace):
  87. model, sess_options = None, None
  88. start_time, end_time = None, None
  89. # There are multiple sources that the model could come from:
  90. # 1) Benchmark Whisper from Hugging Face
  91. # 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing)
  92. # 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing)
  93. if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
  94. source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
  95. start_time = time.time()
  96. model = AutoModelForSpeechSeq2Seq.from_pretrained(
  97. source,
  98. torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
  99. use_cache=True,
  100. ).to(args.target_device)
  101. end_time = time.time()
  102. if args.benchmark_type == "hf-pt-compile":
  103. model = torch.compile(model)
  104. elif args.benchmark_type in {"hf-ort", "ort"}:
  105. sess_options = ort.SessionOptions()
  106. sess_options.enable_profiling = args.profile
  107. sess_options.register_custom_ops_library(get_library_path())
  108. if args.verbose:
  109. sess_options.log_verbosity_level = 1
  110. sess_options.log_severity_level = 1
  111. else:
  112. raise Exception(f"Cannot recognize {args.benchmark_type}")
  113. if args.benchmark_type == "hf-ort":
  114. # Optimum export
  115. provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
  116. provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
  117. start_time = time.time()
  118. model = ORTModelForSpeechSeq2Seq.from_pretrained(
  119. args.hf_ort_dir_path,
  120. provider=provider,
  121. provider_options=provider_options,
  122. session_options=sess_options,
  123. use_io_binding=True, # Avoid memory copy overhead
  124. )
  125. end_time = time.time()
  126. if args.benchmark_type == "ort":
  127. # convert_to_onnx.py export
  128. logger.info(f"Loading model from {args.ort_model_path}")
  129. start_time = time.time()
  130. model = ort.InferenceSession(
  131. args.ort_model_path,
  132. sess_options,
  133. providers=[args.execution_provider],
  134. )
  135. end_time = time.time()
  136. logger.info(f"Loaded model in {end_time - start_time} s")
  137. return model
  138. def time_fn(args, fn, inputs):
  139. warmup_inputs = inputs[0] if type(inputs) is tuple else inputs
  140. benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs
  141. torch_device = torch.device(args.target_device)
  142. # Warm up
  143. warmup_range = (
  144. range(args.warmup_runs)
  145. if args.benchmark_type == "ort"
  146. else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
  147. )
  148. if args.verbose:
  149. outputs = fn(warmup_inputs)
  150. logger.info(outputs)
  151. for _ in warmup_range:
  152. fn(warmup_inputs)
  153. # Benchmark
  154. if args.device != "cpu":
  155. torch.cuda.synchronize(torch_device)
  156. start_time = time.time()
  157. bench_range = (
  158. range(args.num_runs)
  159. if args.benchmark_type == "ort"
  160. else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
  161. )
  162. for _ in bench_range:
  163. fn(benchmark_inputs)
  164. if args.device != "cpu":
  165. torch.cuda.synchronize(torch_device)
  166. end_time = time.time()
  167. # Newline print after trange in order to print metrics on new lines without progress bar on same line
  168. if args.benchmark_type != "ort":
  169. logger.info("")
  170. batch_size = 1
  171. latency = (end_time - start_time) / args.num_runs
  172. throughput = batch_size / latency
  173. logger.info(f"Latency: {latency} s")
  174. logger.info(f"Throughput: {throughput} qps")
  175. return
  176. def profile_fn(args, fn, inputs, inputs_type):
  177. # Filename prefix format:
  178. # "<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
  179. prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
  180. filename = None
  181. if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
  182. # Profile PyTorch kernels
  183. with profile( # noqa: SIM117
  184. activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
  185. ) as prof:
  186. with record_function("model_inference"):
  187. fn(inputs)
  188. prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
  189. filename = os.path.join(args.log_folder, f"{prefix}.log")
  190. with open(filename, "w") as f:
  191. f.write(prof_data)
  192. else:
  193. # Profile ORT kernels
  194. fn(inputs)
  195. # Set new log name for ORT profile log generated
  196. filename = f"{prefix}.json"
  197. return filename
  198. def measure_fn(args, fn, inputs):
  199. # Measure CPU usage
  200. pid = os.getpid()
  201. process = psutil.Process(pid)
  202. process.cpu_percent(interval=0.1)
  203. fn(inputs)
  204. logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
  205. # Measure memory usage
  206. gc.collect()
  207. torch.cuda.empty_cache()
  208. measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type)
  209. # Flush output so memory usage is printed
  210. sys.stdout.flush()
  211. def run_hf_inference(args, inputs, model):
  212. # Inference steps to measure
  213. def get_pred_ids(inputs):
  214. # Inference pass with predicted token ids generation
  215. predicted_ids = model.generate(**inputs)
  216. return predicted_ids
  217. def gen_and_dec(inputs):
  218. # Inference pass with generation and decoding
  219. predicted_ids = get_pred_ids(inputs)
  220. transcription = []
  221. for _ in range(args.num_return_sequences):
  222. transcription.append(args.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
  223. return predicted_ids, transcription
  224. # Examples of other inference steps that can be measured:
  225. # To use, uncomment the function and assign it to `generate_fn`
  226. # def get_logits(inputs):
  227. # # Inference pass without decoding
  228. # outputs = model(**inputs)
  229. # return outputs
  230. generate_fn = gen_and_dec
  231. if args.benchmark_type == "hf-pt-compile":
  232. # Run forward pass once with each set of inputs to process through Dynamo
  233. generate_fn(inputs)
  234. if args.profile:
  235. new_logname = profile_fn(args, generate_fn, inputs, "gen-and-dec")
  236. if args.benchmark_type == "hf-ort":
  237. # Rename log files per model component and turn profiling off to stop appending to log
  238. new_prefix = new_logname[: -len(".json")]
  239. old_logname = model.encoder.session.end_profiling()
  240. new_logname = new_prefix + "-encoder.json"
  241. if os.path.isfile(old_logname):
  242. logger.warning(f"Renaming {old_logname} to {new_logname}")
  243. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  244. old_logname = model.decoder.session.end_profiling()
  245. new_logname = new_prefix + "-decoder.json"
  246. if os.path.isfile(old_logname):
  247. logger.warning(f"Renaming {old_logname} to {new_logname}")
  248. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  249. old_logname = model.decoder_with_past.session.end_profiling()
  250. new_logname = new_prefix + "-decoder-with-past.json"
  251. if os.path.isfile(old_logname):
  252. logger.warning(f"Renaming {old_logname} to {new_logname}")
  253. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  254. return
  255. # PyTorch evaluations
  256. logger.info("\nEvaluating PyTorch...")
  257. time_fn(args, generate_fn, inputs)
  258. predicted_ids, transcription = generate_fn(inputs)
  259. logger.info(f"Generated token length: {len(predicted_ids[0])} tokens")
  260. logger.info(f"Transcription: {transcription[0]}")
  261. measure_fn(args, generate_fn, inputs)
  262. def run_ort_inference(args, inputs, model):
  263. def prepare_ort_inputs(inputs, warmup=False):
  264. # Check that all model inputs will be provided
  265. model_inputs = {model_input.name for model_input in model.get_inputs()}
  266. user_inputs = set(inputs.keys())
  267. missing_inputs = model_inputs - user_inputs
  268. if len(missing_inputs):
  269. logger.error(f"The following model inputs are missing: {missing_inputs}")
  270. raise Exception("There are missing inputs to the model. Please add them and try again.")
  271. # Remove unnecessary inputs from model inputs
  272. unnecessary_inputs = user_inputs - model_inputs
  273. if len(unnecessary_inputs):
  274. for unnecessary_input in unnecessary_inputs:
  275. logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
  276. del inputs[unnecessary_input]
  277. # Add IO bindings for non-CPU execution providers
  278. if args.device != "cpu":
  279. io_binding = model.io_binding()
  280. for k, v in inputs.items():
  281. io_binding.bind_cpu_input(k, v)
  282. for output in model.get_outputs():
  283. io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id)
  284. return io_binding
  285. return inputs
  286. def with_io_binding(io_binding):
  287. # Inference pass with IO binding
  288. model.run_with_iobinding(io_binding)
  289. return io_binding
  290. def without_io_binding(inputs):
  291. # Inference pass without IO binding
  292. outputs = model.run(None, inputs)
  293. return outputs
  294. def handle_output(output):
  295. if args.eos_token_id in output:
  296. first_end = np.where(output == args.eos_token_id)[0][0]
  297. return output[: first_end + 1]
  298. return output
  299. generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
  300. ort_inputs = prepare_ort_inputs(inputs)
  301. if args.profile:
  302. new_logname = profile_fn(args, generate_fn, ort_inputs, "e2e")
  303. # Turn profiling off to stop appending to log file
  304. old_logname = model.end_profiling()
  305. logger.warning(f"Renaming {old_logname} to {new_logname}")
  306. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  307. return
  308. # ORT evaluation
  309. logger.info("\nEvaluating ONNX Runtime...")
  310. ort_evaluate_inputs = ort_inputs
  311. time_fn(args, generate_fn, ort_evaluate_inputs)
  312. ort_outputs = generate_fn(ort_inputs)
  313. if args.device != "cpu":
  314. ort_outputs = ort_outputs.copy_outputs_to_cpu()
  315. ort_outputs = ort_outputs[0]
  316. if args.has_audio_stream:
  317. # ONNX E2E model from Olive produces transcribed output
  318. logger.info(f"Transcription: {ort_outputs[0][0]}")
  319. else:
  320. # convert_to_onnx model produces generated ids
  321. actual_output = handle_output(ort_outputs[0][0])
  322. logger.info(f"Generated token length: {len(actual_output)} tokens")
  323. transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0]
  324. # print to stdout as the output for comparison
  325. print(f"{transcription}")
  326. measure_fn(args, generate_fn, ort_inputs)
  327. def run_inference(args, inputs, model):
  328. if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
  329. run_hf_inference(args, inputs, model)
  330. elif args.benchmark_type == "ort":
  331. run_ort_inference(args, inputs, model)
  332. else:
  333. raise Exception(f"Cannot recognize {args.benchmark_type}")
  334. def parse_args():
  335. parser = argparse.ArgumentParser()
  336. parser.add_argument(
  337. "-bt",
  338. "--benchmark-type",
  339. type=str,
  340. required=True,
  341. choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"],
  342. )
  343. parser.add_argument(
  344. "-m",
  345. "--model-name",
  346. type=str,
  347. required=True,
  348. help="Hugging Face name of model (e.g. 'openai/whisper-large-v2')",
  349. )
  350. parser.add_argument(
  351. "-p",
  352. "--precision",
  353. type=str,
  354. required=True,
  355. default="fp32",
  356. choices=["int4", "int8", "fp16", "fp32"],
  357. help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
  358. )
  359. parser.add_argument(
  360. "--hf-pt-model-path",
  361. type=str,
  362. default="",
  363. help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
  364. )
  365. parser.add_argument(
  366. "--hf-ort-dir-path",
  367. type=str,
  368. default="",
  369. help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
  370. )
  371. parser.add_argument(
  372. "--ort-model-path",
  373. type=str,
  374. default="",
  375. help="Path to ONNX model",
  376. )
  377. # Args for running and evaluating the model
  378. parser.add_argument("-a", "--audio-path", type=str, required=True, help="Path to audio file for E2E evaluation")
  379. parser.add_argument(
  380. "-d",
  381. "--device",
  382. type=str,
  383. default="cuda" if torch.cuda.is_available() else "cpu",
  384. choices=["cpu", "cuda"],
  385. )
  386. parser.add_argument("-id", "--device-id", type=int, default=0)
  387. parser.add_argument("-w", "--warmup-runs", type=int, default=5)
  388. parser.add_argument("-n", "--num-runs", type=int, default=10)
  389. parser.add_argument("--seed", type=int, default=2)
  390. # Optional args:
  391. parser.add_argument("--sampling-rate", type=int, default=16000, help="Sampling rate for audio (in Hz)")
  392. # Args for decoding logic
  393. # Required args:
  394. parser.add_argument("--max-length", type=int, default=448)
  395. parser.add_argument("--min-length", type=int, default=0)
  396. parser.add_argument("--num-beams", type=int, default=1)
  397. parser.add_argument("--num-return-sequences", type=int, default=1)
  398. parser.add_argument("--length-penalty", type=float, default=1.0)
  399. parser.add_argument("--repetition-penalty", type=float, default=1.0)
  400. parser.add_argument("--no-repeat-ngram-size", type=int, default=3)
  401. # Optional args for E2E solution:
  402. parser.add_argument(
  403. "--decoder-input-ids",
  404. type=str,
  405. default="[]",
  406. help="The forced decoder ids for generation. Format is [start token, timestamp token, language token, task token]. Default is [start token]. See `decoder_input_ids` in https://github.com/microsoft/Olive/tree/main/examples/whisper for details.",
  407. )
  408. parser.add_argument(
  409. "--logits-processor",
  410. type=int,
  411. default=1,
  412. help="Whether to use timestamps logits processor or not (0 for false, 1 for true).",
  413. )
  414. parser.add_argument(
  415. "--temperature",
  416. type=float,
  417. default=1.0,
  418. help="Temperature value for generation.",
  419. )
  420. # Args for accessing detailed info
  421. parser.add_argument("--profile", default=False, action="store_true")
  422. parser.add_argument(
  423. "--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
  424. )
  425. parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
  426. parser.add_argument("--verbose", default=False, action="store_true")
  427. parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
  428. args = parser.parse_args()
  429. # Set seed properties
  430. np.random.seed(args.seed)
  431. torch.manual_seed(args.seed)
  432. args.monitor_type = args.device
  433. # Set runtime properties
  434. if "ort" in args.benchmark_type:
  435. args.execution_provider = f"{args.device.upper()}ExecutionProvider"
  436. if args.execution_provider == "CUDAExecutionProvider":
  437. args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
  438. # Check that model paths have been specified for any benchmarking with ORT
  439. if args.benchmark_type == "hf-ort":
  440. assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
  441. if args.benchmark_type == "ort":
  442. assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
  443. # Convert decoder_input_ids string to list of ids
  444. # (e.g. "[1, 50257]" for Hugging Face or "[50257]" for ORT)
  445. args.decoder_input_ids = ast.literal_eval(args.decoder_input_ids)
  446. return args
  447. def main():
  448. args = parse_args()
  449. setup_logger(args.verbose)
  450. logger.info(args.__dict__)
  451. torch.backends.cudnn.benchmark = True
  452. config = WhisperConfig.from_pretrained(args.model_name)
  453. processor = WhisperProcessor.from_pretrained(args.model_name)
  454. target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
  455. use_fp16 = args.precision == "fp16" or (args.precision in {"int8", "int4"} and args.device != "cpu")
  456. setattr(args, "processor", processor) # noqa: B010
  457. setattr(args, "target_device", target_device) # noqa: B010
  458. setattr(args, "use_fp16", use_fp16) # noqa: B010
  459. setattr(args, "has_audio_stream", False) # noqa: B010
  460. setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010
  461. logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")
  462. # Measure cost to transcribe audio
  463. model = get_model(args)
  464. if args.benchmark_type == "ort":
  465. # Check for optional inputs that could have been added during export
  466. ort_model_inputs = {model_input.name for model_input in model.get_inputs()}
  467. args.has_audio_stream = "audio_stream" in ort_model_inputs
  468. setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010
  469. setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010
  470. setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010
  471. if args.decoder_input_ids == []:
  472. args.decoder_input_ids = [config.decoder_start_token_id]
  473. inputs = get_inputs(args)
  474. run_inference(args, inputs, model)
  475. if __name__ == "__main__":
  476. main()