onnx_utils.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from fusion_utils import NumpyHelper
  6. from onnx import ModelProto, TensorProto
  7. from onnx.external_data_helper import set_external_data
  8. from onnx_model import OnnxModel
  9. from onnxruntime import OrtValue
  10. def extract_raw_data_from_model(model: ModelProto):
  11. """
  12. Extract external data from model and return the external data as a list of tuples (name, value).
  13. Note this function does not handle external data that is not loaded into the model as raw data.
  14. Args:
  15. model (ModelProto): the model proto to extract external data from.
  16. Returns:
  17. (external_names, external_values): a tuple of two lists of external data names and values.
  18. """
  19. external_data = []
  20. onnx_model = OnnxModel(model)
  21. for graph in onnx_model.graphs():
  22. for initializer in graph.initializer:
  23. name = initializer.name
  24. if initializer.HasField("raw_data"):
  25. numpy_tensor = NumpyHelper.to_array(initializer)
  26. ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor)
  27. external_data.append((name, ort_value))
  28. # mimic set_external_data
  29. set_external_data(initializer, location="foo.bin")
  30. initializer.name = name
  31. initializer.ClearField("raw_data")
  32. return zip(*external_data, strict=False)
  33. def has_external_data(model: ModelProto):
  34. """
  35. Check if the model has external data.
  36. Args:
  37. model (ModelProto): the model proto to check for external data.
  38. Returns:
  39. bool: True if the model has external data, False otherwise.
  40. """
  41. onnx_model = OnnxModel(model)
  42. for graph in onnx_model.graphs():
  43. for initializer in graph.initializer:
  44. if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
  45. return True
  46. return False