Skip to content

Commit

Permalink
Fix exception handling for WorkerPlugin.setup and ``WorkerPlugin.…
Browse files Browse the repository at this point in the history
…teardown`` (#8810)
  • Loading branch information
hendrikmakait committed Aug 2, 2024
1 parent 2173999 commit 798183d
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 189 deletions.
3 changes: 1 addition & 2 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5409,8 +5409,7 @@ async def _unregister_worker_plugin(self, name, nanny=None):

for response in responses.values():
if response["status"] == "error":
exc = response["exception"]
tb = response["traceback"]
_, exc, tb = clean_exception(**response)
raise exc.with_traceback(tb)
return responses

Expand Down
58 changes: 57 additions & 1 deletion distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import asyncio
import logging
import warnings

import pytest

from distributed import Worker, WorkerPlugin
from distributed.protocol.pickle import dumps
from distributed.utils_test import async_poll_for, gen_cluster, inc
from distributed.utils_test import async_poll_for, captured_logger, gen_cluster, inc


class MyPlugin(WorkerPlugin):
Expand Down Expand Up @@ -423,3 +424,58 @@ def setup(self, worker):
await c.register_plugin(second, idempotent=True)
assert "idempotentplugin" in a.plugins
assert a.plugins["idempotentplugin"].instance == "first"


class BrokenSetupPlugin(WorkerPlugin):
def setup(self, worker):
raise RuntimeError("test error")


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_register_plugin_with_broken_setup_to_existing_workers_raises(c, s, a):
with pytest.raises(RuntimeError, match="test error"):
with captured_logger("distributed.worker", level=logging.ERROR) as caplog:
await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1")
logs = caplog.getvalue()
assert "TestPlugin1 failed to setup" in logs
assert "test error" in logs


@gen_cluster(client=True, nthreads=[])
async def test_plugin_with_broken_setup_on_new_worker_logs(c, s):
await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1")

with captured_logger("distributed.worker", level=logging.ERROR) as caplog:
async with Worker(s.address):
pass
logs = caplog.getvalue()
assert "TestPlugin1 failed to setup" in logs
assert "test error" in logs


class BrokenTeardownPlugin(WorkerPlugin):
def teardown(self, worker):
raise RuntimeError("test error")


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_unregister_worker_plugin_with_broken_teardown_raises(c, s, a):
await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1")
with pytest.raises(RuntimeError, match="test error"):
with captured_logger("distributed.worker", level=logging.ERROR) as caplog:
await c.unregister_worker_plugin("TestPlugin1")
logs = caplog.getvalue()
assert "TestPlugin1 failed to teardown" in logs
assert "test error" in logs


@gen_cluster(client=True, nthreads=[])
async def test_plugin_with_broken_teardown_logs_on_close(c, s):
await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1")

with captured_logger("distributed.worker", level=logging.ERROR) as caplog:
async with Worker(s.address):
pass
logs = caplog.getvalue()
assert "TestPlugin1 failed to teardown" in logs
assert "test error" in logs
Loading

0 comments on commit 798183d

Please sign in to comment.