contexts.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import contextlib
  2. import io
  3. import os
  4. import shutil
  5. import site
  6. import sys
  7. import tempfile
  8. from filelock import FileLock
  9. @contextlib.contextmanager
  10. def tempdir(cd=lambda dir: None, **kwargs):
  11. temp_dir = tempfile.mkdtemp(**kwargs)
  12. orig_dir = os.getcwd()
  13. try:
  14. cd(temp_dir)
  15. yield temp_dir
  16. finally:
  17. cd(orig_dir)
  18. shutil.rmtree(temp_dir)
  19. @contextlib.contextmanager
  20. def environment(**replacements):
  21. """
  22. In a context, patch the environment with replacements. Pass None values
  23. to clear the values.
  24. """
  25. saved = dict((key, os.environ[key]) for key in replacements if key in os.environ)
  26. # remove values that are null
  27. remove = (key for (key, value) in replacements.items() if value is None)
  28. for key in list(remove):
  29. os.environ.pop(key, None)
  30. replacements.pop(key)
  31. os.environ.update(replacements)
  32. try:
  33. yield saved
  34. finally:
  35. for key in replacements:
  36. os.environ.pop(key, None)
  37. os.environ.update(saved)
  38. @contextlib.contextmanager
  39. def quiet():
  40. """
  41. Redirect stdout/stderr to StringIO objects to prevent console output from
  42. distutils commands.
  43. """
  44. old_stdout = sys.stdout
  45. old_stderr = sys.stderr
  46. new_stdout = sys.stdout = io.StringIO()
  47. new_stderr = sys.stderr = io.StringIO()
  48. try:
  49. yield new_stdout, new_stderr
  50. finally:
  51. new_stdout.seek(0)
  52. new_stderr.seek(0)
  53. sys.stdout = old_stdout
  54. sys.stderr = old_stderr
  55. @contextlib.contextmanager
  56. def save_user_site_setting():
  57. saved = site.ENABLE_USER_SITE
  58. try:
  59. yield saved
  60. finally:
  61. site.ENABLE_USER_SITE = saved
  62. @contextlib.contextmanager
  63. def suppress_exceptions(*excs):
  64. try:
  65. yield
  66. except excs:
  67. pass
  68. def multiproc(request):
  69. """
  70. Return True if running under xdist and multiple
  71. workers are used.
  72. """
  73. try:
  74. worker_id = request.getfixturevalue('worker_id')
  75. except Exception:
  76. return False
  77. return worker_id != 'master'
  78. @contextlib.contextmanager
  79. def session_locked_tmp_dir(request, tmp_path_factory, name):
  80. """Uses a file lock to guarantee only one worker can access a temp dir"""
  81. # get the temp directory shared by all workers
  82. base = tmp_path_factory.getbasetemp()
  83. shared_dir = base.parent if multiproc(request) else base
  84. locked_dir = shared_dir / name
  85. with FileLock(locked_dir.with_suffix(".lock")):
  86. # ^-- prevent multiple workers to access the directory at once
  87. locked_dir.mkdir(exist_ok=True, parents=True)
  88. yield locked_dir
  89. @contextlib.contextmanager
  90. def save_paths():
  91. """Make sure ``sys.path``, ``sys.meta_path`` and ``sys.path_hooks`` are preserved"""
  92. prev = sys.path[:], sys.meta_path[:], sys.path_hooks[:]
  93. try:
  94. yield
  95. finally:
  96. sys.path, sys.meta_path, sys.path_hooks = prev
  97. @contextlib.contextmanager
  98. def save_sys_modules():
  99. """Make sure initial ``sys.modules`` is preserved"""
  100. prev_modules = sys.modules
  101. try:
  102. sys.modules = sys.modules.copy()
  103. yield
  104. finally:
  105. sys.modules = prev_modules