diff --git a/more_executors/_impl/throttle.py b/more_executors/_impl/throttle.py index 3f4c9888..3cc8f6b9 100644 --- a/more_executors/_impl/throttle.py +++ b/more_executors/_impl/throttle.py @@ -59,13 +59,14 @@ class ThrottleExecutor(CanCustomizeBind, Executor): - Where `count` is used to initialize this executor, if there are already `count` futures submitted to the delegate executor and not yet :meth:`~concurrent.futures.Future.done`, additional callables will - be queued and only submitted to the delegate executor once there are - less than `count` futures in progress. + either be queued or will block on submit, and will only be submitted + to the delegate executor once there are less than `count` futures in + progress. .. versionadded:: 1.9.0 """ - def __init__(self, delegate, count, logger=None, name="default"): + def __init__(self, delegate, count, logger=None, name="default", block=False): """ Parameters: delegate (~concurrent.futures.Executor): @@ -84,6 +85,13 @@ def __init__(self, delegate, count, logger=None, name="default"): .. versionadded:: 2.5.0 + block (bool) + If ``True``, calls to ``submit()`` on this executor may block if + there are already ``count`` futures in progress. + + Otherwise, calls to ``submit()`` will always return immediately + and callables will be queued internally. + logger (~logging.Logger): a logger used for messages from this executor @@ -92,10 +100,14 @@ def __init__(self, delegate, count, logger=None, name="default"): .. versionchanged:: 2.7.0 Introduced ``name``. + + .. versionchanged:: 2.11.0 + Introduced ``block``. """ self._log = LogWrapper( logger if logger else logging.getLogger("ThrottleExecutor") ) + self._block = block self._name = name self._delegate = delegate self._to_submit = deque() @@ -120,6 +132,8 @@ def __init__(self, delegate, count, logger=None, name="default"): def submit(self, fn, *args, **kwargs): # pylint: disable=arguments-differ with self._shutdown.ensure_alive(): + self._block_until_ready(self._eval_throttle()) + out = ThrottleFuture(self) track_future(out, type="throttle", executor=self._name) @@ -140,6 +154,13 @@ def shutdown(self, wait=True, **_kwargs): if wait: self._thread.join(MAX_TIMEOUT) + def _block_until_ready(self, throttle_val): + while self._block and not self._shutdown.is_shutdown: + if len(self._to_submit) < throttle_val: + return + self._log.debug("%s: throttling on submit", self._name) + self._event.wait(30.0) + def _eval_throttle(self): try: self._last_throttle = self._throttle() diff --git a/more_executors/_impl/throttle.pyi b/more_executors/_impl/throttle.pyi index d3e8357a..1412257a 100644 --- a/more_executors/_impl/throttle.pyi +++ b/more_executors/_impl/throttle.pyi @@ -15,6 +15,7 @@ class ThrottleExecutor(ExecutorProtocol): count: int | Callable[[], int], logger: logging.Logger | None = ..., name: str = ..., + block: bool = ..., ): ... def __enter__(self) -> ThrottleExecutor: ... @@ -25,5 +26,6 @@ class TypedThrottleExecutor(TypedExecutorProtocol[A, B]): count: int | Callable[[], int], logger: logging.Logger | None = ..., name: str = ..., + block: bool = ..., ): ... def __enter__(self) -> TypedThrottleExecutor[A, B]: ... diff --git a/setup.py b/setup.py index 5370d99e..8e97b450 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ def get_install_requires(): setup( name="more-executors", - version="2.10.1", + version="2.11.0", author="Rohan McGovern", author_email="rohan@mcgovern.id.au", packages=find_packages(exclude=["tests", "tests.*"]), diff --git a/tests/test_throttle.py b/tests/test_throttle.py index 91bab4f3..8c0b1f01 100644 --- a/tests/test_throttle.py +++ b/tests/test_throttle.py @@ -1,3 +1,5 @@ +import pytest + from threading import Lock import time @@ -6,7 +8,8 @@ from more_executors import Executors, ThrottleExecutor -def test_throttle(): +@pytest.mark.parametrize("block", [True, False]) +def test_throttle(block): THREADS = 8 COUNT = 3 samples = [] @@ -27,7 +30,9 @@ def record(x): running_now.remove(x) futures = [] - executor = ThrottleExecutor(Executors.thread_pool(max_workers=THREADS), count=COUNT) + executor = ThrottleExecutor( + Executors.thread_pool(max_workers=THREADS), count=COUNT, block=block + ) with executor: for i in range(0, 1000): future = executor.submit(record, i) @@ -51,6 +56,6 @@ def record(x): def test_with_throttle(): assert_that( - Executors.sync(name="throttle-test").with_throttle(4), + Executors.sync(name="throttle-test").with_throttle(4, block=True), instance_of(ThrottleExecutor), )