__init__.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. __all__ = [
  4. # Modules
  5. "errors",
  6. "ops",
  7. # Public functions
  8. "export",
  9. "is_in_onnx_export",
  10. # Base error
  11. "OnnxExporterError",
  12. "ONNXProgram",
  13. "ExportableModule",
  14. "InputObserver",
  15. ]
  16. from typing import Any, TYPE_CHECKING
  17. import torch
  18. from torch._C import _onnx as _C_onnx
  19. from torch._C._onnx import ( # Deprecated members that are excluded from __all__
  20. OperatorExportTypes as OperatorExportTypes,
  21. TensorProtoDataType as TensorProtoDataType,
  22. TrainingMode as TrainingMode,
  23. )
  24. from . import errors, ops
  25. from ._internal.exporter._exportable_module import ExportableModule
  26. from ._internal.exporter._input_observer import InputObserver
  27. from ._internal.exporter._onnx_program import ONNXProgram
  28. from ._internal.torchscript_exporter import ( # Deprecated members that are excluded from __all__
  29. symbolic_helper,
  30. symbolic_opset10,
  31. symbolic_opset9,
  32. utils,
  33. )
  34. from ._internal.torchscript_exporter._type_utils import (
  35. JitScalarType, # Deprecated members that are excluded from __all__
  36. )
  37. from ._internal.torchscript_exporter.utils import ( # Deprecated members that are excluded from __all__
  38. register_custom_op_symbolic,
  39. select_model_mode_for_export, # pyrefly: ignore # deprecated
  40. unregister_custom_op_symbolic,
  41. )
  42. from .errors import OnnxExporterError
  43. if TYPE_CHECKING:
  44. import os
  45. from collections.abc import Callable, Collection, Mapping, Sequence
  46. # Set namespace for exposed private names
  47. ONNXProgram.__module__ = "torch.onnx"
  48. ExportableModule.__module__ = "torch.onnx"
  49. OnnxExporterError.__module__ = "torch.onnx"
  50. InputObserver.__module__ = "torch.onnx"
  51. # TODO(justinchuby): Remove these two properties
  52. producer_name = "pytorch"
  53. producer_version = _C_onnx.PRODUCER_VERSION
  54. def export(
  55. model: torch.nn.Module
  56. | torch.export.ExportedProgram
  57. | torch.jit.ScriptModule
  58. | torch.jit.ScriptFunction,
  59. args: tuple[Any, ...] = (),
  60. f: str | os.PathLike | None = None,
  61. *,
  62. kwargs: dict[str, Any] | None = None,
  63. verbose: bool | None = None,
  64. input_names: Sequence[str] | None = None,
  65. output_names: Sequence[str] | None = None,
  66. opset_version: int | None = None,
  67. dynamo: bool = True,
  68. # Dynamo only options
  69. external_data: bool = True,
  70. dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
  71. custom_translation_table: dict[Callable, Callable] | None = None,
  72. report: bool = False,
  73. optimize: bool = True,
  74. verify: bool = False,
  75. profile: bool = False,
  76. dump_exported_program: bool = False,
  77. artifacts_dir: str | os.PathLike = ".",
  78. # BC options
  79. export_params: bool = True,
  80. keep_initializers_as_inputs: bool = False,
  81. dynamic_axes: Mapping[str, Mapping[int, str]]
  82. | Mapping[str, Sequence[int]]
  83. | None = None,
  84. # Deprecated options
  85. training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
  86. operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
  87. do_constant_folding: bool = True,
  88. custom_opsets: Mapping[str, int] | None = None,
  89. export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
  90. autograd_inlining: bool = True,
  91. ) -> ONNXProgram | None:
  92. r"""Exports a model into ONNX format.
  93. Setting ``dynamo=True`` enables the new ONNX export logic
  94. which is based on :class:`torch.export.ExportedProgram` and a more modern
  95. set of translation logic. This is the recommended and default way to export models
  96. to ONNX.
  97. When ``dynamo=True``:
  98. The exporter tries the following strategies to get an ExportedProgram for conversion to ONNX.
  99. #. If the model is already an ExportedProgram, it will be used as-is.
  100. #. Use :func:`torch.export.export` and set ``strict=False``.
  101. #. Use :func:`torch.export.export` and set ``strict=True``.
  102. Args:
  103. model: The model to be exported.
  104. args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the
  105. exported model; any Tensor arguments will become inputs of the exported model,
  106. in the order they occur in the tuple.
  107. f: Path to the output ONNX model file. E.g. "model.onnx". This argument is kept for
  108. backward compatibility. It is recommended to leave unspecified (None)
  109. and use the returned :class:`torch.onnx.ONNXProgram` to serialize the model
  110. to a file instead.
  111. kwargs: Optional example keyword inputs.
  112. verbose: Whether to enable verbose logging.
  113. input_names: names to assign to the input nodes of the graph, in order.
  114. output_names: names to assign to the output nodes of the graph, in order.
  115. opset_version: The version of the
  116. `default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
  117. to target. You should set ``opset_version`` according to the supported opset versions
  118. of the runtime backend or compiler you want to run the exported model with.
  119. Leave as default (``None``) to use the recommended version, or refer to
  120. the ONNX operators documentation for more information.
  121. dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript.
  122. external_data: Whether to save the model weights as an external data file.
  123. This is required for models with large weights that exceed the ONNX file size limit (2GB).
  124. When False, the weights are saved in the ONNX file with the model architecture.
  125. dynamic_shapes: A dictionary or a tuple of dynamic shapes for the model inputs. Refer to
  126. :func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True.
  127. Note that dynamic_shapes is designed to be used when the model is exported with dynamo=True, while
  128. dynamic_axes is used when dynamo=False.
  129. custom_translation_table: A dictionary of custom decompositions for operators in the model.
  130. The dictionary should have the callable target in the fx Node as the key (e.g. ``torch.ops.aten.stft.default``),
  131. and the value should be a function that builds that graph using ONNX Script. This option
  132. is only valid when dynamo is True.
  133. report: Whether to generate a markdown report for the export process. This option
  134. is only valid when dynamo is True.
  135. optimize: Whether to optimize the exported model. This option
  136. is only valid when dynamo is True. Default is True.
  137. verify: Whether to verify the exported model using ONNX Runtime. This option
  138. is only valid when dynamo is True.
  139. profile: Whether to profile the export process. This option
  140. is only valid when dynamo is True.
  141. dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file.
  142. This is useful for debugging the exporter. This option is only valid when dynamo is True.
  143. artifacts_dir: The directory to save the debugging artifacts like the report and the serialized
  144. exported program. This option is only valid when dynamo is True.
  145. export_params: **When ``f`` is specified**: If false, parameters (weights) will not be exported.
  146. You can also leave it unspecified and use the returned :class:`torch.onnx.ONNXProgram`
  147. to control how initializers are treated when serializing the model.
  148. keep_initializers_as_inputs: **When ``f`` is specified**: If True, all the
  149. initializers (typically corresponding to model weights) in the
  150. exported graph will also be added as inputs to the graph. If False,
  151. then initializers are not added as inputs to the graph, and only
  152. the user inputs are added as inputs.
  153. Set this to True if you intend to supply model weights at runtime.
  154. Set it to False if the weights are static to allow for better optimizations
  155. (e.g. constant folding) by backends/runtimes.
  156. You can also leave it unspecified and use the returned :class:`torch.onnx.ONNXProgram`
  157. to control how initializers are treated when serializing the model.
  158. dynamic_axes:
  159. Deprecated: Prefer specifying ``dynamic_shapes`` when ``dynamo=True``.
  160. By default the exported model will have the shapes of all input and output tensors
  161. set to exactly match those given in ``args``. To specify axes of tensors as
  162. dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema:
  163. * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or
  164. ``output_names``.
  165. * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
  166. list, each element is an axis index.
  167. For example::
  168. class SumModule(torch.nn.Module):
  169. def forward(self, x):
  170. return torch.sum(x, dim=1)
  171. torch.onnx.export(
  172. SumModule(),
  173. (torch.ones(2, 2),),
  174. "onnx.pb",
  175. input_names=["x"],
  176. output_names=["sum"],
  177. )
  178. Produces::
  179. input {
  180. name: "x"
  181. ...
  182. shape {
  183. dim {
  184. dim_value: 2 # axis 0
  185. }
  186. dim {
  187. dim_value: 2 # axis 1
  188. ...
  189. output {
  190. name: "sum"
  191. ...
  192. shape {
  193. dim {
  194. dim_value: 2 # axis 0
  195. ...
  196. While::
  197. torch.onnx.export(
  198. SumModule(),
  199. (torch.ones(2, 2),),
  200. "onnx.pb",
  201. input_names=["x"],
  202. output_names=["sum"],
  203. dynamic_axes={
  204. # dict value: manually named axes
  205. "x": {0: "my_custom_axis_name"},
  206. # list value: automatic names
  207. "sum": [0],
  208. },
  209. )
  210. Produces::
  211. input {
  212. name: "x"
  213. ...
  214. shape {
  215. dim {
  216. dim_param: "my_custom_axis_name" # axis 0
  217. }
  218. dim {
  219. dim_value: 2 # axis 1
  220. ...
  221. output {
  222. name: "sum"
  223. ...
  224. shape {
  225. dim {
  226. dim_param: "sum_dynamic_axes_1" # axis 0
  227. ...
  228. training: Deprecated option. Instead, set the training mode of the model before exporting.
  229. operator_export_type: Deprecated option. Only ONNX is supported.
  230. do_constant_folding: Deprecated option.
  231. custom_opsets: Deprecated option.
  232. export_modules_as_functions: Deprecated option.
  233. autograd_inlining: Deprecated option.
  234. Returns:
  235. :class:`torch.onnx.ONNXProgram` if dynamo is True, otherwise None.
  236. .. versionchanged:: 2.6
  237. ``training`` is now deprecated. Instead, set the training mode of the model before exporting.
  238. ``operator_export_type`` is now deprecated. Only ONNX is supported.
  239. ``do_constant_folding`` is now deprecated. It is always enabled.
  240. ``export_modules_as_functions`` is now deprecated.
  241. ``autograd_inlining`` is now deprecated.
  242. .. versionchanged:: 2.7
  243. ``optimize`` is now True by default.
  244. .. versionchanged:: 2.9
  245. ``dynamo`` is now True by default.
  246. .. versionchanged:: 2.11
  247. ``fallback`` option has been removed.
  248. """
  249. if dynamo is True or isinstance(
  250. model, (torch.export.ExportedProgram, ExportableModule)
  251. ):
  252. from torch.onnx._internal.exporter import _compat
  253. if isinstance(args, torch.Tensor):
  254. args = (args,)
  255. return _compat.export_compat(
  256. model,
  257. args,
  258. f,
  259. kwargs=kwargs,
  260. export_params=export_params,
  261. verbose=verbose,
  262. input_names=input_names,
  263. output_names=output_names,
  264. opset_version=opset_version,
  265. custom_translation_table=custom_translation_table,
  266. dynamic_axes=dynamic_axes,
  267. keep_initializers_as_inputs=keep_initializers_as_inputs,
  268. external_data=external_data,
  269. dynamic_shapes=dynamic_shapes,
  270. report=report,
  271. optimize=optimize,
  272. verify=verify,
  273. profile=profile,
  274. dump_exported_program=dump_exported_program,
  275. artifacts_dir=artifacts_dir,
  276. )
  277. else:
  278. import warnings
  279. from ._internal.torchscript_exporter.utils import export
  280. warnings.warn(
  281. "You are using the legacy TorchScript-based ONNX export. Starting in PyTorch 2.9, "
  282. "the new torch.export-based ONNX exporter has become the default. "
  283. "Learn more about the new export logic: https://docs.pytorch.org/docs/stable/onnx_export.html. "
  284. "For exporting control flow: "
  285. "https://pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html",
  286. category=DeprecationWarning,
  287. stacklevel=2,
  288. )
  289. if dynamic_shapes:
  290. raise ValueError(
  291. "The exporter only supports dynamic shapes "
  292. "through parameter dynamic_axes when dynamo=False."
  293. )
  294. export(
  295. model,
  296. args,
  297. f, # type: ignore[arg-type]
  298. kwargs=kwargs,
  299. export_params=export_params,
  300. verbose=verbose is True,
  301. input_names=input_names,
  302. output_names=output_names,
  303. opset_version=opset_version,
  304. dynamic_axes=dynamic_axes,
  305. keep_initializers_as_inputs=keep_initializers_as_inputs,
  306. training=training,
  307. operator_export_type=operator_export_type,
  308. do_constant_folding=do_constant_folding,
  309. custom_opsets=custom_opsets,
  310. export_modules_as_functions=export_modules_as_functions,
  311. autograd_inlining=autograd_inlining,
  312. )
  313. return None
  314. def is_in_onnx_export() -> bool:
  315. """Returns whether it is in the middle of ONNX export."""
  316. from torch.onnx._internal.exporter import _flags
  317. from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
  318. return GLOBALS.in_onnx_export or _flags._is_onnx_exporting