optimize_onnx_model.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #!/usr/bin/env python3
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. from __future__ import annotations
  5. import argparse
  6. import os
  7. import pathlib
  8. from .onnx_model_utils import get_optimization_level, optimize_model
  9. def optimize_model_helper():
  10. parser = argparse.ArgumentParser(
  11. f"{os.path.basename(__file__)}:{optimize_model_helper.__name__}",
  12. description="""
  13. Optimize an ONNX model using ONNX Runtime to the specified level.
  14. See https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html for more
  15. details of the optimization levels.""",
  16. )
  17. parser.add_argument(
  18. "--opt_level",
  19. default="basic",
  20. choices=["disable", "basic", "extended", "layout", "all"],
  21. help="Optimization level to use.",
  22. )
  23. parser.add_argument(
  24. "--log_level",
  25. choices=["debug", "info", "warning", "error"],
  26. type=str,
  27. required=False,
  28. default="error",
  29. help="Log level. Defaults to Error so we don't get output about unused initializers "
  30. "being removed. Warning or Info may be desirable in some scenarios.",
  31. )
  32. parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
  33. parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write optimized ONNX model to.")
  34. args = parser.parse_args()
  35. if args.log_level == "error":
  36. log_level = 3
  37. elif args.log_level == "debug":
  38. log_level = 0 # ORT verbose level
  39. elif args.log_level == "info":
  40. log_level = 1
  41. elif args.log_level == "warning":
  42. log_level = 2
  43. optimize_model(args.input_model, args.output_model, get_optimization_level(args.opt_level), log_level)
  44. if __name__ == "__main__":
  45. optimize_model_helper()