Parallel-inl.h 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/util/Exception.h>
  4. #include <c10/util/ParallelGuard.h>
  5. #include <c10/util/SmallVector.h>
  6. namespace at {
  7. template <class F>
  8. inline void parallel_for(
  9. const int64_t begin,
  10. const int64_t end,
  11. const int64_t grain_size,
  12. const F& f) {
  13. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
  14. if (begin >= end) {
  15. return;
  16. }
  17. #ifdef INTRA_OP_PARALLEL
  18. at::internal::lazy_init_num_threads();
  19. const auto numiter = end - begin;
  20. const bool use_parallel =
  21. (numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
  22. at::get_num_threads() > 1);
  23. if (!use_parallel) {
  24. internal::ThreadIdGuard tid_guard(0);
  25. c10::ParallelGuard guard(true);
  26. f(begin, end);
  27. return;
  28. }
  29. internal::invoke_parallel(
  30. begin, end, grain_size, [&](int64_t begin, int64_t end) {
  31. c10::ParallelGuard guard(true);
  32. f(begin, end);
  33. });
  34. #else
  35. internal::ThreadIdGuard tid_guard(0);
  36. c10::ParallelGuard guard(true);
  37. f(begin, end);
  38. #endif
  39. }
  40. template <class scalar_t, class F, class SF>
  41. inline scalar_t parallel_reduce(
  42. const int64_t begin,
  43. const int64_t end,
  44. const int64_t grain_size,
  45. const scalar_t ident,
  46. const F& f,
  47. const SF& sf) {
  48. TORCH_CHECK(grain_size >= 0);
  49. if (begin >= end) {
  50. return ident;
  51. }
  52. #ifdef INTRA_OP_PARALLEL
  53. at::internal::lazy_init_num_threads();
  54. const auto max_threads = at::get_num_threads();
  55. const bool use_parallel =
  56. ((end - begin) > grain_size && !at::in_parallel_region() &&
  57. max_threads > 1);
  58. if (!use_parallel) {
  59. internal::ThreadIdGuard tid_guard(0);
  60. c10::ParallelGuard guard(true);
  61. return f(begin, end, ident);
  62. }
  63. c10::SmallVector<scalar_t, 64> results(max_threads, ident);
  64. internal::invoke_parallel(
  65. begin,
  66. end,
  67. grain_size,
  68. [&](const int64_t my_begin, const int64_t my_end) {
  69. const auto tid = at::get_thread_num();
  70. c10::ParallelGuard guard(true);
  71. results[tid] = f(my_begin, my_end, ident);
  72. });
  73. scalar_t result = ident;
  74. for (auto partial_result : results) {
  75. result = sf(result, partial_result);
  76. }
  77. return result;
  78. #else
  79. internal::ThreadIdGuard tid_guard(0);
  80. c10::ParallelGuard guard(true);
  81. return f(begin, end, ident);
  82. #endif
  83. }
  84. } // namespace at
  85. #else
  86. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  87. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)