Exception.h 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #ifndef C10_UTIL_EXCEPTION_H_
  3. #define C10_UTIL_EXCEPTION_H_
  4. #include <c10/macros/Export.h>
  5. #include <c10/macros/Macros.h>
  6. #include <c10/util/Backtrace.h>
  7. #include <c10/util/Lazy.h>
  8. #include <c10/util/StringUtil.h>
  9. #include <cstdint>
  10. #include <exception>
  11. #include <memory>
  12. #include <string>
  13. #include <variant>
  14. #include <vector>
  15. #if defined(_MSC_VER) && _MSC_VER <= 1900
  16. #define __func__ __FUNCTION__
  17. #endif
  18. namespace c10 {
  19. /// The primary ATen error class.
  20. /// Provides a complete error message with source location information via
  21. /// `what()`, and a more concise message via `what_without_backtrace()`.
  22. /// Don't throw this directly; use TORCH_CHECK/TORCH_INTERNAL_ASSERT instead.
  23. ///
  24. /// NB: c10::Error is handled specially by the default torch to suppress the
  25. /// backtrace, see torch/csrc/Exceptions.h
  26. class C10_API Error : public std::exception {
  27. private:
  28. // The actual error message.
  29. std::string msg_;
  30. // Context for the message (in order of decreasing specificity). Context will
  31. // be automatically formatted appropriately, so it is not necessary to add
  32. // extra leading/trailing newlines to strings inside this vector
  33. std::vector<std::string> context_;
  34. // The C++ backtrace at the point when this exception was raised. This
  35. // may be empty if there is no valid backtrace. (We don't use optional
  36. // here to reduce the dependencies this file has.)
  37. Backtrace backtrace_;
  38. // These two are derived fields from msg_stack_ and backtrace_, but we need
  39. // fields for the strings so that we can return a const char* (as the
  40. // signature of std::exception requires). Currently, the invariant
  41. // is that these fields are ALWAYS populated consistently with respect
  42. // to msg_stack_ and backtrace_.
  43. mutable OptimisticLazy<std::string> what_;
  44. std::string what_without_backtrace_;
  45. // This is a little debugging trick: you can stash a relevant pointer
  46. // in caller, and then when you catch the exception, you can compare
  47. // against pointers you have on hand to get more information about
  48. // where the exception came from. In Caffe2, this is used to figure
  49. // out which operator raised an exception.
  50. const void* caller_;
  51. public:
  52. // PyTorch-style Error constructor. NB: the implementation of this
  53. // is actually in Logging.cpp
  54. Error(SourceLocation source_location, std::string msg);
  55. // Caffe2-style error message
  56. Error(
  57. const char* file,
  58. const uint32_t line,
  59. const char* condition,
  60. const std::string& msg,
  61. Backtrace backtrace,
  62. const void* caller = nullptr);
  63. // Base constructor
  64. Error(
  65. std::string msg,
  66. Backtrace backtrace = nullptr,
  67. const void* caller = nullptr);
  68. // Add some new context to the message stack. The last added context
  69. // will be formatted at the end of the context list upon printing.
  70. // WARNING: This method is O(n) in the size of the stack, so don't go
  71. // wild adding a ridiculous amount of context to error messages.
  72. void add_context(std::string msg);
  73. const std::string& msg() const {
  74. return msg_;
  75. }
  76. const std::vector<std::string>& context() const {
  77. return context_;
  78. }
  79. const Backtrace& backtrace() const;
  80. /// Returns the complete error message, including the source location.
  81. /// The returned pointer is invalidated if you call add_context() on
  82. /// this object.
  83. const char* what() const noexcept override;
  84. const void* caller() const noexcept {
  85. return caller_;
  86. }
  87. /// Returns only the error message string, without source location.
  88. /// The returned pointer is invalidated if you call add_context() on
  89. /// this object.
  90. virtual const char* what_without_backtrace() const noexcept {
  91. return what_without_backtrace_.c_str();
  92. }
  93. private:
  94. void refresh_what();
  95. std::string compute_what(bool include_backtrace) const;
  96. };
  97. class C10_API Warning {
  98. public:
  99. class C10_API UserWarning{};
  100. class C10_API DeprecationWarning{};
  101. using warning_variant_t = std::variant<UserWarning, DeprecationWarning>;
  102. Warning(
  103. warning_variant_t type,
  104. const SourceLocation& source_location,
  105. std::string msg,
  106. bool verbatim);
  107. Warning(
  108. warning_variant_t type,
  109. SourceLocation source_location,
  110. const char* msg,
  111. bool verbatim);
  112. Warning(
  113. warning_variant_t type,
  114. SourceLocation source_location,
  115. ::c10::detail::CompileTimeEmptyString msg,
  116. bool verbatim);
  117. // Getters for members
  118. warning_variant_t type() const;
  119. const SourceLocation& source_location() const;
  120. const std::string& msg() const;
  121. bool verbatim() const;
  122. private:
  123. // The type of warning
  124. warning_variant_t type_;
  125. // Where the warning happened.
  126. SourceLocation source_location_;
  127. // The actual warning message.
  128. std::string msg_;
  129. // See note: [Verbatim Warnings]
  130. bool verbatim_;
  131. };
  132. using UserWarning = Warning::UserWarning;
  133. using DeprecationWarning = Warning::DeprecationWarning;
  134. // Issue a warning with a given message. Dispatched to the current
  135. // warning handler.
  136. void C10_API warn(const Warning& warning);
  137. class C10_API WarningHandler {
  138. public:
  139. virtual ~WarningHandler() = default;
  140. /// The default warning handler. Prints the message to stderr.
  141. virtual void process(const Warning& warning);
  142. };
  143. namespace WarningUtils {
  144. // Note: [Verbatim Warnings]
  145. // Warnings originating in C++ code can appear out-of-place to Python users:
  146. // a user runs a line in Python, but the warning references a line in C++.
  147. // Some parts of PyTorch, like the JIT, are cognizant of this mismatch
  148. // and take care to map warnings back to the user's program, but most
  149. // of PyTorch simply throws a context-free warning. To allow warning
  150. // handlers to add context where appropriate, warn takes the
  151. // "verbatim" flag. When this is false a warning handler might append
  152. // the C++ warning to a Python warning message that relates the warning
  153. // back to the user's program. Callers who have already accounted for
  154. // context in their warnings should set verbatim to true so their warnings
  155. // appear without modification.
  156. /// Sets the global warning handler. This is not thread-safe, so it should
  157. /// generally be called once during initialization or while holding the GIL
  158. /// for programs that use python.
  159. /// User is responsible for keeping the WarningHandler alive until
  160. /// it is not needed.
  161. C10_API void set_warning_handler(WarningHandler* handler) noexcept(true);
  162. /// Gets the global warning handler.
  163. C10_API WarningHandler* get_warning_handler() noexcept(true);
  164. class C10_API WarningHandlerGuard {
  165. WarningHandler* prev_handler_;
  166. public:
  167. WarningHandlerGuard(WarningHandler* new_handler)
  168. : prev_handler_(c10::WarningUtils::get_warning_handler()) {
  169. c10::WarningUtils::set_warning_handler(new_handler);
  170. }
  171. WarningHandlerGuard(WarningHandlerGuard&& other) = delete;
  172. WarningHandlerGuard(const WarningHandlerGuard&) = delete;
  173. WarningHandlerGuard& operator=(const WarningHandlerGuard&) = delete;
  174. WarningHandlerGuard& operator=(WarningHandlerGuard&&) = delete;
  175. ~WarningHandlerGuard() {
  176. c10::WarningUtils::set_warning_handler(prev_handler_);
  177. }
  178. };
  179. /// The TORCH_WARN_ONCE macro is difficult to test for. Use
  180. /// setWarnAlways(true) to turn it into TORCH_WARN, which can be
  181. /// tested for more easily.
  182. C10_API void set_warnAlways(bool /*setting*/) noexcept(true);
  183. C10_API bool get_warnAlways() noexcept(true);
  184. // A RAII guard that sets warn_always (not thread-local) on
  185. // construction, and sets it back to the original value upon destruction.
  186. struct C10_API WarnAlways {
  187. public:
  188. explicit WarnAlways(bool setting = true);
  189. ~WarnAlways();
  190. private:
  191. bool prev_setting;
  192. };
  193. } // namespace WarningUtils
  194. // Like Error, but we always report the C++ backtrace, instead of only
  195. // reporting when TORCH_SHOW_CPP_STACKTRACES
  196. class C10_API ErrorAlwaysShowCppStacktrace : public Error {
  197. using Error::Error;
  198. const char* what_without_backtrace() const noexcept override {
  199. return what();
  200. }
  201. };
  202. // Used in ATen for out-of-bound indices that can reasonably only be detected
  203. // lazily inside a kernel (See: advanced indexing). These turn into
  204. // IndexError when they cross to Python.
  205. class C10_API IndexError : public Error {
  206. using Error::Error;
  207. };
  208. // Used in ATen for invalid values. These turn into
  209. // ValueError when they cross to Python.
  210. class C10_API ValueError : public Error {
  211. using Error::Error;
  212. };
  213. // Used in ATen for invalid types. These turn into
  214. // TypeError when they cross to Python.
  215. class C10_API TypeError : public Error {
  216. using Error::Error;
  217. };
  218. // Used in ATen for functionality that is not implemented. These turn into
  219. // NotImplementedError when they cross to Python.
  220. class C10_API NotImplementedError : public Error {
  221. using Error::Error;
  222. };
  223. // Used in ATen for buffer-related errors, e.g. trying to create a DLPack of
  224. // an unsupported device. These turn into BufferError when they cross to
  225. // Python.
  226. class C10_API BufferError : public Error {
  227. using Error::Error;
  228. };
  229. // Used in ATen for non finite indices. These turn into
  230. // ExitException when they cross to Python.
  231. class C10_API EnforceFiniteError : public Error {
  232. using Error::Error;
  233. };
  234. // Used in Onnxifi backend lowering. These turn into
  235. // ExitException when they cross to Python.
  236. class C10_API OnnxfiBackendSystemError : public Error {
  237. using Error::Error;
  238. };
  239. // Used for numerical errors from the linalg module. These
  240. // turn into LinAlgError when they cross into Python.
  241. class C10_API LinAlgError : public Error {
  242. using Error::Error;
  243. };
  244. class C10_API OutOfMemoryError : public Error {
  245. using Error::Error;
  246. };
  247. // Used for handling syntactic errors in input arguments.
  248. // These turn into SyntaxError when the cross into Python.
  249. class C10_API SyntaxError : public Error {
  250. using Error::Error;
  251. };
  252. // Raised when accelerator API call hits an error.
  253. // These turn into AcceleratorError when the cross into Python
  254. class C10_API AcceleratorError : public Error {
  255. int32_t error_code;
  256. public:
  257. AcceleratorError(SourceLocation loc, int32_t code, const std::string& msg)
  258. : Error(loc, msg), error_code(code) {}
  259. int32_t get_error_code() const {
  260. return error_code;
  261. }
  262. };
  263. // Base error type for all distributed errors.
  264. // These turn into DistError when they cross into Python.
  265. class C10_API DistError : public Error {
  266. using Error::Error;
  267. };
  268. // Used for collective communication library errors from the distributed module.
  269. // These turn into DistBackendError when they cross into Python.
  270. class C10_API DistBackendError : public DistError {
  271. using DistError::DistError;
  272. };
  273. // Used for errors originating from the store.
  274. // These turn into DistStoreError when they cross into Python.
  275. class C10_API DistStoreError : public DistError {
  276. using DistError::DistError;
  277. };
  278. // Used for errors originating from the TCP/IP stack and not from collective
  279. // libraries. These turn into DistNetworkError when they cross into Python.
  280. class C10_API DistNetworkError : public DistError {
  281. using DistError::DistError;
  282. };
  283. // Raised when a queue is empty and a non-blocking pop is called.
  284. // Translated to torch.distributed.QueueEmptyError in Python
  285. class C10_API DistQueueEmptyError : public DistStoreError {
  286. using DistStoreError::DistStoreError;
  287. };
  288. // A utility function to return an exception std::string by prepending its
  289. // exception type before its what() content
  290. C10_API std::string GetExceptionString(const std::exception& e);
  291. } // namespace c10
  292. // Private helper macro for implementing TORCH_INTERNAL_ASSERT and TORCH_CHECK
  293. //
  294. // Note: In the debug build With MSVC, __LINE__ might be of long type (a.k.a
  295. // int32_t), which is different from the definition of `SourceLocation` that
  296. // requires unsigned int (a.k.a uint32_t) and may cause a compile error with the
  297. // message: error C2397: conversion from 'long' to 'uint32_t' requires a
  298. // narrowing conversion Here the static cast is used to pass the build. if this
  299. // is used inside a lambda the __func__ macro expands to operator(), which isn't
  300. // very useful, but hard to fix in a macro so suppressing the warning.
  301. #define C10_THROW_ERROR(err_type, msg) \
  302. throw ::c10::err_type( \
  303. {__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, msg)
  304. #define C10_BUILD_ERROR(err_type, msg) \
  305. ::c10::err_type({__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, msg)
  306. // Private helper macro for workaround MSVC misexpansion of nested macro
  307. // invocations involving __VA_ARGS__. See
  308. // https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly
  309. #define C10_EXPAND_MSVC_WORKAROUND(x) x
  310. #include <torch/headeronly/util/Exception.h>
  311. // ----------------------------------------------------------------------------
  312. // Error reporting macros
  313. // ----------------------------------------------------------------------------
  314. #ifdef STRIP_ERROR_MESSAGES
  315. #define TORCH_RETHROW(e, ...) \
  316. do { \
  317. (void)e; /* Suppress unused variable warning */ \
  318. throw; \
  319. } while (false)
  320. #else
  321. #define TORCH_RETHROW(e, ...) \
  322. do { \
  323. e.add_context(::c10::str(__VA_ARGS__)); \
  324. throw; \
  325. } while (false)
  326. #endif
  327. // A utility macro to provide assert()-like functionality; that is, enforcement
  328. // of internal invariants in code. It supports an arbitrary number of extra
  329. // arguments (evaluated only on failure), which will be printed in the assert
  330. // failure message using operator<< (this is useful to print some variables
  331. // which may be useful for debugging.)
  332. //
  333. // Usage:
  334. // TORCH_INTERNAL_ASSERT(should_be_true);
  335. // TORCH_INTERNAL_ASSERT(x == 0, "x = ", x);
  336. //
  337. // Assuming no bugs in PyTorch, the conditions tested by this macro should
  338. // always be true; e.g., it should be possible to disable all of these
  339. // conditions without changing observable user behavior. If you would like to
  340. // do error reporting for user input, please use TORCH_CHECK instead.
  341. //
  342. // NOTE: It is SAFE to use this macro in production code; on failure, this
  343. // simply raises an exception, it does NOT unceremoniously quit the process
  344. // (unlike assert()).
  345. //
  346. #ifdef STRIP_ERROR_MESSAGES
  347. #define TORCH_INTERNAL_ASSERT(cond, ...) \
  348. if (C10_UNLIKELY_OR_CONST(!(cond))) { \
  349. ::c10::detail::torchCheckFail( \
  350. __func__, \
  351. __FILE__, \
  352. static_cast<uint32_t>(__LINE__), \
  353. #cond " INTERNAL ASSERT FAILED at " C10_STRINGIZE(__FILE__)); \
  354. }
  355. #else
  356. // It would be nice if we could build a combined string literal out of
  357. // the TORCH_INTERNAL_ASSERT prefix and a user-provided string literal
  358. // as the first argument, but there doesn't seem to be any good way to
  359. // do that while still supporting having a first argument that isn't a
  360. // string literal.
  361. #define TORCH_INTERNAL_ASSERT(cond, ...) \
  362. if (C10_UNLIKELY_OR_CONST(!(cond))) { \
  363. ::c10::detail::torchInternalAssertFail( \
  364. __func__, \
  365. __FILE__, \
  366. static_cast<uint32_t>(__LINE__), \
  367. #cond \
  368. " INTERNAL ASSERT FAILED at " C10_STRINGIZE(__FILE__) ":" C10_STRINGIZE( \
  369. __LINE__) ", please report a bug to PyTorch. ", \
  370. c10::str(__VA_ARGS__)); \
  371. }
  372. #endif
  373. // A utility macro to make it easier to test for error conditions from user
  374. // input. Like TORCH_INTERNAL_ASSERT, it supports an arbitrary number of extra
  375. // arguments (evaluated only on failure), which will be printed in the error
  376. // message using operator<< (e.g., you can pass any object which has
  377. // operator<< defined. Most objects in PyTorch have these definitions!)
  378. //
  379. // Usage:
  380. // TORCH_CHECK(should_be_true); // A default error message will be provided
  381. // // in this case; but we recommend writing an
  382. // // explicit error message, as it is more
  383. // // user friendly.
  384. // TORCH_CHECK(x == 0, "Expected x to be 0, but got ", x);
  385. //
  386. // On failure, this macro will raise an exception. If this exception propagates
  387. // to Python, it will convert into a Python RuntimeError.
  388. //
  389. // NOTE: It is SAFE to use this macro in production code; on failure, this
  390. // simply raises an exception, it does NOT unceremoniously quit the process
  391. // (unlike CHECK() from glog.)
  392. //
  393. #define TORCH_CHECK_WITH(error_t, cond, ...) \
  394. TORCH_CHECK_WITH_MSG(error_t, cond, "", __VA_ARGS__)
  395. #ifdef STRIP_ERROR_MESSAGES
  396. #define TORCH_CHECK_MSG(cond, type, ...) \
  397. (#cond #type " CHECK FAILED at " C10_STRINGIZE(__FILE__))
  398. #define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \
  399. if (C10_UNLIKELY_OR_CONST(!(cond))) { \
  400. C10_THROW_ERROR(Error, TORCH_CHECK_MSG(cond, type, __VA_ARGS__)); \
  401. }
  402. #else
  403. namespace c10::detail {
  404. template <typename... Args>
  405. auto torchCheckMsgImpl(const char* /*msg*/, const Args&... args) {
  406. return ::c10::str(args...);
  407. }
  408. inline C10_API const char* torchCheckMsgImpl(const char* msg) {
  409. return msg;
  410. }
  411. // If there is just 1 user-provided C-string argument, use it.
  412. inline C10_API const char* torchCheckMsgImpl(
  413. const char* /*msg*/,
  414. const char* args) {
  415. return args;
  416. }
  417. } // namespace c10::detail
  418. #define TORCH_CHECK_MSG(cond, type, ...) \
  419. (::c10::detail::torchCheckMsgImpl( \
  420. "Expected " #cond \
  421. " to be true, but got false. " \
  422. "(Could this error message be improved? If so, " \
  423. "please report an enhancement request to PyTorch.)", \
  424. ##__VA_ARGS__))
  425. #define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \
  426. if (C10_UNLIKELY_OR_CONST(!(cond))) { \
  427. C10_THROW_ERROR(error_t, TORCH_CHECK_MSG(cond, type, __VA_ARGS__)); \
  428. }
  429. #endif
  430. namespace c10::detail {
  431. [[noreturn]] C10_API void torchCheckFail(
  432. const char* func,
  433. const char* file,
  434. uint32_t line,
  435. const std::string& msg);
  436. [[noreturn]] C10_API void torchCheckFail(
  437. const char* func,
  438. const char* file,
  439. uint32_t line,
  440. const char* msg);
  441. // The c10::str() call that creates userMsg can have 1 of 3 return
  442. // types depending on the number and types of arguments passed to
  443. // TORCH_INTERNAL_ASSERT. 0 arguments will get a
  444. // CompileTimeEmptyString, 1 const char * will be passed straight
  445. // through, and anything else will get converted to std::string.
  446. [[noreturn]] C10_API void torchInternalAssertFail(
  447. const char* func,
  448. const char* file,
  449. uint32_t line,
  450. const char* condMsg,
  451. const char* userMsg);
  452. [[noreturn]] inline C10_API void torchInternalAssertFail(
  453. const char* func,
  454. const char* file,
  455. uint32_t line,
  456. const char* condMsg,
  457. ::c10::detail::CompileTimeEmptyString /*userMsg*/) {
  458. torchCheckFail(func, file, line, condMsg);
  459. }
  460. [[noreturn]] C10_API void torchInternalAssertFail(
  461. const char* func,
  462. const char* file,
  463. uint32_t line,
  464. const char* condMsg,
  465. const std::string& userMsg);
  466. } // namespace c10::detail
  467. #ifdef STANDALONE_TORCH_HEADER
  468. // TORCH_CHECK throws std::runtime_error instead of c10::Error which is
  469. // useful when certain headers are used in a libtorch-independent way,
  470. // e.g. when Vectorized<T> is used in AOTInductor generated code.
  471. #ifdef STRIP_ERROR_MESSAGES
  472. #define TORCH_CHECK(cond, ...) \
  473. if (C10_UNLIKELY_OR_CONST(!(cond))) { \
  474. throw std::runtime_error(TORCH_CHECK_MSG( \
  475. cond, \
  476. "", \
  477. __func__, \
  478. ", ", \
  479. __FILE__, \
  480. ":", \
  481. __LINE__, \
  482. ", ", \
  483. __VA_ARGS__)); \
  484. }
  485. #else
  486. #define TORCH_CHECK(cond, ...) \
  487. if (C10_UNLIKELY_OR_CONST(!(cond))) { \
  488. throw std::runtime_error(TORCH_CHECK_MSG( \
  489. cond, \
  490. "", \
  491. __func__, \
  492. ", ", \
  493. __FILE__, \
  494. ":", \
  495. __LINE__, \
  496. ", ", \
  497. ##__VA_ARGS__)); \
  498. }
  499. #endif
  500. #else
  501. #ifdef STRIP_ERROR_MESSAGES
  502. #define TORCH_CHECK(cond, ...) \
  503. if (C10_UNLIKELY_OR_CONST(!(cond))) { \
  504. ::c10::detail::torchCheckFail( \
  505. __func__, \
  506. __FILE__, \
  507. static_cast<uint32_t>(__LINE__), \
  508. TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \
  509. }
  510. #else
  511. #define TORCH_CHECK(cond, ...) \
  512. if (C10_UNLIKELY_OR_CONST(!(cond))) { \
  513. ::c10::detail::torchCheckFail( \
  514. __func__, \
  515. __FILE__, \
  516. static_cast<uint32_t>(__LINE__), \
  517. TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \
  518. }
  519. #endif
  520. #endif
  521. // An utility macro that does what `TORCH_CHECK` does if compiled in the host
  522. // code, otherwise does nothing. Supposed to be used in the code shared between
  523. // host and device code as an alternative for `TORCH_CHECK`.
  524. #if defined(__CUDACC__) || defined(__HIPCC__)
  525. #define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...)
  526. #else
  527. #define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...) TORCH_CHECK(cond, ##__VA_ARGS__)
  528. #endif
  529. // Debug only version of TORCH_INTERNAL_ASSERT. This macro only checks in debug
  530. // build, and does nothing in release build. It is appropriate to use
  531. // in situations where you want to add an assert to a hotpath, but it is
  532. // too expensive to run this assert on production builds.
  533. #ifdef NDEBUG
  534. // Optimized version - generates no code.
  535. #define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \
  536. while (false) \
  537. C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__))
  538. #else
  539. #define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \
  540. C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__))
  541. #endif
  542. // TODO: We're going to get a lot of similar looking string literals
  543. // this way; check if this actually affects binary size.
  544. // Like TORCH_CHECK, but raises LinAlgError instead of Error.
  545. #define TORCH_CHECK_LINALG(cond, ...) \
  546. TORCH_CHECK_WITH_MSG(LinAlgError, cond, "LINALG", __VA_ARGS__)
  547. // Like TORCH_CHECK, but raises IndexErrors instead of Errors.
  548. #define TORCH_CHECK_INDEX(cond, ...) \
  549. TORCH_CHECK_WITH_MSG(IndexError, cond, "INDEX", __VA_ARGS__)
  550. // Like TORCH_CHECK, but raises ValueErrors instead of Errors.
  551. #define TORCH_CHECK_VALUE(cond, ...) \
  552. TORCH_CHECK_WITH_MSG(ValueError, cond, "VALUE", __VA_ARGS__)
  553. // Like TORCH_CHECK, but raises TypeErrors instead of Errors.
  554. #define TORCH_CHECK_TYPE(cond, ...) \
  555. TORCH_CHECK_WITH_MSG(TypeError, cond, "TYPE", __VA_ARGS__)
  556. // Like TORCH_CHECK, but raises NotImplementedErrors instead of Errors.
  557. #define TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
  558. TORCH_CHECK_WITH_MSG(NotImplementedError, cond, "TYPE", __VA_ARGS__)
  559. // Like TORCH_CHECK, but raises BufferError instead of Errors.
  560. #define TORCH_CHECK_BUFFER(cond, ...) \
  561. TORCH_CHECK_WITH_MSG(BufferError, cond, "TYPE", __VA_ARGS__)
  562. #define TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(cond, ...) \
  563. TORCH_CHECK_WITH_MSG( \
  564. ErrorAlwaysShowCppStacktrace, cond, "TYPE", ##__VA_ARGS__)
  565. #ifdef STRIP_ERROR_MESSAGES
  566. #define WARNING_MESSAGE_STRING(...) \
  567. ::c10::detail::CompileTimeEmptyString {}
  568. #else
  569. #define WARNING_MESSAGE_STRING(...) ::c10::str(__VA_ARGS__)
  570. #endif
  571. // Report a warning to the user. Accepts an arbitrary number of extra
  572. // arguments which are concatenated into the warning message using operator<<
  573. //
  574. #ifdef DISABLE_WARN
  575. #define _TORCH_WARN_WITH(...) ((void)0);
  576. #else
  577. #define _TORCH_WARN_WITH(warning_t, ...) \
  578. ::c10::warn(::c10::Warning( \
  579. warning_t(), \
  580. {__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \
  581. WARNING_MESSAGE_STRING(__VA_ARGS__), \
  582. false));
  583. #endif
  584. #define TORCH_WARN(...) _TORCH_WARN_WITH(::c10::UserWarning, __VA_ARGS__);
  585. #define TORCH_WARN_DEPRECATION(...) \
  586. _TORCH_WARN_WITH(::c10::DeprecationWarning, __VA_ARGS__);
  587. // Report a warning to the user only once. Accepts an arbitrary number of extra
  588. // arguments which are concatenated into the warning message using operator<<
  589. //
  590. #define _TORCH_WARN_ONCE(...) \
  591. [[maybe_unused]] static const auto C10_ANONYMOUS_VARIABLE( \
  592. torch_warn_once_) = [&] { \
  593. TORCH_WARN(__VA_ARGS__); \
  594. return true; \
  595. }()
  596. #ifdef DISABLE_WARN
  597. #define TORCH_WARN_ONCE(...) ((void)0);
  598. #else
  599. #define TORCH_WARN_ONCE(...) \
  600. if (::c10::WarningUtils::get_warnAlways()) { \
  601. TORCH_WARN(__VA_ARGS__); \
  602. } else { \
  603. _TORCH_WARN_ONCE(__VA_ARGS__); \
  604. }
  605. #endif
  606. // Report an error with a specific argument
  607. // NOTE: using the argument name in TORCH_CHECK's message is preferred
  608. #define TORCH_CHECK_ARG(cond, argN, ...) \
  609. TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__)
  610. #ifndef FATAL_IF
  611. #ifdef C10_USE_GLOG
  612. #define FATAL_IF(condition) \
  613. condition ? (void)0 \
  614. : ::c10::LoggerVoidify() & \
  615. ::c10::MessageLogger( \
  616. ::c10::SourceLocation::current(), ::google::GLOG_FATAL) \
  617. .stream()
  618. #else
  619. #define FATAL_IF(condition) \
  620. condition ? (void)0 \
  621. : ::c10::LoggerVoidify() & \
  622. ::c10::MessageLogger( \
  623. ::c10::SourceLocation::current(), ::c10::GLOG_FATAL) \
  624. .stream()
  625. #endif
  626. #endif
  627. #ifndef NON_FATAL_IF
  628. #ifdef C10_USE_GLOG
  629. #define NON_FATAL_IF(condition) \
  630. condition ? (void)0 \
  631. : ::c10::LoggerVoidify() & \
  632. ::c10::MessageLogger( \
  633. ::c10::SourceLocation::current(), ::google::GLOG_FATAL, false) \
  634. .stream()
  635. #else
  636. #define NON_FATAL_IF(condition) \
  637. condition ? (void)0 \
  638. : ::c10::LoggerVoidify() & \
  639. ::c10::MessageLogger( \
  640. ::c10::SourceLocation::current(), ::c10::GLOG_FATAL, false) \
  641. .stream()
  642. #endif
  643. #endif
  644. // Binary comparison check macros
  645. #define TORCH_CHECK_OP(val1, val2, op) \
  646. NON_FATAL_IF(((val1)op(val2))) \
  647. << "Check failed: " #val1 " " #op " " #val2 " (" << (val1) << " vs. " \
  648. << (val2) << "). "
  649. #define TORCH_DCHECK_OP(val1, val2, op) \
  650. FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
  651. << (val1) << " vs. " << (val2) << "). "
  652. #define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
  653. #define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
  654. #define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
  655. #define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
  656. #define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
  657. #define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
  658. // Debug versions of TORCH_CHECK_OP macros
  659. #ifndef NDEBUG
  660. #define TORCH_DCHECK_EQ(val1, val2) TORCH_DCHECK_OP(val1, val2, ==)
  661. #define TORCH_DCHECK_NE(val1, val2) TORCH_DCHECK_OP(val1, val2, !=)
  662. #define TORCH_DCHECK_LE(val1, val2) TORCH_DCHECK_OP(val1, val2, <=)
  663. #define TORCH_DCHECK_LT(val1, val2) TORCH_DCHECK_OP(val1, val2, <)
  664. #define TORCH_DCHECK_GE(val1, val2) TORCH_DCHECK_OP(val1, val2, >=)
  665. #define TORCH_DCHECK_GT(val1, val2) TORCH_DCHECK_OP(val1, val2, >)
  666. #else // !NDEBUG
  667. // Optimized versions - generate no code
  668. #define TORCH_DCHECK_EQ(val1, val2) \
  669. while (false) \
  670. TORCH_DCHECK_OP(val1, val2, ==)
  671. #define TORCH_DCHECK_NE(val1, val2) \
  672. while (false) \
  673. TORCH_DCHECK_OP(val1, val2, !=)
  674. #define TORCH_DCHECK_LE(val1, val2) \
  675. while (false) \
  676. TORCH_DCHECK_OP(val1, val2, <=)
  677. #define TORCH_DCHECK_LT(val1, val2) \
  678. while (false) \
  679. TORCH_DCHECK_OP(val1, val2, <)
  680. #define TORCH_DCHECK_GE(val1, val2) \
  681. while (false) \
  682. TORCH_DCHECK_OP(val1, val2, >=)
  683. #define TORCH_DCHECK_GT(val1, val2) \
  684. while (false) \
  685. TORCH_DCHECK_OP(val1, val2, >)
  686. #endif // NDEBUG
  687. // Null pointer check macro
  688. #define TORCH_CHECK_NOTNULL(val) \
  689. ::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), false)
  690. #ifndef NDEBUG
  691. #define TORCH_DCHECK_NOTNULL(val) \
  692. ::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), true)
  693. #else // !NDEBUG
  694. #define TORCH_DCHECK_NOTNULL(val) \
  695. while (false) \
  696. TORCH_CHECK_NOTNULL(val)
  697. #endif // NDEBUG
  698. // ----------------------------------------------------------------------------
  699. // Deprecated macros
  700. // ----------------------------------------------------------------------------
  701. namespace c10::detail {
  702. /*
  703. // Deprecation disabled until we fix sites in our codebase
  704. [[deprecated("AT_ERROR(msg) is deprecated, use TORCH_CHECK(false, msg)
  705. instead.")]]
  706. */
  707. inline void deprecated_AT_ERROR() {}
  708. /*
  709. // Deprecation disabled until we fix sites in our codebase
  710. [[deprecated("AT_ASSERT is deprecated, if you mean to indicate an
  711. internal invariant failure, use " \
  712. "TORCH_INTERNAL_ASSERT instead; if you mean to do user
  713. error checking, use " \ "TORCH_CHECK. See
  714. https://github.com/pytorch/pytorch/issues/20287 for more details.")]]
  715. */
  716. inline void deprecated_AT_ASSERT() {}
  717. /*
  718. // Deprecation disabled until we fix sites in our codebase
  719. [[deprecated("AT_ASSERTM is deprecated, if you mean to indicate an
  720. internal invariant failure, use " \
  721. "TORCH_INTERNAL_ASSERT instead; if you mean to do user
  722. error checking, use " \ "TORCH_CHECK. See
  723. https://github.com/pytorch/pytorch/issues/20287 for more details.")]]
  724. */
  725. inline void deprecated_AT_ASSERTM() {}
  726. } // namespace c10::detail
  727. // Deprecated alias; this alias was deprecated because people kept mistakenly
  728. // using it for user error checking. Use TORCH_INTERNAL_ASSERT or TORCH_CHECK
  729. // instead. See https://github.com/pytorch/pytorch/issues/20287 for more
  730. // details.
  731. #define AT_ASSERT(...) \
  732. do { \
  733. ::c10::detail::deprecated_AT_ASSERT(); \
  734. C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)); \
  735. } while (false)
  736. // Deprecated alias, like AT_ASSERT. The new TORCH_INTERNAL_ASSERT macro
  737. // supports both 0-ary and variadic calls, so having a separate
  738. // message-accepting macro is not necessary.
  739. //
  740. // NB: we MUST include cond explicitly here, as MSVC will miscompile the macro
  741. // expansion, shunting all of __VA_ARGS__ to cond. An alternate workaround
  742. // can be seen at
  743. // https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly
  744. #define AT_ASSERTM(cond, ...) \
  745. do { \
  746. ::c10::detail::deprecated_AT_ASSERTM(); \
  747. C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__)); \
  748. } while (false)
  749. // Deprecated alias; this alias was deprecated because it represents extra API
  750. // surface that makes it hard for people to understand what macro to use.
  751. // Use TORCH_CHECK(false, ...) or TORCH_INTERNAL_ASSERT(false, ...) to
  752. // unconditionally fail at a line of code.
  753. #define AT_ERROR(...) \
  754. do { \
  755. ::c10::detail::deprecated_AT_ERROR(); \
  756. C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
  757. } while (false)
  758. #endif // C10_UTIL_EXCEPTION_H_
  759. #else
  760. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  761. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)