remove_initializer_from_input.py 994 B

12345678910111213141516171819202122232425262728293031323334353637
  1. import argparse
  2. import onnx
  3. def get_args():
  4. parser = argparse.ArgumentParser()
  5. parser.add_argument("--input", required=True, help="input model")
  6. parser.add_argument("--output", required=True, help="output model")
  7. args = parser.parse_args()
  8. return args
  9. def remove_initializer_from_input(model: onnx.ModelProto) -> bool:
  10. if model.ir_version < 4:
  11. print("Model with ir_version below 4 requires to include initializer in graph input")
  12. return False
  13. inputs = model.graph.input
  14. name_to_input = {}
  15. for input in inputs:
  16. name_to_input[input.name] = input
  17. modified = False
  18. for initializer in model.graph.initializer:
  19. if initializer.name in name_to_input:
  20. modified = True
  21. inputs.remove(name_to_input[initializer.name])
  22. return modified
  23. if __name__ == "__main__":
  24. args = get_args()
  25. model = onnx.load(args.input)
  26. remove_initializer_from_input(model)
  27. onnx.save(model, args.output)