sequential.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Any, Optional, Union
  19. from kornia.core.external import onnx
  20. from kornia.core.external import onnxruntime as ort
  21. from kornia.core.mixin import ONNXMixin, ONNXRuntimeMixin
  22. __all__ = ["ONNXSequential"]
  23. class ONNXSequential(ONNXMixin, ONNXRuntimeMixin):
  24. """ONNXSequential to chain multiple ONNX operators together.
  25. Args:
  26. *args: A variable number of ONNX models (either ONNX ModelProto objects or file paths).
  27. For Hugging Face-hosted models, use the format 'hf://model_name'. Valid `model_name` can be found on
  28. https://huggingface.co/kornia/ONNX_models. Or a URL to the ONNX model.
  29. providers: A list of execution providers for ONNXRuntime
  30. (e.g., ['CUDAExecutionProvider', 'CPUExecutionProvider']).
  31. session_options: Optional ONNXRuntime session options for optimizing the session.
  32. io_maps: An optional list of list of tuples specifying input-output mappings for combining models.
  33. If None, we assume the default input name and output name are "input" and "output" accordingly, and
  34. only one input and output node for each graph.
  35. If not None, `io_maps[0]` shall represent the `io_map` for combining the first and second ONNX models.
  36. cache_dir: The directory where ONNX models are cached locally (only for downloading from HuggingFace).
  37. Defaults to None, which will use a default `kornia.config.hub_onnx_dir` directory.
  38. auto_ir_version_conversion: If True, automatically convert the model's IR version to 9, and OPSET version to 17.
  39. Other versions may be pointed to by `target_ir_version` and `target_opset_version`.
  40. target_ir_version: The target IR version to convert to.
  41. target_opset_version: The target OPSET version to convert to.
  42. """
  43. def __init__(
  44. self,
  45. *args: Union[onnx.ModelProto, str], # type:ignore
  46. providers: Optional[list[str]] = None,
  47. session_options: Optional[ort.SessionOptions] = None, # type:ignore
  48. io_maps: Optional[list[tuple[str, str]]] = None,
  49. cache_dir: Optional[str] = None,
  50. auto_ir_version_conversion: bool = False,
  51. target_ir_version: Optional[int] = None,
  52. target_opset_version: Optional[int] = None,
  53. ) -> None:
  54. self.operators = self._load_ops(*args, cache_dir=cache_dir)
  55. if auto_ir_version_conversion:
  56. self.operators = self._auto_version_conversion(
  57. *self.operators, target_ir_version=target_ir_version, target_opset_version=target_opset_version
  58. )
  59. self._combined_op = self.combine(io_maps=io_maps)
  60. session = self.create_session(providers=providers, session_options=session_options)
  61. self.set_session(session=session)
  62. def _auto_version_conversion(
  63. self,
  64. *args: list[onnx.ModelProto], # type:ignore
  65. target_ir_version: Optional[int] = None,
  66. target_opset_version: Optional[int] = None,
  67. ) -> list[onnx.ModelProto]: # type:ignore
  68. """Automatic conversion of the model's IR/OPSET version to the given target version.
  69. If `target_ir_version` is not provided, the model is converted to 9 by default.
  70. If `target_opset_version` is not provided, the model is converted to 17 by default.
  71. Args:
  72. args: List of operations to convert.
  73. target_ir_version: The target IR version to convert to.
  74. target_opset_version: The target OPSET version to convert to.
  75. """
  76. # TODO: maybe another logic for versioning.
  77. if target_ir_version is None:
  78. target_ir_version = 9
  79. if target_opset_version is None:
  80. target_opset_version = 17
  81. op_list = []
  82. for op in args:
  83. op = super()._onnx_version_conversion(
  84. op, target_ir_version=target_ir_version, target_opset_version=target_opset_version
  85. )
  86. op_list.append(op)
  87. return op_list
  88. def combine(self, io_maps: list[tuple[str, str]] | None = None) -> onnx.ModelProto: # type: ignore
  89. return super()._combine(*self.operators, io_maps=io_maps)
  90. def create_session(
  91. self, providers: list[str] | None = None, session_options: Any | None = None
  92. ) -> ort.InferenceSession: # type: ignore
  93. return super()._create_session(self._combined_op, providers, session_options)
  94. def export(self, file_path: str, **kwargs: Any) -> None:
  95. return super()._export(self._combined_op, file_path, **kwargs)
  96. def add_metadata(self, additional_metadata: Optional[list[tuple[str, str]]] = None) -> onnx.ModelProto: # type:ignore
  97. return super()._add_metadata(self._combined_op, additional_metadata)