Source code for aiomisc.thread_pool

import asyncio
import contextvars
import inspect
import logging
import threading
import time
import warnings
from concurrent.futures import ThreadPoolExecutor as ThreadPoolExecutorBase
from functools import partial, wraps
from multiprocessing import cpu_count
from queue import SimpleQueue
from types import MappingProxyType
from typing import (
    Any, Awaitable, Callable, Coroutine, Dict, FrozenSet, NamedTuple, Optional,
    Set, Tuple, TypeVar,
)

from ._context_vars import EVENT_LOOP
from .compat import ParamSpec
from .counters import Statistic
from .iterator_wrapper import IteratorWrapper


P = ParamSpec("P")
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])
log = logging.getLogger(__name__)


[docs] def context_partial( func: F, *args: Any, **kwargs: Any, ) -> Any: context = contextvars.copy_context() return partial(context.run, func, *args, **kwargs)
[docs] class ThreadPoolException(RuntimeError): pass
[docs] class WorkItemBase(NamedTuple): func: Callable[..., Any] args: Tuple[Any, ...] kwargs: Dict[str, Any] future: asyncio.Future loop: asyncio.AbstractEventLoop
[docs] class ThreadPoolStatistic(Statistic): threads: int done: int error: int success: int submitted: int sum_time: float
[docs] class WorkItem(WorkItemBase):
[docs] @staticmethod def set_result( future: asyncio.Future, result: Any, exception: Exception, ) -> None: if future.done(): return if exception: future.set_exception(exception) else: future.set_result(result)
def __call__(self, statistic: ThreadPoolStatistic) -> None: if self.future.done(): return result, exception = None, None delta = -time.monotonic() try: result = self.func(*self.args, **self.kwargs) statistic.success += 1 except BaseException as e: statistic.error += 1 exception = e finally: delta += time.monotonic() statistic.sum_time += delta statistic.done += 1 if self.loop.is_closed(): raise asyncio.CancelledError self.loop.call_soon_threadsafe( self.__class__.set_result, self.future, result, exception, )
[docs] class ThreadPoolExecutor(ThreadPoolExecutorBase): __slots__ = ( "__futures", "__pool", "__tasks", "__write_lock", "__thread_events", ) DEFAULT_POOL_SIZE = min((max((cpu_count() or 1, 4)), 32)) def __init__( self, max_workers: int = DEFAULT_POOL_SIZE, loop: Optional[asyncio.AbstractEventLoop] = None, statistic_name: Optional[str] = None, ) -> None: """""" if loop: warnings.warn(DeprecationWarning("loop argument is obsolete")) self.__futures: Set[asyncio.Future[Any]] = set() self.__thread_events: Set[threading.Event] = set() self.__tasks: SimpleQueue[Optional[WorkItem]] = SimpleQueue() self.__write_lock = threading.RLock() self._statistic = ThreadPoolStatistic(statistic_name) pools = set() for idx in range(max_workers): pools.add(self._start_thread(idx)) self.__pool: FrozenSet[threading.Thread] = frozenset(pools) def _start_thread(self, idx: int) -> threading.Thread: event = threading.Event() self.__thread_events.add(event) thread_name = f"Thread {idx}" if self._statistic.name: thread_name += f" from pool {self._statistic.name}" thread = threading.Thread( target=self._in_thread, name=thread_name.strip(), args=(event,), ) thread.daemon = True thread.start() return thread def _in_thread(self, event: threading.Event) -> None: self._statistic.threads += 1 try: while True: work_item = self.__tasks.get() if work_item is None: break try: if work_item.loop.is_closed(): log.warning( "Event loop is closed. Call %r skipped", work_item.func, ) continue work_item(self._statistic) finally: del work_item finally: self._statistic.threads -= 1 event.set()
[docs] def submit( # type: ignore self, fn: F, *args: Any, **kwargs: Any, ) -> asyncio.Future: """ Submit blocking function to the pool """ if fn is None or not callable(fn): raise ValueError("First argument must be callable") loop = asyncio.get_event_loop() future: asyncio.Future = loop.create_future() self.__futures.add(future) future.add_done_callback(self.__futures.remove) with self.__write_lock: self.__tasks.put_nowait( WorkItem( func=fn, args=args, kwargs=kwargs, future=future, loop=loop, ), ) self._statistic.submitted += 1 return future
# noinspection PyMethodOverriding
[docs] def shutdown(self, wait: bool = True) -> None: # type: ignore for _ in self.__pool: self.__tasks.put_nowait(None) for f in filter(lambda x: not x.done(), self.__futures): f.set_exception(ThreadPoolException("Pool closed")) if not wait: return while not all(e.is_set() for e in self.__thread_events): time.sleep(0)
def _adjust_thread_count(self) -> None: raise NotImplementedError def __del__(self) -> None: self.shutdown()
[docs] def run_in_executor( func: Callable[..., T], executor: Optional[ThreadPoolExecutorBase] = None, args: Any = (), kwargs: Any = MappingProxyType({}), ) -> Awaitable[T]: try: loop = asyncio.get_running_loop() return loop.run_in_executor( executor, context_partial(func, *args, **kwargs), ) except RuntimeError: # In case the event loop is not running right now is # returning coroutine to avoid DeprecationWarning in Python 3.10 async def lazy_wrapper() -> T: loop = asyncio.get_running_loop() return await loop.run_in_executor( executor, context_partial(func, *args, **kwargs), ) return lazy_wrapper()
async def _awaiter(future: asyncio.Future) -> T: try: result = await future return result except asyncio.CancelledError as e: if not future.done(): future.set_exception(e) raise
[docs] def threaded( func: Callable[P, T], ) -> Callable[P, Awaitable[T]]: if asyncio.iscoroutinefunction(func): raise TypeError("Can not wrap coroutine") if inspect.isgeneratorfunction(func): return threaded_iterable(func) @wraps(func) def wrap( *args: P.args, **kwargs: P.kwargs, ) -> Awaitable[T]: return run_in_executor(func=func, args=args, kwargs=kwargs) return wrap
[docs] def run_in_new_thread( func: F, args: Any = (), kwargs: Any = MappingProxyType({}), detach: bool = True, no_return: bool = False, statistic_name: Optional[str] = None, ) -> asyncio.Future: loop = asyncio.get_event_loop() future = loop.create_future() statistic = ThreadPoolStatistic(statistic_name) def set_result(result: Any) -> None: if future.done() or loop.is_closed(): return future.set_result(result) def set_exception(exc: Exception) -> None: if future.done() or loop.is_closed(): return future.set_exception(exc) @wraps(func) def in_thread(target: F) -> None: statistic.threads += 1 statistic.submitted += 1 try: loop.call_soon_threadsafe( set_result, target(), ) statistic.success += 1 except Exception as exc: statistic.error += 1 if loop.is_closed() and no_return: return elif loop.is_closed(): log.exception("Uncaught exception from separate thread") return loop.call_soon_threadsafe(set_exception, exc) finally: statistic.done += 1 statistic.threads -= 1 thread = threading.Thread( target=in_thread, name=func.__name__, args=( context_partial(func, *args, **kwargs), ), daemon=detach, ) thread.start() return future
[docs] def threaded_separate( func: F, detach: bool = True, ) -> Callable[..., Awaitable[Any]]: if isinstance(func, bool): return partial(threaded_separate, detach=detach) if asyncio.iscoroutinefunction(func): raise TypeError("Can not wrap coroutine") @wraps(func) def wrap(*args: Any, **kwargs: Any) -> Any: future = run_in_new_thread( func, args=args, kwargs=kwargs, detach=detach, ) return future return wrap
[docs] def threaded_iterable( func: Optional[F] = None, max_size: int = 0, ) -> Any: if isinstance(func, int): return partial(threaded_iterable, max_size=func) if func is None: return partial(threaded_iterable, max_size=max_size) @wraps(func) def wrap(*args: Any, **kwargs: Any) -> Any: return IteratorWrapper( context_partial(func, *args, **kwargs), # type: ignore max_size=max_size, ) return wrap
[docs] class IteratorWrapperSeparate(IteratorWrapper): def _run(self) -> Any: return run_in_new_thread(self._in_thread)
[docs] def threaded_iterable_separate( func: Optional[F] = None, max_size: int = 0, ) -> Any: if isinstance(func, int): return partial(threaded_iterable_separate, max_size=func) if func is None: return partial(threaded_iterable_separate, max_size=max_size) @wraps(func) def wrap(*args: Any, **kwargs: Any) -> Any: return IteratorWrapperSeparate( context_partial(func, *args, **kwargs), # type: ignore max_size=max_size, ) return wrap
[docs] class CoroutineWaiter: def __init__( self, coroutine: Coroutine[Any, Any, T], loop: Optional[asyncio.AbstractEventLoop] = None, ): self.__coro: Coroutine[Any, Any, T] = coroutine self.__loop = loop or EVENT_LOOP.get() self.__event = threading.Event() self.__result: Optional[T] = None self.__exception: Optional[BaseException] = None def _on_result(self, task: asyncio.Future) -> None: self.__exception = task.exception() if self.__exception is None: self.__result = task.result() self.__event.set() def _awaiter(self) -> None: task: asyncio.Future = self.__loop.create_task(self.__coro) task.add_done_callback(self._on_result)
[docs] def start(self) -> None: self.__loop.call_soon_threadsafe(self._awaiter)
[docs] def wait(self) -> Any: self.__event.wait() if self.__exception is not None: raise self.__exception return self.__result
[docs] def wait_coroutine( coro: Coroutine[Any, Any, T], loop: Optional[asyncio.AbstractEventLoop] = None, ) -> T: waiter = CoroutineWaiter(coro, loop) waiter.start() return waiter.wait()
[docs] def sync_wait_coroutine( loop: Optional[asyncio.AbstractEventLoop], coro_func: Callable[..., Coroutine[Any, Any, T]], *args: Any, **kwargs: Any, ) -> T: return wait_coroutine(coro_func(*args, **kwargs), loop=loop)
[docs] def sync_await( func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any, ) -> T: async def awaiter() -> T: return await func(*args, **kwargs) return wait_coroutine(awaiter())