onnxruntime_test.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from __future__ import annotations
  6. import argparse
  7. import os
  8. import sys
  9. from timeit import default_timer as timer
  10. import numpy as np
  11. import onnxruntime as onnxrt
  12. float_dict = {
  13. "tensor(float16)": "float16",
  14. "tensor(float)": "float32",
  15. "tensor(double)": "float64",
  16. }
  17. integer_dict = {
  18. "tensor(int32)": "int32",
  19. "tensor(int8)": "int8",
  20. "tensor(uint8)": "uint8",
  21. "tensor(int16)": "int16",
  22. "tensor(uint16)": "uint16",
  23. "tensor(int64)": "int64",
  24. "tensor(uint64)": "uint64",
  25. }
  26. def generate_feeds(sess, symbolic_dims: dict | None = None):
  27. feeds = {}
  28. symbolic_dims = symbolic_dims or {}
  29. for input_meta in sess.get_inputs():
  30. # replace any symbolic dimensions
  31. shape = []
  32. for dim in input_meta.shape:
  33. if not dim:
  34. # unknown dim
  35. shape.append(1)
  36. elif isinstance(dim, str):
  37. # symbolic dim. see if we have a value otherwise use 1
  38. if dim in symbolic_dims:
  39. shape.append(int(symbolic_dims[dim]))
  40. else:
  41. shape.append(1)
  42. else:
  43. shape.append(dim)
  44. if input_meta.type in float_dict:
  45. feeds[input_meta.name] = np.random.rand(*shape).astype(float_dict[input_meta.type])
  46. elif input_meta.type in integer_dict:
  47. feeds[input_meta.name] = np.random.uniform(high=1000, size=tuple(shape)).astype(
  48. integer_dict[input_meta.type]
  49. )
  50. elif input_meta.type == "tensor(bool)":
  51. feeds[input_meta.name] = np.random.randint(2, size=tuple(shape)).astype("bool")
  52. else:
  53. print(f"unsupported input type {input_meta.type} for input {input_meta.name}")
  54. sys.exit(-1)
  55. return feeds
  56. # simple test program for loading onnx model, feeding all inputs and running the model num_iters times.
  57. def run_model(
  58. model_path,
  59. num_iters=1,
  60. debug=None,
  61. profile=None,
  62. symbolic_dims=None,
  63. feeds=None,
  64. override_initializers=True,
  65. ):
  66. symbolic_dims = symbolic_dims or {}
  67. if debug:
  68. print(f"Pausing execution ready for debugger to attach to pid: {os.getpid()}")
  69. print("Press key to continue.")
  70. sys.stdin.read(1)
  71. sess_options = None
  72. if profile:
  73. sess_options = onnxrt.SessionOptions()
  74. sess_options.enable_profiling = True
  75. sess_options.profile_file_prefix = os.path.basename(model_path)
  76. sess = onnxrt.InferenceSession(
  77. model_path,
  78. sess_options=sess_options,
  79. providers=onnxrt.get_available_providers(),
  80. )
  81. meta = sess.get_modelmeta()
  82. if not feeds:
  83. feeds = generate_feeds(sess, symbolic_dims)
  84. if override_initializers:
  85. # Starting with IR4 some initializers provide default values
  86. # and can be overridden (available in IR4). For IR < 4 models
  87. # the list would be empty
  88. for initializer in sess.get_overridable_initializers():
  89. shape = [dim if dim else 1 for dim in initializer.shape]
  90. if initializer.type in float_dict:
  91. feeds[initializer.name] = np.random.rand(*shape).astype(float_dict[initializer.type])
  92. elif initializer.type in integer_dict:
  93. feeds[initializer.name] = np.random.uniform(high=1000, size=tuple(shape)).astype(
  94. integer_dict[initializer.type]
  95. )
  96. elif initializer.type == "tensor(bool)":
  97. feeds[initializer.name] = np.random.randint(2, size=tuple(shape)).astype("bool")
  98. else:
  99. print(f"unsupported initializer type {initializer.type} for initializer {initializer.name}")
  100. sys.exit(-1)
  101. start = timer()
  102. for _i in range(num_iters):
  103. outputs = sess.run([], feeds) # fetch all outputs
  104. end = timer()
  105. print(f"model: {meta.graph_name}")
  106. print(f"version: {meta.version}")
  107. print(f"iterations: {num_iters}")
  108. print(f"avg latency: {((end - start) * 1000) / num_iters} ms")
  109. if profile:
  110. trace_file = sess.end_profiling()
  111. print(f"trace file written to: {trace_file}")
  112. return 0, feeds, num_iters > 0 and outputs
  113. def main():
  114. parser = argparse.ArgumentParser(description="Simple ONNX Runtime Test Tool.")
  115. parser.add_argument("model_path", help="model path")
  116. parser.add_argument(
  117. "num_iters",
  118. nargs="?",
  119. type=int,
  120. default=1000,
  121. help="model run iterations. default=1000",
  122. )
  123. parser.add_argument(
  124. "--debug",
  125. action="store_true",
  126. help="pause execution to allow attaching a debugger.",
  127. )
  128. parser.add_argument("--profile", action="store_true", help="enable chrome timeline trace profiling.")
  129. parser.add_argument(
  130. "--symbolic_dims",
  131. default={},
  132. type=lambda s: dict(x.split("=") for x in s.split(",")),
  133. help="Comma separated name=value pairs for any symbolic dimensions in the model input. "
  134. "e.g. --symbolic_dims batch=1,seqlen=5. "
  135. "If not provided, the value of 1 will be used for all symbolic dimensions.",
  136. )
  137. args = parser.parse_args()
  138. exit_code, _, _ = run_model(args.model_path, args.num_iters, args.debug, args.profile, args.symbolic_dims)
  139. sys.exit(exit_code)
  140. if __name__ == "__main__":
  141. main()