diff --git a/docs/source/conf.py b/docs/source/conf.py index 281baae2a8..29a7f83739 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,6 +21,13 @@ # import sys # sys.path.insert(0, os.path.abspath('.')) +# Warn about all references to unknown targets +nitpicky = True +# Except for these ones, which we expect to point to unknown targets: +nitpick_ignore = [ + ("py:obj", "CapacityLimiter-like object"), +] + # XX hack the RTD theme until # https://github.com/rtfd/sphinx_rtd_theme/pull/382 # is shipped (should be in the release after 0.2.4) diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index ab87dc52c9..fe1cc9989f 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1358,6 +1358,9 @@ synchronization logic. All of classes discussed in this section are implemented on top of the public APIs in :mod:`trio.hazmat`; they don't have any special access to trio's internals.) +.. autoclass:: CapacityLimiter + :members: + .. autoclass:: Semaphore :members: @@ -1395,6 +1398,8 @@ communicate back with trio, there's the closely related .. autofunction:: run_in_worker_thread +.. autofunction:: current_default_worker_thread_limiter + .. function:: current_run_in_trio_thread current_await_in_trio_thread diff --git a/trio/_sync.py b/trio/_sync.py index 4a69fc83b3..75e86f34f9 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -7,7 +7,9 @@ from ._util import aiter_compat __all__ = [ - "Event", "Semaphore", "Lock", "StrictFIFOLock", "Condition", "Queue"] + "Event", "CapacityLimiter", "Semaphore", "Lock", "StrictFIFOLock", + "Condition", "Queue", +] @attr.s(slots=True, repr=False, cmp=False, hash=False) class Event: @@ -15,7 +17,7 @@ class Event: inspired by :class:`threading.Event`. An event object manages an internal boolean flag, which is initially - False. + False, and tasks can wait for it to become True. """ @@ -81,6 +83,245 @@ async def __aexit__(self, *args): return cls +@attr.s(frozen=True) +class _CapacityLimiterStatistics: + borrowed_tokens = attr.ib() + total_tokens = attr.ib() + borrowers = attr.ib() + tasks_waiting = attr.ib() + + +@async_cm +class CapacityLimiter: + """An object for controlling access to a resource with limited capacity. + + Sometimes you need to put a limit on how many tasks can do something at + the same time. For example, you might want to use some threads to run + multiple blocking I/O operations in parallel... but if you use too many + threads at once, then your system can become overloaded and it'll actually + make things slower. One popular solution is to impose a policy like "run + up to 40 threads at the same time, but no more". But how do you implement + a policy like this? + + That's what :class:`CapacityLimiter` is for. You can think of a + :class:`CapacityLimiter` object as a sack that starts out holding some fixed + number of tokens:: + + limit = trio.CapacityLimiter(40) + + Then tasks can come along and borrow a token out of the sack:: + + # Borrow a token: + async with limit: + # We are holding a token! + await perform_expensive_operation() + # Exiting the 'async with' block puts the token back into the sack + + And crucially, if you try to borrow a token but the sack is empty, then + you have to wait for another task to finish what it's doing and put its + token back first before you can take it and continue. + + Another way to think of it: a :class:`CapacityLimiter` is like a sofa with a + fixed number of seats, and if they're all taken then you have to wait for + someone to get up before you can sit down. + + By default, :func:`run_in_worker_thread` uses a :class:`CapacityLimiter` to + limit the number of threads running at once; see + :func:`current_default_worker_thread_limiter` for details. + + If you're familiar with semaphores, then you can think of this as a + restricted semaphore that's specialized for one common use case, with + additional error checking. For a more traditional semaphore, see + :class:`Semaphore`. + + .. note:: + + Don't confuse this with the `"leaky bucket" + `__ or `"token bucket" + `__ algorithms used to + limit bandwidth usage on networks. The basic idea of using tokens to + track a resource limit is similar, but this is a very simple sack where + tokens aren't automatically created or destroyed over time; they're + just borrowed and then put back. + + """ + def __init__(self, total_tokens): + self._lot = _core.ParkingLot() + self._borrowers = set() + # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of + self._pending_borrowers = {} + # invoke the property setter for validation + self.total_tokens = total_tokens + assert self._total_tokens == total_tokens + + def __repr__(self): + return ("" + .format(id(self), len(self._borrowers), self._total_tokens, len(self._lot))) + + @property + def total_tokens(self): + """The total capacity available. + + You can change :attr:`total_tokens` by assigning to this attribute. If + you make it larger, then the appropriate number of waiting tasks will + be woken immediately to take the new tokens. If you decrease + total_tokens below the number of tasks that are currently using the + resource, then all current tasks will be allowed to finish as normal, + but no new tasks will be allowed in until the total number of tasks + drops below the new total_tokens. + + """ + return self._total_tokens + + def _wake_waiters(self): + available = self._total_tokens - len(self._borrowers) + for woken in self._lot.unpark(count=available): + self._borrowers.add(self._pending_borrowers.pop(woken)) + + @total_tokens.setter + def total_tokens(self, new_total_tokens): + if not isinstance(new_total_tokens, int): + raise TypeError("total_tokens must be an int") + if new_total_tokens < 1: + raise ValueError("total_tokens must be >= 1") + self._total_tokens = new_total_tokens + self._wake_waiters() + + @property + def borrowed_tokens(self): + """The amount of capacity that's currently in use. + + """ + return len(self._borrowers) + + @property + def available_tokens(self): + """The amount of capacity that's available to use. + + """ + return self.total_tokens - self.borrowed_tokens + + @_core.enable_ki_protection + def acquire_nowait(self): + """Borrow a token from the sack, without blocking. + + Raises: + WouldBlock: if no tokens are available. + + """ + self.acquire_on_behalf_of_nowait(_core.current_task()) + + @_core.enable_ki_protection + def acquire_on_behalf_of_nowait(self, borrower): + """Borrow a token from the sack on behalf of ``borrower``, without + blocking. + + Args: + borrower: A :class:`Task` or arbitrary opaque object used to record + who is borrowing this token. This is used by + :func:`run_in_worker_thread` to allow threads to "hold tokens", + with the intention in the future of using it to `allow deadlock + detection and other useful things + `__ + + Raises: + WouldBlock: if no tokens are available. + + """ + if borrower in self._borrowers: + raise RuntimeError( + "this borrower is already holding one of this " + "CapacityLimiter's tokens") + if len(self._borrowers) < self._total_tokens and not self._lot: + self._borrowers.add(borrower) + else: + raise _core.WouldBlock + + @_core.enable_ki_protection + async def acquire(self): + """Borrow a token from the sack, blocking if necessary. + + """ + await self.acquire_on_behalf_of(_core.current_task()) + + @_core.enable_ki_protection + async def acquire_on_behalf_of(self, borrower): + """Borrow a token from the sack on behalf of ``borrower``, blocking if + necessary. + + Args: + borrower: A :class:`Task` or arbitrary opaque object used to record + who is borrowing this token; see + :meth:`acquire_on_behalf_of_nowait` for details. + + """ + await _core.yield_if_cancelled() + try: + self.acquire_on_behalf_of_nowait(borrower) + except _core.WouldBlock: + task = _core.current_task() + self._pending_borrowers[task] = borrower + await self._lot.park() + except: + await _core.yield_briefly_no_cancel() + raise + else: + await _core.yield_briefly_no_cancel() + + @_core.enable_ki_protection + def release(self): + """Put a token back into the sack. + + Raises: + RuntimeError: if the current task has not acquired one of this + sack's tokens. + + """ + self.release_on_behalf_of(_core.current_task()) + + @_core.enable_ki_protection + def release_on_behalf_of(self, borrower): + """Put a token back into the sack on behalf of ``borrower``. + + Raises: + RuntimeError: if the given borrower has not acquired one of this + sack's tokens. + + """ + if borrower not in self._borrowers: + raise RuntimeError( + "this borrower isn't holding any of this CapacityLimiter's " + "tokens") + self._borrowers.remove(borrower) + self._wake_waiters() + + def statistics(self): + """Return an object containing debugging information. + + Currently the following fields are defined: + + * ``borrowed_tokens``: The number of tokens currently borrowed from + the sack. + * ``total_tokens``: The total number of tokens in the sack. Usually + this will be larger than ``borrowed_tokens``, but it's possibly for + it to be smaller if :attr:`total_tokens` was recently decreased. + * ``borrowers``: A list of all tasks or other entities that currently + hold a token. + * ``tasks_waiting``: The number of tasks blocked on this + :class:`CapacityLimiter`\'s :meth:`acquire` or + :meth:`acquire_on_behalf_of` methods. + + """ + return _CapacityLimiterStatistics( + borrowed_tokens=len(self._borrowers), + total_tokens=self._total_tokens, + # Use a list instead of a frozenset just in case we start to allow + # one borrower to hold multiple tokens in the future + borrowers=list(self._borrowers), + tasks_waiting=len(self._lot), + ) + + @async_cm class Semaphore: """A `semaphore `__. @@ -90,20 +331,9 @@ class Semaphore: the value is never allowed to drop below zero. If the value is zero, then :meth:`acquire` will block until someone calls :meth:`release`. - This is a very flexible synchronization object, but perhaps the most - common use is to represent a resource with some bounded supply. For - example, if you want to make sure that there are never more than four - tasks simultaneously performing some operation, you could do something - like:: - - # Allocate a shared Semaphore object, and somehow distribute it to all - # your tasks. NB: max_value=4 isn't technically necessary, but can - # help catch errors. - sem = trio.Semaphore(4, max_value=4) - - # Then when you perform the operation: - async with sem: - await perform_operation() + If you're looking for a :class:`Semaphore` to limit the number of tasks + that can access some resource simultaneously, then consider using a + :class:`CapacityLimiter` instead. This object's interface is similar to, but different from, that of :class:`threading.Semaphore`. diff --git a/trio/_threads.py b/trio/_threads.py index 14dcb6e7f8..83e72c24dc 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -2,11 +2,14 @@ import queue as stdlib_queue from itertools import count +import attr + from . import _core +from ._sync import CapacityLimiter __all__ = [ "current_await_in_trio_thread", "current_run_in_trio_thread", - "run_in_worker_thread", + "run_in_worker_thread", "current_default_worker_thread_limiter", ] def _await_in_trio_thread_cb(q, afn, args): @@ -140,11 +143,38 @@ def current_await_in_trio_thread(): # really the only real limit is on stack size actually *used*; how much you # *allocate* should be pretty much irrelevant.) +_limiter_local = _core.RunLocal() +# I pulled this number out of the air; it isn't based on anything. Probably we +# should make some kind of measurements to pick a good value. +DEFAULT_LIMIT = 40 _worker_thread_counter = count() +def current_default_worker_thread_limiter(): + """Get the default :class:`CapacityLimiter` used by + :func:`run_in_worker_thread`. + + The most common reason to call this would be if you want to modify its + :attr:`~CapacityLimiter.total_tokens` attribute. + + """ + try: + limiter = _limiter_local.limiter + except AttributeError: + limiter = _limiter_local.limiter = CapacityLimiter(DEFAULT_LIMIT) + return limiter + +# Eventually we might build this into a full-fledged deadlock-detection +# system; see https://github.com/python-trio/trio/issues/182 +# But for now we just need an object to stand in for the thread, so we can +# keep track of who's holding the CapacityLimiter's token. +@attr.s(frozen=True, cmp=False, hash=False, slots=True) +class ThreadPlaceholder: + name = attr.ib() + @_core.enable_ki_protection -async def run_in_worker_thread(sync_fn, *args, cancellable=False): - """Convert a blocking operation in an async operation using a thread. +async def run_in_worker_thread( + sync_fn, *args, cancellable=False, limiter=None): + """Convert a blocking operation into an async operation using a thread. These two lines are equivalent:: @@ -156,12 +186,31 @@ async def run_in_worker_thread(sync_fn, *args, cancellable=False): tasks to continue working while ``sync_fn`` runs. This is accomplished by pushing the call to ``sync_fn(*args)`` off into a worker thread. + Args: + sync_fn: An arbitrary synchronous callable. + *args: Positional arguments to pass to sync_fn. If you need keyword + arguments, use :func:`functools.partial`. + cancellable (bool): Whether to allow cancellation of this operation. See + discussion below. + limiter (None, CapacityLimiter, or CapacityLimiter-like object): + An object used to limit the number of simultaneous threads. Most + commonly this will be a :class:`CapacityLimiter`, but it could be + anything providing compatible + :meth:`~trio.CapacityLimiter.acquire_on_behalf_of` and + :meth:`~trio.CapacityLimiter.release_on_behalf_of` + methods. :func:`run_in_worker_thread` will call + ``acquire_on_behalf_of`` before starting the thread, and + ``release_on_behalf_of`` after the thread has finished. + + If None (the default), uses the default :class:`CapacityLimiter`, as + returned by :func:`current_default_worker_thread_limiter`. + **Cancellation handling**: Cancellation is a tricky issue here, because neither Python nor the operating systems it runs on provide any general - way to communicate with an arbitrary synchronous function running in a - thread and tell it to stop. This function will always check for - cancellation on entry, before starting the thread. But once the thread is - running, there are two ways it can handle being cancelled: + mechanism for cancelling an arbitrary synchronous function running in a + thread. :func:`run_in_worker_thread` will always check for cancellation on + entry, before starting the thread. But once the thread is running, there + are two ways it can handle being cancelled: * If ``cancellable=False``, the function ignores the cancellation and keeps going, just like if we had called ``sync_fn`` synchronously. This @@ -172,30 +221,30 @@ async def run_in_worker_thread(sync_fn, *args, cancellable=False): background** – we just abandon it to do whatever it's going to do, and silently discard any return value or errors that it raises. Only use this if you know that the operation is safe and side-effect free. (For - example: ``trio.socket.getaddrinfo`` is implemented using + example: :func:`trio.socket.getaddrinfo` is implemented using :func:`run_in_worker_thread`, and it sets ``cancellable=True`` because - it doesn't really matter if a stray hostname lookup keeps running in the - background.) + it doesn't really affect anything if a stray hostname lookup keeps + running in the background.) - .. warning:: + The ``limiter`` is only released after the thread has *actually* + finished – which in the case of cancellation may be some time after + :func:`run_in_worker_thread` has returned. (This is why it's crucial + that :func:`run_in_worker_thread` takes care of acquiring and releasing + the limiter.) If :func:`trio.run` finishes before the thread does, then + the limiter release method will never be called at all. - You should not use :func:`run_in_worker_thread` to call CPU-bound - functions! In addition to the usual GIL-related reasons why using - threads for CPU-bound work is not very effective in Python, there is an - additional problem: on CPython, `CPU-bound threads tend to "starve out" - IO-bound threads `__, so using - :func:`run_in_worker_thread` for CPU-bound work is likely to adversely - affect the main thread running trio. If you need to do this, you're - better off using a worker process, or perhaps PyPy (which still has a - GIL, but may do a better job of fairly allocating CPU time between - threads). + .. warning:: - Args: - sync_fn: An arbitrary synchronous callable. - *args: Positional arguments to pass to sync_fn. If you need keyword - arguments, use :func:`functools.partial`. - cancellable (bool): Whether to allow cancellation of this operation. See - discussion above. + You should not use :func:`run_in_worker_thread` to call long-running + CPU-bound functions! In addition to the usual GIL-related reasons why + using threads for CPU-bound work is not very effective in Python, there + is an additional problem: on CPython, `CPU-bound threads tend to + "starve out" IO-bound threads `__, + so using :func:`run_in_worker_thread` for CPU-bound work is likely to + adversely affect the main thread running trio. If you need to do this, + you're better off using a worker process, or perhaps PyPy (which still + has a GIL, but may do a better job of fairly allocating CPU time + between threads). Returns: Whatever ``sync_fn(*args)`` returns. @@ -206,22 +255,56 @@ async def run_in_worker_thread(sync_fn, *args, cancellable=False): """ await _core.yield_if_cancelled() call_soon = _core.current_call_soon_thread_and_signal_safe() + if limiter is None: + limiter = current_default_worker_thread_limiter() + + # Holds a reference to the task that's blocked in this function waiting + # for the result – or None if this function was cancelled and we should + # discard the result. task_register = [_core.current_task()] - def trio_thread_fn(result): + name = "trio-worker-{}".format(next(_worker_thread_counter)) + placeholder = ThreadPlaceholder(name) + + # This function gets scheduled into the trio run loop to deliver the + # thread's result. + def report_back_in_trio_thread_fn(result): + print("in trio thread", result) + def do_release_then_return_result(): + print("asdF") + # release_on_behalf_of is an arbitrary user-defined method, so it + # might raise an error. If it does, we want that error to + # replace the regular return value, and if the regular return was + # already an exception then we want them to chain. + try: + return result.unwrap() + finally: + limiter.release_on_behalf_of(placeholder) + result = _core.Result.capture(do_release_then_return_result) if task_register[0] is not None: _core.reschedule(task_register[0], result) + + # This is the function that runs in the worker thread to do the actual + # work and then schedule the call to report_back_in_trio_thread_fn def worker_thread_fn(): result = _core.Result.capture(sync_fn, *args) try: - call_soon(trio_thread_fn, result) + call_soon(report_back_in_trio_thread_fn, result) except _core.RunFinishedError: # The entire run finished, so our particular task is certainly # long gone -- it must have cancelled. pass - name = "trio-worker-{}".format(next(_worker_thread_counter)) - # daemonic because it might get left behind if we cancel - thread = threading.Thread(target=worker_thread_fn, name=name, daemon=True) - thread.start() + + await limiter.acquire_on_behalf_of(placeholder) + try: + # daemon=True because it might get left behind if we cancel, and in + # this case shouldn't block process exit. + thread = threading.Thread( + target=worker_thread_fn, name=name, daemon=True) + thread.start() + except: + limiter.release_on_behalf_of(placeholder) + raise + def abort(_): if cancellable: task_register[0] = None diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index 3f5e622abd..a428f3d618 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -36,6 +36,114 @@ async def child(): assert record == ["sleeping", "sleeping", "woken", "woken"] +async def test_CapacityLimiter(): + with pytest.raises(TypeError): + CapacityLimiter(1.0) + with pytest.raises(ValueError): + CapacityLimiter(-1) + c = CapacityLimiter(2) + repr(c) # smoke test + assert c.total_tokens == 2 + assert c.borrowed_tokens == 0 + assert c.available_tokens == 2 + with pytest.raises(RuntimeError): + c.release() + assert c.borrowed_tokens == 0 + c.acquire_nowait() + assert c.borrowed_tokens == 1 + assert c.available_tokens == 1 + + stats = c.statistics() + assert stats.borrowed_tokens == 1 + assert stats.total_tokens == 2 + assert stats.borrowers == [_core.current_task()] + assert stats.tasks_waiting == 0 + + # Can't re-acquire when we already have it + with pytest.raises(RuntimeError): + c.acquire_nowait() + assert c.borrowed_tokens == 1 + with assert_yields(): + with pytest.raises(RuntimeError): + await c.acquire() + assert c.borrowed_tokens == 1 + + # We can acquire on behalf of someone else though + with assert_yields(): + await c.acquire_on_behalf_of("someone") + + # But then we've run out of capacity + assert c.borrowed_tokens == 2 + with pytest.raises(_core.WouldBlock): + c.acquire_on_behalf_of_nowait("third party") + + assert set(c.statistics().borrowers) == {_core.current_task(), "someone"} + + # Until we release one + c.release_on_behalf_of(_core.current_task()) + assert c.statistics().borrowers == ["someone"] + + c.release_on_behalf_of("someone") + assert c.borrowed_tokens == 0 + with assert_yields(): + async with c: + assert c.borrowed_tokens == 1 + + async with _core.open_nursery() as nursery: + await c.acquire_on_behalf_of("value 1") + await c.acquire_on_behalf_of("value 2") + t = nursery.spawn(c.acquire_on_behalf_of, "value 3") + await wait_all_tasks_blocked() + assert t.result is None + assert c.borrowed_tokens == 2 + assert c.statistics().tasks_waiting == 1 + c.release_on_behalf_of("value 2") + # Fairness: + assert c.borrowed_tokens == 2 + with pytest.raises(_core.WouldBlock): + c.acquire_nowait() + await t.wait() + + c.release_on_behalf_of("value 3") + c.release_on_behalf_of("value 1") + + +async def test_CapacityLimiter_change_tokens(): + c = CapacityLimiter(2) + + with pytest.raises(TypeError): + c.total_tokens = 1.0 + + with pytest.raises(ValueError): + c.total_tokens = 0 + + with pytest.raises(ValueError): + c.total_tokens = -10 + + assert c.total_tokens == 2 + + async with _core.open_nursery() as nursery: + for i in range(5): + nursery.spawn(c.acquire_on_behalf_of, i) + await wait_all_tasks_blocked() + assert set(c.statistics().borrowers) == {0, 1} + assert c.statistics().tasks_waiting == 3 + c.total_tokens += 2 + assert set(c.statistics().borrowers) == {0, 1, 2, 3} + assert c.statistics().tasks_waiting == 1 + c.total_tokens -= 3 + assert c.borrowed_tokens == 4 + assert c.total_tokens == 1 + c.release_on_behalf_of(0) + c.release_on_behalf_of(1) + c.release_on_behalf_of(2) + assert set(c.statistics().borrowers) == {3} + assert c.statistics().tasks_waiting == 1 + c.release_on_behalf_of(3) + assert set(c.statistics().borrowers) == {4} + assert c.statistics().tasks_waiting == 0 + + async def test_Semaphore(): with pytest.raises(TypeError): Semaphore(1.0) diff --git a/trio/tests/test_threads.py b/trio/tests/test_threads.py index 149260fc5d..65736018c1 100644 --- a/trio/tests/test_threads.py +++ b/trio/tests/test_threads.py @@ -3,16 +3,17 @@ import time import os import signal +from functools import partial import pytest from .. import _core -from .. import Event +from .. import Event, CapacityLimiter, sleep from ..testing import wait_all_tasks_blocked from .._threads import * -from .._timeouts import sleep from .._core.tests.test_ki import ki_self +from .._core.tests.tutil import slow async def test_do_in_trio_thread(): trio_thread = threading.current_thread() @@ -248,3 +249,156 @@ async def child(): assert not out and not err +@pytest.mark.parametrize("MAX", [3, 5, 10]) +@pytest.mark.parametrize("cancel", [False, True]) +@pytest.mark.parametrize("use_default_limiter", [False, True]) +async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter): + # This test is a bit tricky. The goal is to make sure that if we set + # limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever + # running at a time, even if there are more concurrent calls to + # run_in_worker_thread, and even if some of those are cancelled. And also + # to make sure that the default limiter actually limits. + COUNT = 2 * MAX + gate = threading.Event() + lock = threading.Lock() + if use_default_limiter: + c = current_default_worker_thread_limiter() + orig_total_tokens = c.total_tokens + c.total_tokens = MAX + limiter_arg = None + else: + c = CapacityLimiter(MAX) + orig_total_tokens = MAX + limiter_arg = c + try: + ran = 0 + high_water = 0 + running = 0 + parked = 0 + + run_in_trio_thread = current_run_in_trio_thread() + + def thread_fn(cancel_scope): + print("thread_fn start") + nonlocal ran, running, high_water, parked + run_in_trio_thread(cancel_scope.cancel) + with lock: + ran += 1 + running += 1 + high_water = max(high_water, running) + # The trio thread below watches this value and uses it as a + # signal that all the stats calculations have finished. + parked += 1 + gate.wait() + with lock: + parked -= 1 + running -= 1 + print("thread_fn exiting") + + async def run_thread(): + with _core.open_cancel_scope() as cancel_scope: + await run_in_worker_thread( + thread_fn, cancel_scope, + limiter=limiter_arg, cancellable=cancel) + print("run_thread finished, cancelled:", + cancel_scope.cancelled_caught) + + async with _core.open_nursery() as nursery: + print("spawning") + tasks = [] + for i in range(COUNT): + tasks.append(nursery.spawn(run_thread)) + await wait_all_tasks_blocked() + # In the cancel case, we in particular want to make sure that the + # cancelled tasks don't release the semaphore. So let's wait until + # at least one of them has exited, and that everything has had a + # chance to settle down from this, before we check that everyone + # who's supposed to be waiting is waiting: + if cancel: + print("waiting for first cancellation to clear") + await tasks[0].wait() + await wait_all_tasks_blocked() + # Then wait until the first MAX threads are parked in gate.wait(), + # and the next MAX threads are parked on the semaphore, to make + # sure no-one is sneaking past, and to make sure the high_water + # check below won't fail due to scheduling issues. (It could still + # fail if too many threads are let through here.) + while parked != MAX or c.statistics().tasks_waiting != MAX: + await sleep(0.01) # pragma: no cover + # Then release the threads + gate.set() + + assert high_water == MAX + + if cancel: + # Some threads might still be running; need to wait to them to + # finish before checking that all threads ran. We can do this + # using the CapacityLimiter. + while c.borrowed_tokens > 0: + await sleep(0.01) # pragma: no cover + + assert ran == COUNT + finally: + c.total_tokens = orig_total_tokens + + +async def test_run_in_worker_thread_custom_limiter(): + # Basically just checking that we only call acquire_on_behalf_of and + # release_on_behalf_of, since that's part of our documented API. + record = [] + class CustomLimiter: + async def acquire_on_behalf_of(self, borrower): + record.append("acquire") + self._borrower = borrower + + def release_on_behalf_of(self, borrower): + record.append("release") + assert borrower == self._borrower + + await run_in_worker_thread(lambda: None, limiter=CustomLimiter()) + assert record == ["acquire", "release"] + + +async def test_run_in_worker_thread_limiter_error(): + record = [] + + class BadCapacityLimiter: + async def acquire_on_behalf_of(self, borrower): + record.append("acquire") + + def release_on_behalf_of(self, borrower): + record.append("release") + raise ValueError + + bs = BadCapacityLimiter() + + with pytest.raises(ValueError) as excinfo: + await run_in_worker_thread(lambda: None, limiter=bs) + assert excinfo.value.__context__ is None + assert record == ["acquire", "release"] + record = [] + + # If the original function raised an error, then the semaphore error + # chains with it + d = {} + with pytest.raises(ValueError) as excinfo: + await run_in_worker_thread(lambda: d["x"], limiter=bs) + assert isinstance(excinfo.value.__context__, KeyError) + assert record == ["acquire", "release"] + + +async def test_run_in_worker_thread_fail_to_spawn(monkeypatch): + # Test the unlikely but possible case where trying to spawn a thread fails + def bad_start(self): + raise RuntimeError("the engines canna take it captain") + monkeypatch.setattr(threading.Thread, "start", bad_start) + + limiter = current_default_worker_thread_limiter() + assert limiter.borrowed_tokens == 0 + + # We get an appropriate error, and the limiter is cleanly released + with pytest.raises(RuntimeError) as excinfo: + await run_in_worker_thread(lambda: None) + assert "engines" in str(excinfo.value) + + assert limiter.borrowed_tokens == 0