PTThreadPool.h 645 B

12345678910111213141516171819202122
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/Parallel.h>
  4. #include <c10/core/thread_pool.h>
  5. namespace at {
  6. class TORCH_API PTThreadPool : public c10::ThreadPool {
  7. public:
  8. explicit PTThreadPool(int pool_size, int numa_node_id = -1)
  9. : c10::ThreadPool(pool_size, numa_node_id, []() {
  10. c10::setThreadName("PTThreadPool");
  11. at::init_num_threads();
  12. }) {}
  13. };
  14. } // namespace at
  15. #else
  16. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  17. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)