diff --git a/src/spdl/pipeline/_builder.py b/src/spdl/pipeline/_builder.py index 4d1ba792..a0a1c609 100644 --- a/src/spdl/pipeline/_builder.py +++ b/src/spdl/pipeline/_builder.py @@ -19,7 +19,7 @@ Iterable, Sequence, ) -from concurrent.futures import Executor +from concurrent.futures import Executor, ThreadPoolExecutor from contextlib import asynccontextmanager, contextmanager from functools import partial from typing import TypeVar @@ -893,4 +893,8 @@ def build(self, *, num_threads: int | None = None) -> Pipeline: ] num_threads = max(concurrencies) if concurrencies else 4 assert num_threads is not None - return Pipeline(coro, queues, num_threads, desc=self._get_desc()) + executor = ThreadPoolExecutor( + max_workers=num_threads, + thread_name_prefix="spdl_", + ) + return Pipeline(coro, queues, executor, desc=self._get_desc()) diff --git a/src/spdl/pipeline/_pipeline.py b/src/spdl/pipeline/_pipeline.py index 290d76f5..52cf2af7 100644 --- a/src/spdl/pipeline/_pipeline.py +++ b/src/spdl/pipeline/_pipeline.py @@ -13,6 +13,7 @@ import warnings from asyncio import AbstractEventLoop, Queue as AsyncQueue from collections.abc import Coroutine, Iterator +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from enum import IntEnum from threading import Event as SyncEvent, Thread @@ -36,9 +37,13 @@ # This class has a bit exessive debug logs, because it is tricky to debug # it from the outside. class _EventLoop: - def __init__(self, coro: Coroutine[None, None, None], num_threads: int) -> None: + def __init__( + self, + coro: Coroutine[None, None, None], + executor: ThreadPoolExecutor, + ) -> None: self._coro = coro - self._num_threads = num_threads + self._executor = executor self._loop: AbstractEventLoop | None = None @@ -62,14 +67,8 @@ def __str__(self) -> str: async def _execute_task(self) -> None: _LG.debug("The event loop thread coroutine is started.") - _LG.debug("Initializing the thread pool of size=%d.", self._num_threads) self._loop = asyncio.get_running_loop() - self._loop.set_default_executor( - concurrent.futures.ThreadPoolExecutor( - max_workers=self._num_threads, - thread_name_prefix="spdl_", - ) - ) + self._loop.set_default_executor(self._executor) _LG.debug("Starting the task.") @@ -263,7 +262,7 @@ def __init__( self, coro: Coroutine[None, None, None], queues: list[AsyncQueue], - num_threads: int, + executor: ThreadPoolExecutor, *, desc: list[str], ) -> None: @@ -272,7 +271,7 @@ def __init__( self._str: str = "\n".join([repr(self), *desc]) self._output_queue: AsyncQueue = queues[-1] - self._event_loop = _EventLoop(coro, num_threads) + self._event_loop = _EventLoop(coro, executor) self._event_loop_state: _EventLoopState = _EventLoopState.NOT_STARTED try: