onnx.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from typing import Optional, Tuple, List
  2. import torch
  3. def onnx_forward(onnx_file, example_input):
  4. import onnxruntime
  5. sess_options = onnxruntime.SessionOptions()
  6. session = onnxruntime.InferenceSession(onnx_file, sess_options)
  7. input_name = session.get_inputs()[0].name
  8. output = session.run([], {input_name: example_input.numpy()})
  9. output = output[0]
  10. return output
  11. def onnx_export(
  12. model: torch.nn.Module,
  13. output_file: str,
  14. example_input: Optional[torch.Tensor] = None,
  15. training: bool = False,
  16. verbose: bool = False,
  17. check: bool = True,
  18. check_forward: bool = False,
  19. batch_size: int = 64,
  20. input_size: Tuple[int, int, int] = None,
  21. opset: Optional[int] = None,
  22. dynamic_size: bool = False,
  23. aten_fallback: bool = False,
  24. keep_initializers: Optional[bool] = None,
  25. use_dynamo: bool = False,
  26. input_names: List[str] = None,
  27. output_names: List[str] = None,
  28. ):
  29. import onnx
  30. if training:
  31. training_mode = torch.onnx.TrainingMode.TRAINING
  32. model.train()
  33. else:
  34. training_mode = torch.onnx.TrainingMode.EVAL
  35. model.eval()
  36. if example_input is None:
  37. if not input_size:
  38. assert hasattr(model, 'default_cfg'), 'Cannot file model default config, input size must be provided'
  39. input_size = model.default_cfg.get('input_size')
  40. example_input = torch.randn((batch_size,) + input_size, requires_grad=training)
  41. # Run model once before export trace, sets padding for models with Conv2dSameExport. This means
  42. # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for
  43. # the input img_size specified in this script.
  44. # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
  45. # issues in the tracing of the dynamic padding or errors attempting to export the model after jit
  46. # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
  47. with torch.inference_mode():
  48. original_out = model(example_input)
  49. input_names = input_names or ["input0"]
  50. output_names = output_names or ["output0"]
  51. dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}}
  52. if dynamic_size:
  53. dynamic_axes['input0'][2] = 'height'
  54. dynamic_axes['input0'][3] = 'width'
  55. if aten_fallback:
  56. export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
  57. else:
  58. export_type = torch.onnx.OperatorExportTypes.ONNX
  59. if use_dynamo:
  60. export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_size)
  61. export_output = torch.onnx.dynamo_export(
  62. model,
  63. example_input,
  64. export_options=export_options,
  65. )
  66. export_output.save(output_file)
  67. else:
  68. torch.onnx.export(
  69. model,
  70. example_input,
  71. output_file,
  72. training=training_mode,
  73. export_params=True,
  74. verbose=verbose,
  75. input_names=input_names,
  76. output_names=output_names,
  77. keep_initializers_as_inputs=keep_initializers,
  78. dynamic_axes=dynamic_axes,
  79. opset_version=opset,
  80. operator_export_type=export_type
  81. )
  82. if check:
  83. onnx_model = onnx.load(output_file)
  84. onnx.checker.check_model(onnx_model, full_check=True) # assuming throw on error
  85. if check_forward and not training:
  86. import numpy as np
  87. onnx_out = onnx_forward(output_file, example_input)
  88. np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)