| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- from __future__ import annotations
- import collections
- import collections.abc
- import os
- import typing
- import warnings
- from collections.abc import Callable, Sequence
- from typing import Any
- from onnxruntime.capi import _pybind_state as C
- if typing.TYPE_CHECKING:
- import numpy as np
- import numpy.typing as npt
- import onnxruntime
- def get_ort_device_type(device_type: str) -> int:
- if device_type == "cuda":
- return C.OrtDevice.cuda()
- elif device_type == "cann":
- return C.OrtDevice.cann()
- elif device_type == "cpu":
- return C.OrtDevice.cpu()
- elif device_type == "dml":
- return C.OrtDevice.dml()
- elif device_type == "webgpu":
- return C.OrtDevice.webgpu()
- elif device_type == "gpu":
- return C.OrtDevice.gpu()
- elif device_type == "npu":
- return C.OrtDevice.npu()
- else:
- raise Exception("Unsupported device type: " + device_type)
- class AdapterFormat:
- """
- This class is used to create adapter files from python structures
- """
- def __init__(self, adapter=None) -> None:
- if adapter is None:
- self._adapter = C.AdapterFormat()
- else:
- self._adapter = adapter
- @staticmethod
- def read_adapter(file_path: os.PathLike) -> AdapterFormat:
- return AdapterFormat(C.AdapterFormat.read_adapter(file_path))
- def export_adapter(self, file_path: os.PathLike):
- """
- This function writes a file at the specified location
- in onnxrunitme adapter format containing Lora parameters.
- :param file_path: absolute path for the adapter
- """
- self._adapter.export_adapter(file_path)
- def get_format_version(self) -> int:
- return self._adapter.format_version
- def set_adapter_version(self, adapter_version: int) -> None:
- self._adapter.adapter_version = adapter_version
- def get_adapter_version(self) -> int:
- return self._adapter.adapter_version
- def set_model_version(self, model_version: int) -> None:
- self._adapter.model_version = model_version
- def get_model_version(self) -> int:
- return self._adapter.model_version
- def set_parameters(self, params: dict[str, OrtValue]) -> None:
- self._adapter.parameters = {k: v._ortvalue for k, v in params.items()}
- def get_parameters(self) -> dict[str, OrtValue]:
- return {k: OrtValue(v) for k, v in self._adapter.parameters.items()}
- def check_and_normalize_provider_args(
- providers: Sequence[str | tuple[str, dict[Any, Any]]] | None,
- provider_options: Sequence[dict[Any, Any]] | None,
- available_provider_names: Sequence[str],
- ):
- """
- Validates the 'providers' and 'provider_options' arguments and returns a
- normalized version.
- :param providers: Optional sequence of providers in order of decreasing
- precedence. Values can either be provider names or tuples of
- (provider name, options dict).
- :param provider_options: Optional sequence of options dicts corresponding
- to the providers listed in 'providers'.
- :param available_provider_names: The available provider names.
- :return: Tuple of (normalized 'providers' sequence, normalized
- 'provider_options' sequence).
- 'providers' can contain either names or names and options. When any options
- are given in 'providers', 'provider_options' should not be used.
- The normalized result is a tuple of:
- 1. Sequence of provider names in the same order as 'providers'.
- 2. Sequence of corresponding provider options dicts with string keys and
- values. Unspecified provider options yield empty dicts.
- """
- if providers is None:
- return [], []
- provider_name_to_options = collections.OrderedDict()
- def set_provider_options(name, options):
- if name not in available_provider_names:
- warnings.warn(
- "Specified provider '{}' is not in available provider names.Available providers: '{}'".format(
- name, ", ".join(available_provider_names)
- )
- )
- if name in provider_name_to_options:
- warnings.warn(f"Duplicate provider '{name}' encountered, ignoring.")
- return
- normalized_options = {str(key): str(value) for key, value in options.items()}
- provider_name_to_options[name] = normalized_options
- if not isinstance(providers, collections.abc.Sequence):
- raise ValueError("'providers' should be a sequence.")
- if provider_options is not None:
- if not isinstance(provider_options, collections.abc.Sequence):
- raise ValueError("'provider_options' should be a sequence.")
- if len(providers) != len(provider_options):
- raise ValueError("'providers' and 'provider_options' should be the same length if both are given.")
- if not all(isinstance(provider, str) for provider in providers):
- raise ValueError("Only string values for 'providers' are supported if 'provider_options' is given.")
- if not all(isinstance(options_for_provider, dict) for options_for_provider in provider_options):
- raise ValueError("'provider_options' values must be dicts.")
- for name, options in zip(providers, provider_options, strict=False):
- set_provider_options(name, options)
- else:
- for provider in providers:
- if isinstance(provider, str):
- set_provider_options(provider, {})
- elif (
- isinstance(provider, tuple)
- and len(provider) == 2
- and isinstance(provider[0], str)
- and isinstance(provider[1], dict)
- ):
- set_provider_options(provider[0], provider[1])
- else:
- raise ValueError("'providers' values must be either strings or (string, dict) tuples.")
- return list(provider_name_to_options.keys()), list(provider_name_to_options.values())
- class Session:
- """
- This is the main class used to run a model.
- """
- def __init__(self, enable_fallback: bool = True):
- # self._sess is managed by the derived class and relies on bindings from C.InferenceSession
- self._sess = None
- self._enable_fallback = enable_fallback
- def get_session_options(self) -> onnxruntime.SessionOptions:
- "Return the session options. See :class:`onnxruntime.SessionOptions`."
- return self._sess_options
- def get_inputs(self) -> Sequence[onnxruntime.NodeArg]:
- "Return the inputs metadata as a list of :class:`onnxruntime.NodeArg`."
- return self._inputs_meta
- def get_outputs(self) -> Sequence[onnxruntime.NodeArg]:
- "Return the outputs metadata as a list of :class:`onnxruntime.NodeArg`."
- return self._outputs_meta
- def get_overridable_initializers(self) -> Sequence[onnxruntime.NodeArg]:
- "Return the inputs (including initializers) metadata as a list of :class:`onnxruntime.NodeArg`."
- return self._overridable_initializers
- def get_modelmeta(self) -> onnxruntime.ModelMetadata:
- "Return the metadata. See :class:`onnxruntime.ModelMetadata`."
- return self._model_meta
- def get_input_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]:
- "Return the memory info for the inputs."
- return self._input_meminfos
- def get_output_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]:
- "Return the memory info for the outputs."
- return self._output_meminfos
- def get_input_epdevices(self) -> Sequence[onnxruntime.OrtEpDevice]:
- "Return the execution providers for the inputs."
- return self._input_epdevices
- def get_providers(self) -> Sequence[str]:
- "Return list of registered execution providers."
- return self._providers
- def get_provider_options(self):
- "Return registered execution providers' configurations."
- return self._provider_options
- def get_provider_graph_assignment_info(self) -> Sequence[onnxruntime.OrtEpAssignedSubgraph]:
- """
- Get information about the subgraphs assigned to each execution provider and the nodes within.
- Application must enable the recording of graph assignment information by setting the session configuration
- for the key "session.record_ep_graph_assignment_info" to "1".
- """
- return self._sess.get_provider_graph_assignment_info()
- def set_providers(self, providers=None, provider_options=None) -> None:
- """
- Register the input list of execution providers. The underlying session is re-created.
- :param providers: Optional sequence of providers in order of decreasing
- precedence. Values can either be provider names or tuples of
- (provider name, options dict). If not provided, then all available
- providers are used with the default precedence.
- :param provider_options: Optional sequence of options dicts corresponding
- to the providers listed in 'providers'.
- 'providers' can contain either names or names and options. When any options
- are given in 'providers', 'provider_options' should not be used.
- The list of providers is ordered by precedence. For example
- `['CUDAExecutionProvider', 'CPUExecutionProvider']`
- means execute a node using CUDAExecutionProvider if capable,
- otherwise execute using CPUExecutionProvider.
- """
- # recreate the underlying C.InferenceSession
- self._reset_session(providers, provider_options)
- def disable_fallback(self) -> None:
- """
- Disable session.run() fallback mechanism.
- """
- self._enable_fallback = False
- def enable_fallback(self) -> None:
- """
- Enable session.Run() fallback mechanism. If session.Run() fails due to an internal Execution Provider failure,
- reset the Execution Providers enabled for this session.
- If GPU is enabled, fall back to CUDAExecutionProvider.
- otherwise fall back to CPUExecutionProvider.
- """
- self._enable_fallback = True
- def _validate_input(self, feed_input_names):
- missing_input_names = []
- for input in self._inputs_meta:
- if input.name not in feed_input_names and not input.type.startswith("optional"):
- missing_input_names.append(input.name)
- if missing_input_names:
- raise ValueError(
- f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names})."
- )
- def run(self, output_names, input_feed, run_options=None) -> Sequence[np.ndarray | SparseTensor | list | dict]:
- """
- Compute the predictions.
- :param output_names: name of the outputs
- :param input_feed: dictionary ``{ input_name: input_value }``
- :param run_options: See :class:`onnxruntime.RunOptions`.
- :return: list of results, every result is either a numpy array,
- a sparse tensor, a list or a dictionary.
- ::
- sess.run([output_name], {input_name: x})
- """
- self._validate_input(list(input_feed.keys()))
- if not output_names:
- output_names = [output.name for output in self._outputs_meta]
- try:
- return self._sess.run(output_names, input_feed, run_options)
- except C.EPFail as err:
- if self._enable_fallback:
- print(f"EP Error: {err!s} using {self._providers}")
- print(f"Falling back to {self._fallback_providers} and retrying.")
- self.set_providers(self._fallback_providers)
- # Fallback only once.
- self.disable_fallback()
- return self._sess.run(output_names, input_feed, run_options)
- raise
- def run_async(self, output_names, input_feed, callback, user_data, run_options=None):
- """
- Compute the predictions asynchronously in a separate cxx thread from ort intra-op threadpool.
- :param output_names: name of the outputs
- :param input_feed: dictionary ``{ input_name: input_value }``
- :param callback: python function that accept array of results, and a status string on error.
- The callback will be invoked by a cxx thread from ort intra-op threadpool.
- :param run_options: See :class:`onnxruntime.RunOptions`.
- ::
- class MyData:
- def __init__(self):
- # ...
- def save_results(self, results):
- # ...
- def callback(results: np.ndarray, user_data: MyData, err: str) -> None:
- if err:
- print (err)
- else:
- # save results to user_data
- sess.run_async([output_name], {input_name: x}, callback)
- """
- self._validate_input(list(input_feed.keys()))
- if not output_names:
- output_names = [output.name for output in self._outputs_meta]
- return self._sess.run_async(output_names, input_feed, callback, user_data, run_options)
- def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=None) -> Sequence[OrtValue]:
- """
- Compute the predictions.
- :param output_names: name of the outputs
- :param input_dict_ort_values: dictionary ``{ input_name: input_ort_value }``
- See ``OrtValue`` class how to create `OrtValue`
- from numpy array or `SparseTensor`
- :param run_options: See :class:`onnxruntime.RunOptions`.
- :return: an array of `OrtValue`
- ::
- sess.run([output_name], {input_name: x})
- """
- def invoke(sess, output_names, input_dict_ort_values, run_options):
- input_dict = {}
- for n, v in input_dict_ort_values.items():
- input_dict[n] = v._get_c_value()
- result = sess.run_with_ort_values(input_dict, output_names, run_options)
- if not isinstance(result, C.OrtValueVector):
- raise TypeError("run_with_ort_values() must return a instance of type 'OrtValueVector'.")
- ort_values = [OrtValue(v) for v in result]
- return ort_values
- self._validate_input(list(input_dict_ort_values.keys()))
- if not output_names:
- output_names = [output.name for output in self._outputs_meta]
- try:
- return invoke(self._sess, output_names, input_dict_ort_values, run_options)
- except C.EPFail as err:
- if self._enable_fallback:
- print(f"EP Error: {err!s} using {self._providers}")
- print(f"Falling back to {self._fallback_providers} and retrying.")
- self.set_providers(self._fallback_providers)
- # Fallback only once.
- self.disable_fallback()
- return invoke(self._sess, output_names, input_dict_ort_values, run_options)
- raise
- def end_profiling(self):
- """
- End profiling and return results in a file.
- The results are stored in a filename if the option
- :meth:`onnxruntime.SessionOptions.enable_profiling`.
- """
- return self._sess.end_profiling()
- def get_profiling_start_time_ns(self):
- """
- Return the nanoseconds of profiling's start time
- Comparable to time.monotonic_ns() after Python 3.3
- On some platforms, this timer may not be as precise as nanoseconds
- For instance, on Windows and MacOS, the precision will be ~100ns
- """
- return self._sess.get_profiling_start_time_ns
- def io_binding(self) -> IOBinding:
- "Return an onnxruntime.IOBinding object`."
- return IOBinding(self)
- def run_with_iobinding(self, iobinding, run_options=None):
- """
- Compute the predictions.
- :param iobinding: the iobinding object that has graph inputs/outputs bind.
- :param run_options: See :class:`onnxruntime.RunOptions`.
- """
- self._sess.run_with_iobinding(iobinding._iobinding, run_options)
- def set_ep_dynamic_options(self, options: dict[str, str]):
- """
- Set dynamic options for execution providers.
- :param options: Dictionary of key-value pairs where both keys and values are strings.
- These options will be passed to the execution providers to modify
- their runtime behavior.
- """
- self._sess.set_ep_dynamic_options(options)
- def get_tuning_results(self):
- return self._sess.get_tuning_results()
- def set_tuning_results(self, results, *, error_on_invalid=False):
- return self._sess.set_tuning_results(results, error_on_invalid)
- def run_with_ortvaluevector(self, run_options, feed_names, feeds, fetch_names, fetches, fetch_devices):
- """
- Compute the predictions similar to other run_*() methods but with minimal C++/Python conversion overhead.
- :param run_options: See :class:`onnxruntime.RunOptions`.
- :param feed_names: list of input names.
- :param feeds: list of input OrtValue.
- :param fetch_names: list of output names.
- :param fetches: list of output OrtValue.
- :param fetch_devices: list of output devices.
- """
- self._sess.run_with_ortvaluevector(run_options, feed_names, feeds, fetch_names, fetches, fetch_devices)
- class InferenceSession(Session):
- """
- This is the main class used to run a model.
- """
- def __init__(
- self,
- path_or_bytes: str | bytes | os.PathLike,
- sess_options: onnxruntime.SessionOptions | None = None,
- providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
- provider_options: Sequence[dict[Any, Any]] | None = None,
- **kwargs,
- ) -> None:
- """
- :param path_or_bytes: Filename or serialized ONNX or ORT format model in a byte string.
- :param sess_options: Session options.
- :param providers: Optional sequence of providers in order of decreasing
- precedence. Values can either be provider names or tuples of
- (provider name, options dict). If not provided, then all available
- providers are used with the default precedence.
- :param provider_options: Optional sequence of options dicts corresponding
- to the providers listed in 'providers'.
- The model type will be inferred unless explicitly set in the SessionOptions.
- To explicitly set:
- ::
- so = onnxruntime.SessionOptions()
- # so.add_session_config_entry('session.load_model_format', 'ONNX') or
- so.add_session_config_entry('session.load_model_format', 'ORT')
- A file extension of '.ort' will be inferred as an ORT format model.
- All other filenames are assumed to be ONNX format models.
- 'providers' can contain either names or names and options. When any options
- are given in 'providers', 'provider_options' should not be used.
- The list of providers is ordered by precedence. For example
- `['CUDAExecutionProvider', 'CPUExecutionProvider']`
- means execute a node using `CUDAExecutionProvider`
- if capable, otherwise execute using `CPUExecutionProvider`.
- """
- super().__init__(enable_fallback=int(kwargs.get("enable_fallback", 1)) == 1)
- if isinstance(path_or_bytes, (str, os.PathLike)):
- self._model_path = os.fspath(path_or_bytes)
- self._model_bytes = None
- elif isinstance(path_or_bytes, bytes):
- self._model_path = None
- self._model_bytes = path_or_bytes # TODO: This is bad as we're holding the memory indefinitely
- else:
- raise TypeError(f"Unable to load from type '{type(path_or_bytes)}'")
- self._sess_options = sess_options
- self._sess_options_initial = sess_options
- if "read_config_from_model" in kwargs:
- self._read_config_from_model = int(kwargs["read_config_from_model"]) == 1
- else:
- self._read_config_from_model = os.environ.get("ORT_LOAD_CONFIG_FROM_MODEL") == "1"
- # internal parameters that we don't expect to be used in general so aren't documented
- disabled_optimizers = kwargs.get("disabled_optimizers")
- try:
- self._create_inference_session(providers, provider_options, disabled_optimizers)
- except (ValueError, RuntimeError) as e:
- if self._enable_fallback:
- try:
- print("*************** EP Error ***************")
- print(f"EP Error {e} when using {providers}")
- print(f"Falling back to {self._fallback_providers} and retrying.")
- print("****************************************")
- self._create_inference_session(self._fallback_providers, None)
- # Fallback only once.
- self.disable_fallback()
- return
- except Exception as fallback_error:
- raise fallback_error from e
- # Fallback is disabled. Raise the original error.
- raise e
- def _create_inference_session(self, providers, provider_options, disabled_optimizers=None):
- available_providers = C.get_available_providers()
- # Validate that TensorrtExecutionProvider and NvTensorRTRTXExecutionProvider are not both specified
- if providers:
- has_tensorrt = any(
- provider == "TensorrtExecutionProvider"
- or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
- for provider in providers
- )
- has_tensorrt_rtx = any(
- provider == "NvTensorRTRTXExecutionProvider"
- or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider")
- for provider in providers
- )
- if has_tensorrt and has_tensorrt_rtx:
- raise ValueError(
- "Cannot enable both 'TensorrtExecutionProvider' and 'NvTensorRTRTXExecutionProvider' "
- "in the same session."
- )
- # Tensorrt and TensorRT RTX can fall back to CUDA if it's explicitly assigned. All others fall back to CPU.
- if "NvTensorRTRTXExecutionProvider" in available_providers:
- if (
- providers
- and any(
- provider == "CUDAExecutionProvider"
- or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider")
- for provider in providers
- )
- and any(
- provider == "NvTensorRTRTXExecutionProvider"
- or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider")
- for provider in providers
- )
- ):
- self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
- else:
- self._fallback_providers = ["CPUExecutionProvider"]
- elif "TensorrtExecutionProvider" in available_providers:
- if (
- providers
- and any(
- provider == "CUDAExecutionProvider"
- or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider")
- for provider in providers
- )
- and any(
- provider == "TensorrtExecutionProvider"
- or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
- for provider in providers
- )
- ):
- self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
- else:
- self._fallback_providers = ["CPUExecutionProvider"]
- else:
- self._fallback_providers = ["CPUExecutionProvider"]
- # validate providers and provider_options before other initialization
- providers, provider_options = check_and_normalize_provider_args(
- providers, provider_options, available_providers
- )
- # Print a warning if user passed providers to InferenceSession() but the SessionOptions instance
- # already has provider information (e.g., via add_provider_for_devices()). The providers specified
- # here will take precedence.
- if self._sess_options is not None and (providers or provider_options) and self._sess_options.has_providers():
- warnings.warn(
- "Specified 'providers'/'provider_options' when creating InferenceSession but SessionOptions has "
- "already been configured with providers. InferenceSession will only use the providers "
- "passed to InferenceSession()."
- )
- session_options = self._sess_options if self._sess_options else C.get_default_session_options()
- self._register_ep_custom_ops(session_options, providers, provider_options, available_providers)
- if self._model_path:
- sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
- else:
- sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
- if disabled_optimizers is None:
- disabled_optimizers = set()
- elif not isinstance(disabled_optimizers, set):
- # convert to set. assumes iterable
- disabled_optimizers = set(disabled_optimizers)
- # initialize the C++ InferenceSession
- sess.initialize_session(providers, provider_options, disabled_optimizers)
- self._sess = sess
- self._sess_options = self._sess.session_options
- self._inputs_meta = self._sess.inputs_meta
- self._outputs_meta = self._sess.outputs_meta
- self._overridable_initializers = self._sess.overridable_initializers
- self._input_meminfos = self._sess.input_meminfos
- self._output_meminfos = self._sess.output_meminfos
- self._input_epdevices = self._sess.input_epdevices
- self._model_meta = self._sess.model_meta
- self._providers = self._sess.get_providers()
- self._provider_options = self._sess.get_provider_options()
- self._profiling_start_time_ns = self._sess.get_profiling_start_time_ns
- def _reset_session(self, providers, provider_options) -> None:
- "release underlying session object."
- # meta data references session internal structures
- # so they must be set to None to decrement _sess reference count.
- self._sess_options = None
- self._inputs_meta = None
- self._outputs_meta = None
- self._overridable_initializers = None
- self._input_meminfos = None
- self._output_meminfos = None
- self._input_epdevices = None
- self._model_meta = None
- self._providers = None
- self._provider_options = None
- self._profiling_start_time_ns = None
- # create a new C.InferenceSession
- self._sess = None
- self._sess_options = self._sess_options_initial
- self._create_inference_session(providers, provider_options)
- def _register_ep_custom_ops(self, session_options, providers, provider_options, available_providers):
- for i in range(len(providers)):
- if providers[i] in available_providers and providers[i] == "TensorrtExecutionProvider":
- C.register_tensorrt_plugins_as_custom_ops(session_options, provider_options[i])
- elif (
- isinstance(providers[i], tuple)
- and providers[i][0] in available_providers
- and providers[i][0] == "TensorrtExecutionProvider"
- ):
- C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1])
- if providers[i] in available_providers and providers[i] == "NvTensorRTRTXExecutionProvider":
- C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, provider_options[i])
- elif (
- isinstance(providers[i], tuple)
- and providers[i][0] in available_providers
- and providers[i][0] == "NvTensorrtRTXExecutionProvider"
- ):
- C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, providers[i][1])
- def make_get_initializer_location_func_wrapper(
- get_initializer_location_func: GetInitializerLocationFunc,
- ) -> GetInitializerLocationWrapperFunc:
- """
- Wraps a user's "get initializer location" function. The returned wrapper function adheres to the
- signature expected by ORT.
- Need this wrapper to:
- - Convert the `initializer_value` parameter from `C.OrtValue` to `onnxruntime.OrtValue`, which is more
- convenient for the user's function to use.
- - Allow the user's function to return the original `external_info` parameter (this wrapper makes a copy)
- """
- def get_initializer_location_func_wrapper(
- initializer_name: str,
- initializer_value: C.OrtValue,
- external_info: C.OrtExternalInitializerInfo | None,
- ) -> C.OrtExternalInitializerInfo | None:
- ret_val: C.OrtExternalInitializerInfo | None = get_initializer_location_func(
- initializer_name, OrtValue(initializer_value), external_info
- )
- if ret_val is not None and ret_val == external_info:
- # User returned `external_info` (const and owned by ORT). ORT expects the returned value to be
- # a new instance (that it deletes), so make a copy.
- ret_val = C.OrtExternalInitializerInfo(ret_val.filepath, ret_val.file_offset, ret_val.byte_size)
- return ret_val
- return get_initializer_location_func_wrapper
- class ModelCompiler:
- """
- This class is used to compile an ONNX model. A compiled ONNX model has EPContext nodes that each
- encapsulates a subgraph compiled/optimized for a specific execution provider.
- Refer to the EPContext design document for more information about EPContext models:
- https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html
- ::
- sess_options = onnxruntime.SessionOptions()
- sess_options.add_provider("SomeExecutionProvider", {"option1": "value1"})
- # Alternatively, allow ONNX Runtime to select the provider automatically given a policy:
- # sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_NPU)
- model_compiler = onnxruntime.ModelCompiler(sess_options, "input_model.onnx")
- model_compiler.compile_to_file("output_model.onnx")
- """
- def __init__(
- self,
- sess_options: onnxruntime.SessionOptions,
- input_model_path_or_bytes: str | os.PathLike | bytes,
- embed_compiled_data_into_model: bool = False,
- external_initializers_file_path: str | os.PathLike | None = None,
- external_initializers_size_threshold: int = 1024,
- flags: int = C.OrtCompileApiFlags.NONE,
- graph_optimization_level: C.GraphOptimizationLevel = C.GraphOptimizationLevel.ORT_DISABLE_ALL,
- get_initializer_location_func: GetInitializerLocationFunc | None = None,
- ):
- """
- Creates a ModelCompiler instance.
- :param sess_options: Session options containing the providers for which the model will be compiled.
- Refer to SessionOptions.add_provider() and SessionOptions.set_provider_selection_policy().
- :param input_model_path_or_bytes: The path to the input model file or bytes representing a serialized
- ONNX model.
- :param embed_compiled_data_into_model: Defaults to False. Set to True to embed compiled binary data into
- EPContext nodes in the compiled model.
- :param external_initializers_file_path: Defaults to None. Set to a path for a file that will store the
- initializers for non-compiled nodes.
- :param external_initializers_size_threshold: Defaults to 1024. Ignored if `external_initializers_file_path`
- is None or empty. Initializers larger than this threshold are stored in the external initializers file.
- :param flags: Additional boolean options to enable. Set this parameter to a bitwise OR of
- flags in onnxruntime.OrtCompileApiFlags.
- :param graph_optimization_level: The graph optimization level.
- Defaults to onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL.
- :param get_initializer_location_func: Optional function called for every initializer to allow user to specify
- whether an initializer should be stored within the model or externally. Example:
- ```
- def get_initializer_location(
- initializer_name: str,
- initializer_value: onnxrt.OrtValue,
- external_info: onnxrt.OrtExternalInitializerInfo | None,
- ) -> onnxrt.OrtExternalInitializerInfo | None:
- byte_size = initializer_value.tensor_size_in_bytes()
- if byte_size < 64:
- return None # Store small initializer within compiled model.
- # Else, write initializer to new external file.
- value_np = initializer_value.numpy()
- file_offset = ext_init_file.tell()
- ext_init_file.write(value_np.tobytes())
- return onnxrt.OrtExternalInitializerInfo(initializer_file_path, file_offset, byte_size)
- ```
- """
- input_model_path: str | os.PathLike | None = None
- input_model_bytes: bytes | None = None
- if isinstance(input_model_path_or_bytes, (str, os.PathLike)):
- if not input_model_path_or_bytes:
- raise ValueError("Input model path is empty")
- input_model_path = os.fspath(input_model_path_or_bytes)
- elif isinstance(input_model_path_or_bytes, bytes):
- if len(input_model_path_or_bytes) == 0:
- raise ValueError("Input model bytes array is empty")
- input_model_bytes = input_model_path_or_bytes
- else:
- raise TypeError(f"Unable to load from type '{type(input_model_path_or_bytes)}'")
- if external_initializers_file_path:
- if not isinstance(external_initializers_file_path, (str, os.PathLike)):
- arg_type = type(external_initializers_file_path)
- raise TypeError(f"Output external initializer filepath is of unexpected type '{arg_type}'")
- external_initializers_file_path = os.fspath(external_initializers_file_path)
- else:
- external_initializers_file_path = ""
- if get_initializer_location_func is not None:
- if external_initializers_file_path:
- raise ValueError(
- "Cannot initialize ModelCompiler with both `external_initializers_file_path` "
- "and `get_initializer_location_func`"
- )
- self.get_initializer_location_func_wrapper = make_get_initializer_location_func_wrapper(
- get_initializer_location_func
- )
- else:
- self.get_initializer_location_func_wrapper = None
- if input_model_path:
- self._model_compiler = C.ModelCompiler(
- sess_options,
- input_model_path,
- True, # is path
- embed_compiled_data_into_model,
- external_initializers_file_path,
- external_initializers_size_threshold,
- flags,
- graph_optimization_level,
- self.get_initializer_location_func_wrapper,
- )
- else:
- self._model_compiler = C.ModelCompiler(
- sess_options,
- input_model_bytes,
- False, # is bytes
- embed_compiled_data_into_model,
- external_initializers_file_path,
- external_initializers_size_threshold,
- flags,
- graph_optimization_level,
- self.get_initializer_location_func_wrapper,
- )
- def compile_to_file(self, output_model_path: str | None = None):
- """
- Compiles to an output file. If an output file path is not provided,
- the output file path is generated based on the input model path by replacing
- '.onnx' with '_ctx.onnx'. Ex: The generated output file is 'model_ctx.onnx' for
- an input model with path 'model.onnx'.
- Raises an 'InvalidArgument' exception if the compilation options are invalid.
- :param output_model_path: Defaults to None. The path for the output/compiled model.
- """
- if output_model_path:
- if not isinstance(output_model_path, (str, os.PathLike)):
- raise TypeError(f"Output model's filepath is of unexpected type '{type(output_model_path)}'")
- output_model_path = os.fspath(output_model_path)
- self._model_compiler.compile_to_file(output_model_path)
- def compile_to_bytes(self) -> bytes:
- """
- Compiles to bytes representing the serialized compiled ONNX model.
- Raises an 'InvalidArgument' exception if the compilation options are invalid.
- :return: A bytes object representing the compiled ONNX model.
- """
- return self._model_compiler.compile_to_bytes()
- def compile_to_stream(self, write_function: Callable[[bytes], None]):
- """
- Compiles the input model and writes the serialized ONNX bytes to a stream using the provided write function.
- Raises an 'InvalidArgument' exception if the compilation options are invalid.
- :param write_function: A callable that accepts a bytes buffer to write.
- """
- self._model_compiler.compile_to_stream(write_function)
- class IOBinding:
- """
- This class provides API to bind input/output to a specified device, e.g. GPU.
- """
- def __init__(self, session: Session):
- self._iobinding = C.SessionIOBinding(session._sess)
- self._numpy_obj_references = {}
- def bind_cpu_input(self, name, arr_on_cpu):
- """
- bind an input to array on CPU
- :param name: input name
- :param arr_on_cpu: input values as a python array on CPU
- """
- # Hold a reference to the numpy object as the bound OrtValue is backed
- # directly by the data buffer of the numpy object and so the numpy object
- # must be around until this IOBinding instance is around
- self._numpy_obj_references[name] = arr_on_cpu
- self._iobinding.bind_input(name, arr_on_cpu)
- def bind_input(self, name, device_type, device_id, element_type, shape, buffer_ptr):
- """
- :param name: input name
- :param device_type: e.g. cpu, cuda, cann
- :param device_id: device id, e.g. 0
- :param element_type: input element type. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16)
- :param shape: input shape
- :param buffer_ptr: memory pointer to input data
- """
- self._iobinding.bind_input(
- name,
- C.OrtDevice(
- get_ort_device_type(device_type),
- C.OrtDevice.default_memory(),
- device_id,
- ),
- element_type,
- shape,
- buffer_ptr,
- )
- def bind_ortvalue_input(self, name, ortvalue):
- """
- :param name: input name
- :param ortvalue: OrtValue instance to bind
- """
- self._iobinding.bind_ortvalue_input(name, ortvalue._ortvalue)
- def synchronize_inputs(self):
- self._iobinding.synchronize_inputs()
- def bind_output(
- self,
- name,
- device_type="cpu",
- device_id=0,
- element_type=None,
- shape=None,
- buffer_ptr=None,
- ):
- """
- :param name: output name
- :param device_type: e.g. cpu, cuda, cann, cpu by default
- :param device_id: device id, e.g. 0
- :param element_type: output element type. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16)
- :param shape: output shape
- :param buffer_ptr: memory pointer to output data
- """
- # Follow the `if` path when the user has not provided any pre-allocated buffer but still
- # would like to bind an output to a specific device (e.g. cuda).
- # Pre-allocating an output buffer may not be an option for the user as :
- # (1) They may not want to use a custom allocator specific to the device they want to bind the output to,
- # in which case ORT will allocate the memory for the user
- # (2) The output has a dynamic shape and hence the size of the buffer may not be fixed across runs
- if buffer_ptr is None:
- self._iobinding.bind_output(
- name,
- C.OrtDevice(
- get_ort_device_type(device_type),
- C.OrtDevice.default_memory(),
- device_id,
- ),
- )
- else:
- if element_type is None or shape is None:
- raise ValueError("`element_type` and `shape` are to be provided if pre-allocated memory is provided")
- self._iobinding.bind_output(
- name,
- C.OrtDevice(
- get_ort_device_type(device_type),
- C.OrtDevice.default_memory(),
- device_id,
- ),
- element_type,
- shape,
- buffer_ptr,
- )
- def bind_ortvalue_output(self, name, ortvalue):
- """
- :param name: output name
- :param ortvalue: OrtValue instance to bind
- """
- self._iobinding.bind_ortvalue_output(name, ortvalue._ortvalue)
- def synchronize_outputs(self):
- self._iobinding.synchronize_outputs()
- def get_outputs(self):
- """
- Returns the output OrtValues from the Run() that preceded the call.
- The data buffer of the obtained OrtValues may not reside on CPU memory
- """
- outputs = self._iobinding.get_outputs()
- if not isinstance(outputs, C.OrtValueVector):
- raise TypeError("get_outputs() must return an instance of type 'OrtValueVector'.")
- return [OrtValue(ortvalue) for ortvalue in outputs]
- def get_outputs_as_ortvaluevector(self):
- return self._iobinding.get_outputs()
- def copy_outputs_to_cpu(self):
- """Copy output contents to CPU."""
- return self._iobinding.copy_outputs_to_cpu()
- def clear_binding_inputs(self):
- self._iobinding.clear_binding_inputs()
- def clear_binding_outputs(self):
- self._iobinding.clear_binding_outputs()
- class OrtValue:
- """
- A data structure that supports all ONNX data formats (tensors and non-tensors) that allows users
- to place the data backing these on a device, for example, on a CUDA supported device.
- This class provides APIs to construct and deal with OrtValues.
- """
- def __init__(self, ortvalue: C.OrtValue, numpy_obj: np.ndarray | None = None):
- if isinstance(ortvalue, C.OrtValue):
- self._ortvalue = ortvalue
- # Hold a ref count to the numpy object if the OrtValue is backed directly
- # by its data buffer so that it isn't destroyed when the OrtValue is in use
- self._numpy_obj = numpy_obj
- else:
- # An end user won't hit this error
- raise ValueError(
- "`Provided ortvalue` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtValue`"
- )
- def _get_c_value(self) -> C.OrtValue:
- return self._ortvalue
- @classmethod
- def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0, vendor_id=-1) -> OrtValue:
- """
- Factory method to construct an OrtValue (which holds a Tensor) from a given Numpy object
- A copy of the data in the Numpy object is held by the OrtValue only if the device is NOT cpu
- :param numpy_obj: The Numpy object to construct the OrtValue from
- :param device_type: e.g. cpu, cuda, cann, cpu by default
- :param device_id: device id, e.g. 0
- :param vendor_id: The device's PCI vendor id. If provided, the device_type should be "gpu" or "npu".
- """
- # Hold a reference to the numpy object (if device_type is 'cpu') as the OrtValue
- # is backed directly by the data buffer of the numpy object and so the numpy object
- # must be around until this OrtValue instance is around
- return cls(
- C.OrtValue.ortvalue_from_numpy(
- numpy_obj,
- OrtDevice.make(device_type, device_id, vendor_id)._get_c_device(),
- ),
- numpy_obj if device_type.lower() == "cpu" else None,
- )
- @classmethod
- def ortvalue_from_numpy_with_onnx_type(cls, data: np.ndarray, /, onnx_element_type: int) -> OrtValue:
- """
- This method creates an instance of OrtValue on top of the numpy array.
- No data copy is made and the lifespan of the resulting OrtValue should never
- exceed the lifespan of bytes object. The API attempts to reinterpret
- the data type which is expected to be the same size. This is useful
- when we want to use an ONNX data type that is not supported by numpy.
- :param data: numpy.ndarray.
- :param onnx_element_type: a valid onnx TensorProto::DataType enum value
- """
- return cls(C.OrtValue.ortvalue_from_numpy_with_onnx_type(data, onnx_element_type), data)
- @classmethod
- def ortvalue_from_shape_and_type(
- cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0, vendor_id: int = -1
- ) -> OrtValue:
- """
- Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type
- :param shape: List of integers indicating the shape of the OrtValue
- :param element_type: The data type of the elements. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16).
- :param device_type: e.g. cpu, cuda, cann, cpu by default
- :param device_id: device id, e.g. 0
- :param vendor_id: If provided the device type should be "gpu" or "npu".
- """
- device = OrtDevice.make(device_type, device_id, vendor_id)._get_c_device()
- # Integer for onnx element type (see https://onnx.ai/onnx/api/mapping.html).
- # This is helpful for some data type (like TensorProto.BFLOAT16) that is not available in numpy.
- if isinstance(element_type, int):
- return cls(
- C.OrtValue.ortvalue_from_shape_and_onnx_type(
- shape,
- element_type,
- device,
- )
- )
- return cls(
- C.OrtValue.ortvalue_from_shape_and_type(
- shape,
- element_type,
- device,
- )
- )
- @classmethod
- def ort_value_from_sparse_tensor(cls, sparse_tensor: SparseTensor) -> OrtValue:
- """
- The function will construct an OrtValue instance from a valid SparseTensor
- The new instance of OrtValue will assume the ownership of sparse_tensor
- """
- return cls(C.OrtValue.ort_value_from_sparse_tensor(sparse_tensor._get_c_tensor()))
- def as_sparse_tensor(self) -> SparseTensor:
- """
- The function will return SparseTensor contained in this OrtValue
- """
- return SparseTensor(self._ortvalue.as_sparse_tensor())
- def data_ptr(self) -> int:
- """
- Returns the address of the first element in the OrtValue's data buffer
- """
- return self._ortvalue.data_ptr()
- def device_name(self) -> str:
- """
- Returns the name of the device where the OrtValue's data buffer resides e.g. cpu, cuda, cann
- """
- return self._ortvalue.device_name().lower()
- def shape(self) -> Sequence[int]:
- """
- Returns the shape of the data in the OrtValue
- """
- return self._ortvalue.shape()
- def data_type(self) -> str:
- """
- Returns the data type of the data in the OrtValue. E.g. 'tensor(int64)'
- """
- return self._ortvalue.data_type()
- def element_type(self) -> int:
- """
- Returns the proto type of the data in the OrtValue
- if the OrtValue is a tensor.
- """
- return self._ortvalue.element_type()
- def tensor_size_in_bytes(self) -> int:
- """
- Returns the size of the data in the OrtValue in bytes
- if the OrtValue is a tensor.
- """
- return self._ortvalue.tensor_size_in_bytes()
- def has_value(self) -> bool:
- """
- Returns True if the OrtValue corresponding to an
- optional type contains data, else returns False
- """
- return self._ortvalue.has_value()
- def is_tensor(self) -> bool:
- """
- Returns True if the OrtValue contains a Tensor, else returns False
- """
- return self._ortvalue.is_tensor()
- def is_sparse_tensor(self) -> bool:
- """
- Returns True if the OrtValue contains a SparseTensor, else returns False
- """
- return self._ortvalue.is_sparse_tensor()
- def is_tensor_sequence(self) -> bool:
- """
- Returns True if the OrtValue contains a Tensor Sequence, else returns False
- """
- return self._ortvalue.is_tensor_sequence()
- def numpy(self) -> np.ndarray:
- """
- Returns a Numpy object from the OrtValue.
- Valid only for OrtValues holding Tensors. Throws for OrtValues holding non-Tensors.
- Use accessors to gain a reference to non-Tensor objects such as SparseTensor
- """
- return self._ortvalue.numpy()
- def update_inplace(self, np_arr) -> None:
- """
- Update the OrtValue in place with a new Numpy array. The numpy contents
- are copied over to the device memory backing the OrtValue. It can be used
- to update the input valuess for an InferenceSession with CUDA graph
- enabled or other scenarios where the OrtValue needs to be updated while
- the memory address can not be changed.
- """
- self._ortvalue.update_inplace(np_arr)
- def copy_tensors(src: Sequence[OrtValue], dst: Sequence[OrtValue], stream=None) -> None:
- """
- Copy tensor data from source OrtValue sequence to destination OrtValue sequence.
- """
- c_sources = [s._get_c_value() for s in src]
- c_dsts = [d._get_c_value() for d in dst]
- C.copy_tensors(c_sources, c_dsts, stream)
- class OrtDevice:
- """
- A data structure that exposes the underlying C++ OrtDevice
- """
- def __init__(self, c_ort_device):
- """
- Internal constructor
- """
- if isinstance(c_ort_device, C.OrtDevice):
- self._ort_device = c_ort_device
- else:
- # An end user won't hit this error
- raise ValueError(
- "`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`"
- )
- def _get_c_device(self):
- """
- Internal accessor to underlying object
- """
- return self._ort_device
- @staticmethod
- def make(ort_device_name, device_id, vendor_id=-1):
- if vendor_id < 0:
- # backwards compatibility with predefined OrtDevice names
- return OrtDevice(
- C.OrtDevice(
- get_ort_device_type(ort_device_name),
- C.OrtDevice.default_memory(),
- device_id,
- )
- )
- else:
- # generic. use GPU or NPU for ort_device_name and provide a vendor id.
- # vendor id of 0 is valid in some cases (e.g. webgpu is generic and does not have a vendor id)
- return OrtDevice(
- C.OrtDevice(
- get_ort_device_type(ort_device_name),
- C.OrtDevice.default_memory(),
- vendor_id,
- device_id,
- )
- )
- def device_id(self):
- return self._ort_device.device_id()
- def device_type(self):
- return self._ort_device.device_type()
- def device_vendor_id(self):
- return self._ort_device.vendor_id()
- def device_mem_type(self):
- return self._ort_device.mem_type()
- class SparseTensor:
- """
- A data structure that project the C++ SparseTensor object
- The class provides API to work with the object.
- Depending on the format, the class will hold more than one buffer
- depending on the format
- """
- def __init__(self, sparse_tensor: C.SparseTensor):
- """
- Internal constructor
- """
- if isinstance(sparse_tensor, C.SparseTensor):
- self._tensor = sparse_tensor
- else:
- # An end user won't hit this error
- raise ValueError(
- "`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.SparseTensor`"
- )
- def _get_c_tensor(self) -> C.SparseTensor:
- return self._tensor
- @classmethod
- def sparse_coo_from_numpy(
- cls,
- dense_shape: npt.NDArray[np.int64],
- values: np.ndarray,
- coo_indices: npt.NDArray[np.int64],
- ort_device: OrtDevice,
- ) -> SparseTensor:
- """
- Factory method to construct a SparseTensor in COO format from given arguments
- :param dense_shape: 1-D numpy array(int64) or a python list that contains a dense_shape of the sparse tensor
- must be on cpu memory
- :param values: a homogeneous, contiguous 1-D numpy array that contains non-zero elements of the tensor
- of a type.
- :param coo_indices: contiguous numpy array(int64) that contains COO indices for the tensor. coo_indices may
- have a 1-D shape when it contains a linear index of non-zero values and its length must be equal to
- that of the values. It can also be of 2-D shape, in which has it contains pairs of coordinates for
- each of the nnz values and its length must be exactly twice of the values length.
- :param ort_device: - describes the backing memory owned by the supplied nummpy arrays. Only CPU memory is
- suppored for non-numeric data types.
- For primitive types, the method will map values and coo_indices arrays into native memory and will use
- them as backing storage. It will increment the reference count for numpy arrays and it will decrement it
- on GC. The buffers may reside in any storage either CPU or GPU.
- For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those
- on other devices and their memory can not be mapped.
- """
- return cls(C.SparseTensor.sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device._get_c_device()))
- @classmethod
- def sparse_csr_from_numpy(
- cls,
- dense_shape: npt.NDArray[np.int64],
- values: np.ndarray,
- inner_indices: npt.NDArray[np.int64],
- outer_indices: npt.NDArray[np.int64],
- ort_device: OrtDevice,
- ) -> SparseTensor:
- """
- Factory method to construct a SparseTensor in CSR format from given arguments
- :param dense_shape: 1-D numpy array(int64) or a python list that contains a dense_shape of the
- sparse tensor (rows, cols) must be on cpu memory
- :param values: a contiguous, homogeneous 1-D numpy array that contains non-zero elements of the tensor
- of a type.
- :param inner_indices: contiguous 1-D numpy array(int64) that contains CSR inner indices for the tensor.
- Its length must be equal to that of the values.
- :param outer_indices: contiguous 1-D numpy array(int64) that contains CSR outer indices for the tensor.
- Its length must be equal to the number of rows + 1.
- :param ort_device: - describes the backing memory owned by the supplied nummpy arrays. Only CPU memory is
- suppored for non-numeric data types.
- For primitive types, the method will map values and indices arrays into native memory and will use them as
- backing storage. It will increment the reference count and it will decrement then count when it is GCed.
- The buffers may reside in any storage either CPU or GPU.
- For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those
- on other devices and their memory can not be mapped.
- """
- return cls(
- C.SparseTensor.sparse_csr_from_numpy(
- dense_shape,
- values,
- inner_indices,
- outer_indices,
- ort_device._get_c_device(),
- )
- )
- def values(self) -> np.ndarray:
- """
- The method returns a numpy array that is backed by the native memory
- if the data type is numeric. Otherwise, the returned numpy array that contains
- copies of the strings.
- """
- return self._tensor.values()
- def as_coo_view(self):
- """
- The method will return coo representation of the sparse tensor which will enable
- querying COO indices. If the instance did not contain COO format, it would throw.
- You can query coo indices as:
- ::
- coo_indices = sparse_tensor.as_coo_view().indices()
- which will return a numpy array that is backed by the native memory.
- """
- return self._tensor.get_coo_data()
- def as_csrc_view(self):
- """
- The method will return CSR(C) representation of the sparse tensor which will enable
- querying CRS(C) indices. If the instance dit not contain CSR(C) format, it would throw.
- You can query indices as:
- ::
- inner_ndices = sparse_tensor.as_csrc_view().inner()
- outer_ndices = sparse_tensor.as_csrc_view().outer()
- returning numpy arrays backed by the native memory.
- """
- return self._tensor.get_csrc_data()
- def as_blocksparse_view(self):
- """
- The method will return coo representation of the sparse tensor which will enable
- querying BlockSparse indices. If the instance did not contain BlockSparse format, it would throw.
- You can query coo indices as:
- ::
- block_sparse_indices = sparse_tensor.as_blocksparse_view().indices()
- which will return a numpy array that is backed by the native memory
- """
- return self._tensor.get_blocksparse_data()
- def to_cuda(self, ort_device):
- """
- Returns a copy of this instance on the specified cuda device
- :param ort_device: with name 'cuda' and valid gpu device id
- The method will throw if:
- - this instance contains strings
- - this instance is already on GPU. Cross GPU copy is not supported
- - CUDA is not present in this build
- - if the specified device is not valid
- """
- return SparseTensor(self._tensor.to_cuda(ort_device._get_c_device()))
- def format(self):
- """
- Returns a OrtSparseFormat enumeration
- """
- return self._tensor.format
- def dense_shape(self) -> npt.NDArray[np.int64]:
- """
- Returns a numpy array(int64) containing a dense shape of a sparse tensor
- """
- return self._tensor.dense_shape()
- def data_type(self) -> str:
- """
- Returns a string data type of the data in the OrtValue
- """
- return self._tensor.data_type()
- def device_name(self) -> str:
- """
- Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda
- """
- return self._tensor.device_name().lower()
- # Type hint for user-specified function that allows the user to specify initializer locations when compiling a model.
- GetInitializerLocationFunc = Callable[
- [str, OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None
- ]
- # Type hint that adheres to the signature expected by ORT.
- GetInitializerLocationWrapperFunc = Callable[
- [str, C.OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None
- ]
|