| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License. See License.txt in the project root for
- # license information.
- # --------------------------------------------------------------------------
- import argparse
- import datetime
- import gc
- import itertools
- import logging
- import os
- import sys
- import time
- import numpy as np
- import onnx
- import psutil
- import torch
- from benchmark_helper import measure_memory, setup_logger
- from dist_settings import get_rank, get_size
- from llama_inputs import (
- add_io_bindings_as_ortvalues,
- get_merged_sample_with_past_kv_inputs,
- get_msft_sample_inputs,
- get_sample_inputs,
- get_sample_with_past_kv_inputs,
- verify_ort_inputs,
- )
- from optimum.onnxruntime import ORTModelForCausalLM
- from torch.profiler import ProfilerActivity, profile, record_function
- from tqdm import trange
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
- import onnxruntime as ort
- logger = logging.getLogger(__name__)
- # For determining whether the ONNX model can do both prompt generation and token generation or only one of the two
- def get_ort_model_inputs_len(args, model):
- if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
- return 0
- if args.benchmark_type == "hf-ort":
- try:
- # New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268)
- return len(model.inputs_names)
- except Exception:
- # Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54)
- return len(model.decoder.input_names)
- return len(model.get_inputs())
- def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
- init_inputs, iter_inputs = None, None
- # For past_present_share_buffer:
- # Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
- # Set max_seq_len to config value for other models
- max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings
- if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
- init_inputs = get_sample_inputs(
- args.config,
- args.target_device,
- args.batch_size,
- args.sequence_length,
- return_dict=True,
- )
- iter_inputs = get_sample_with_past_kv_inputs(
- args.config,
- args.target_device,
- args.batch_size,
- args.sequence_length,
- use_fp16=args.use_fp16,
- return_dict=True,
- )
- elif args.benchmark_type in {"hf-ort"}:
- if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
- # Using split models in Optimum (e.g. created by Optimum export)
- init_inputs = get_sample_inputs(
- args.config,
- args.target_device,
- args.batch_size,
- args.sequence_length,
- return_dict=True,
- )
- iter_inputs = get_sample_with_past_kv_inputs(
- args.config,
- args.target_device,
- args.batch_size,
- args.sequence_length,
- use_fp16=args.use_fp16,
- return_dict=True,
- )
- else:
- # Using merged model in Optimum (e.g. created by convert_to_onnx export)
- init_inputs = get_merged_sample_with_past_kv_inputs(
- args.config,
- args.target_device,
- args.batch_size,
- seq_len=args.sequence_length,
- past_seq_len=0,
- max_seq_len=max_seq_len,
- use_fp16=args.use_fp16,
- use_buffer_share=args.use_buffer_share,
- engine="pt",
- return_dict=True,
- )
- iter_inputs = get_merged_sample_with_past_kv_inputs(
- args.config,
- args.target_device,
- args.batch_size,
- seq_len=1,
- past_seq_len=args.sequence_length,
- max_seq_len=max_seq_len,
- use_fp16=args.use_fp16,
- use_buffer_share=args.use_buffer_share,
- engine="pt",
- return_dict=True,
- )
- elif args.benchmark_type == "ort-convert-to-onnx":
- # Microsoft export from convert_to_onnx
- init_inputs = get_merged_sample_with_past_kv_inputs(
- args.config,
- args.target_device,
- args.batch_size,
- seq_len=args.sequence_length,
- past_seq_len=0,
- max_seq_len=max_seq_len,
- use_fp16=args.use_fp16,
- use_buffer_share=args.use_buffer_share,
- engine="ort",
- return_dict=True,
- world_size=args.world_size,
- )
- iter_inputs = get_merged_sample_with_past_kv_inputs(
- args.config,
- args.target_device,
- args.batch_size,
- seq_len=1,
- past_seq_len=args.sequence_length,
- max_seq_len=max_seq_len,
- use_fp16=args.use_fp16,
- use_buffer_share=args.use_buffer_share,
- engine="ort",
- return_dict=True,
- world_size=args.world_size,
- )
- elif args.benchmark_type == "ort-msft":
- # Microsoft export from https://github.com/microsoft/Llama-2-Onnx
- split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos]
- init_inputs = get_msft_sample_inputs(
- args.config,
- args.batch_size,
- past_seq_len=0,
- seq_len=args.sequence_length,
- max_seq_len=max_seq_len,
- use_fp16=args.use_fp16,
- use_buffer_share=args.use_buffer_share,
- split_kv=split_kv,
- )
- iter_inputs = get_msft_sample_inputs(
- args.config,
- args.batch_size,
- past_seq_len=args.sequence_length,
- seq_len=1,
- max_seq_len=max_seq_len,
- use_fp16=args.use_fp16,
- use_buffer_share=args.use_buffer_share,
- split_kv=split_kv,
- )
- else:
- raise Exception("Unable to auto-detect inputs for provided model")
- return init_inputs, iter_inputs
- def get_model(args: argparse.Namespace):
- model, sess_options = None, None
- start_time, end_time = None, None
- # There are multiple sources that the model could come from:
- # 1) Benchmark LLaMA-2 from unofficial source on Hugging Face
- # 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token
- # 3) Benchmark LLaMA-2 from local download of model
- # 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx)
- # 5) Benchmark LLaMA-2 from convert_to_onnx
- if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
- source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name
- start_time = time.time()
- model = AutoModelForCausalLM.from_pretrained(
- source,
- torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
- use_auth_token=args.auth,
- trust_remote_code=args.auth,
- use_cache=True,
- cache_dir=args.cache_dir,
- ).to(args.target_device)
- end_time = time.time()
- if args.benchmark_type == "hf-pt-compile":
- model = torch.compile(model)
- elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}:
- sess_options = ort.SessionOptions()
- sess_options.enable_profiling = args.profile
- if args.verbose:
- sess_options.log_verbosity_level = 1
- sess_options.log_severity_level = 1
- else:
- raise Exception(f"Cannot recognize {args.benchmark_type}")
- if args.benchmark_type == "hf-ort":
- # Optimum export or convert_to_onnx.py export
- provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
- provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
- decoder_file_name = None
- decoder_with_past_file_name = None
- for filename in os.listdir(args.hf_ort_dir_path):
- if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
- continue
- if "decoder_model" in filename or filename == "model.onnx":
- decoder_file_name = filename
- if "decoder_with_past_model" in filename:
- decoder_with_past_file_name = filename
- if "decoder_merged_model" in filename:
- decoder_file_name = filename
- decoder_with_past_file_name = filename
- start_time = time.time()
- model = ORTModelForCausalLM.from_pretrained(
- args.hf_ort_dir_path,
- decoder_file_name=decoder_file_name,
- decoder_with_past_file_name=decoder_with_past_file_name,
- use_auth_token=args.auth,
- trust_remote_code=args.auth,
- use_io_binding=True, # Large perf gain even for cpu due to avoiding output copy.
- use_merged=(True if decoder_file_name == "model.onnx" else None),
- provider=provider,
- provider_options=provider_options,
- session_options=sess_options,
- )
- end_time = time.time()
- if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
- # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
- logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}")
- start_time = time.time()
- model = ort.InferenceSession(
- args.ort_model_path.format(args.rank),
- sess_options,
- providers=[args.execution_provider],
- )
- end_time = time.time()
- logger.info(f"Loaded model in {end_time - start_time} s")
- return model
- def time_fn(args, fn, inputs):
- # Warm up
- warmup_range = (
- range(args.warmup_runs)
- if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
- else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
- )
- if args.verbose:
- outputs = fn(inputs)
- logger.info(outputs)
- input_sync = lambda *kwargs: ( # noqa: E731
- args.io_binding.synchronize_inputs()
- if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
- else lambda *kwargs: (
- torch.cuda.synchronize()
- if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
- else lambda *kwargs: None
- )
- ) # no-op function
- output_sync = lambda *kwargs: ( # noqa: E731
- args.io_binding.synchronize_outputs()
- if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
- else lambda *kwargs: (
- torch.cuda.synchronize()
- if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
- else lambda *kwargs: None
- )
- ) # no-op function
- for _ in warmup_range:
- input_sync()
- fn(inputs)
- output_sync()
- # Benchmark
- total_time = 0
- bench_range = (
- range(args.num_runs)
- if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
- else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
- )
- for _ in bench_range:
- input_sync()
- start_time = time.time()
- fn(inputs)
- output_sync()
- end_time = time.time()
- total_time += end_time - start_time
- # Newline print after trange in order to print metrics on new lines without progress bar on same line
- if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}:
- logger.info("")
- latency = total_time / args.num_runs
- throughput = args.batch_size / latency
- if args.rank == 0:
- logger.info(f"Batch Size: {args.batch_size}")
- logger.info(f"Sequence Length: {args.sequence_length}")
- logger.info(f"Latency: {latency} s")
- logger.info(f"Throughput: {throughput} tps")
- return
- def profile_fn(args, fn, inputs, inputs_type):
- # Filename prefix format:
- # "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
- prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
- filename = None
- if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
- # Profile PyTorch kernels
- with profile( # noqa: SIM117
- activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
- ) as prof:
- with record_function("model_inference"):
- fn(inputs)
- prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
- filename = os.path.join(args.log_folder, f"{prefix}.log")
- with open(filename, "w") as f:
- f.write(prof_data)
- else:
- # Profile ORT kernels
- fn(inputs)
- # Set new log name for ORT profile log generated
- filename = f"{prefix}.json"
- return filename
- def measure_fn(args, fn, inputs):
- # Measure CPU usage
- pid = os.getpid()
- process = psutil.Process(pid)
- process.cpu_percent(interval=0.1)
- fn(inputs)
- if args.rank == 0:
- logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%")
- # Measure memory usage
- gc.collect()
- torch.cuda.empty_cache()
- measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
- # Flush output so memory usage is printed
- sys.stdout.flush()
- def run_hf_inference(args, init_inputs, iter_inputs, model):
- # Inference steps to measure
- def get_logits(inputs):
- # Inference pass without decoding
- outputs = model(**inputs)
- return outputs
- # Examples of other inference steps that can be measured:
- # To use, uncomment the function and assign it to `generate_fn`
- # def get_pred_ids(inputs):
- # # Inference pass with predicted token ids generation
- # predicted_ids = model.generate(**inputs)
- # return predicted_ids
- # def gen_and_dec(inputs):
- # # Inference pass with generation and decoding
- # predicted_ids = get_pred_ids(inputs)
- # transcription = []
- # for bs in range(args.batch_size):
- # for rs in range(args.num_return_sequences):
- # transcription.append(
- # args.tokenizer.batch_decode(
- # predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True
- # )[0]
- # )
- # return transcription
- generate_fn = get_logits
- if args.benchmark_type == "hf-pt-compile":
- # Run forward pass once with each set of inputs to process through Dynamo
- generate_fn(init_inputs)
- generate_fn(iter_inputs)
- if args.profile:
- new_logname = profile_fn(args, generate_fn, init_inputs, "prompt")
- if args.benchmark_type == "hf-ort":
- # Turn profiling off to stop appending to log
- old_logname = model.decoder.session.end_profiling()
- logger.warning(f"Renaming {old_logname} to {new_logname}")
- os.rename(old_logname, os.path.join(args.log_folder, new_logname))
- new_logname = profile_fn(args, generate_fn, iter_inputs, "token")
- if args.benchmark_type == "hf-ort":
- # Turn profiling off to stop appending to log
- old_logname = model.decoder_with_past.session.end_profiling()
- logger.warning(f"Renaming {old_logname} to {new_logname}")
- os.rename(old_logname, os.path.join(args.log_folder, new_logname))
- return
- # PyTorch evaluations
- logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
- time_fn(args, generate_fn, init_inputs)
- measure_fn(args, generate_fn, init_inputs)
- logger.info("\nEvaluating `model(inputs)` step with past_key_values")
- time_fn(args, generate_fn, iter_inputs)
- measure_fn(args, generate_fn, iter_inputs)
- def run_ort_inference(args, init_inputs, iter_inputs, model):
- def prepare_ort_inputs(inputs, kv_cache_ortvalues):
- # Verify model inputs
- inputs = verify_ort_inputs(model, inputs)
- # Add IO bindings for non-CPU execution providers
- if args.device != "cpu":
- io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
- model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues
- )
- setattr(args, "io_binding", io_binding) # noqa: B010
- return io_binding, kv_cache_ortvalues
- return inputs, kv_cache_ortvalues
- def with_io_binding(io_binding):
- # Inference pass with IO binding
- model.run_with_iobinding(io_binding)
- def without_io_binding(inputs):
- # Inference pass without IO binding
- outputs = model.run(None, inputs)
- return outputs
- generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
- kv_cache_ortvalues = {}
- if args.profile:
- ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
- new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
- # Turn profiling off to stop appending to log file
- old_logname = model.end_profiling()
- logger.warning(f"Renaming {old_logname} to {new_logname}")
- os.rename(old_logname, os.path.join(args.log_folder, new_logname))
- # Re-initialize model for new log file instead of appending to old log file
- model = get_model(args)
- ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
- new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
- # Turn profiling off to stop appending to log
- old_logname = model.end_profiling()
- logger.warning(f"Renaming {old_logname} to {new_logname}")
- os.rename(old_logname, os.path.join(args.log_folder, new_logname))
- return
- # ORT evaluations
- logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
- ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
- time_fn(args, generate_fn, ort_init_inputs)
- measure_fn(args, generate_fn, ort_init_inputs)
- logger.info("\nEvaluating `model(inputs)` step with past_key_values")
- ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
- time_fn(args, generate_fn, ort_iter_inputs)
- measure_fn(args, generate_fn, ort_iter_inputs)
- def run_inference(args, init_inputs, iter_inputs, model):
- if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
- run_hf_inference(args, init_inputs, iter_inputs, model)
- elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
- run_ort_inference(args, init_inputs, iter_inputs, model)
- else:
- raise Exception(f"Cannot recognize {args.benchmark_type}")
- def get_args(rank=0):
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-bt",
- "--benchmark-type",
- type=str,
- required=True,
- choices=[
- "hf-pt-eager",
- "hf-pt-compile",
- "hf-ort",
- "ort-msft",
- "ort-convert-to-onnx",
- ],
- )
- parser.add_argument(
- "-m",
- "--model-name",
- type=str,
- required=True,
- help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
- )
- parser.add_argument(
- "-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model"
- )
- # Args for choosing the model
- parser.add_argument(
- "-p",
- "--precision",
- required=True,
- type=str,
- default="fp32",
- choices=["int4", "int8", "fp16", "fp32"],
- help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
- )
- parser.add_argument(
- "--hf-pt-dir-path",
- type=str,
- default="",
- help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
- )
- parser.add_argument(
- "--hf-ort-dir-path",
- type=str,
- default="",
- help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)",
- )
- parser.add_argument(
- "--ort-model-path",
- type=str,
- default="",
- help="Path to ONNX model",
- )
- # Args for running and evaluating the model
- parser.add_argument(
- "-b",
- "--batch-sizes",
- default="1 2",
- )
- parser.add_argument(
- "-s",
- "--sequence-lengths",
- default="32 64 128 256 512",
- )
- parser.add_argument(
- "-d",
- "--device",
- type=str,
- default="cuda" if torch.cuda.is_available() else "cpu",
- choices=["cpu", "cuda"],
- )
- parser.add_argument("-id", "--device-id", type=int, default=0)
- parser.add_argument("-w", "--warmup-runs", type=int, default=5)
- parser.add_argument("-n", "--num-runs", type=int, default=10)
- parser.add_argument("--seed", type=int, default=2)
- # Args for decoding logic
- parser.add_argument("--max-length", type=int, default=32)
- parser.add_argument("--num-return-sequences", type=int, default=1)
- # Args for accessing detailed info
- parser.add_argument("--profile", default=False, action="store_true")
- parser.add_argument(
- "--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
- )
- parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
- parser.add_argument("--verbose", default=False, action="store_true")
- parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
- parser.add_argument(
- "--cache-dir",
- type=str,
- required=True,
- default="./model_cache",
- help="Cache dir where Hugging Face files are stored",
- )
- args = parser.parse_args()
- # Set seed properties
- np.random.seed(args.seed)
- torch.manual_seed(args.seed)
- # Set runtime properties
- if "ort" in args.benchmark_type:
- setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
- if args.execution_provider == "CUDAExecutionProvider":
- args.execution_provider = (args.execution_provider, {"device_id": rank})
- # Check that paths have been specified for any benchmarking with ORT
- if args.benchmark_type == "hf-ort":
- assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
- if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
- assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
- args.batch_sizes = args.batch_sizes.split(" ")
- args.sequence_lengths = args.sequence_lengths.split(" ")
- # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
- args.precision = (
- "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
- )
- # Check that only one (batch_size, sequence_length) combination is set for profiling
- if args.profile:
- assert len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1, (
- "Please provide only one (batch_size, sequence_length) combination for profiling"
- )
- return args
- def main():
- rank = get_rank()
- world_size = get_size()
- args = get_args(rank)
- setup_logger(args.verbose)
- logger.info(args.__dict__)
- torch.backends.cudnn.benchmark = True
- args.rank = rank
- args.world_size = world_size
- tokenizer = AutoTokenizer.from_pretrained(
- args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
- )
- config = AutoConfig.from_pretrained(
- args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
- )
- target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
- use_fp16 = args.precision == "fp16"
- setattr(args, "tokenizer", tokenizer) # noqa: B010
- setattr(args, "config", config) # noqa: B010
- setattr(args, "target_device", target_device) # noqa: B010
- setattr(args, "use_fp16", use_fp16) # noqa: B010
- # Get model and model info
- model = get_model(args)
- ort_model_inputs_len = get_ort_model_inputs_len(args, model)
- # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
- if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
- onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False)
- gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
- use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
- setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010
- else:
- setattr(args, "use_buffer_share", False) # noqa: B010
- # Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
- for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
- if args.rank == 0:
- logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
- setattr(args, "batch_size", int(batch_size)) # noqa: B010
- setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
- init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len)
- run_inference(args, init_inputs, iter_inputs, model)
- if __name__ == "__main__":
- main()
|