benchmark_helper.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import csv
  7. import logging
  8. import os
  9. import random
  10. import sys
  11. import time
  12. import timeit
  13. from abc import ABC, abstractmethod
  14. from concurrent.futures import ThreadPoolExecutor
  15. from datetime import datetime
  16. from enum import Enum
  17. from time import sleep
  18. from typing import Any
  19. import numpy
  20. import torch
  21. import transformers
  22. from packaging import version
  23. import onnxruntime
  24. logger = logging.getLogger(__name__)
  25. class Precision(Enum):
  26. FLOAT32 = "fp32"
  27. FLOAT16 = "fp16"
  28. INT8 = "int8"
  29. INT4 = "int4"
  30. def __str__(self):
  31. return self.value
  32. class OptimizerInfo(Enum):
  33. # no_opt means using the raw ONNX model, but OnnxRuntime might still apply optimization as long as
  34. # graph optimization level is not 0 (disable all).
  35. NOOPT = "no_opt"
  36. BYORT = "by_ort"
  37. BYSCRIPT = "by_script"
  38. def __str__(self):
  39. return self.value
  40. class ConfigModifier:
  41. def __init__(self, num_layers):
  42. self.num_layers = num_layers
  43. def modify(self, config):
  44. if self.num_layers is None:
  45. return
  46. if hasattr(config, "num_hidden_layers"):
  47. config.num_hidden_layers = self.num_layers
  48. logger.info(f"Modifying pytorch model's number of hidden layers to: {self.num_layers}")
  49. if hasattr(config, "encoder_layers"):
  50. config.encoder_layers = self.num_layers
  51. logger.info(f"Modifying pytorch model's number of encoder layers to: {self.num_layers}")
  52. if hasattr(config, "decoder_layers "):
  53. config.decoder_layers = self.num_layers
  54. logger.info(f"Modifying pytorch model's number of decoder layers to: {self.num_layers}")
  55. def get_layer_num(self):
  56. return self.num_layers
  57. IO_BINDING_DATA_TYPE_MAP = {
  58. "float32": numpy.float32,
  59. # TODO: Add more.
  60. }
  61. def create_onnxruntime_session(
  62. onnx_model_path,
  63. use_gpu,
  64. provider=None,
  65. enable_all_optimization=True,
  66. num_threads=-1,
  67. enable_profiling=False,
  68. verbose=False,
  69. enable_mlas_gemm_fastmath_arm64_bfloat16=False,
  70. provider_options={}, # map execution provider name to its option # noqa: B006
  71. ):
  72. sess_options = onnxruntime.SessionOptions()
  73. if enable_all_optimization:
  74. sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  75. else:
  76. sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
  77. if enable_profiling:
  78. sess_options.enable_profiling = True
  79. if num_threads > 0:
  80. sess_options.intra_op_num_threads = num_threads
  81. logger.debug(f"Session option: intra_op_num_threads={sess_options.intra_op_num_threads}")
  82. if verbose:
  83. sess_options.log_severity_level = 0
  84. else:
  85. sess_options.log_severity_level = 4
  86. if provider in onnxruntime.get_available_providers():
  87. providers = [provider]
  88. elif use_gpu:
  89. if provider == "dml":
  90. providers = ["DmlExecutionProvider", "CPUExecutionProvider"]
  91. elif provider == "migraphx":
  92. providers = [
  93. "MIGraphXExecutionProvider",
  94. "CPUExecutionProvider",
  95. ]
  96. elif provider == "cuda" or provider is None:
  97. providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
  98. elif provider == "tensorrt":
  99. providers = [
  100. "TensorrtExecutionProvider",
  101. "CUDAExecutionProvider",
  102. "CPUExecutionProvider",
  103. ]
  104. else:
  105. raise RuntimeError(f"The execution provider is not supported: {provider}")
  106. else:
  107. providers = ["CPUExecutionProvider"]
  108. if provider_options:
  109. providers = [(name, provider_options[name]) if name in provider_options else name for name in providers]
  110. if enable_mlas_gemm_fastmath_arm64_bfloat16:
  111. sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1")
  112. session = None
  113. try:
  114. session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers)
  115. except Exception:
  116. logger.exception(f"Failed to create session for {onnx_model_path} with providers={providers}")
  117. return session
  118. def setup_logger(verbose=True):
  119. if verbose:
  120. logging.basicConfig(
  121. format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
  122. level=logging.DEBUG,
  123. )
  124. else:
  125. logging.basicConfig(format="%(message)s", level=logging.INFO)
  126. logging.getLogger("transformers").setLevel(logging.WARNING)
  127. def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
  128. if cache_dir and not os.path.exists(cache_dir):
  129. os.makedirs(cache_dir)
  130. if output_dir and not os.path.exists(output_dir):
  131. os.makedirs(output_dir)
  132. if use_gpu:
  133. if provider == "dml":
  134. assert "DmlExecutionProvider" in onnxruntime.get_available_providers(), (
  135. "Please install onnxruntime-directml package to test GPU inference."
  136. )
  137. else:
  138. assert not set(onnxruntime.get_available_providers()).isdisjoint(
  139. ["CUDAExecutionProvider", "MIGraphXExecutionProvider"]
  140. ), "Please install onnxruntime-gpu package, or install migraphx, to test GPU inference."
  141. logger.info(f"PyTorch Version:{torch.__version__}")
  142. logger.info(f"Transformers Version:{transformers.__version__}")
  143. logger.info(f"OnnxRuntime Version:{onnxruntime.__version__}")
  144. # Support three major versions of PyTorch and OnnxRuntime, and up to 9 months of transformers.
  145. assert version.parse(torch.__version__) >= version.parse("1.10.0")
  146. assert version.parse(transformers.__version__) >= version.parse("4.12.0")
  147. assert version.parse(onnxruntime.__version__) >= version.parse("1.10.0")
  148. def get_latency_result(latency_list, batch_size):
  149. latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
  150. latency_variance = numpy.var(latency_list, dtype=numpy.float64) * 1000.0
  151. throughput = batch_size * (1000.0 / latency_ms)
  152. return {
  153. "test_times": len(latency_list),
  154. "latency_variance": f"{latency_variance:.2f}",
  155. "latency_90_percentile": f"{numpy.percentile(latency_list, 90) * 1000.0:.2f}",
  156. "latency_95_percentile": f"{numpy.percentile(latency_list, 95) * 1000.0:.2f}",
  157. "latency_99_percentile": f"{numpy.percentile(latency_list, 99) * 1000.0:.2f}",
  158. "average_latency_ms": f"{latency_ms:.2f}",
  159. "QPS": f"{throughput:.2f}",
  160. }
  161. def output_details(results, csv_filename):
  162. with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
  163. column_names = [
  164. "engine",
  165. "version",
  166. "providers",
  167. "device",
  168. "precision",
  169. "optimizer",
  170. "io_binding",
  171. "model_name",
  172. "inputs",
  173. "threads",
  174. "batch_size",
  175. "sequence_length",
  176. "custom_layer_num",
  177. "datetime",
  178. "test_times",
  179. "QPS",
  180. "average_latency_ms",
  181. "latency_variance",
  182. "latency_90_percentile",
  183. "latency_95_percentile",
  184. "latency_99_percentile",
  185. ]
  186. csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
  187. csv_writer.writeheader()
  188. for result in results:
  189. csv_writer.writerow(result)
  190. logger.info(f"Detail results are saved to csv file: {csv_filename}")
  191. def output_summary(results, csv_filename, args):
  192. with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
  193. header_names = [
  194. "model_name",
  195. "inputs",
  196. "custom_layer_num",
  197. "engine",
  198. "version",
  199. "providers",
  200. "device",
  201. "precision",
  202. "optimizer",
  203. "io_binding",
  204. "threads",
  205. ]
  206. data_names = []
  207. for batch_size in args.batch_sizes:
  208. if args.sequence_lengths == [""]:
  209. data_names.append(f"b{batch_size}")
  210. else:
  211. for sequence_length in args.sequence_lengths:
  212. data_names.append(f"b{batch_size}_s{sequence_length}")
  213. csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
  214. csv_writer.writeheader()
  215. for model_name in args.models:
  216. for input_count in [1, 2, 3]:
  217. for engine_name in args.engines:
  218. for io_binding in [True, False, ""]:
  219. for threads in args.num_threads:
  220. row = {}
  221. for result in results:
  222. if (
  223. result["model_name"] == model_name
  224. and result["inputs"] == input_count
  225. and result["engine"] == engine_name
  226. and result["io_binding"] == io_binding
  227. and result["threads"] == threads
  228. ):
  229. headers = {k: v for k, v in result.items() if k in header_names}
  230. if not row:
  231. row.update(headers)
  232. row.update(dict.fromkeys(data_names, ""))
  233. else:
  234. for k in header_names:
  235. assert row[k] == headers[k]
  236. b = result["batch_size"]
  237. s = result["sequence_length"]
  238. if s:
  239. row[f"b{b}_s{s}"] = result["average_latency_ms"]
  240. else:
  241. row[f"b{b}"] = result["average_latency_ms"]
  242. if row:
  243. csv_writer.writerow(row)
  244. logger.info(f"Summary results are saved to csv file: {csv_filename}")
  245. def output_fusion_statistics(model_fusion_statistics, csv_filename):
  246. with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
  247. column_names = [
  248. "model_filename",
  249. "datetime",
  250. "transformers",
  251. "torch",
  252. *list(next(iter(model_fusion_statistics.values())).keys()),
  253. ]
  254. csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
  255. csv_writer.writeheader()
  256. for key in model_fusion_statistics:
  257. model_fusion_statistics[key]["datetime"] = str(datetime.now())
  258. model_fusion_statistics[key]["transformers"] = transformers.__version__
  259. model_fusion_statistics[key]["torch"] = torch.__version__
  260. model_fusion_statistics[key]["model_filename"] = key
  261. csv_writer.writerow(model_fusion_statistics[key])
  262. logger.info(f"Fusion statistics is saved to csv file: {csv_filename}")
  263. def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size, warm_up_repeat=0):
  264. result = {}
  265. timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=warm_up_repeat) # Dry run
  266. latency_list = timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=repeat_times)
  267. result.update(result_template)
  268. result.update({"io_binding": False})
  269. result.update(get_latency_result(latency_list, batch_size))
  270. return result
  271. def inference_ort_with_io_binding(
  272. ort_session,
  273. ort_inputs,
  274. result_template,
  275. repeat_times,
  276. ort_output_names,
  277. ort_outputs,
  278. output_buffers,
  279. output_buffer_max_sizes,
  280. batch_size,
  281. device,
  282. data_type=numpy.longlong,
  283. warm_up_repeat=0,
  284. ):
  285. result = {}
  286. # Bind inputs and outputs to onnxruntime session
  287. io_binding = ort_session.io_binding()
  288. # Bind inputs to device
  289. for name in ort_inputs:
  290. np_input = torch.from_numpy(ort_inputs[name]).to(device)
  291. input_type = IO_BINDING_DATA_TYPE_MAP.get(str(ort_inputs[name].dtype), data_type)
  292. io_binding.bind_input(
  293. name,
  294. np_input.device.type,
  295. 0,
  296. input_type,
  297. np_input.shape,
  298. np_input.data_ptr(),
  299. )
  300. # Bind outputs buffers with the sizes needed if not allocated already
  301. if len(output_buffers) == 0:
  302. allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device)
  303. for i, ort_output_name in enumerate(ort_output_names):
  304. io_binding.bind_output(
  305. ort_output_name,
  306. output_buffers[i].device.type,
  307. 0,
  308. numpy.float32,
  309. ort_outputs[i].shape,
  310. output_buffers[i].data_ptr(),
  311. )
  312. timeit.repeat(
  313. lambda: ort_session.run_with_iobinding(io_binding),
  314. number=1,
  315. repeat=warm_up_repeat,
  316. ) # Dry run
  317. latency_list = timeit.repeat(
  318. lambda: ort_session.run_with_iobinding(io_binding),
  319. number=1,
  320. repeat=repeat_times,
  321. )
  322. result.update(result_template)
  323. result.update({"io_binding": True})
  324. result.update(get_latency_result(latency_list, batch_size))
  325. return result
  326. def allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device): # noqa: N802
  327. # Allocate output tensors with the largest test size needed. So the allocated memory can be reused
  328. # for each test run.
  329. for i in output_buffer_max_sizes:
  330. output_buffers.append(torch.empty(i, dtype=torch.float32, device=device))
  331. def set_random_seed(seed=123):
  332. """Set random seed manually to get deterministic results"""
  333. random.seed(seed)
  334. numpy.random.seed(seed)
  335. torch.manual_seed(seed)
  336. torch.cuda.manual_seed(seed)
  337. torch.cuda.manual_seed_all(seed)
  338. # torch.backends.cudnn.enabled = False
  339. # torch.backends.cudnn.benchmark = False
  340. # torch.backends.cudnn.deterministic = True
  341. def get_gpu_info() -> list[dict[str, Any]] | None:
  342. from py3nvml.py3nvml import ( # noqa: PLC0415
  343. NVMLError,
  344. nvmlDeviceGetCount,
  345. nvmlDeviceGetHandleByIndex,
  346. nvmlDeviceGetMemoryInfo,
  347. nvmlDeviceGetName,
  348. nvmlInit,
  349. nvmlShutdown,
  350. )
  351. try:
  352. nvmlInit()
  353. result = []
  354. device_count = nvmlDeviceGetCount()
  355. if not isinstance(device_count, int):
  356. return None
  357. for i in range(device_count):
  358. info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
  359. if isinstance(info, str):
  360. return None
  361. result.append(
  362. {
  363. "id": i,
  364. "name": nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)),
  365. "total": info.total,
  366. "free": info.free,
  367. "used": info.used,
  368. }
  369. )
  370. nvmlShutdown()
  371. return result
  372. except NVMLError as error:
  373. print("Error fetching GPU information using nvml: %s", error)
  374. return None
  375. class MemoryMonitor(ABC):
  376. def __init__(self, keep_measuring=True):
  377. self.keep_measuring = keep_measuring
  378. def measure_cpu_usage(self):
  379. import psutil # noqa: PLC0415
  380. max_usage = 0
  381. while True:
  382. max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
  383. sleep(0.005) # 5ms
  384. if not self.keep_measuring:
  385. break
  386. return max_usage
  387. @abstractmethod
  388. def measure_gpu_usage(self) -> list[dict[str, Any]] | None:
  389. raise NotImplementedError()
  390. class CudaMemoryMonitor(MemoryMonitor):
  391. def __init__(self, keep_measuring=True):
  392. super().__init__(keep_measuring)
  393. def measure_gpu_usage(self) -> list[dict[str, Any]] | None:
  394. from py3nvml.py3nvml import ( # noqa: PLC0415
  395. NVMLError,
  396. nvmlDeviceGetCount,
  397. nvmlDeviceGetHandleByIndex,
  398. nvmlDeviceGetMemoryInfo,
  399. nvmlDeviceGetName,
  400. nvmlInit,
  401. nvmlShutdown,
  402. )
  403. max_gpu_usage = []
  404. gpu_name = []
  405. try:
  406. nvmlInit()
  407. device_count = nvmlDeviceGetCount()
  408. if not isinstance(device_count, int):
  409. logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
  410. return None
  411. max_gpu_usage = [0 for i in range(device_count)]
  412. gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
  413. while True:
  414. for i in range(device_count):
  415. info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
  416. if isinstance(info, str):
  417. logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
  418. return None
  419. max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
  420. sleep(0.005) # 5ms
  421. if not self.keep_measuring:
  422. break
  423. nvmlShutdown()
  424. return [
  425. {
  426. "device_id": i,
  427. "name": gpu_name[i],
  428. "max_used_MB": max_gpu_usage[i],
  429. }
  430. for i in range(device_count)
  431. ]
  432. except NVMLError as error:
  433. logger.error("Error fetching GPU information using nvml: %s", error)
  434. return None
  435. class RocmMemoryMonitor(MemoryMonitor):
  436. def __init__(self, keep_measuring=True):
  437. super().__init__(keep_measuring)
  438. rocm_smi_path = "/opt/rocm/libexec/rocm_smi"
  439. if os.path.exists(rocm_smi_path):
  440. if rocm_smi_path not in sys.path:
  441. sys.path.append(rocm_smi_path)
  442. try:
  443. import rocm_smi # noqa: PLC0415
  444. self.rocm_smi = rocm_smi
  445. self.rocm_smi.initializeRsmi()
  446. except ImportError:
  447. self.rocm_smi = None
  448. def get_used_memory(self, dev):
  449. if self.rocm_smi is None:
  450. return -1
  451. return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024
  452. def measure_gpu_usage(self):
  453. if self.rocm_smi is None:
  454. return None
  455. device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0
  456. max_gpu_usage = [0 for i in range(device_count)]
  457. gpu_name = [f"GPU{i}" for i in range(device_count)]
  458. while True:
  459. for i in range(device_count):
  460. max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i))
  461. time.sleep(0.005) # 5ms
  462. if not self.keep_measuring:
  463. break
  464. return [
  465. {
  466. "device_id": i,
  467. "name": gpu_name[i],
  468. "max_used_MB": max_gpu_usage[i],
  469. }
  470. for i in range(device_count)
  471. ]
  472. def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None):
  473. memory_monitor_type = None
  474. if monitor_type == "rocm":
  475. memory_monitor_type = RocmMemoryMonitor
  476. else:
  477. memory_monitor_type = CudaMemoryMonitor
  478. monitor = memory_monitor_type(False)
  479. if is_gpu:
  480. if start_memory is not None:
  481. memory_before_test = start_memory
  482. else:
  483. memory_before_test = monitor.measure_gpu_usage()
  484. if memory_before_test is None:
  485. return None
  486. if func is None:
  487. return memory_before_test
  488. with ThreadPoolExecutor() as executor:
  489. monitor = memory_monitor_type()
  490. mem_thread = executor.submit(monitor.measure_gpu_usage)
  491. try:
  492. fn_thread = executor.submit(func)
  493. _ = fn_thread.result()
  494. finally:
  495. monitor.keep_measuring = False
  496. max_usage = mem_thread.result()
  497. if max_usage is None:
  498. return None
  499. logger.info(f"GPU memory usage: before={memory_before_test} peak={max_usage}")
  500. if len(memory_before_test) >= 1 and len(max_usage) >= 1 and len(memory_before_test) == len(max_usage):
  501. # When there are multiple GPUs, we will check the one with maximum usage.
  502. max_used = 0
  503. for i, memory_before in enumerate(memory_before_test):
  504. before = memory_before["max_used_MB"]
  505. after = max_usage[i]["max_used_MB"]
  506. used = after - before
  507. max_used = max(max_used, used)
  508. return max_used
  509. return None
  510. # CPU memory
  511. if start_memory is not None:
  512. memory_before_test = start_memory
  513. else:
  514. memory_before_test = monitor.measure_cpu_usage()
  515. if func is None:
  516. return memory_before_test
  517. with ThreadPoolExecutor() as executor:
  518. monitor = memory_monitor_type()
  519. mem_thread = executor.submit(monitor.measure_cpu_usage)
  520. try:
  521. fn_thread = executor.submit(func)
  522. _ = fn_thread.result()
  523. finally:
  524. monitor.keep_measuring = False
  525. max_usage = mem_thread.result()
  526. logger.info(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB")
  527. return max_usage - memory_before_test
  528. def get_ort_environment_variables():
  529. # Environment variables might impact ORT performance on transformer models. Note that they are for testing only.
  530. env_names = [
  531. "ORT_DISABLE_FUSED_ATTENTION",
  532. "ORT_ENABLE_FUSED_CAUSAL_ATTENTION",
  533. "ORT_DISABLE_FUSED_CROSS_ATTENTION",
  534. "ORT_DISABLE_TRT_FLASH_ATTENTION",
  535. "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
  536. "ORT_TRANSFORMER_OPTIONS",
  537. "ORT_CUDA_GEMM_OPTIONS",
  538. ]
  539. env = ""
  540. for name in env_names:
  541. value = os.getenv(name)
  542. if value is None:
  543. continue
  544. if env:
  545. env += ","
  546. env += f"{name}={value}"
  547. return env