critical_section.h 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // Copyright (c) 2016-2025 The Pybind Development Team.
  3. // All rights reserved. Use of this source code is governed by a
  4. // BSD-style license that can be found in the LICENSE file.
  5. #pragma once
  6. #include "pytypes.h"
  7. PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
  8. /// This does not do anything if there's a GIL. On free-threaded Python,
  9. /// it locks an object. This uses the CPython API, which has limits
  10. class scoped_critical_section {
  11. public:
  12. #ifdef Py_GIL_DISABLED
  13. explicit scoped_critical_section(handle obj1, handle obj2 = handle{}) {
  14. if (obj1) {
  15. if (obj2) {
  16. PyCriticalSection2_Begin(&section2, obj1.ptr(), obj2.ptr());
  17. rank = 2;
  18. } else {
  19. PyCriticalSection_Begin(&section, obj1.ptr());
  20. rank = 1;
  21. }
  22. } else if (obj2) {
  23. PyCriticalSection_Begin(&section, obj2.ptr());
  24. rank = 1;
  25. }
  26. }
  27. ~scoped_critical_section() {
  28. if (rank == 1) {
  29. PyCriticalSection_End(&section);
  30. } else if (rank == 2) {
  31. PyCriticalSection2_End(&section2);
  32. }
  33. }
  34. #else
  35. explicit scoped_critical_section(handle, handle = handle{}) {};
  36. ~scoped_critical_section() = default;
  37. #endif
  38. scoped_critical_section(const scoped_critical_section &) = delete;
  39. scoped_critical_section &operator=(const scoped_critical_section &) = delete;
  40. private:
  41. #ifdef Py_GIL_DISABLED
  42. int rank{0};
  43. union {
  44. PyCriticalSection section;
  45. PyCriticalSection2 section2;
  46. };
  47. #endif
  48. };
  49. PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
  50. #else
  51. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  52. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)