From 798183deaa7de3ae663314af584905a634c59e55 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 2 Aug 2024 14:59:38 +0200 Subject: [PATCH] Fix exception handling for ``WorkerPlugin.setup`` and ``WorkerPlugin.teardown`` (#8810) --- distributed/client.py | 3 +- .../diagnostics/tests/test_worker_plugin.py | 58 ++- distributed/shuffle/tests/test_shuffle.py | 343 +++++++++--------- distributed/worker.py | 13 +- 4 files changed, 228 insertions(+), 189 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index ad283a352ac..0601b0db5fb 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5409,8 +5409,7 @@ async def _unregister_worker_plugin(self, name, nanny=None): for response in responses.values(): if response["status"] == "error": - exc = response["exception"] - tb = response["traceback"] + _, exc, tb = clean_exception(**response) raise exc.with_traceback(tb) return responses diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 0f206512b8a..001576afe33 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -1,13 +1,14 @@ from __future__ import annotations import asyncio +import logging import warnings import pytest from distributed import Worker, WorkerPlugin from distributed.protocol.pickle import dumps -from distributed.utils_test import async_poll_for, gen_cluster, inc +from distributed.utils_test import async_poll_for, captured_logger, gen_cluster, inc class MyPlugin(WorkerPlugin): @@ -423,3 +424,58 @@ def setup(self, worker): await c.register_plugin(second, idempotent=True) assert "idempotentplugin" in a.plugins assert a.plugins["idempotentplugin"].instance == "first" + + +class BrokenSetupPlugin(WorkerPlugin): + def setup(self, worker): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_register_plugin_with_broken_setup_to_existing_workers_raises(c, s, a): + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_plugin_with_broken_setup_on_new_worker_logs(c, s): + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + async with Worker(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +class BrokenTeardownPlugin(WorkerPlugin): + def teardown(self, worker): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_unregister_worker_plugin_with_broken_teardown_raises(c, s, a): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + await c.unregister_worker_plugin("TestPlugin1") + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_plugin_with_broken_teardown_logs_on_close(c, s): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + async with Worker(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 39f53528fd4..443595b0bbd 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -87,29 +87,32 @@ def lose_annotations(request): return request.param -async def check_worker_cleanup( +async def assert_worker_cleanup( worker: Worker, - closed: bool = False, + close: bool = False, interval: float = 0.01, timeout: int | None = 5, ) -> None: """Assert that the worker has no shuffle state""" - deadline = Deadline.after(timeout) plugin = worker.plugins["shuffle"] assert isinstance(plugin, ShuffleWorkerPlugin) - while plugin.shuffle_runs._runs and not deadline.expired: - await asyncio.sleep(interval) - assert not plugin.shuffle_runs._runs - if closed: + deadline = Deadline.after(timeout) + if close: + await worker.close() + assert "shuffle" not in worker.plugins assert plugin.closed + else: + while plugin.shuffle_runs._runs and not deadline.expired: + await asyncio.sleep(interval) + assert not plugin.shuffle_runs._runs for dirpath, dirnames, filenames in os.walk(worker.local_directory): assert "shuffle" not in dirpath for fn in dirnames + filenames: assert "shuffle" not in fn -async def check_scheduler_cleanup( +async def assert_scheduler_cleanup( scheduler: Scheduler, interval: float = 0.01, timeout: int | None = 5 ) -> None: """Assert that the scheduler has no shuffle state""" @@ -175,9 +178,9 @@ async def test_basic_cudf_support(c, s, a, b): result, expected = await c.compute([shuffled, df], sync=True) dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) def get_active_shuffle_run(shuffle_id: ShuffleId, worker: Worker) -> ShuffleRun: @@ -213,9 +216,9 @@ async def test_basic_integration(c, s, a, b, npartitions, disk): result, expected = await c.compute([shuffled, df], sync=True) dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("processes", [True, False]) @@ -260,9 +263,9 @@ async def test_shuffle_with_array_conversion(c, s, a, b, npartitions): else: await c.compute(out) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) def test_shuffle_before_categorize(loop_in_thread): @@ -295,9 +298,9 @@ async def test_concurrent(c, s, a, b): dd.assert_eq(x, df, check_index=False) dd.assert_eq(y, df, check_index=False) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -323,9 +326,9 @@ async def test_bad_disk(c, s, a, b): out = await c.compute(out) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) async def wait_until_worker_has_tasks( @@ -401,15 +404,14 @@ async def test_closed_worker_during_transfer(c, s, a, b): shuffled = df.shuffle("x") fut = c.compute([shuffled, df], sync=True) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) result, expected = await fut dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -428,16 +430,15 @@ async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b): out = df.shuffle("x") out = c.compute(out.x.size) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) with pytest.raises(KilledWorker): await out assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -491,14 +492,13 @@ async def test_restarting_does_not_log_p2p_failed(c, s, a, b): out = df.shuffle("x") out = c.compute(out.x.size) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) await out assert not s.get_events("p2p") await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) class BlockedGetOrCreateShuffleRunManager(_ShuffleRunManager): @@ -538,7 +538,7 @@ async def test_get_or_create_from_dangling_transfer(c, s, a, b): shuffle_extB.shuffle_runs.block_get_or_create.set() await shuffle_extA.shuffle_runs.in_get_or_create.wait() - await b.close() + await assert_worker_cleanup(b, close=True) await async_poll_for( lambda: not any(ws.processing for ws in s.workers.values()), timeout=5 ) @@ -552,10 +552,9 @@ async def test_get_or_create_from_dangling_transfer(c, s, a, b): await async_poll_for(lambda: not a.state.tasks, timeout=10) assert not s.plugins["shuffle"].active_shuffles - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) + await assert_worker_cleanup(a) await c.close() - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -581,8 +580,8 @@ async def test_crashed_worker_during_transfer(c, s, a): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -648,15 +647,14 @@ def mock_get_worker_for_range_sharding( shuffled = df.shuffle("x") fut = c.compute([shuffled, df], sync=True) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b, 0.001) - await b.close() + await assert_worker_cleanup(b, close=True) result, expected = await fut dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -691,8 +689,8 @@ def mock_mock_get_worker_for_range_sharding( dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) # @pytest.mark.slow @@ -714,15 +712,14 @@ async def test_closed_bystanding_worker_during_shuffle(c, s, w1, w2, w3): ) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, w1) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, w2) - await w3.close() + await assert_worker_cleanup(w3, close=True) result, expected = await fut dd.assert_eq(result, expected) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2) - await check_worker_cleanup(w3, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_worker_cleanup(w2) + await assert_scheduler_cleanup(s) class RaiseOnCloseShuffleRun(DataFrameShuffleRun): @@ -749,9 +746,8 @@ async def test_exception_on_close_cleans_up(c, s, caplog): with dask.config.set({"dataframe.shuffle.method": "p2p"}): shuffled = df.shuffle("x") await c.compute([shuffled, df], sync=True) - + await assert_worker_cleanup(w, close=True) assert any("test-exception-on-close" in record.message for record in caplog.records) - await check_worker_cleanup(w, closed=True) class BlockedInputsDoneShuffle(DataFrameShuffleRun): @@ -798,7 +794,7 @@ async def test_closed_worker_during_barrier(c, s, a, b): else: close_worker, alive_worker = b, a alive_shuffle = shuffleA - await close_worker.close() + await assert_worker_cleanup(close_worker, close=True) alive_shuffle.block_inputs_done.set() alive_shuffles = get_active_shuffle_runs(alive_worker) @@ -820,9 +816,8 @@ def shuffle_restarted(): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(close_worker, closed=True) - await check_worker_cleanup(alive_worker) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(alive_worker) + await assert_scheduler_cleanup(s) @mock.patch( @@ -861,7 +856,7 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): else: close_worker, alive_worker = b, a alive_shuffle = shuffleA - await close_worker.close() + await assert_worker_cleanup(close_worker, close=True) with pytest.raises(KilledWorker): await out @@ -870,9 +865,8 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): alive_shuffle.block_inputs_done.set() await c.close() - await check_worker_cleanup(close_worker, closed=True) - await check_worker_cleanup(alive_worker) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(alive_worker) + await assert_scheduler_cleanup(s) @mock.patch( @@ -909,7 +903,7 @@ async def test_closed_other_worker_during_barrier(c, s, a, b): else: close_worker, alive_worker = a, b alive_shuffle = shuffleB - await close_worker.close() + await assert_worker_cleanup(close_worker, close=True) alive_shuffle.block_inputs_done.set() alive_shuffles = get_active_shuffle_runs(alive_worker) @@ -931,9 +925,8 @@ def shuffle_restarted(): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(close_worker, closed=True) - await check_worker_cleanup(alive_worker) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(alive_worker) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -981,8 +974,8 @@ def shuffle_restarted(): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)] * 2) @@ -997,15 +990,14 @@ async def test_closed_worker_during_unpack(c, s, a, b): shuffled = df.shuffle("x") fut = c.compute([shuffled, df], sync=True) await wait_for_tasks_in_state(UNPACK_PREFIX, "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) result, expected = await fut dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -1024,16 +1016,15 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b): out = df.shuffle("x") out = c.compute(out.x.size) await wait_for_tasks_in_state(UNPACK_PREFIX, "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) with pytest.raises(KilledWorker): await out assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -1059,14 +1050,14 @@ async def test_crashed_worker_during_unpack(c, s, a): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) async def test_heartbeat(c, s, a, b): await a.heartbeat() - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1084,10 +1075,10 @@ async def test_heartbeat(c, s, a, b): assert s.plugins["shuffle"].heartbeats.values() await out - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del out - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @pytest.mark.skipif("not pa", reason="Requires PyArrow") @@ -1292,10 +1283,10 @@ async def test_head(c, s, a, b): assert list(os.walk(a.local_directory)) == a_files # cleaned up files? assert list(os.walk(b.local_directory)) == b_files - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del out - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) def test_split_by_worker(): @@ -1399,9 +1390,9 @@ async def test_clean_after_forgotten_early(c, s, a, b): await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, a) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) del out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1424,9 +1415,9 @@ async def test_tail(c, s, a, b): assert len(s.tasks) < ntasks_full del partial - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @@ -1454,9 +1445,9 @@ async def test_repeat_shuffle_instance(c, s, a, b, wait_until_forgotten): await c.compute(out) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @@ -1485,9 +1476,9 @@ async def test_repeat_shuffle_operation(c, s, a, b, wait_until_forgotten): with dask.config.set({"dataframe.shuffle.method": "p2p"}): await c.compute(df.shuffle("x")) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1532,8 +1523,8 @@ def block(df, in_event, block_event): assert result == expected await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1561,8 +1552,8 @@ async def test_crashed_worker_after_shuffle_persisted(c, s, a): assert result == expected await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)] * 3) @@ -1578,25 +1569,22 @@ async def test_closed_worker_between_repeats(c, s, w1, w2, w3): out = df.shuffle("x") await c.compute(out.head(compute=False)) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2) - await check_worker_cleanup(w3) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_worker_cleanup(w2) + await assert_worker_cleanup(w3) + await assert_scheduler_cleanup(s) - await w3.close() + await assert_worker_cleanup(w3, close=True) await c.compute(out.tail(compute=False)) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2) - await check_worker_cleanup(w3, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_worker_cleanup(w2) + await assert_scheduler_cleanup(s) - await w2.close() + await assert_worker_cleanup(w2, close=True) await c.compute(out.head(compute=False)) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2, closed=True) - await check_worker_cleanup(w3, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1616,11 +1604,11 @@ async def test_new_worker(c, s, a, b): async with Worker(s.address) as w: await c.compute(persisted) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_worker_cleanup(w) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_worker_cleanup(w) del persisted - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1644,9 +1632,9 @@ async def test_multi(c, s, a, b): out = await c.compute(out.size) assert out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.skipif( @@ -1694,10 +1682,10 @@ async def test_delete_some_results(c, s, a, b): x = x.partitions[: x.npartitions // 2] x = await c.compute(x.size) - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del x - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1719,11 +1707,11 @@ async def test_add_some_results(c, s, a, b): await c.compute(x.size) - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del x del y - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -1743,12 +1731,11 @@ async def test_clean_after_close(c, s, a, b): await wait_for_tasks_in_state("shuffle-transfer", "executing", 1, a) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await a.close() - await check_worker_cleanup(a, closed=True) + await assert_worker_cleanup(a, close=True) del out - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class DataFrameShuffleTestPool(AbstractShuffleTestPool): @@ -2115,9 +2102,9 @@ async def test_deduplicate_stale_transfer(c, s, a, b, wait_until_forgotten): expected = await c.compute(df) dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class BlockedBarrierShuffleWorkerPlugin(ShuffleWorkerPlugin): @@ -2172,9 +2159,9 @@ async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten): result, expected = await fut dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -2270,8 +2257,8 @@ async def test_shuffle_run_consistency(c, s, a): await out del out - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -2317,8 +2304,8 @@ async def test_fail_fetch_race(c, s, a): worker_plugin.block_barrier.set() del out - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) class BlockedShuffleAccessAndFailShuffleRunManager(_ShuffleRunManager): @@ -2393,7 +2380,7 @@ async def test_replace_stale_shuffle(c, s, a, b): await asyncio.sleep(0) # A is cleaned - await check_worker_cleanup(a) + await assert_worker_cleanup(a) # B is not cleaned assert shuffle_id in get_active_shuffle_runs(b) @@ -2424,9 +2411,9 @@ async def test_replace_stale_shuffle(c, s, a, b): await out del out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2444,9 +2431,9 @@ async def test_handle_null_partitions(c, s, a, b): result = await c.compute(ddf) dd.assert_eq(result, df) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2467,9 +2454,9 @@ def make_partition(i): expected = await expected dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2496,9 +2483,9 @@ async def test_handle_object_columns(c, s, a, b): result = await c.compute(shuffled) dd.assert_eq(result, df) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2529,9 +2516,9 @@ def make_partition(i): await c.close() del out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2555,9 +2542,9 @@ def make_partition(i): await c.compute(out) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2582,9 +2569,9 @@ async def test_handle_categorical_data(c, s, a, b): result, expected = await c.compute([shuffled, df], sync=True) dd.assert_eq(result, expected, check_categorical=False) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2620,8 +2607,8 @@ async def test_set_index(c, s, *workers): dd.assert_eq(result, df.set_index("a")) await c.close() - await asyncio.gather(*[check_worker_cleanup(w) for w in workers]) - await check_scheduler_cleanup(s) + await asyncio.gather(*[assert_worker_cleanup(w) for w in workers]) + await assert_scheduler_cleanup(s) def test_shuffle_with_existing_index(client): @@ -2741,9 +2728,9 @@ async def test_unpack_is_non_rootish(c, s, a, b): scheduler_plugin.block_barrier.set() result = await result - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class FlakyConnectionPool(ConnectionPool): @@ -2791,10 +2778,10 @@ async def test_flaky_connect_fails_without_retry(c, s, a, b): ): await c.compute(x) - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) await c.close() - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -2823,9 +2810,9 @@ async def test_flaky_connect_recover_with_retry(c, s, a, b): assert len(line) < 250 assert not line or line.startswith("Retrying") - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class BlockedAfterGatherDep(Worker): @@ -2900,9 +2887,9 @@ def make_partition(partition_id, size): for _, group in result.groupby("b"): assert group["a"].is_monotonic_increasing - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("disk", [True, False]) diff --git a/distributed/worker.py b/distributed/worker.py index 5778d60dac4..18ef0aca86c 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1229,7 +1229,7 @@ async def _register_with_scheduler(self) -> None: *( self.plugin_add(name=name, plugin=plugin) for name, plugin in response["worker-plugins"].items() - ) + ), ) logger.info(" Registered to: %26s", self.scheduler.address) @@ -1560,12 +1560,7 @@ async def close( # type: ignore # Cancel async instructions await BaseWorker.close(self, timeout=timeout) - teardowns = [ - plugin.teardown(self) - for plugin in self.plugins.values() - if hasattr(plugin, "teardown") - ] - await asyncio.gather(*(td for td in teardowns if isawaitable(td))) + await asyncio.gather(*(self.plugin_remove(name) for name in self.plugins)) for extension in self.extensions.values(): if hasattr(extension, "close"): @@ -1870,13 +1865,14 @@ async def plugin_add( self.plugins[name] = plugin - logger.info("Starting Worker plugin %s" % name) + logger.info("Starting Worker plugin %s", name) if hasattr(plugin, "setup"): try: result = plugin.setup(worker=self) if isawaitable(result): result = await result except Exception as e: + logger.exception("Worker plugin %s failed to setup", name) if not catch_errors: raise return error_message(e) @@ -1893,6 +1889,7 @@ async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage: if isawaitable(result): result = await result except Exception as e: + logger.exception("Worker plugin %s failed to teardown", name) return error_message(e) return {"status": "OK"}