make_dynamic_shape_fixed.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #!/usr/bin/env python3
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. from __future__ import annotations
  5. import argparse
  6. import os
  7. import pathlib
  8. import sys
  9. import onnx
  10. from .onnx_model_utils import fix_output_shapes, make_dim_param_fixed, make_input_shape_fixed
  11. def make_dynamic_shape_fixed_helper():
  12. parser = argparse.ArgumentParser(
  13. f"{os.path.basename(__file__)}:{make_dynamic_shape_fixed_helper.__name__}",
  14. description="""
  15. Assign a fixed value to a dim_param or input shape
  16. Provide either dim_param and dim_value or input_name and input_shape.""",
  17. )
  18. parser.add_argument(
  19. "--dim_param", type=str, required=False, help="Symbolic parameter name. Provide dim_value if specified."
  20. )
  21. parser.add_argument(
  22. "--dim_value", type=int, required=False, help="Value to replace dim_param with in the model. Must be > 0."
  23. )
  24. parser.add_argument(
  25. "--input_name",
  26. type=str,
  27. required=False,
  28. help="Model input name to replace shape of. Provide input_shape if specified.",
  29. )
  30. parser.add_argument(
  31. "--input_shape",
  32. type=lambda x: [int(i) for i in x.split(",")],
  33. required=False,
  34. help="Shape to use for input_shape. Provide comma separated list for the shape. "
  35. "All values must be > 0. e.g. --input_shape 1,3,256,256",
  36. )
  37. parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
  38. parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.")
  39. args = parser.parse_args()
  40. if (
  41. (args.dim_param and args.input_name)
  42. or (not args.dim_param and not args.input_name)
  43. or (args.dim_param and (not args.dim_value or args.dim_value < 1))
  44. or (args.input_name and (not args.input_shape or any(value < 1 for value in args.input_shape)))
  45. ):
  46. print("Invalid usage.")
  47. parser.print_help()
  48. sys.exit(-1)
  49. model = onnx.load(str(args.input_model.resolve(strict=True)))
  50. if args.dim_param:
  51. make_dim_param_fixed(model.graph, args.dim_param, args.dim_value)
  52. else:
  53. make_input_shape_fixed(model.graph, args.input_name, args.input_shape)
  54. # update the output shapes to make them fixed if possible.
  55. fix_output_shapes(model)
  56. onnx.save(model, str(args.output_model.resolve()))
  57. if __name__ == "__main__":
  58. make_dynamic_shape_fixed_helper()