diff --git a/newsfragments/2696.feature.rst b/newsfragments/2696.feature.rst new file mode 100644 index 000000000..560cf3b36 --- /dev/null +++ b/newsfragments/2696.feature.rst @@ -0,0 +1,4 @@ +:func:`trio.lowlevel.start_guest_run` now does a bit more setup of the guest run +before it returns to its caller, so that the caller can immediately make calls to +:func:`trio.current_time`, :func:`trio.lowlevel.spawn_system_task`, +:func:`trio.lowlevel.current_trio_token`, etc. diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 0b6d32654..50c9cf258 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -2118,6 +2118,16 @@ def start_guest_run( the host loop and then immediately starts the guest run, and then shuts down the host when the guest run completes. + Once :func:`start_guest_run` returns successfully, the guest run + has been set up enough that you can invoke sync-colored Trio + functions such as :func:`current_time`, :func:`spawn_system_task`, + and :func:`current_trio_token`. If a `TrioInternalError` occurs + during this early setup of the guest run, it will be raised out of + :func:`start_guest_run`. All other errors, including all errors + raised by the *async_fn*, will be delivered to your + *done_callback* at some point after :func:`start_guest_run` returns + successfully. + Args: run_sync_soon_threadsafe: An arbitrary callable, which will be passed a @@ -2178,6 +2188,39 @@ def my_done_callback(run_outcome): host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd, ), ) + + # Run a few ticks of the guest run synchronously, so that by the + # time we return, the system nursery exists and callers can use + # spawn_system_task. We don't actually run any user code during + # this time, so it shouldn't be possible to get an exception here, + # except for a TrioInternalError. + next_send = None + for tick in range(5): # expected need is 2 iterations + leave some wiggle room + if runner.system_nursery is not None: + # We're initialized enough to switch to async guest ticks + break + try: + timeout = guest_state.unrolled_run_gen.send(next_send) + except StopIteration: # pragma: no cover + raise TrioInternalError( + "Guest runner exited before system nursery was initialized" + ) + if timeout != 0: # pragma: no cover + guest_state.unrolled_run_gen.throw( + TrioInternalError( + "Guest runner blocked before system nursery was initialized" + ) + ) + next_send = () + else: # pragma: no cover + guest_state.unrolled_run_gen.throw( + TrioInternalError( + "Guest runner yielded too many times before " + "system nursery was initialized" + ) + ) + + guest_state.unrolled_run_next_send = Value(next_send) run_sync_soon_not_threadsafe(guest_state.guest_tick) diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py index 7b004cf04..7aef3e437 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -26,7 +26,7 @@ # our main # - final result is returned # - any unhandled exceptions cause an immediate crash -def trivial_guest_run(trio_fn, **start_guest_run_kwargs): +def trivial_guest_run(trio_fn, *, in_host_after_start=None, **start_guest_run_kwargs): todo = queue.Queue() host_thread = threading.current_thread() @@ -58,6 +58,8 @@ def done_callback(outcome): done_callback=done_callback, **start_guest_run_kwargs, ) + if in_host_after_start is not None: + in_host_after_start() try: while True: @@ -109,6 +111,48 @@ async def do_receive(): trivial_guest_run(trio_main) +def test_guest_is_initialized_when_start_returns(): + trio_token = None + record = [] + + async def trio_main(in_host): + record.append("main task ran") + await trio.sleep(0) + assert trio.lowlevel.current_trio_token() is trio_token + return "ok" + + def after_start(): + # We should get control back before the main task executes any code + assert record == [] + + nonlocal trio_token + trio_token = trio.lowlevel.current_trio_token() + trio_token.run_sync_soon(record.append, "run_sync_soon cb ran") + + @trio.lowlevel.spawn_system_task + async def early_task(): + record.append("system task ran") + await trio.sleep(0) + + res = trivial_guest_run(trio_main, in_host_after_start=after_start) + assert res == "ok" + assert set(record) == {"system task ran", "main task ran", "run_sync_soon cb ran"} + + # Errors during initialization (which can only be TrioInternalErrors) + # are raised out of start_guest_run, not out of the done_callback + with pytest.raises(trio.TrioInternalError): + class BadClock: + def start_clock(self): + raise ValueError("whoops") + + def after_start_never_runs(): # pragma: no cover + pytest.fail("shouldn't get here") + + trivial_guest_run( + trio_main, clock=BadClock(), in_host_after_start=after_start_never_runs + ) + + def test_host_can_directly_wake_trio_task(): async def trio_main(in_host): ev = trio.Event()