| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- from __future__ import annotations
- from typing import Callable, Optional
- from concurrent.futures import Executor, as_completed, Future
- from contextvars import ContextVar
- active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None)
- class FutureKernel:
- def __init__(self, finalize_compile: Callable, future: Future):
- self.finalize_compile = finalize_compile
- self.kernel = None
- self.future = future
- def result(self, ignore_errors: bool = False):
- if self.kernel is not None:
- return self.kernel
- try:
- kernel = self.future.result()
- except Exception:
- if ignore_errors:
- return
- else:
- raise
- self.finalize_compile(kernel)
- self.kernel = kernel
- return kernel
- class AsyncCompileMode:
- def __init__(self, executor: Executor, *, ignore_errors=False):
- self.executor = executor
- self.ignore_errors = ignore_errors
- self.raw_futures = []
- self.future_kernels = {}
- def submit(self, key, compile_fn, finalize_fn):
- future = self.future_kernels.get(key)
- if future is not None:
- return future
- future = self.executor.submit(compile_fn)
- future._key = key
- self.raw_futures.append(future)
- future_kernel = FutureKernel(finalize_fn, future)
- self.future_kernels[key] = future_kernel
- return future_kernel
- def __enter__(self):
- if active_mode.get() is not None:
- raise RuntimeError("Another AsyncCompileMode is already active")
- active_mode.set(self)
- return self
- def __exit__(self, exc_type, exc_value, traceback):
- # Finalize any outstanding compiles
- for future in as_completed(self.raw_futures):
- self.future_kernels[future._key].result(self.ignore_errors)
- active_mode.set(None)
|