Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that adaptive only stops once #8807

Merged
merged 5 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions distributed/deploy/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def __init__(

self.target_duration = parse_timedelta(target_duration)

logger.info("Adaptive scaling started: minimum=%s maximum=%s", minimum, maximum)

super().__init__(
minimum=minimum, maximum=maximum, wait_count=wait_count, interval=interval
)
Expand Down
55 changes: 42 additions & 13 deletions distributed/deploy/adaptive_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict, deque
from collections.abc import Iterable
from datetime import timedelta
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Literal, cast

import tlz as toolz
from tornado.ioloop import IOLoop
Expand All @@ -17,12 +17,21 @@
from distributed.metrics import time

if TYPE_CHECKING:
from distributed.scheduler import WorkerState
from typing_extensions import TypeAlias

from distributed.scheduler import WorkerState

logger = logging.getLogger(__name__)


AdaptiveStateState: TypeAlias = Literal[
"starting",
"running",
"stopped",
"inactive",
]


class AdaptiveCore:
"""
The core logic for adaptive deployments, with none of the cluster details
Expand Down Expand Up @@ -89,6 +98,8 @@ class AdaptiveCore:
observed: set[WorkerState]
close_counts: defaultdict[WorkerState, int]
_adapting: bool
#: Whether this adaptive strategy is periodically adapting
_state: AdaptiveStateState
log: deque[tuple[float, dict]]

def __init__(
Expand All @@ -107,12 +118,6 @@ def __init__(
self.interval = parse_timedelta(interval, "seconds")
self.periodic_callback = None

def f():
try:
self.periodic_callback.start()
except AttributeError:
pass

if self.interval:
import weakref

Expand All @@ -124,8 +129,10 @@ async def _adapt():
await core.adapt()

self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000)
self.loop.add_callback(f)

self.loop.add_callback(self._start)
self._state = "starting"
else:
self._state = "inactive"
try:
self.plan = set()
self.requested = set()
Expand All @@ -140,12 +147,34 @@ async def _adapt():
maxlen=dask.config.get("distributed.admin.low-level-log-length")
)

def _start(self) -> None:
if self._state != "starting":
return

assert self.periodic_callback is not None
self.periodic_callback.start()
self._state = "running"
logger.info(
"Adaptive scaling started: minimum=%s maximum=%s",
self.minimum,
self.maximum,
)

def stop(self) -> None:
logger.info("Adaptive stop")
if self._state in ("inactive", "stopped"):
return

if self.periodic_callback:
if self._state == "running":
assert self.periodic_callback is not None
self.periodic_callback.stop()
self.periodic_callback = None
logger.info(
"Adaptive scaling stopped: minimum=%s maximum=%s",
self.minimum,
self.maximum,
)

self.periodic_callback = None
self._state = "stopped"

async def target(self) -> int:
"""The target number of workers that should exist"""
Expand Down
19 changes: 16 additions & 3 deletions distributed/deploy/tests/test_adaptive_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,12 @@ def safe_target(self):
raise OSError()

with captured_logger("distributed.deploy.adaptive_core") as log:
adapt = BadAdaptive(minimum=1, maximum=4)
await adapt.adapt()
adapt = BadAdaptive(minimum=1, maximum=4, interval="10ms")
while adapt._state != "stopped":
await asyncio.sleep(0.01)
text = log.getvalue()
assert "Adaptive stopping due to error" in text
assert "Adaptive stop" in text
assert "Adaptive scaling stopped" in text
assert not adapt._adapting
assert not adapt.periodic_callback

Expand Down Expand Up @@ -147,6 +148,18 @@ async def scale_down(self, workers=None):
adapt.stop()


@gen_test()
async def test_adaptive_logs_stopping_once():
with captured_logger("distributed.deploy.adaptive_core") as log:
adapt = MyAdaptive(interval="100ms")
while not adapt.periodic_callback.is_running():
await asyncio.sleep(0.01)
adapt.stop()
adapt.stop()
lines = log.getvalue().splitlines()
assert sum("Adaptive scaling stopped" in line for line in lines) == 1


@gen_test()
async def test_adapt_stop_del():
adapt = MyAdaptive(interval="100ms")
Expand Down
Loading