CUDADeviceAssertion.h 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/cuda/CUDAException.h>
  4. #include <c10/macros/Macros.h>
  5. namespace c10::cuda {
  6. #ifdef TORCH_USE_CUDA_DSA
  7. C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")
  8. // Copy string from `src` to `dst`
  9. static __device__ void dstrcpy(char* dst, const char* src) {
  10. int i = 0;
  11. // Copy string from source to destination, ensuring that it
  12. // isn't longer than `C10_CUDA_DSA_MAX_STR_LEN-1`
  13. while (*src != '\0' && i++ < C10_CUDA_DSA_MAX_STR_LEN - 1) {
  14. *dst++ = *src++;
  15. }
  16. *dst = '\0';
  17. }
  18. static __device__ void dsa_add_new_assertion_failure(
  19. DeviceAssertionsData* assertions_data,
  20. const char* assertion_msg,
  21. const char* filename,
  22. const char* function_name,
  23. const int line_number,
  24. const uint32_t caller,
  25. const dim3 block_id,
  26. const dim3 thread_id) {
  27. // `assertions_data` may be nullptr if device-side assertion checking
  28. // is disabled at run-time. If it is disabled at compile time this
  29. // function will never be called
  30. if (!assertions_data) {
  31. return;
  32. }
  33. // Atomically increment so other threads can fail at the same time
  34. // Note that incrementing this means that the CPU can observe that
  35. // a failure has happened and can begin to respond before we've
  36. // written information about that failure out to the buffer.
  37. const auto nid = atomicAdd(&(assertions_data->assertion_count), 1);
  38. if (nid >= C10_CUDA_DSA_ASSERTION_COUNT) {
  39. // At this point we're ran out of assertion buffer space.
  40. // We could print a message about this, but that'd get
  41. // spammy if a lot of threads did it, so we just silently
  42. // ignore any other assertion failures. In most cases the
  43. // failures will all probably be analogous anyway.
  44. return;
  45. }
  46. // Write information about the assertion failure to memory.
  47. // Note that this occurs only after the `assertion_count`
  48. // increment broadcasts that there's been a problem.
  49. auto& self = assertions_data->assertions[nid];
  50. dstrcpy(self.assertion_msg, assertion_msg);
  51. dstrcpy(self.filename, filename);
  52. dstrcpy(self.function_name, function_name);
  53. self.line_number = line_number;
  54. self.caller = caller;
  55. self.block_id[0] = block_id.x;
  56. self.block_id[1] = block_id.y;
  57. self.block_id[2] = block_id.z;
  58. self.thread_id[0] = thread_id.x;
  59. self.thread_id[1] = thread_id.y;
  60. self.thread_id[2] = thread_id.z;
  61. }
  62. C10_CLANG_DIAGNOSTIC_POP()
  63. // Emulates a kernel assertion. The assertion won't stop the kernel's progress,
  64. // so you should assume everything the kernel produces is garbage if there's an
  65. // assertion failure.
  66. // NOTE: This assumes that `assertions_data` and `assertion_caller_id` are
  67. // arguments of the kernel and therefore accessible.
  68. #define CUDA_KERNEL_ASSERT2(condition) \
  69. do { \
  70. if (C10_UNLIKELY(!(condition))) { \
  71. /* Has an atomic element so threads can fail at the same time */ \
  72. c10::cuda::dsa_add_new_assertion_failure( \
  73. assertions_data, \
  74. C10_STRINGIZE(condition), \
  75. __FILE__, \
  76. __FUNCTION__, \
  77. __LINE__, \
  78. assertion_caller_id, \
  79. blockIdx, \
  80. threadIdx); \
  81. /* Now that the kernel has failed we early exit the kernel, but */ \
  82. /* otherwise keep going and rely on the host to check UVM and */ \
  83. /* determine we've had a problem */ \
  84. return; \
  85. } \
  86. } while (false)
  87. #else
  88. #define CUDA_KERNEL_ASSERT2(condition) assert(condition)
  89. #endif
  90. } // namespace c10::cuda
  91. #else
  92. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  93. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)