_async_compile.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from __future__ import annotations
  2. from typing import Callable, Optional
  3. from concurrent.futures import Executor, as_completed, Future
  4. from contextvars import ContextVar
  5. active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None)
  6. class FutureKernel:
  7. def __init__(self, finalize_compile: Callable, future: Future):
  8. self.finalize_compile = finalize_compile
  9. self.kernel = None
  10. self.future = future
  11. def result(self, ignore_errors: bool = False):
  12. if self.kernel is not None:
  13. return self.kernel
  14. try:
  15. kernel = self.future.result()
  16. except Exception:
  17. if ignore_errors:
  18. return
  19. else:
  20. raise
  21. self.finalize_compile(kernel)
  22. self.kernel = kernel
  23. return kernel
  24. class AsyncCompileMode:
  25. def __init__(self, executor: Executor, *, ignore_errors=False):
  26. self.executor = executor
  27. self.ignore_errors = ignore_errors
  28. self.raw_futures = []
  29. self.future_kernels = {}
  30. def submit(self, key, compile_fn, finalize_fn):
  31. future = self.future_kernels.get(key)
  32. if future is not None:
  33. return future
  34. future = self.executor.submit(compile_fn)
  35. future._key = key
  36. self.raw_futures.append(future)
  37. future_kernel = FutureKernel(finalize_fn, future)
  38. self.future_kernels[key] = future_kernel
  39. return future_kernel
  40. def __enter__(self):
  41. if active_mode.get() is not None:
  42. raise RuntimeError("Another AsyncCompileMode is already active")
  43. active_mode.set(self)
  44. return self
  45. def __exit__(self, exc_type, exc_value, traceback):
  46. # Finalize any outstanding compiles
  47. for future in as_completed(self.raw_futures):
  48. self.future_kernels[future._key].result(self.ignore_errors)
  49. active_mode.set(None)