exc.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from __future__ import annotations
  2. import os
  3. import tempfile
  4. import textwrap
  5. from functools import lru_cache
  6. from typing import Any, Optional, TYPE_CHECKING
  7. from torch._dynamo.exc import BackendCompilerFailed, ShortenTraceback
  8. if TYPE_CHECKING:
  9. import types
  10. from torch.cuda import _CudaDeviceProperties
  11. if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":
  12. @lru_cache(None)
  13. def _record_missing_op(target: Any) -> None:
  14. with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd:
  15. fd.write(str(target) + "\n")
  16. else:
  17. def _record_missing_op(target: Any) -> None: # type: ignore[misc]
  18. pass
  19. class OperatorIssue(RuntimeError):
  20. @staticmethod
  21. def operator_str(target: Any, args: list[Any], kwargs: dict[str, Any]) -> str:
  22. lines = [f"target: {target}"] + [
  23. f"args[{i}]: {arg}" for i, arg in enumerate(args)
  24. ]
  25. if kwargs:
  26. lines.append(f"kwargs: {kwargs}")
  27. return textwrap.indent("\n".join(lines), " ")
  28. class MissingOperatorWithoutDecomp(OperatorIssue):
  29. def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None:
  30. _record_missing_op(target)
  31. super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
  32. class MissingOperatorWithDecomp(OperatorIssue):
  33. def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None:
  34. _record_missing_op(target)
  35. super().__init__(
  36. f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
  37. + textwrap.dedent(
  38. f"""
  39. There is a decomposition available for {target} in
  40. torch._decomp.get_decompositions(). Please add this operator to the
  41. `decompositions` list in torch._inductor.decomposition
  42. """
  43. )
  44. )
  45. class LoweringException(OperatorIssue):
  46. def __init__(
  47. self,
  48. exc: Exception,
  49. target: Any,
  50. args: list[Any],
  51. kwargs: dict[str, Any],
  52. stack_trace: Optional[str] = None,
  53. ) -> None:
  54. msg = f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
  55. if stack_trace:
  56. msg += f"{msg}\nFound from : \n {stack_trace}"
  57. super().__init__(msg)
  58. class SubgraphLoweringException(RuntimeError):
  59. pass
  60. class InvalidCxxCompiler(RuntimeError):
  61. def __init__(self) -> None:
  62. from . import config
  63. super().__init__(
  64. f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}"
  65. )
  66. class CppWrapperCodegenError(RuntimeError):
  67. def __init__(self, msg: str) -> None:
  68. super().__init__(f"C++ wrapper codegen error: {msg}")
  69. class CppCompileError(RuntimeError):
  70. def __init__(self, cmd: list[str], output: str) -> None:
  71. if isinstance(output, bytes):
  72. output = output.decode("utf-8")
  73. self.cmd = cmd
  74. self.output = output
  75. super().__init__(
  76. textwrap.dedent(
  77. """
  78. C++ compile error
  79. Command:
  80. {cmd}
  81. Output:
  82. {output}
  83. """
  84. )
  85. .strip()
  86. .format(cmd=" ".join(cmd), output=output)
  87. )
  88. def __reduce__(self) -> tuple[type, tuple[list[str], str]]:
  89. return (self.__class__, (self.cmd, self.output))
  90. class CUDACompileError(CppCompileError):
  91. pass
  92. class TritonMissing(ShortenTraceback):
  93. def __init__(self, first_useful_frame: Optional[types.FrameType]) -> None:
  94. super().__init__(
  95. "Cannot find a working triton installation. "
  96. "Either the package is not installed or it is too old. "
  97. "More information on installing Triton can be found at: https://github.com/triton-lang/triton",
  98. first_useful_frame=first_useful_frame,
  99. )
  100. class GPUTooOldForTriton(ShortenTraceback):
  101. def __init__(
  102. self,
  103. # pyrefly: ignore [not-a-type]
  104. device_props: _CudaDeviceProperties,
  105. first_useful_frame: Optional[types.FrameType],
  106. ) -> None:
  107. super().__init__(
  108. f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, "
  109. "which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, "
  110. f"but your device is of CUDA capability {device_props.major}.{device_props.minor}",
  111. first_useful_frame=first_useful_frame,
  112. )
  113. class InductorError(BackendCompilerFailed):
  114. backend_name = "inductor"
  115. def __init__(
  116. self,
  117. inner_exception: Exception,
  118. first_useful_frame: Optional[types.FrameType],
  119. ) -> None:
  120. self.inner_exception = inner_exception
  121. ShortenTraceback.__init__(
  122. self,
  123. f"{type(inner_exception).__name__}: {inner_exception}",
  124. first_useful_frame=first_useful_frame,
  125. )