From fea5515030e3b79475e3555fec84e309177132f8 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 2 Aug 2024 15:02:45 +0200 Subject: [PATCH] Fix exception handling for ``NannyPlugin.setup`` and ``NannyPlugin.teardown`` (#8811) --- .../diagnostics/tests/test_nanny_plugin.py | 59 ++++++++++++++++++- distributed/nanny.py | 12 ++-- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/distributed/diagnostics/tests/test_nanny_plugin.py b/distributed/diagnostics/tests/test_nanny_plugin.py index db17fe70b5b..3c481dce26b 100644 --- a/distributed/diagnostics/tests/test_nanny_plugin.py +++ b/distributed/diagnostics/tests/test_nanny_plugin.py @@ -1,10 +1,12 @@ from __future__ import annotations +import logging + import pytest from distributed import Nanny, NannyPlugin from distributed.protocol.pickle import dumps -from distributed.utils_test import gen_cluster +from distributed.utils_test import captured_logger, gen_cluster @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) @@ -160,3 +162,58 @@ def setup(self, nanny): await c.register_plugin(second, idempotent=True) assert "idempotentplugin" in a.plugins assert a.plugins["idempotentplugin"].instance == "first" + + +class BrokenSetupPlugin(NannyPlugin): + def setup(self, nanny): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_register_plugin_with_broken_setup_to_existing_nannies_raises(c, s, a): + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_plugin_with_broken_setup_on_new_nanny_logs(c, s): + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + async with Nanny(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +class BrokenTeardownPlugin(NannyPlugin): + def teardown(self, nanny): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_unregister_nanny_plugin_with_broken_teardown_raises(c, s, a): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + await c.unregister_worker_plugin("TestPlugin1", nanny=True) + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_nanny_plugin_with_broken_teardown_logs_on_close(c, s): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + async with Nanny(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs diff --git a/distributed/nanny.py b/distributed/nanny.py index af0d9a62ad5..7a14ee65760 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -477,13 +477,14 @@ async def plugin_add( self.plugins[name] = plugin - logger.info("Starting Nanny plugin %s" % name) + logger.info("Starting Nanny plugin %s", name) if hasattr(plugin, "setup"): try: result = plugin.setup(nanny=self) if isawaitable(result): result = await result except Exception as e: + logger.exception("Nanny plugin %s failed to setup", name) return error_message(e) if getattr(plugin, "restart", False): await self.restart(reason=f"nanny-plugin-{name}-restart") @@ -500,6 +501,7 @@ async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage: if isawaitable(result): result = await result except Exception as e: + logger.exception("Nanny plugin %s failed to teardown", name) msg = error_message(e) return msg @@ -610,13 +612,7 @@ async def close( # type:ignore[override] await self.preloads.teardown() - teardowns = [ - plugin.teardown(self) - for plugin in self.plugins.values() - if hasattr(plugin, "teardown") - ] - - await asyncio.gather(*(td for td in teardowns if isawaitable(td))) + await asyncio.gather(*(self.plugin_remove(name) for name in self.plugins)) self.stop() if self.process is not None: