errors.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. """ONNX exporter exceptions."""
  2. from __future__ import annotations
  3. __all__ = [
  4. "OnnxExporterWarning",
  5. "SymbolicValueError",
  6. "UnsupportedOperatorError",
  7. ]
  8. import textwrap
  9. from typing import TYPE_CHECKING
  10. if TYPE_CHECKING:
  11. from torch import _C
  12. class OnnxExporterWarning(UserWarning):
  13. """Warnings in the ONNX exporter."""
  14. class OnnxExporterError(RuntimeError):
  15. """Errors raised by the ONNX exporter. This is the base class for all exporter errors."""
  16. class UnsupportedOperatorError(OnnxExporterError):
  17. """Raised when an operator is unsupported by the exporter."""
  18. # NOTE: This is legacy and is only used by the torchscript exporter
  19. # Clean up when the torchscript exporter is removed
  20. def __init__(self, name: str, version: int, supported_version: int | None) -> None:
  21. if supported_version is not None:
  22. msg = (
  23. f"Exporting the operator '{name}' to ONNX opset version {version} "
  24. "is not supported. Support for this operator was added in version "
  25. f"{supported_version}, try exporting with this version"
  26. )
  27. elif name.startswith(("aten::", "prim::", "quantized::")):
  28. msg = (
  29. f"Exporting the operator '{name}' to ONNX opset version {version} "
  30. "is not supported"
  31. )
  32. else:
  33. msg = (
  34. f"ONNX export failed on an operator with unrecognized namespace {name}. "
  35. "If you are trying to export a custom operator, make sure you registered it with "
  36. "the right domain and version."
  37. )
  38. super().__init__(msg)
  39. class SymbolicValueError(OnnxExporterError):
  40. """Errors around TorchScript values and nodes."""
  41. # NOTE: This is legacy and is only used by the torchscript exporter
  42. # Clean up when the torchscript exporter is removed
  43. def __init__(self, msg: str, value: _C.Value) -> None:
  44. message = (
  45. f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the "
  46. f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] "
  47. )
  48. code_location = value.node().sourceRange()
  49. if code_location:
  50. message += f"\n (node defined in {code_location})"
  51. try:
  52. # Add its input and output to the message.
  53. message += "\n\n"
  54. message += textwrap.indent(
  55. (
  56. "Inputs:\n"
  57. + (
  58. "\n".join(
  59. f" #{i}: {input_} (type '{input_.type()}')"
  60. for i, input_ in enumerate(value.node().inputs())
  61. )
  62. or " Empty"
  63. )
  64. + "\n"
  65. + "Outputs:\n"
  66. + (
  67. "\n".join(
  68. f" #{i}: {output} (type '{output.type()}')"
  69. for i, output in enumerate(value.node().outputs())
  70. )
  71. or " Empty"
  72. )
  73. ),
  74. " ",
  75. )
  76. except AttributeError:
  77. message += (
  78. " Failed to obtain its input and output for debugging. "
  79. "Please refer to the TorchScript graph for debugging information."
  80. )
  81. super().__init__(message)