| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <algorithm>
- #include <atomic>
- #include <cstddef>
- #include <exception>
- #ifdef _OPENMP
- #define INTRA_OP_PARALLEL
- #include <omp.h>
- #endif
- #ifdef _OPENMP
- namespace at::internal {
- template <typename F>
- inline void invoke_parallel(
- int64_t begin,
- int64_t end,
- int64_t grain_size,
- const F& f) {
- std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
- std::exception_ptr eptr;
- #pragma omp parallel
- {
- // choose number of tasks based on grain size and number of threads
- // can't use num_threads clause due to bugs in GOMP's thread pool (See
- // #32008)
- int64_t num_threads = omp_get_num_threads();
- if (grain_size > 0) {
- num_threads = std::min(num_threads, divup((end - begin), grain_size));
- }
- int64_t tid = omp_get_thread_num();
- int64_t chunk_size = divup((end - begin), num_threads);
- int64_t begin_tid = begin + tid * chunk_size;
- if (begin_tid < end) {
- try {
- internal::ThreadIdGuard tid_guard(tid);
- f(begin_tid, std::min(end, chunk_size + begin_tid));
- } catch (...) {
- if (!err_flag.test_and_set()) {
- eptr = std::current_exception();
- }
- }
- }
- }
- if (eptr) {
- std::rethrow_exception(eptr);
- }
- }
- } // namespace at::internal
- #endif // _OPENMP
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|