ParallelOpenMP.h 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <algorithm>
  4. #include <atomic>
  5. #include <cstddef>
  6. #include <exception>
  7. #ifdef _OPENMP
  8. #define INTRA_OP_PARALLEL
  9. #include <omp.h>
  10. #endif
  11. #ifdef _OPENMP
  12. namespace at::internal {
  13. template <typename F>
  14. inline void invoke_parallel(
  15. int64_t begin,
  16. int64_t end,
  17. int64_t grain_size,
  18. const F& f) {
  19. std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
  20. std::exception_ptr eptr;
  21. #pragma omp parallel
  22. {
  23. // choose number of tasks based on grain size and number of threads
  24. // can't use num_threads clause due to bugs in GOMP's thread pool (See
  25. // #32008)
  26. int64_t num_threads = omp_get_num_threads();
  27. if (grain_size > 0) {
  28. num_threads = std::min(num_threads, divup((end - begin), grain_size));
  29. }
  30. int64_t tid = omp_get_thread_num();
  31. int64_t chunk_size = divup((end - begin), num_threads);
  32. int64_t begin_tid = begin + tid * chunk_size;
  33. if (begin_tid < end) {
  34. try {
  35. internal::ThreadIdGuard tid_guard(tid);
  36. f(begin_tid, std::min(end, chunk_size + begin_tid));
  37. } catch (...) {
  38. if (!err_flag.test_and_set()) {
  39. eptr = std::current_exception();
  40. }
  41. }
  42. }
  43. }
  44. if (eptr) {
  45. std::rethrow_exception(eptr);
  46. }
  47. }
  48. } // namespace at::internal
  49. #endif // _OPENMP
  50. #else
  51. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  52. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)