backend_rep.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. """
  6. Implements ONNX's backend API.
  7. """
  8. from onnx.backend.base import BackendRep
  9. from onnxruntime import RunOptions
  10. class OnnxRuntimeBackendRep(BackendRep):
  11. """
  12. Computes the prediction for a pipeline converted into
  13. an :class:`onnxruntime.InferenceSession` node.
  14. """
  15. def __init__(self, session):
  16. """
  17. :param session: :class:`onnxruntime.InferenceSession`
  18. """
  19. self._session = session
  20. def run(self, inputs, **kwargs): # type: (Any, **Any) -> Tuple[Any, ...]
  21. """
  22. Computes the prediction.
  23. See :meth:`onnxruntime.InferenceSession.run`.
  24. """
  25. options = RunOptions()
  26. for k, v in kwargs.items():
  27. if hasattr(options, k):
  28. setattr(options, k, v)
  29. if isinstance(inputs, list):
  30. inps = {}
  31. for i, inp in enumerate(self._session.get_inputs()):
  32. inps[inp.name] = inputs[i]
  33. outs = self._session.run(None, inps, options)
  34. if isinstance(outs, list):
  35. return outs
  36. else:
  37. output_names = [o.name for o in self._session.get_outputs()]
  38. return [outs[name] for name in output_names]
  39. else:
  40. inp = self._session.get_inputs()
  41. if len(inp) != 1:
  42. raise RuntimeError(f"Model expect {len(inp)} inputs")
  43. inps = {inp[0].name: inputs}
  44. return self._session.run(None, inps, options)