From 6ea9cbc5fba58a2f397e38c6f8219f6e547fc633 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 17 Jul 2023 12:39:05 +0100 Subject: [PATCH 01/21] Log stimulus_id in retire_worker (#8003) --- distributed/scheduler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b72712c6c4..8cbfaa1363 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4985,7 +4985,7 @@ async def remove_worker( host = get_address_host(address) - ws: WorkerState = self.workers[address] + ws = self.workers[address] event_msg = { "action": "remove-worker", @@ -4995,7 +4995,7 @@ async def remove_worker( event_msg["worker"] = address self.log_event("all", event_msg) - logger.info("Remove worker %s", ws) + logger.info(f"Remove worker {ws} ({stimulus_id=})") if close: with suppress(AttributeError, CommClosedError): self.stream_comms[address].send( @@ -5004,7 +5004,7 @@ async def remove_worker( self.remove_resources(address) - dh: dict = self.host_info[host] + dh = self.host_info[host] dh_addresses: set = dh["addresses"] dh_addresses.remove(address) dh["nthreads"] -= ws.nthreads @@ -5025,7 +5025,6 @@ async def remove_worker( recommendations: Recs = {} - ts: TaskState for ts in list(ws.processing): k = ts.key recommendations[k] = "released" From 9d516dad15c13c4f063eeaa1feff305cecedf7fc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 17 Jul 2023 19:21:00 -0500 Subject: [PATCH 02/21] Bump JamesIves/github-pages-deploy-action from 4.4.2 to 4.4.3 (#8008) --- .github/workflows/test-report.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-report.yaml b/.github/workflows/test-report.yaml index 3de10b867b..afe57d4b48 100644 --- a/.github/workflows/test-report.yaml +++ b/.github/workflows/test-report.yaml @@ -56,7 +56,7 @@ jobs: mv test_report.html test_short_report.html deploy/ - name: Deploy 🚀 - uses: JamesIves/github-pages-deploy-action@v4.4.2 + uses: JamesIves/github-pages-deploy-action@v4.4.3 with: branch: gh-pages folder: deploy From 2be7f35ee9fa47fde3341bd261f50f823134c296 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 18 Jul 2023 10:09:38 +0100 Subject: [PATCH 03/21] Add some top-level exposition to the p2p rechunking code (#7978) --- distributed/shuffle/_rechunk.py | 96 +++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 34be8103c7..884eb924fd 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -1,3 +1,99 @@ +""" +Utilities for rechunking arrays through p2p shuffles +==================================================== + +Tensors (or n-D arrays) in dask are split up across the workers as +regular n-D "chunks" or bricks. These bricks are stacked up to form +the global array. + +A key algorithm for these tensors is to "rechunk" them. That is to +reassemble the same global representation using differently shaped n-D +bricks. + +For example, to take an FFT of an n-D array, one uses a sequence of 1D +FFTs along each axis. The implementation in dask (and indeed almost +all distributed array frameworks) requires that 1D +axis along which the FFT is taken is local to a single brick. So to +perform the global FFT we need to arrange that each axis in turn is +local to bricks. + +This can be achieved through all-to-all communication between the +workers to exchange sub-pieces of their individual bricks, given a +"rechunking" scheme. + +To perform the redistribution, each input brick is cut up into some +number of smaller pieces, each of which contributes to one of the +output bricks. The mapping from input brick to output bricks +decomposes into the Cartesian product of axis-by-axis mappings. To +see this, consider first a 1D example. + +Suppose our array is split up into three equally sized bricks:: + + |----0----|----1----|----2----| + +And the requested output chunks are:: + + |--A--|--B--|----C----|---D---| + +So brick 0 contributes to output bricks A and B; brick 1 contributes +to B and C; and brick 2 contributes to C and D. + +Now consider a 2D example of the same problem:: + + +----0----+----1----+----2----+ + | | | | + α | | | + | | | | + +---------+---------+---------+ + | | | | + β | | | + | | | | + +---------+---------+---------+ + | | | | + γ | | | + | | | | + +---------+---------+---------+ + +Each brick can be described as the ordered pair of row and column +1D bricks, (0, α), (0, β), ..., (2, γ). Since the rechunking does +not also reshape the array, axes do not "interfere" with one another +when determining output bricks:: + + +--A--+--B--+----C----+---D---+ + | | | | | + Σ | | | | + | | | | | + +-----+ ----+---------+-------+ + | | | | | + | | | | | + | | | | | + Π | | | | + | | | | | + | | | | | + | | | | | + +-----+-----+---------+-------+ + +Consider the output (B, Σ) brick. This is contributed to by the +input (0, α) and (1, α) bricks. Determination of the subslices is +just done by slicing the the axes separately and combining them. + +The key thing to note here is that we never need to create, and +store, the dense 2D mapping, we can instead construct it on the fly +for each output brick in turn as necessary. + +The implementation here uses :func:`split_axes` to construct these +1D rechunkings. The output partitioning in +:meth:`~.ArrayRechunkRun.add_partition` then lazily constructs the +subsection of the Cartesian product it needs to determine the slices +of the current input brick. + +This approach relies on the generic p2p buffering machinery to +ensure that there are not too many small messages exchanged, since +no special effort is made to minimise messages between workers when +a worker might have two adjacent input bricks that are sliced into +the same output brick. +""" + from __future__ import annotations from typing import TYPE_CHECKING, NamedTuple From b7e5f8f97ef0eb95368122f29ac8915ae692f94e Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 18 Jul 2023 15:15:55 +0200 Subject: [PATCH 04/21] Fix shuffle code to work with pyarrow 13 (#8009) --- distributed/shuffle/_worker_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 097aa2237e..c634867349 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -951,7 +951,7 @@ def split_by_worker( # bytestream such that it cannot be deserialized anymore t = pa.Table.from_pandas(df, preserve_index=True) t = t.sort_by("_worker") - codes = np.asarray(t.select(["_worker"]))[0] + codes = np.asarray(t["_worker"]) t = t.drop(["_worker"]) del df @@ -983,7 +983,7 @@ def split_by_partition(t: pa.Table, column: str) -> dict[Any, pa.Table]: partitions.sort() t = t.sort_by(column) - partition = np.asarray(t.select([column]))[0] + partition = np.asarray(t[column]) splits = np.where(partition[1:] != partition[:-1])[0] + 1 splits = np.concatenate([[0], splits]) From d9a3457dc4b6019e428964e58c551cf0b3c9786f Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 19 Jul 2023 12:15:16 +0200 Subject: [PATCH 05/21] Pass `stimulus_id` to `SchedulerPlugin.remove_worker` and `SchedulerPlugin.transition` (#7974) Co-authored-by: crusaderky --- distributed/diagnostics/plugin.py | 5 +- distributed/diagnostics/progress.py | 8 +- .../tests/test_scheduler_plugin.py | 80 +++++++++++++++++-- .../diagnostics/tests/test_task_stream.py | 10 ++- distributed/diagnostics/websocket.py | 2 + distributed/scheduler.py | 29 ++++++- distributed/shuffle/_scheduler_plugin.py | 5 +- distributed/stealing.py | 16 ++-- 8 files changed, 132 insertions(+), 23 deletions(-) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 93fbf2a73e..02ae91a2ae 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -123,6 +123,7 @@ def transition( start: TaskStateState, finish: TaskStateState, *args: Any, + stimulus_id: str, **kwargs: Any, ) -> None: """Run whenever a task changes state @@ -143,6 +144,8 @@ def transition( One of released, waiting, processing, memory, error. finish : string Final state of the transition. + stimulus_id: string + ID of stimulus causing the transition. *args, **kwargs : More options passed when transitioning This may include worker ID, compute time, etc. @@ -164,7 +167,7 @@ def add_worker(self, scheduler: Scheduler, worker: str) -> None | Awaitable[None """ def remove_worker( - self, scheduler: Scheduler, worker: str + self, scheduler: Scheduler, worker: str, *, stimulus_id: str, **kwargs: Any ) -> None | Awaitable[None]: """Run when a worker leaves the cluster diff --git a/distributed/diagnostics/progress.py b/distributed/diagnostics/progress.py index 21532fcf07..11520d6a0c 100644 --- a/distributed/diagnostics/progress.py +++ b/distributed/diagnostics/progress.py @@ -104,7 +104,9 @@ async def setup(self): logger.debug("Set up Progress keys") for k in errors: - self.transition(k, None, "erred", exception=True) + self.transition( + k, None, "erred", stimulus_id="progress-setup", exception=True + ) def transition(self, key, start, finish, *args, **kwargs): if key in self.keys and start == "processing" and finish == "memory": @@ -240,7 +242,9 @@ def group_key(k): self.keys[k] = set() for k in errors: - self.transition(k, None, "erred", exception=True) + self.transition( + k, None, "erred", stimulus_id="multiprogress-setup", exception=True + ) logger.debug("Set up Progress keys") def transition(self, key, start, finish, *args, **kwargs): diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 92f72310da..7ab9e82514 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -16,7 +16,8 @@ def start(self, scheduler): scheduler.add_plugin(self, name="counter") self.count = 0 - def transition(self, key, start, finish, *args, **kwargs): + def transition(self, key, start, finish, *args, stimulus_id, **kwargs): + assert stimulus_id is not None if start == "processing" and finish == "memory": self.count += 1 @@ -51,8 +52,9 @@ def add_worker(self, worker, scheduler): assert scheduler is s events.append(("add_worker", worker)) - def remove_worker(self, worker, scheduler): + def remove_worker(self, worker, scheduler, *, stimulus_id, **kwargs): assert scheduler is s + assert stimulus_id is not None events.append(("remove_worker", worker)) plugin = MyPlugin() @@ -80,6 +82,70 @@ def remove_worker(self, worker, scheduler): assert events == [] +@gen_cluster(nthreads=[]) +async def test_remove_worker_renamed_kwargs_allowed(s): + events = [] + + class MyPlugin(SchedulerPlugin): + name = "MyPlugin" + + def remove_worker(self, worker, scheduler, **kwds): + assert scheduler is s + events.append(("remove_worker", worker)) + + plugin = MyPlugin() + s.add_plugin(plugin) + assert events == [] + + a = Worker(s.address) + await a + await a.close() + + assert events == [ + ("remove_worker", a.address), + ] + + events[:] = [] + s.remove_plugin(plugin.name) + async with Worker(s.address): + pass + assert events == [] + + +@gen_cluster(nthreads=[]) +async def test_remove_worker_without_kwargs_deprecated(s): + events = [] + + class DeprecatedPlugin(SchedulerPlugin): + name = "DeprecatedPlugin" + + def remove_worker(self, worker, scheduler): + assert scheduler is s + events.append(("remove_worker", worker)) + + plugin = DeprecatedPlugin() + with pytest.warns( + FutureWarning, + match="The signature of `SchedulerPlugin.remove_worker` now requires `\\*\\*kwargs`", + ): + s.add_plugin(plugin) + assert events == [] + + a = Worker(s.address) + await a + await a.close() + + assert events == [ + ("remove_worker", a.address), + ] + + events[:] = [] + s.remove_plugin(plugin.name) + async with Worker(s.address): + pass + assert events == [] + + @gen_cluster(nthreads=[]) async def test_async_add_remove_worker(s): events = [] @@ -91,7 +157,7 @@ async def add_worker(self, worker, scheduler): assert scheduler is s events.append(("add_worker", worker)) - async def remove_worker(self, worker, scheduler): + async def remove_worker(self, worker, scheduler, **kwargs): assert scheduler is s events.append(("remove_worker", worker)) @@ -135,7 +201,7 @@ async def add_worker(self, scheduler, worker): await asyncio.sleep(0) events.append((self.name, "add_worker", worker)) - async def remove_worker(self, scheduler, worker): + async def remove_worker(self, scheduler, worker, **kwargs): assert scheduler is s self.in_remove_worker.set() await self.block_remove_worker.wait() @@ -149,7 +215,7 @@ def add_worker(self, worker, scheduler): assert scheduler is s events.append((self.name, "add_worker", worker)) - def remove_worker(self, worker, scheduler): + def remove_worker(self, worker, scheduler, **kwargs): assert scheduler is s events.append((self.name, "remove_worker", worker)) @@ -229,7 +295,7 @@ async def add_worker(self, scheduler, worker): await asyncio.sleep(0) raise RuntimeError("Async add_worker failed") - async def remove_worker(self, scheduler, worker): + async def remove_worker(self, scheduler, worker, **kwargs): assert scheduler is s await asyncio.sleep(0) raise RuntimeError("Async remove_worker failed") @@ -257,7 +323,7 @@ def add_worker(self, scheduler, worker): assert scheduler is s raise RuntimeError("Async add_worker failed") - def remove_worker(self, scheduler, worker): + def remove_worker(self, scheduler, worker, **kwargs): assert scheduler is s raise RuntimeError("Async remove_worker failed") diff --git a/distributed/diagnostics/tests/test_task_stream.py b/distributed/diagnostics/tests/test_task_stream.py index 3c152a14d6..0e0206952e 100644 --- a/distributed/diagnostics/tests/test_task_stream.py +++ b/distributed/diagnostics/tests/test_task_stream.py @@ -94,16 +94,20 @@ async def test_no_startstops(c, s, a, b): await wait(future) assert len(tasks.buffer) == 1 - tasks.transition(future.key, "processing", "erred") + tasks.transition(future.key, "processing", "erred", stimulus_id="s1") # Transition was not recorded because it didn't contain `startstops` assert len(tasks.buffer) == 1 - tasks.transition(future.key, "processing", "erred", startstops=[]) + tasks.transition(future.key, "processing", "erred", stimulus_id="s2", startstops=[]) # Transition was not recorded because `startstops` was empty assert len(tasks.buffer) == 1 tasks.transition( - future.key, "processing", "erred", startstops=[dict(start=time(), stop=time())] + future.key, + "processing", + "erred", + stimulus_id="s3", + startstops=[dict(start=time(), stop=time())], ) assert len(tasks.buffer) == 2 diff --git a/distributed/diagnostics/websocket.py b/distributed/diagnostics/websocket.py index 5e77e1aa24..4df5de6d66 100644 --- a/distributed/diagnostics/websocket.py +++ b/distributed/diagnostics/websocket.py @@ -48,6 +48,8 @@ def transition(self, key, start, finish, *args, **kwargs): One of released, waiting, processing, memory, error. finish : string Final state of the transition. + stimulus_id: string + ID of stimulus causing the transition. *args, **kwargs : More options passed when transitioning This may include worker ID, compute time, etc. """ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8cbfaa1363..109a9e6729 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1978,7 +1978,9 @@ def _transition( self.tasks[ts.key] = ts for plugin in list(self.plugins.values()): try: - plugin.transition(key, start, actual_finish, **kwargs) + plugin.transition( + key, start, actual_finish, stimulus_id=stimulus_id, **kwargs + ) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts.state == "forgotten": @@ -5069,7 +5071,19 @@ async def remove_worker( awaitables = [] for plugin in list(self.plugins.values()): try: - result = plugin.remove_worker(scheduler=self, worker=address) + try: + result = plugin.remove_worker( + scheduler=self, worker=address, stimulus_id=stimulus_id + ) + except TypeError: + parameters = inspect.signature(plugin.remove_worker).parameters + if "stimulus_id" not in parameters and not any( + p.kind is p.VAR_KEYWORD for p in parameters.values() + ): + # Deprecated (see add_plugin) + result = plugin.remove_worker(scheduler=self, worker=address) # type: ignore + else: + raise if inspect.isawaitable(result): awaitables.append(result) except Exception as e: @@ -5724,6 +5738,15 @@ def add_plugin( category=UserWarning, ) + parameters = inspect.signature(plugin.remove_worker).parameters + if not any(p.kind is p.VAR_KEYWORD for p in parameters.values()): + warnings.warn( + "The signature of `SchedulerPlugin.remove_worker` now requires `**kwargs` " + "to ensure that plugins remain forward-compatible. Not including " + "`**kwargs` in the signature will no longer be supported in future versions.", + FutureWarning, + ) + self.plugins[name] = plugin def remove_plugin( @@ -8420,7 +8443,7 @@ def add_worker(self, scheduler: Scheduler, worker: str) -> None: except CommClosedError: scheduler.remove_plugin(name=self.name) - def remove_worker(self, scheduler: Scheduler, worker: str) -> None: + def remove_worker(self, scheduler: Scheduler, worker: str, **kwargs: Any) -> None: try: self.bcomm.send(["remove", worker]) except CommClosedError: diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 8802917342..911f4f89b9 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -309,7 +309,9 @@ def _unset_restriction(self, ts: TaskState) -> None: original_restrictions = ts.annotations.pop("shuffle_original_restrictions") self.scheduler.set_restrictions({ts.key: original_restrictions}) - def remove_worker(self, scheduler: Scheduler, worker: str) -> None: + def remove_worker( + self, scheduler: Scheduler, worker: str, *, stimulus_id: str, **kwargs: Any + ) -> None: from time import time stimulus_id = f"shuffle-failed-worker-left-{time()}" @@ -342,6 +344,7 @@ def transition( start: TaskStateState, finish: TaskStateState, *args: Any, + stimulus_id: str, **kwargs: Any, ) -> None: if finish not in ("released", "forgotten"): diff --git a/distributed/stealing.py b/distributed/stealing.py index afc76ef515..45b26101a7 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -21,7 +21,13 @@ if TYPE_CHECKING: # Recursive imports - from distributed.scheduler import Scheduler, SchedulerState, TaskState, WorkerState + from distributed.scheduler import ( + Scheduler, + SchedulerState, + TaskState, + TaskStateState, + WorkerState, + ) # Stealing requires multiple network bounces and if successful also task # submission which may include code serialization. Therefore, be very @@ -155,7 +161,7 @@ def log(self, msg: Any) -> None: def add_worker(self, scheduler: Any = None, worker: Any = None) -> None: self.stealable[worker] = tuple(set() for _ in range(15)) - def remove_worker(self, scheduler: Scheduler, worker: str) -> None: + def remove_worker(self, scheduler: Scheduler, worker: str, **kwargs: Any) -> None: del self.stealable[worker] def teardown(self) -> None: @@ -167,10 +173,8 @@ def teardown(self) -> None: def transition( self, key: str, - start: str, - finish: str, - compute_start: Any = None, - compute_stop: Any = None, + start: TaskStateState, + finish: TaskStateState, *args: Any, **kwargs: Any, ) -> None: From 19c9f4b8d1c10ef917d60aec4357d846d3241f7e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 20 Jul 2023 16:31:19 +0100 Subject: [PATCH 06/21] gather_dep should handle CancelledError (#8013) --- distributed/tests/test_worker.py | 72 ++++++++++++++++++++++++++++++++ distributed/worker.py | 7 +++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index cf902c0b41..29b789618c 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3351,6 +3351,78 @@ async def test_gather_dep_no_longer_in_flight_tasks(c, s, a): assert not any("missing-dep" in msg for msg in f2_story) +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_gather_dep_cancelled_error(c, s, a): + """Something somewhere in the networking stack raises CancelledError while + gather_dep is running + + See Also + -------- + test_get_data_cancelled_error + https://github.com/dask/distributed/issues/8006 + """ + async with BlockedGetData(s.address) as b: + x = c.submit(inc, 1, key="x", workers=[b.address]) + y = c.submit(inc, x, key="y", workers=[a.address]) + await b.in_get_data.wait() + tasks = { + task for task in asyncio.all_tasks() if "gather_dep" in task.get_name() + } + assert tasks + # There should be only one task but cope with finding more just in case a + # previous test didn't properly clean up + for task in tasks: + task.cancel() + + b.block_get_data.set() + assert await y == 3 + + assert_story( + a.state.story("x"), + [ + ("x", "fetch", "flight", "flight", {}), + ("x", "flight", "missing", "missing", {}), + ("x", "missing", "fetch", "fetch", {}), + ("x", "fetch", "flight", "flight", {}), + ("x", "flight", "memory", "memory", {"y": "ready"}), + ], + ) + + +@gen_cluster(client=True, nthreads=[("", 1)], timeout=5) +async def test_get_data_cancelled_error(c, s, a): + """Something somewhere in the networking stack raises CancelledError while + get_data is running + + See Also + -------- + test_gather_dep_cancelled_error + https://github.com/dask/distributed/issues/8006 + """ + + class FlakyInboundRPC(Worker): + flake = 0 + + def handle_comm(self, comm): + if self.flake: + self.flake -= 1 + + async def write(*args, **kwargs): + raise asyncio.CancelledError() + + comm.write = write + + return super().handle_comm(comm) + + async with FlakyInboundRPC(s.address) as b: + x = c.submit(inc, 1, key="x", workers=[b.address]) + await wait(x) + b.flake = 2 + y = c.submit(inc, x, key="y", workers=[a.address]) + assert await y == 3 + assert b.flake == 0 + + @gen_cluster(client=True, nthreads=[("", 1)]) async def test_Worker__to_dict(c, s, a): x = c.submit(inc, 1, key="x") diff --git a/distributed/worker.py b/distributed/worker.py index fd4941a22b..fe52f7960b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2081,7 +2081,10 @@ async def gather_dep( stimulus_id=f"gather-dep-success-{time()}", ) - except OSError: + # Note: CancelledError and asyncio.TimeoutError are rare conditions + # that can be raised by the network stack. + # See https://github.com/dask/distributed/issues/8006 + except (OSError, asyncio.CancelledError, asyncio.TimeoutError): logger.exception("Worker stream died during communication: %s", worker) self.state.log.append( ("gather-dep-failed", worker, to_gather, stimulus_id, time()) @@ -2094,6 +2097,8 @@ async def gather_dep( except Exception as e: # e.g. data failed to deserialize + # FIXME this will deadlock the cluster + # https://github.com/dask/distributed/issues/6705 logger.exception(e) self.state.log.append( ("gather-dep-failed", worker, to_gather, stimulus_id, time()) From fc41f4cd249078d934d9e92bee0588bb0f8081dc Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 20 Jul 2023 17:41:16 +0200 Subject: [PATCH 07/21] Add test when not ``repartitioning`` for ``p2p`` in ``set_index`` (#8016) --- distributed/shuffle/tests/test_shuffle.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 45688c61ec..664b31d011 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1853,3 +1853,17 @@ async def test_handle_null_partitions_p2p_shuffling(c, s, *workers): await c.close() await asyncio.gather(*[check_worker_cleanup(w) for w in workers]) await check_scheduler_cleanup(s) + + +@gen_cluster(client=True) +async def test_set_index_p2p(c, s, *workers): + df = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6, 7, 8], "b": 1}) + ddf = dd.from_pandas(df, npartitions=3) + ddf = ddf.set_index("a", shuffle="p2p", divisions=(1, 3, 8)) + assert ddf.npartitions == 2 + result = await c.compute(ddf) + dd.assert_eq(result, df.set_index("a")) + + await c.close() + await asyncio.gather(*[check_worker_cleanup(w) for w in workers]) + await check_scheduler_cleanup(s) From 70c9eb85217407e6a0faa7534b6c88b04e2cab45 Mon Sep 17 00:00:00 2001 From: Irina Truong Date: Thu, 20 Jul 2023 10:42:14 -0700 Subject: [PATCH 08/21] Fix for ``TypeError: '<' not supported`` in graph dashboard (#8017) --- distributed/diagnostics/graph_layout.py | 11 ++++-- .../diagnostics/tests/test_graph_layout.py | 11 ++++++ distributed/tests/test_utils.py | 36 +++++++++++++++++++ distributed/utils.py | 35 ++++++++++++++++++ 4 files changed, 91 insertions(+), 2 deletions(-) diff --git a/distributed/diagnostics/graph_layout.py b/distributed/diagnostics/graph_layout.py index ba8003467a..966e78ea01 100644 --- a/distributed/diagnostics/graph_layout.py +++ b/distributed/diagnostics/graph_layout.py @@ -3,6 +3,7 @@ import uuid from distributed.diagnostics.plugin import SchedulerPlugin +from distributed.utils import TupleComparable class GraphLayout(SchedulerPlugin): @@ -48,7 +49,9 @@ def __init__(self, scheduler): def update_graph( self, scheduler, *, dependencies=None, priority=None, tasks=None, **kwargs ): - stack = sorted(tasks, key=lambda k: priority.get(k, 0), reverse=True) + stack = sorted( + tasks, key=lambda k: TupleComparable(priority.get(k, 0)), reverse=True + ) while stack: key = stack.pop() if key in self.x or key not in scheduler.tasks: @@ -58,7 +61,11 @@ def update_graph( if not all(dep in self.y for dep in deps): stack.append(key) stack.extend( - sorted(deps, key=lambda k: priority.get(k, 0), reverse=True) + sorted( + deps, + key=lambda k: TupleComparable(priority.get(k, 0)), + reverse=True, + ) ) continue else: diff --git a/distributed/diagnostics/tests/test_graph_layout.py b/distributed/diagnostics/tests/test_graph_layout.py index 64152bca17..cb99cfdd9d 100644 --- a/distributed/diagnostics/tests/test_graph_layout.py +++ b/distributed/diagnostics/tests/test_graph_layout.py @@ -99,3 +99,14 @@ async def test_unique_positions(c, s, a, b): y_positions = [(gl.x[k], gl.y[k]) for k in gl.x] assert len(y_positions) == len(set(y_positions)) + + +@gen_cluster(client=True) +async def test_layout_with_priorities(c, s, a, b): + gl = GraphLayout(s) + s.add_plugin(gl) + + low = c.submit(inc, 1, key="low", priority=1) + mid = c.submit(inc, 2, key="mid", priority=2) + high = c.submit(inc, 3, key="high", priority=3) + await wait([high, mid, low]) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 313137135a..2275227529 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -33,6 +33,7 @@ LoopRunner, RateLimiterFilter, TimeoutError, + TupleComparable, _maybe_complex, ensure_ip, ensure_memoryview, @@ -1043,3 +1044,38 @@ def test_rate_limiter_filter(caplog): "Hello again!", "Hello once more!", ] + + +@pytest.mark.parametrize( + "obj1,obj2,expected", + [ + [(1, 2), (1, 2), False], + [(1, 2), (1, 3), True], + [1, 1, False], + [1, 2, True], + [None, 0, False], + [None, (1, 2), True], + ], +) +def test_tuple_comparable_lt(obj1, obj2, expected): + assert (TupleComparable(obj1) < TupleComparable(obj2)) == expected + + +@pytest.mark.parametrize( + "obj1,obj2,expected", + [ + [(1, 2), (1, 2), True], + [(1, 2), (1, 3), False], + [1, 1, True], + [1, 2, False], + [None, 0, True], + [None, (1, 2), False], + ], +) +def test_tuple_comparable_eq(obj1, obj2, expected): + assert (TupleComparable(obj1) == TupleComparable(obj2)) == expected + + +def test_tuple_comparable_error(): + with pytest.raises(ValueError): + TupleComparable("string") diff --git a/distributed/utils.py b/distributed/utils.py index 5ad983e706..ab8feb72f3 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1925,3 +1925,38 @@ async def wait_for(fut: Awaitable[T], timeout: float) -> T: async def wait_for(fut: Awaitable[T], timeout: float) -> T: return await asyncio.wait_for(fut, timeout) + + +class TupleComparable: + """Wrap object so that we can compare tuple, int or None + + When comparing two objects of different types Python fails + + >>> (1, 2) < 1 + Traceback (most recent call last): + ... + TypeError: '<' not supported between instances of 'tuple' and 'int' + + This class replaces None with 0, and wraps ints with tuples + + >>> TupleComparable((1, 2)) < TupleComparable(1) + False + """ + + __slots__ = ("obj",) + + def __init__(self, obj): + if obj is None: + self.obj = (0,) + elif isinstance(obj, tuple): + self.obj = obj + elif isinstance(obj, (int, float)): + self.obj = (obj,) + else: + raise ValueError(f"Object must be tuple, int, float or None, got {obj}") + + def __eq__(self, other): + return self.obj == other.obj + + def __lt__(self, other): + return self.obj < other.obj From 52279792a56cf92982d31763b4e99af1d4affa49 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 20 Jul 2023 16:06:03 -0500 Subject: [PATCH 09/21] bump version to 2023.7.1 --- docs/source/changelog.rst | 38 ++++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 45c84e1266..41d6ef03f7 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,43 @@ Changelog ========= +.. _v2023.7.1: + +2023.7.1 +-------- + +Released on July 20, 2023 + +Enhancements +^^^^^^^^^^^^ +- ``gather_dep`` should handle ``CancelledError`` (:pr:`8013`) `crusaderky`_ +- Pass ``stimulus_id`` to ``SchedulerPlugin.remove_worker`` and ``SchedulerPlugin.transition`` (:pr:`7974`) `Hendrik Makait`_ +- Log ``stimulus_id`` in ``retire_worker`` (:pr:`8003`) `crusaderky`_ +- Use ``BufferOutputStream`` in P2P (:pr:`7991`) `Florian Jetter`_ +- Add Coiled to ignored modules for code sniffing (:pr:`7986`) `Matthew Rocklin`_ +- Progress bar can group tasks by span (:pr:`7952`) `Irina Truong`_ +- Improved error messages for P2P shuffling (:pr:`7979`) `Hendrik Makait`_ +- Reduce removing comms log to debug level (:pr:`7972`) `Florian Jetter`_ + +Bug Fixes +^^^^^^^^^ +- Fix for ``TypeError: '<' not supported`` in graph dashboard (:pr:`8017`) `Irina Truong`_ +- Fix shuffle code to work with ``pyarrow`` 13 (:pr:`8009`) `Joris Van den Bossche`_ + +Documentation +^^^^^^^^^^^^^ +- Add some top-level exposition to the p2p rechunking code (:pr:`7978`) `Lawrence Mitchell`_ + +Maintenance +^^^^^^^^^^^ +- Add test when not ``repartitioning`` for ``p2p`` in ``set_index`` (:pr:`8016`) `Patrick Hoefler`_ +- Bump ``JamesIves/github-pages-deploy-action`` from 4.4.2 to 4.4.3 (:pr:`8008`) +- Configure asyncio loop using ``loop_factory`` kwarg rather than using the ``set_event_loop_policy`` (:pr:`7969`) `Thomas Grainger`_ +- Fix P2P worker cleanup (:pr:`7981`) `Hendrik Makait`_ +- Skip ``click`` v8.1.4 in mypy ``pre-commit`` hook (:pr:`7989`) `Thomas Grainger`_ +- Remove accidental duplicated conversion of ``pyarrow`` ``Table`` to pandas (:pr:`7983`) `Joris Van den Bossche`_ + + .. _v2023.7.0: 2023.7.0 @@ -5106,3 +5143,4 @@ significantly without many new features. .. _`ypogorelova`: https://github.com/ypogorelova .. _`Patrick Hoefler`: https://github.com/phofl .. _`Irina Truong`: https://github.com/j-bennet +.. _`Joris Van den Bossche`: https://github.com/jorisvandenbossche diff --git a/pyproject.toml b/pyproject.toml index b8a367cd5f..1457aead70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ requires-python = ">=3.9" dependencies = [ "click >= 8.0", "cloudpickle >= 1.5.0", - "dask == 2023.7.0", + "dask == 2023.7.1", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.0", From c8ec28361164920d780ffd9e1d7827ea8ec136a5 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 24 Jul 2023 13:13:18 +0200 Subject: [PATCH 10/21] Fix log message (#8029) --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 109a9e6729..76bafb54bf 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5055,7 +5055,7 @@ async def remove_worker( "Task %s marked as failed because %d workers died" " while trying to run it", ts.key, - self.allowed_failures, + ts.suspicious, ) for ts in list(ws.has_what): From 472df846799cdc8258cf100f6d69a4ffb0bd6da6 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 24 Jul 2023 14:34:43 +0200 Subject: [PATCH 11/21] Send shards grouped by input chunk in P2P rechunking (#8010) Co-authored-by: Lawrence Mitchell --- distributed/shuffle/_buffer.py | 24 +++--- distributed/shuffle/_rechunk.py | 25 ------- distributed/shuffle/_worker_plugin.py | 75 +++++++++++-------- distributed/shuffle/tests/test_buffer.py | 26 ++++--- distributed/shuffle/tests/test_comm_buffer.py | 18 ++--- distributed/shuffle/tests/test_disk_buffer.py | 8 +- 6 files changed, 77 insertions(+), 99 deletions(-) diff --git a/distributed/shuffle/_buffer.py b/distributed/shuffle/_buffer.py index 1a3afc6b31..e329d54dcc 100644 --- a/distributed/shuffle/_buffer.py +++ b/distributed/shuffle/_buffer.py @@ -166,14 +166,14 @@ def _continue() -> bool: self._shards_available.notify_all() await self.process(part_id, shards, size) - async def write(self, data: dict[str, list[ShardType]]) -> None: + async def write(self, data: dict[str, ShardType]) -> None: """ - Writes many objects into the local buffers, blocks until ready for more + Writes objects into the local buffers, blocks until ready for more Parameters ---------- data: dict - A dictionary mapping destinations to lists of objects that should + A dictionary mapping destinations to the object that should be written to that destination Notes @@ -193,13 +193,7 @@ async def write(self, data: dict[str, list[ShardType]]) -> None: if not data: return - shards = None - size = 0 - - sizes = {} - for id_, shards in data.items(): - size = sum(map(sizeof, shards)) - sizes[id_] = size + sizes = {worker: sizeof(shard) for worker, shard in data.items()} total_batch_size = sum(sizes.values()) self.bytes_memory += total_batch_size self.bytes_total += total_batch_size @@ -207,14 +201,14 @@ async def write(self, data: dict[str, list[ShardType]]) -> None: if self.memory_limiter: self.memory_limiter.increase(total_batch_size) async with self._shards_available: - for id_, shards in data.items(): - self.shards[id_].extend(shards) - self.sizes[id_] += sizes[id_] + for worker, shard in data.items(): + self.shards[worker].append(shard) + self.sizes[worker] += sizes[worker] self._shards_available.notify() if self.memory_limiter: await self.memory_limiter.wait_for_available() - del data, shards - assert size + del data + assert total_batch_size def raise_on_exception(self) -> None: """Raises an exception if something went wrong during writing""" diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 884eb924fd..a35b75d10a 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -196,31 +196,6 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: return da.Array(graph, name, chunks, meta=x) -class ShardID(NamedTuple): - """Unique identifier of an individual shard within an array rechunk - - When rechunking a 1d-array with two chunks into a 1d-array with a single chunk - >>> old = ((2, 2),) # doctest: +SKIP - >>> new = ((4),) # doctest: +SKIP - >>> rechunk_slicing(old, new) # doctest: +SKIP - { - # The first chunk of the old array belongs to the first - # chunk of the new array at the first sub-index - (0,): [(ShardID((0,), (0,)), (slice(0, 2, None),))], - - # The second chunk of the old array belongs to the first - # chunk of the new array at the second sub-index - (1,): [(ShardID((0,), (1,)), (slice(0, 2, None),))], - } - """ - - #: Index of the new output chunk to which the shard belongs - chunk_index: NDIndex - #: Index of the shard within the n-dimensional array of shards that will be - # concatenated into the new chunk - shard_index: NDIndex - - class Split(NamedTuple): """Slice of a chunk that is concatenated with other splits to create a new chunk diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index c634867349..8a3ab7ba72 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -30,9 +30,7 @@ from distributed.shuffle._comms import CommShardsBuffer from distributed.shuffle._disk import DiskShardsBuffer from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._rechunk import ChunkedAxes, NDIndex -from distributed.shuffle._rechunk import ShardID as ArrayRechunkShardID -from distributed.shuffle._rechunk import split_axes +from distributed.shuffle._rechunk import ChunkedAxes, NDIndex, split_axes from distributed.shuffle._shuffle import ShuffleId, ShuffleType from distributed.sizeof import sizeof from distributed.utils import log_errors, sync @@ -45,7 +43,6 @@ from distributed.worker import Worker -T_transfer_shard_id = TypeVar("T_transfer_shard_id") T_partition_id = TypeVar("T_partition_id") T_partition_type = TypeVar("T_partition_type") T = TypeVar("T") @@ -57,7 +54,7 @@ class ShuffleClosedError(RuntimeError): pass -class ShuffleRun(Generic[T_transfer_shard_id, T_partition_id, T_partition_type]): +class ShuffleRun(Generic[T_partition_id, T_partition_type]): def __init__( self, id: ShuffleId, @@ -93,7 +90,7 @@ def __init__( self.diagnostics: dict[str, float] = defaultdict(float) self.transferred = False - self.received: set[T_transfer_shard_id] = set() + self.received: set[T_partition_id] = set() self.total_recvd = 0 self.start_time = time.time() self._exception: Exception | None = None @@ -122,7 +119,7 @@ async def barrier(self) -> None: await self.scheduler.shuffle_barrier(id=self.id, run_id=self.run_id) async def send( - self, address: str, shards: list[tuple[T_transfer_shard_id, bytes]] + self, address: str, shards: list[tuple[T_partition_id, bytes]] ) -> None: self.raise_if_closed() return await self.rpc(address).shuffle_receive( @@ -151,12 +148,12 @@ def heartbeat(self) -> dict[str, Any]: } async def _write_to_comm( - self, data: dict[str, list[tuple[T_transfer_shard_id, bytes]]] + self, data: dict[str, tuple[T_partition_id, bytes]] ) -> None: self.raise_if_closed() await self._comm_buffer.write(data) - async def _write_to_disk(self, data: dict[NDIndex, list[bytes]]) -> None: + async def _write_to_disk(self, data: dict[NDIndex, bytes]) -> None: self.raise_if_closed() await self._disk_buffer.write( {"_".join(str(i) for i in k): v for k, v in data.items()} @@ -205,7 +202,7 @@ def _read_from_disk(self, id: NDIndex) -> bytes: data: bytes = self._disk_buffer.read("_".join(str(i) for i in id)) return data - async def receive(self, data: list[tuple[T_transfer_shard_id, bytes]]) -> None: + async def receive(self, data: list[tuple[T_partition_id, bytes]]) -> None: await self._receive(data) async def _ensure_output_worker(self, i: T_partition_id, key: str) -> None: @@ -225,7 +222,7 @@ def _get_assigned_worker(self, i: T_partition_id) -> str: """Get the address of the worker assigned to the output partition""" @abc.abstractmethod - async def _receive(self, data: list[tuple[T_transfer_shard_id, bytes]]) -> None: + async def _receive(self, data: list[tuple[T_partition_id, bytes]]) -> None: """Receive shards belonging to output partitions of this shuffle run""" @abc.abstractmethod @@ -241,7 +238,7 @@ async def get_output_partition( """Get an output partition to the shuffle run""" -class ArrayRechunkRun(ShuffleRun[ArrayRechunkShardID, NDIndex, "np.ndarray"]): +class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): """State for a single active rechunk execution This object is responsible for splitting, sending, receiving and combining @@ -323,44 +320,57 @@ def __init__( self.worker_for = worker_for self.split_axes = split_axes(old, new) - async def _receive(self, data: list[tuple[ArrayRechunkShardID, bytes]]) -> None: + async def _receive(self, data: list[tuple[NDIndex, bytes]]) -> None: self.raise_if_closed() - buffers = defaultdict(list) + filtered = [] for d in data: id, payload = d if id in self.received: continue + filtered.append(payload) self.received.add(id) self.total_recvd += sizeof(d) - - buffers[id.chunk_index].append(payload) - del data - if not buffers: + if not filtered: return try: - await self._write_to_disk(buffers) + shards = await self.offload(self._repartition_shards, filtered) + del filtered + await self._write_to_disk(shards) except Exception as e: self._exception = e raise + def _repartition_shards(self, data: list[bytes]) -> dict[NDIndex, bytes]: + repartitioned: defaultdict[ + NDIndex, list[tuple[NDIndex, np.ndarray]] + ] = defaultdict(list) + for buffer in data: + for id, shard in pickle.loads(buffer): + repartitioned[id].append(shard) + return {k: pickle.dumps(v) for k, v in repartitioned.items()} + async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int: self.raise_if_closed() if self.transferred: raise RuntimeError(f"Cannot add more partitions to {self}") - def _() -> dict[str, list[tuple[ArrayRechunkShardID, bytes]]]: - """Return a mapping of worker addresses to a list of tuples of shard IDs - and shard data. + def _() -> dict[str, tuple[NDIndex, bytes]]: + """Return a mapping of worker addresses to a tuple of input partition + IDs and shard data. + + TODO: Overhaul! As shard data, we serialize the payload together with the sub-index of the slice within the new chunk. To assemble the new chunk from its shards, it needs the sub-index to know where each shard belongs within the chunk. Adding the sub-index into the serialized payload on the sender allows us to write the serialized payload directly to disk on the receiver. """ - out: dict[str, list[tuple[ArrayRechunkShardID, bytes]]] = defaultdict(list) + out: dict[ + str, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]] + ] = defaultdict(list) from itertools import product ndsplits = product( @@ -369,11 +379,10 @@ def _() -> dict[str, list[tuple[ArrayRechunkShardID, bytes]]]: for ndsplit in ndsplits: chunk_index, shard_index, ndslice = zip(*ndsplit) - id = ArrayRechunkShardID(chunk_index, shard_index) out[self.worker_for[chunk_index]].append( - (id, pickle.dumps((shard_index, data[ndslice]))) + (chunk_index, (shard_index, data[ndslice])) ) - return out + return {k: (partition_id, pickle.dumps(v)) for k, v in out.items()} out = await self.offload(_) await self._write_to_comm(out) @@ -401,7 +410,7 @@ def _get_assigned_worker(self, id: NDIndex) -> str: return self.worker_for[id] -class DataFrameShuffleRun(ShuffleRun[int, int, "pd.DataFrame"]): +class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): """State for a single active shuffle execution This object is responsible for splitting, sending, receiving and combining @@ -503,25 +512,25 @@ async def _receive(self, data: list[tuple[int, bytes]]) -> None: self._exception = e raise - def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, list[bytes]]: + def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, bytes]: table = list_of_buffers_to_table(data) groups = split_by_partition(table, self.column) assert len(table) == sum(map(len, groups.values())) del data - return {(k,): [serialize_table(v)] for k, v in groups.items()} + return {(k,): serialize_table(v) for k, v in groups.items()} async def add_partition(self, data: pd.DataFrame, partition_id: int) -> int: self.raise_if_closed() if self.transferred: raise RuntimeError(f"Cannot add more partitions to {self}") - def _() -> dict[str, list[tuple[int, bytes]]]: + def _() -> dict[str, tuple[int, bytes]]: out = split_by_worker( data, self.column, self.worker_for, ) - out = {k: [(partition_id, serialize_table(t))] for k, t in out.items()} + out = {k: (partition_id, serialize_table(t)) for k, t in out.items()} return out out = await self.offload(_) @@ -1005,8 +1014,8 @@ def convert_chunk(data: bytes) -> np.ndarray: shards: dict[NDIndex, np.ndarray] = {} while file.tell() < len(data): - index, shard = pickle.load(file) - shards[index] = shard + for index, shard in pickle.load(file): + shards[index] = shard subshape = [max(dim) + 1 for dim in zip(*shards.keys())] assert len(shards) == np.prod(subshape) diff --git a/distributed/shuffle/tests/test_buffer.py b/distributed/shuffle/tests/test_buffer.py index 616fb733a1..53e862d829 100644 --- a/distributed/shuffle/tests/test_buffer.py +++ b/distributed/shuffle/tests/test_buffer.py @@ -39,17 +39,17 @@ def read(self, id: str) -> bytes: @pytest.mark.parametrize( - "big_payload", + "big_payloads", [ - {"big": [gen_bytes(2, limit)]}, - {"big": [gen_bytes(0.5, limit)] * 4}, - {f"big-{ix}": [gen_bytes(0.5, limit)] for ix in range(4)}, - {f"big-{ix}": [gen_bytes(0.5, limit)] * 2 for ix in range(2)}, + [{"big": gen_bytes(2, limit)}], + [{"big": gen_bytes(0.5, limit)}] * 4, + [{f"big-{ix}": gen_bytes(0.5, limit)} for ix in range(4)], + [{f"big-{ix}": gen_bytes(0.5, limit)} for ix in range(2)] * 2, ], ) @gen_test() -async def test_memory_limit(big_payload): - small_payload = {"small": [gen_bytes(0.1, limit)]} +async def test_memory_limit(big_payloads): + small_payload = {"small": gen_bytes(0.1, limit)} limiter = ResourceLimiter(limit) @@ -83,15 +83,17 @@ async def test_memory_limit(big_payload): while not buf.memory_limiter.free(): await asyncio.sleep(0.1) buf.allow_process.clear() - big = asyncio.create_task(buf.write(big_payload)) + big_tasks = [ + asyncio.create_task(buf.write(big_payload)) for big_payload in big_payloads + ] small = asyncio.create_task(buf.write(small_payload)) with pytest.raises(asyncio.TimeoutError): - await wait_for(asyncio.shield(big), 0.1) + await wait_for(asyncio.shield(asyncio.gather(*big_tasks)), 0.1) with pytest.raises(asyncio.TimeoutError): await wait_for(asyncio.shield(small), 0.1) # Puts only return once we're below memory limit buf.allow_process.set() - await big + await asyncio.gather(*big_tasks) await small # Once the big write is through, we can write without blocking again before = buf.memory_limiter.time_blocked_total @@ -120,10 +122,10 @@ async def test_memory_limit_blocked_exception(): limit = parse_bytes("10.0 MiB") big_payload = { - "shard-1": [gen_bytes(2, limit)], + "shard-1": gen_bytes(2, limit), } broken_payload = { - "error": ["not-bytes"], + "error": "not-bytes", } limiter = ResourceLimiter(limit) async with BufferShardsBroken( diff --git a/distributed/shuffle/tests/test_comm_buffer.py b/distributed/shuffle/tests/test_comm_buffer.py index ee1441e9ac..f10d945a14 100644 --- a/distributed/shuffle/tests/test_comm_buffer.py +++ b/distributed/shuffle/tests/test_comm_buffer.py @@ -21,8 +21,8 @@ async def send(address, shards): d[address].extend(shards) mc = CommShardsBuffer(send=send) - await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) - await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) + await mc.write({"x": b"0" * 1000, "y": b"1" * 500}) + await mc.write({"x": b"0" * 1000, "y": b"1" * 500}) await mc.flush() @@ -38,13 +38,13 @@ async def send(address, shards): raise Exception(123) mc = CommShardsBuffer(send=send) - await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) + await mc.write({"x": b"0" * 1000, "y": b"1" * 500}) while not mc._exception: await asyncio.sleep(0.1) with pytest.raises(Exception, match="123"): - await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) + await mc.write({"x": b"0" * 1000, "y": b"1" * 500}) await mc.flush() @@ -64,8 +64,8 @@ async def send(address, shards): sending_first.set() mc = CommShardsBuffer(send=send, concurrency_limit=1) - await mc.write({"x": [b"0"], "y": [b"1"]}) - await mc.write({"x": [b"0"], "y": [b"1"]}) + await mc.write({"x": b"0", "y": b"1"}) + await mc.write({"x": b"0", "y": b"1"}) flush_task = asyncio.create_task(mc.flush()) await sending_first.wait() block_send.clear() @@ -96,8 +96,7 @@ async def send(address, shards): send=send, memory_limiter=ResourceLimiter(parse_bytes("100 MiB")) ) payload = { - x: [gen_bytes(frac, comm_buffer.memory_limiter._maxvalue)] - for x in range(nshards) + x: gen_bytes(frac, comm_buffer.memory_limiter._maxvalue) for x in range(nshards) } async with comm_buffer as mc: @@ -138,8 +137,7 @@ async def send(address, shards): send=send, memory_limiter=ResourceLimiter(parse_bytes("100 MiB")) ) payload = { - x: [gen_bytes(frac, comm_buffer.memory_limiter._maxvalue)] - for x in range(nshards) + x: gen_bytes(frac, comm_buffer.memory_limiter._maxvalue) for x in range(nshards) } async with comm_buffer as mc: diff --git a/distributed/shuffle/tests/test_disk_buffer.py b/distributed/shuffle/tests/test_disk_buffer.py index 39cced8ec7..76a4a1e70c 100644 --- a/distributed/shuffle/tests/test_disk_buffer.py +++ b/distributed/shuffle/tests/test_disk_buffer.py @@ -13,8 +13,8 @@ @gen_test() async def test_basic(tmp_path): async with DiskShardsBuffer(directory=tmp_path) as mf: - await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) - await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) + await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) + await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) await mf.flush() @@ -32,7 +32,7 @@ async def test_basic(tmp_path): @gen_test() async def test_read_before_flush(tmp_path): - payload = {"1": [b"foo"]} + payload = {"1": b"foo"} async with DiskShardsBuffer(directory=tmp_path) as mf: with pytest.raises(RuntimeError): mf.read(1) @@ -52,7 +52,7 @@ async def test_read_before_flush(tmp_path): @gen_test() async def test_many(tmp_path, count): async with DiskShardsBuffer(directory=tmp_path) as mf: - d = {i: [str(i).encode() * 100] for i in range(count)} + d = {i: str(i).encode() * 100 for i in range(count)} for _ in range(10): await mf.write(d) From efc7eeb64934286e1ec820cf94580ab72eb072fd Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 24 Jul 2023 14:56:25 +0200 Subject: [PATCH 12/21] Fix compatibility variable naming (#8030) --- distributed/protocol/pickle.py | 4 ++-- distributed/protocol/tests/test_pickle.py | 4 ++-- distributed/shuffle/tests/test_merge.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index 90f0fa3f88..8b4b7328e5 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -10,7 +10,7 @@ from distributed.protocol.serialize import dask_deserialize, dask_serialize -CLOUDPICKLE_GTE_20 = parse_version(cloudpickle.__version__) >= parse_version("2.0.0") +CLOUDPICKLE_GE_20 = parse_version(cloudpickle.__version__) >= parse_version("2.0.0") HIGHEST_PROTOCOL = pickle.HIGHEST_PROTOCOL @@ -68,7 +68,7 @@ def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL): pickler.dump(x) result = f.getvalue() if b"__main__" in result or ( - CLOUDPICKLE_GTE_20 + CLOUDPICKLE_GE_20 and getattr(inspect.getmodule(x), "__name__", None) in cloudpickle.list_registry_pickle_by_value() ): diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 52b5649cae..611c95a07f 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -14,7 +14,7 @@ from distributed import profile from distributed.protocol import deserialize, serialize from distributed.protocol.pickle import ( - CLOUDPICKLE_GTE_20, + CLOUDPICKLE_GE_20, HIGHEST_PROTOCOL, dumps, loads, @@ -201,7 +201,7 @@ def funcs(): @pytest.mark.skipif( - not CLOUDPICKLE_GTE_20, reason="Pickle by value registration not supported" + not CLOUDPICKLE_GE_20, reason="Pickle by value registration not supported" ) def test_pickle_by_value_when_registered(): with save_sys_modules(): diff --git a/distributed/shuffle/tests/test_merge.py b/distributed/shuffle/tests/test_merge.py index 4ea17c6df1..99e4c32d65 100644 --- a/distributed/shuffle/tests/test_merge.py +++ b/distributed/shuffle/tests/test_merge.py @@ -9,7 +9,7 @@ dd = pytest.importorskip("dask.dataframe") import pandas as pd -from dask.dataframe._compat import PANDAS_GT_200, tm +from dask.dataframe._compat import PANDAS_GE_200, tm from dask.dataframe.utils import assert_eq from dask.utils_test import hlg_layer_topological @@ -249,7 +249,7 @@ async def test_merge_by_multiple_columns(c, s, a, b, how): # FIXME: There's an discrepancy with an empty index for # pandas=2.0 (xref https://github.com/dask/dask/issues/9957). # Temporarily avoid index check until the discrepancy is fixed. - check_index=not (PANDAS_GT_200 and expected.index.empty), + check_index=not (PANDAS_GE_200 and expected.index.empty), ) expected = pdr.join(pdl, how=how) @@ -259,7 +259,7 @@ async def test_merge_by_multiple_columns(c, s, a, b, how): # FIXME: There's an discrepancy with an empty index for # pandas=2.0 (xref https://github.com/dask/dask/issues/9957). # Temporarily avoid index check until the discrepancy is fixed. - check_index=not (PANDAS_GT_200 and expected.index.empty), + check_index=not (PANDAS_GE_200 and expected.index.empty), ) expected = pd.merge(pdl, pdr, how=how, left_index=True, right_index=True) @@ -278,7 +278,7 @@ async def test_merge_by_multiple_columns(c, s, a, b, how): # FIXME: There's an discrepancy with an empty index for # pandas=2.0 (xref https://github.com/dask/dask/issues/9957). # Temporarily avoid index check until the discrepancy is fixed. - check_index=not (PANDAS_GT_200 and expected.index.empty), + check_index=not (PANDAS_GE_200 and expected.index.empty), ) expected = pd.merge(pdr, pdl, how=how, left_index=True, right_index=True) @@ -297,7 +297,7 @@ async def test_merge_by_multiple_columns(c, s, a, b, how): # FIXME: There's an discrepancy with an empty index for # pandas=2.0 (xref https://github.com/dask/dask/issues/9957). # Temporarily avoid index check until the discrepancy is fixed. - check_index=not (PANDAS_GT_200 and expected.index.empty), + check_index=not (PANDAS_GE_200 and expected.index.empty), ) # hash join From 16b0c51aeec14e4a761e42fa0b79c6aac36d34af Mon Sep 17 00:00:00 2001 From: Irina Truong Date: Mon, 24 Jul 2023 07:34:43 -0700 Subject: [PATCH 13/21] Add a test for GraphLayout with scatter (#8025) --- distributed/diagnostics/tests/test_graph_layout.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/diagnostics/tests/test_graph_layout.py b/distributed/diagnostics/tests/test_graph_layout.py index cb99cfdd9d..d8d3c725b6 100644 --- a/distributed/diagnostics/tests/test_graph_layout.py +++ b/distributed/diagnostics/tests/test_graph_layout.py @@ -102,11 +102,11 @@ async def test_unique_positions(c, s, a, b): @gen_cluster(client=True) -async def test_layout_with_priorities(c, s, a, b): +async def test_layout_scatter(c, s, a, b): gl = GraphLayout(s) s.add_plugin(gl) - low = c.submit(inc, 1, key="low", priority=1) - mid = c.submit(inc, 2, key="mid", priority=2) - high = c.submit(inc, 3, key="high", priority=3) - await wait([high, mid, low]) + data = await c.scatter([1, 2, 3], broadcast=True) + futures = [c.submit(sum, data) for _ in range(5)] + await wait(futures) + assert len(gl.state_updates) > 0 From ffb1b27126402b063c2d2ba8c762b3a33c8e043c Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 24 Jul 2023 09:37:27 -0500 Subject: [PATCH 14/21] Test against more recent pyarrow versions (#8021) --- continuous_integration/environment-3.10.yaml | 2 +- continuous_integration/environment-3.11.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/continuous_integration/environment-3.10.yaml b/continuous_integration/environment-3.10.yaml index 86563390f9..014084f01b 100644 --- a/continuous_integration/environment-3.10.yaml +++ b/continuous_integration/environment-3.10.yaml @@ -25,7 +25,7 @@ dependencies: - pre-commit - prometheus_client - psutil - - pyarrow=7 + - pyarrow>=7 - pytest - pytest-cov - pytest-faulthandler diff --git a/continuous_integration/environment-3.11.yaml b/continuous_integration/environment-3.11.yaml index 5c6eef9982..ffe3d4f100 100644 --- a/continuous_integration/environment-3.11.yaml +++ b/continuous_integration/environment-3.11.yaml @@ -25,7 +25,7 @@ dependencies: - pre-commit - prometheus_client - psutil - - pyarrow=7 + - pyarrow>=7 - pytest - pytest-cov - pytest-faulthandler From 7b0aca797074965631abfa8a1f0e163ae9f7958c Mon Sep 17 00:00:00 2001 From: Brian Phillips <56549508+bphillips-exos@users.noreply.github.com> Date: Mon, 24 Jul 2023 10:40:10 -0400 Subject: [PATCH 15/21] Add `Client.unregister_scheduler_plugin` method (#7968) --- distributed/client.py | 38 +++++++++++++++++++ .../tests/test_scheduler_plugin.py | 34 +++++++++++++++++ distributed/scheduler.py | 11 +++--- 3 files changed, 78 insertions(+), 5 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index ec55b7f2c0..f01eb31190 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4840,6 +4840,44 @@ def register_scheduler_plugin(self, plugin, name=None, idempotent=False): idempotent=idempotent, ) + async def _unregister_scheduler_plugin(self, name): + return await self.scheduler.unregister_scheduler_plugin(name=name) + + def unregister_scheduler_plugin(self, name): + """Unregisters a scheduler plugin + + See https://distributed.readthedocs.io/en/latest/plugins.html#scheduler-plugins + + Parameters + ---------- + name : str + Name of the plugin to unregister. See the :meth:`Client.register_scheduler_plugin` + docstring for more information. + + Examples + -------- + >>> class MyPlugin(SchedulerPlugin): + ... def __init__(self, *args, **kwargs): + ... pass # the constructor is up to you + ... async def start(self, scheduler: Scheduler) -> None: + ... pass + ... async def before_close(self) -> None: + ... pass + ... async def close(self) -> None: + ... pass + ... def restart(self, scheduler: Scheduler) -> None: + ... pass + + >>> plugin = MyPlugin(1, 2, 3) + >>> client.register_scheduler_plugin(plugin, name='foo') + >>> client.unregister_scheduler_plugin(name='foo') + + See Also + -------- + register_scheduler_plugin + """ + return self.sync(self._unregister_scheduler_plugin, name=name) + def register_worker_callbacks(self, setup=None): """ Registers a setup callback function for all current and future workers. diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 7ab9e82514..01e651bf8b 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -5,6 +5,7 @@ import pytest from distributed import Scheduler, SchedulerPlugin, Worker, get_worker +from distributed.protocol.pickle import dumps from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc @@ -402,6 +403,39 @@ def start(self, scheduler): assert n_plugins == len(s.plugins) +@gen_cluster(nthreads=[]) +async def test_unregister_scheduler_plugin(s): + class Plugin(SchedulerPlugin): + def __init__(self): + self.name = "plugin" + + plugin = Plugin() + await s.register_scheduler_plugin(plugin=dumps(plugin)) + assert "plugin" in s.plugins + + await s.unregister_scheduler_plugin(name="plugin") + assert "plugin" not in s.plugins + + with pytest.raises(ValueError, match="Could not find plugin"): + await s.unregister_scheduler_plugin(name="plugin") + + +@gen_cluster(client=True) +async def test_unregister_scheduler_plugin_from_client(c, s, a, b): + class Plugin(SchedulerPlugin): + name = "plugin" + + assert "plugin" not in s.plugins + await c.register_scheduler_plugin(Plugin()) + assert "plugin" in s.plugins + + await c.unregister_scheduler_plugin("plugin") + assert "plugin" not in s.plugins + + with pytest.raises(ValueError, match="Could not find plugin"): + await c.unregister_scheduler_plugin(name="plugin") + + @gen_cluster(client=True) async def test_log_event_plugin(c, s, a, b): class EventPlugin(SchedulerPlugin): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 76bafb54bf..ca612cd9ba 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3702,6 +3702,7 @@ def __init__( "get_task_stream": self.get_task_stream, "get_task_prefix_states": self.get_task_prefix_states, "register_scheduler_plugin": self.register_scheduler_plugin, + "unregister_scheduler_plugin": self.unregister_scheduler_plugin, "register_worker_plugin": self.register_worker_plugin, "unregister_worker_plugin": self.unregister_worker_plugin, "register_nanny_plugin": self.register_nanny_plugin, @@ -5749,11 +5750,7 @@ def add_plugin( self.plugins[name] = plugin - def remove_plugin( - self, - name: str | None = None, - plugin: SchedulerPlugin | None = None, - ) -> None: + def remove_plugin(self, name: str | None = None) -> None: """Remove external plugin from scheduler Parameters @@ -5801,6 +5798,10 @@ async def register_scheduler_plugin( self.add_plugin(plugin, name=name, idempotent=idempotent) + async def unregister_scheduler_plugin(self, name: str) -> None: + """Unregister a plugin on the scheduler.""" + self.remove_plugin(name) + def worker_send(self, worker: str, msg: dict[str, Any]) -> None: """Send message to worker From f0303aaff3a22fd00660312cfa3f74a1fa178938 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 24 Jul 2023 19:29:23 +0200 Subject: [PATCH 16/21] Automatically restart P2P shuffles when output worker leaves (#7970) Co-authored-by: Lawrence Mitchell --- distributed/shuffle/_exceptions.py | 5 + distributed/shuffle/_scheduler_plugin.py | 157 +++++--- distributed/shuffle/_shuffle.py | 5 + distributed/shuffle/_worker_plugin.py | 143 ++++--- distributed/shuffle/tests/test_shuffle.py | 430 ++++++++++++++-------- 5 files changed, 493 insertions(+), 247 deletions(-) create mode 100644 distributed/shuffle/_exceptions.py diff --git a/distributed/shuffle/_exceptions.py b/distributed/shuffle/_exceptions.py new file mode 100644 index 0000000000..57a54a15e7 --- /dev/null +++ b/distributed/shuffle/_exceptions.py @@ -0,0 +1,5 @@ +from __future__ import annotations + + +class ShuffleClosedError(RuntimeError): + pass diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 911f4f89b9..ec670c0b07 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -6,7 +6,7 @@ import logging from collections import defaultdict from collections.abc import Callable, Iterable, Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from itertools import product from typing import TYPE_CHECKING, Any, ClassVar @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) -@dataclass +@dataclass(eq=False) class ShuffleState(abc.ABC): _run_id_iterator: ClassVar[itertools.count] = itertools.count(1) @@ -42,6 +42,7 @@ class ShuffleState(abc.ABC): run_id: int output_workers: set[str] participating_workers: set[str] + _archived_by: str | None = field(default=None, init=False) @abc.abstractmethod def to_msg(self) -> dict[str, Any]: @@ -50,8 +51,11 @@ def to_msg(self) -> dict[str, Any]: def __str__(self) -> str: return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" + def __hash__(self) -> int: + return hash(self.run_id) -@dataclass + +@dataclass(eq=False) class DataFrameShuffleState(ShuffleState): type: ClassVar[ShuffleType] = ShuffleType.DATAFRAME worker_for: dict[int, str] @@ -68,7 +72,7 @@ def to_msg(self) -> dict[str, Any]: } -@dataclass +@dataclass(eq=False) class ArrayRechunkState(ShuffleState): type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK worker_for: dict[NDIndex, str] @@ -90,19 +94,18 @@ def to_msg(self) -> dict[str, Any]: class ShuffleSchedulerPlugin(SchedulerPlugin): """ Shuffle plugin for the scheduler - This coordinates the individual worker plugins to ensure correctness and collects heartbeat messages for the dashboard. - See Also -------- ShuffleWorkerPlugin """ scheduler: Scheduler - states: dict[ShuffleId, ShuffleState] + active_shuffles: dict[ShuffleId, ShuffleState] heartbeats: defaultdict[ShuffleId, dict] - erred_shuffles: dict[ShuffleId, Exception] + _shuffles: defaultdict[ShuffleId, set[ShuffleState]] + _archived_by_stimulus: defaultdict[str, set[ShuffleState]] def __init__(self, scheduler: Scheduler): self.scheduler = scheduler @@ -115,9 +118,10 @@ def __init__(self, scheduler: Scheduler): } ) self.heartbeats = defaultdict(lambda: defaultdict(dict)) - self.states = {} - self.erred_shuffles = {} + self.active_shuffles = {} self.scheduler.add_plugin(self, name="shuffle") + self._shuffles = defaultdict(set) + self._archived_by_stimulus = defaultdict(set) async def start(self, scheduler: Scheduler) -> None: worker_plugin = ShuffleWorkerPlugin() @@ -126,18 +130,19 @@ async def start(self, scheduler: Scheduler) -> None: ) def shuffle_ids(self) -> set[ShuffleId]: - return set(self.states) + return set(self.active_shuffles) async def barrier(self, id: ShuffleId, run_id: int) -> None: - shuffle = self.states[id] + shuffle = self.active_shuffles[id] assert shuffle.run_id == run_id, f"{run_id=} does not match {shuffle}" msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id} await self.scheduler.broadcast( - msg=msg, workers=list(shuffle.participating_workers) + msg=msg, + workers=list(shuffle.participating_workers), ) def restrict_task(self, id: ShuffleId, run_id: int, key: str, worker: str) -> dict: - shuffle = self.states[id] + shuffle = self.active_shuffles[id] if shuffle.run_id > run_id: return { "status": "error", @@ -158,15 +163,19 @@ def heartbeat(self, ws: WorkerState, data: dict) -> None: self.heartbeats[shuffle_id][ws.address].update(d) def get(self, id: ShuffleId, worker: str) -> dict[str, Any]: - if exception := self.erred_shuffles.get(id): - return {"status": "error", "message": str(exception)} - state = self.states[id] + if worker not in self.scheduler.workers: + # This should never happen + raise RuntimeError( + f"Scheduler is unaware of this worker {worker!r}" + ) # pragma: nocover + state = self.active_shuffles[id] state.participating_workers.add(worker) return state.to_msg() def get_or_create( self, id: ShuffleId, + key: str, type: str, worker: str, spec: dict[str, Any], @@ -178,6 +187,7 @@ def get_or_create( # known by its name. If the name has been mangled, we cannot guarantee # that the shuffle works as intended and should fail instead. self._raise_if_barrier_unknown(id) + self._raise_if_task_not_processing(key) state: ShuffleState if type == ShuffleType.DATAFRAME: @@ -186,7 +196,8 @@ def get_or_create( state = self._create_array_rechunk_state(id, spec) else: # pragma: no cover raise TypeError(type) - self.states[id] = state + self.active_shuffles[id] = state + self._shuffles[id].add(state) state.participating_workers.add(worker) return state.to_msg() @@ -201,6 +212,11 @@ def _raise_if_barrier_unknown(self, id: ShuffleId) -> None: "into this by leaving a comment at distributed#7816." ) + def _raise_if_task_not_processing(self, key: str) -> None: + task = self.scheduler.tasks[key] + if task.state != "processing": + raise RuntimeError(f"Expected {task} to be processing, is {task.state}.") + def _create_dataframe_shuffle_state( self, id: ShuffleId, spec: dict[str, Any] ) -> DataFrameShuffleState: @@ -309,34 +325,67 @@ def _unset_restriction(self, ts: TaskState) -> None: original_restrictions = ts.annotations.pop("shuffle_original_restrictions") self.scheduler.set_restrictions({ts.key: original_restrictions}) + def _restart_recommendations(self, id: ShuffleId) -> Recs: + barrier_task = self.scheduler.tasks[barrier_key(id)] + recs: Recs = {} + + for dt in barrier_task.dependents: + if dt.state == "erred": + return {} + recs.update({dt.key: "released"}) + + if barrier_task.state == "erred": + # This should never happen, a dependent of the barrier should already + # be `erred` + raise RuntimeError( + f"Expected dependents of {barrier_task=} to be 'erred' if " + "the barrier is." + ) # pragma: no cover + recs.update({barrier_task.key: "released"}) + + for dt in barrier_task.dependencies: + if dt.state == "erred": + # This should never happen, a dependent of the barrier should already + # be `erred` + raise RuntimeError( + f"Expected barrier and its dependents to be " + f"'erred' if the barrier's dependency {dt} is." + ) # pragma: no cover + recs.update({dt.key: "released"}) + return recs + + def _restart_shuffle( + self, id: ShuffleId, scheduler: Scheduler, *, stimulus_id: str + ) -> None: + recs = self._restart_recommendations(id) + self.scheduler.transitions(recs, stimulus_id=stimulus_id) + self.scheduler.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) + def remove_worker( self, scheduler: Scheduler, worker: str, *, stimulus_id: str, **kwargs: Any ) -> None: - from time import time - - stimulus_id = f"shuffle-failed-worker-left-{time()}" + """Restart all active shuffles when a participating worker leaves the cluster. + + .. note:: + Due to the order of operations in :meth:`~Scheduler.remove_worker`, the + shuffle may have already been archived by + :meth:`~ShuffleSchedulerPlugin.transition`. In this case, the + ``stimulus_id`` is used as a transaction identifier and all archived shuffles + with a matching `stimulus_id` are restarted. + """ - recs: Recs = {} - for shuffle_id, shuffle in self.states.items(): + # If processing the transactions causes a task to get released, this + # removes the shuffle from self.active_shuffles. Therefore, we must iterate + # over a copy. + for shuffle_id, shuffle in self.active_shuffles.copy().items(): if worker not in shuffle.participating_workers: continue exception = RuntimeError(f"Worker {worker} left during active {shuffle}") - self.erred_shuffles[shuffle_id] = exception self._fail_on_workers(shuffle, str(exception)) + self._clean_on_scheduler(shuffle_id, stimulus_id) - barrier_task = self.scheduler.tasks[barrier_key(shuffle_id)] - if barrier_task.state == "memory": - for dt in barrier_task.dependents: - if worker not in dt.worker_restrictions: - continue - self._unset_restriction(dt) - recs.update({dt.key: "waiting"}) - # TODO: Do we need to handle other states? - - # If processing the transactions causes a task to get released, this - # removes the shuffle from self.states. Therefore, we must process them - # outside of the loop. - self.scheduler.transitions(recs, stimulus_id=stimulus_id) + for shuffle in self._archived_by_stimulus.get(stimulus_id, set()): + self._restart_shuffle(shuffle.id, scheduler, stimulus_id=stimulus_id) def transition( self, @@ -347,17 +396,25 @@ def transition( stimulus_id: str, **kwargs: Any, ) -> None: + """Clean up scheduler and worker state once a shuffle becomes inactive.""" if finish not in ("released", "forgotten"): return if not key.startswith("shuffle-barrier-"): return shuffle_id = id_from_key(key) - try: - shuffle = self.states[shuffle_id] - except KeyError: - return - self._fail_on_workers(shuffle, message=f"{shuffle} forgotten") - self._clean_on_scheduler(shuffle_id) + + if shuffle := self.active_shuffles.get(shuffle_id): + self._fail_on_workers(shuffle, message=f"{shuffle} forgotten") + self._clean_on_scheduler(shuffle_id, stimulus_id=stimulus_id) + + if finish == "forgotten": + shuffles = self._shuffles.pop(shuffle_id, set()) + for shuffle in shuffles: + if shuffle._archived_by: + archived = self._archived_by_stimulus[shuffle._archived_by] + archived.remove(shuffle) + if not archived: + del self._archived_by_stimulus[shuffle._archived_by] def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None: worker_msgs = { @@ -373,9 +430,12 @@ def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None: } self.scheduler.send_all({}, worker_msgs) - def _clean_on_scheduler(self, id: ShuffleId) -> None: - del self.states[id] - self.erred_shuffles.pop(id, None) + def _clean_on_scheduler(self, id: ShuffleId, stimulus_id: str | None) -> None: + shuffle = self.active_shuffles.pop(id) + if not shuffle._archived_by and stimulus_id: + shuffle._archived_by = stimulus_id + self._archived_by_stimulus[stimulus_id].add(shuffle) + with contextlib.suppress(KeyError): del self.heartbeats[id] @@ -384,9 +444,10 @@ def _clean_on_scheduler(self, id: ShuffleId) -> None: self._unset_restriction(dt) def restart(self, scheduler: Scheduler) -> None: - self.states.clear() + self.active_shuffles.clear() self.heartbeats.clear() - self.erred_shuffles.clear() + self._shuffles.clear() + self._archived_by_stimulus.clear() def get_worker_for_range_sharding( diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index ed9c22bb07..62bb13c90b 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -11,6 +11,7 @@ from distributed.exceptions import Reschedule from distributed.shuffle._arrow import check_dtype_support, check_minimal_arrow_version +from distributed.shuffle._exceptions import ShuffleClosedError logger = logging.getLogger("distributed.shuffle") if TYPE_CHECKING: @@ -69,6 +70,8 @@ def shuffle_transfer( column=column, parts_out=parts_out, ) + except ShuffleClosedError: + raise Reschedule() except Exception as e: raise RuntimeError(f"shuffle_transfer failed during shuffle {id}") from e @@ -82,6 +85,8 @@ def shuffle_unpack( ) except Reschedule as e: raise e + except ShuffleClosedError: + raise Reschedule() except Exception as e: raise RuntimeError(f"shuffle_unpack failed during shuffle {id}") from e diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 8a3ab7ba72..0f2fcf415e 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -29,6 +29,7 @@ ) from distributed.shuffle._comms import CommShardsBuffer from distributed.shuffle._disk import DiskShardsBuffer +from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._rechunk import ChunkedAxes, NDIndex, split_axes from distributed.shuffle._shuffle import ShuffleId, ShuffleType @@ -50,10 +51,6 @@ logger = logging.getLogger(__name__) -class ShuffleClosedError(RuntimeError): - pass - - class ShuffleRun(Generic[T_partition_id, T_partition_type]): def __init__( self, @@ -577,6 +574,7 @@ class ShuffleWorkerPlugin(WorkerPlugin): worker: Worker shuffles: dict[ShuffleId, ShuffleRun] _runs: set[ShuffleRun] + _runs_cleanup_condition: asyncio.Condition memory_limiter_comms: ResourceLimiter memory_limiter_disk: ResourceLimiter closed: bool @@ -592,6 +590,7 @@ def setup(self, worker: Worker) -> None: self.worker = worker self.shuffles = {} self._runs = set() + self._runs_cleanup_condition = asyncio.Condition() self.memory_limiter_comms = ResourceLimiter(parse_bytes("100 MiB")) self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB")) self.closed = False @@ -632,6 +631,12 @@ async def shuffle_inputs_done(self, shuffle_id: ShuffleId, run_id: int) -> None: shuffle = await self._get_shuffle_run(shuffle_id, run_id) await shuffle.inputs_done() + async def _close_shuffle_run(self, shuffle: ShuffleRun) -> None: + await shuffle.close() + async with self._runs_cleanup_condition: + self._runs.remove(shuffle) + self._runs_cleanup_condition.notify_all() + def shuffle_fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None: """Fails the shuffle run with the message as exception and triggers cleanup. @@ -648,11 +653,9 @@ def shuffle_fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None exception = RuntimeError(message) shuffle.fail(exception) - async def _(extension: ShuffleWorkerPlugin, shuffle: ShuffleRun) -> None: - await shuffle.close() - extension._runs.remove(shuffle) - - self.worker._ongoing_background_tasks.call_soon(_, self, shuffle) + self.worker._ongoing_background_tasks.call_soon( + self._close_shuffle_run, shuffle + ) def add_partition( self, @@ -726,6 +729,7 @@ async def _get_or_create_shuffle( self, shuffle_id: ShuffleId, type: ShuffleType, + key: str, **kwargs: Any, ) -> ShuffleRun: """Get or create a shuffle matching the ID and data spec. @@ -736,12 +740,15 @@ async def _get_or_create_shuffle( Unique identifier of the shuffle type: Type of the shuffle operation + key: + Task key triggering the function """ shuffle = self.shuffles.get(shuffle_id, None) if shuffle is None: shuffle = await self._refresh_shuffle( shuffle_id=shuffle_id, type=type, + key=key, kwargs=kwargs, ) @@ -763,6 +770,7 @@ async def _refresh_shuffle( self, shuffle_id: ShuffleId, type: ShuffleType, + key: str, kwargs: dict, ) -> ShuffleRun: ... @@ -771,8 +779,10 @@ async def _refresh_shuffle( self, shuffle_id: ShuffleId, type: ShuffleType | None = None, + key: str | None = None, kwargs: dict | None = None, ) -> ShuffleRun: + result: dict[str, Any] if type is None: result = await self.worker.scheduler.shuffle_get( id=shuffle_id, @@ -782,6 +792,7 @@ async def _refresh_shuffle( assert kwargs is not None result = await self.worker.scheduler.shuffle_get_or_create( id=shuffle_id, + key=key, type=type, spec={ "npartitions": kwargs["npartitions"], @@ -794,6 +805,7 @@ async def _refresh_shuffle( assert kwargs is not None result = await self.worker.scheduler.shuffle_get_or_create( id=shuffle_id, + key=key, type=type, spec=kwargs, worker=self.worker.address, @@ -812,67 +824,94 @@ async def _refresh_shuffle( return existing else: self.shuffles.pop(shuffle_id) - existing.fail(RuntimeError("Stale Shuffle")) + existing.fail( + RuntimeError("{existing!r} stale, expected run_id=={run_id}") + ) async def _( extension: ShuffleWorkerPlugin, shuffle: ShuffleRun ) -> None: await shuffle.close() - extension._runs.remove(shuffle) + async with extension._runs_cleanup_condition: + extension._runs.remove(shuffle) + extension._runs_cleanup_condition.notify_all() self.worker._ongoing_background_tasks.call_soon(_, self, existing) + + shuffle = self._create_shuffle_run(shuffle_id, result) + self.shuffles[shuffle_id] = shuffle + self._runs.add(shuffle) + return shuffle + + def _create_shuffle_run( + self, shuffle_id: ShuffleId, result: dict[str, Any] + ) -> ShuffleRun: shuffle: ShuffleRun if result["type"] == ShuffleType.DATAFRAME: - shuffle = DataFrameShuffleRun( - column=result["column"], - worker_for=result["worker_for"], - output_workers=result["output_workers"], - id=shuffle_id, - run_id=result["run_id"], - directory=os.path.join( - self.worker.local_directory, - f"shuffle-{shuffle_id}-{result['run_id']}", - ), - executor=self._executor, - local_address=self.worker.address, - rpc=self.worker.rpc, - scheduler=self.worker.scheduler, - memory_limiter_disk=self.memory_limiter_disk, - memory_limiter_comms=self.memory_limiter_comms, - ) + shuffle = self._create_dataframe_shuffle_run(shuffle_id, result) elif result["type"] == ShuffleType.ARRAY_RECHUNK: - shuffle = ArrayRechunkRun( - worker_for=result["worker_for"], - output_workers=result["output_workers"], - old=result["old"], - new=result["new"], - id=shuffle_id, - run_id=result["run_id"], - directory=os.path.join( - self.worker.local_directory, - f"shuffle-{shuffle_id}-{result['run_id']}", - ), - executor=self._executor, - local_address=self.worker.address, - rpc=self.worker.rpc, - scheduler=self.worker.scheduler, - memory_limiter_disk=self.memory_limiter_disk, - memory_limiter_comms=self.memory_limiter_comms, - ) + shuffle = self._create_array_rechunk_run(shuffle_id, result) else: # pragma: no cover raise TypeError(result["type"]) - self.shuffles[shuffle_id] = shuffle - self._runs.add(shuffle) return shuffle + def _create_dataframe_shuffle_run( + self, shuffle_id: ShuffleId, result: dict[str, Any] + ) -> DataFrameShuffleRun: + return DataFrameShuffleRun( + column=result["column"], + worker_for=result["worker_for"], + output_workers=result["output_workers"], + id=shuffle_id, + run_id=result["run_id"], + directory=os.path.join( + self.worker.local_directory, + f"shuffle-{shuffle_id}-{result['run_id']}", + ), + executor=self._executor, + local_address=self.worker.address, + rpc=self.worker.rpc, + scheduler=self.worker.scheduler, + memory_limiter_disk=self.memory_limiter_disk, + memory_limiter_comms=self.memory_limiter_comms, + ) + + def _create_array_rechunk_run( + self, shuffle_id: ShuffleId, result: dict[str, Any] + ) -> ArrayRechunkRun: + return ArrayRechunkRun( + worker_for=result["worker_for"], + output_workers=result["output_workers"], + old=result["old"], + new=result["new"], + id=shuffle_id, + run_id=result["run_id"], + directory=os.path.join( + self.worker.local_directory, + f"shuffle-{shuffle_id}-{result['run_id']}", + ), + executor=self._executor, + local_address=self.worker.address, + rpc=self.worker.rpc, + scheduler=self.worker.scheduler, + memory_limiter_disk=self.memory_limiter_disk, + memory_limiter_comms=self.memory_limiter_comms, + ) + async def teardown(self, worker: Worker) -> None: assert not self.closed self.closed = True + while self.shuffles: _, shuffle = self.shuffles.popitem() - await shuffle.close() - self._runs.remove(shuffle) + self.worker._ongoing_background_tasks.call_soon( + self._close_shuffle_run, shuffle + ) + + async with self._runs_cleanup_condition: + await self._runs_cleanup_condition.wait_for(lambda: not self._runs) + try: self._executor.shutdown(cancel_futures=True) except Exception: # pragma: no cover @@ -904,11 +943,13 @@ def get_or_create_shuffle( type: ShuffleType, **kwargs: Any, ) -> ShuffleRun: + key = thread_state.key return sync( self.worker.loop, self._get_or_create_shuffle, shuffle_id, type, + key, **kwargs, ) @@ -920,7 +961,7 @@ def get_output_partition( meta: pd.DataFrame | None = None, ) -> Any: """ - Task: Retrieve a shuffled output partition from the ShuffleExtension. + Task: Retrieve a shuffled output partition from the ShuffleWorkerPlugin. Calling this for a ``shuffle_id`` which is unknown or incomplete is an error. """ diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 664b31d011..e86f06eb83 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -9,7 +9,6 @@ from collections import defaultdict from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor -from contextlib import AsyncExitStack from itertools import count from typing import Any from unittest import mock @@ -24,8 +23,7 @@ from dask.utils import stringify from distributed.client import Client -from distributed.diagnostics.plugin import SchedulerPlugin -from distributed.scheduler import Scheduler +from distributed.scheduler import KilledWorker, Scheduler from distributed.scheduler import TaskState as SchedulerTaskState from distributed.shuffle._arrow import serialize_table from distributed.shuffle._limiter import ResourceLimiter @@ -49,6 +47,7 @@ ) from distributed.utils import Deadline from distributed.utils_test import ( + async_poll_for, cluster, gen_cluster, gen_test, @@ -72,7 +71,7 @@ async def check_worker_cleanup( worker: Worker, closed: bool = False, interval: float = 0.01, - timeout: int | None = None, + timeout: int | None = 5, ) -> None: """Assert that the worker has no shuffle state""" deadline = Deadline.after(timeout) @@ -91,15 +90,17 @@ async def check_worker_cleanup( async def check_scheduler_cleanup( - scheduler: Scheduler, interval: float = 0.01, timeout: int | None = None + scheduler: Scheduler, interval: float = 0.01, timeout: int | None = 5 ) -> None: """Assert that the scheduler has no shuffle state""" deadline = Deadline.after(timeout) plugin = scheduler.plugins["shuffle"] assert isinstance(plugin, ShuffleSchedulerPlugin) - while plugin.states and not deadline.expired: + while plugin._shuffles and not deadline.expired: await asyncio.sleep(interval) - assert not plugin.states + assert not plugin.active_shuffles + assert not plugin._shuffles, scheduler.tasks + assert not plugin._archived_by_stimulus assert not plugin.heartbeats @@ -300,12 +301,13 @@ async def test_closed_worker_during_transfer(c, s, a, b): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) await b.close() - with pytest.raises(RuntimeError): - out = await c.compute(out) + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(a) @@ -313,6 +315,85 @@ async def test_closed_worker_during_transfer(c, s, a, b): await check_scheduler_cleanup(s) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.scheduler.allowed-failures": 0}, +) +async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = c.compute(out.x.size) + await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) + await b.close() + + with pytest.raises(KilledWorker): + await out + + await c.close() + await check_worker_cleanup(a) + await check_worker_cleanup(b, closed=True) + await check_scheduler_cleanup(s) + + +class BlockedGetOrCreateWorkerPlugin(ShuffleWorkerPlugin): + def setup(self, worker: Worker) -> None: + super().setup(worker) + self.in_get_or_create = asyncio.Event() + self.block_get_or_create = asyncio.Event() + + async def _get_or_create_shuffle(self, *args, **kwargs): + self.in_get_or_create.set() + await self.block_get_or_create.wait() + return await super()._get_or_create_shuffle(*args, **kwargs) + + +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.scheduler.allowed-failures": 0}, +) +async def test_get_or_create_from_dangling_transfer(c, s, a, b): + await c.register_worker_plugin(BlockedGetOrCreateWorkerPlugin(), name="shuffle") + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = c.compute(out.x.size) + + shuffle_extA = a.plugins["shuffle"] + shuffle_extB = b.plugins["shuffle"] + shuffle_extB.block_get_or_create.set() + + await shuffle_extA.in_get_or_create.wait() + await b.close() + await async_poll_for( + lambda: not any(ws.processing for ws in s.workers.values()), timeout=5 + ) + + with pytest.raises(KilledWorker): + await out + + await async_poll_for(lambda: not s.plugins["shuffle"].active_shuffles, timeout=5) + assert a.state.tasks + shuffle_extA.block_get_or_create.set() + await async_poll_for(lambda: not a.state.tasks, timeout=10) + + assert not s.plugins["shuffle"].active_shuffles + await check_worker_cleanup(a) + await check_worker_cleanup(b, closed=True) + await c.close() + await check_scheduler_cleanup(s) + + @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)]) async def test_crashed_worker_during_transfer(c, s, a): @@ -325,21 +406,21 @@ async def test_crashed_worker_during_transfer(c, s, a): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) await wait_until_worker_has_tasks( "shuffle-transfer", killed_worker_address, 1, s ) await n.process.process.kill() - with pytest.raises(RuntimeError): - out = await c.compute(out) + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(a) await check_scheduler_cleanup(s) -# TODO: Deduplicate instead of failing: distributed#7324 @gen_cluster(client=True, nthreads=[("", 1)] * 2) async def test_closed_input_only_worker_during_transfer(c, s, a, b): def mock_get_worker_for_range_sharding( @@ -358,12 +439,13 @@ def mock_get_worker_for_range_sharding( freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b, 0.001) await b.close() - with pytest.raises(RuntimeError): - out = await c.compute(out) + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(a) @@ -371,7 +453,6 @@ def mock_get_worker_for_range_sharding( await check_scheduler_cleanup(s) -# TODO: Deduplicate instead of failing: distributed#7324 @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)], clean_kwargs={"processes": False}) async def test_crashed_input_only_worker_during_transfer(c, s, a): @@ -393,14 +474,15 @@ def mock_mock_get_worker_for_range_sharding( freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) await wait_until_worker_has_tasks( "shuffle-transfer", n.worker_address, 1, s ) await n.process.process.kill() - with pytest.raises(RuntimeError): - out = await c.compute(out) + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(a) @@ -457,7 +539,7 @@ async def test_closed_worker_during_barrier(c, s, a, b): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) shuffle_id = await wait_until_new_shuffle_is_initialized(s) key = barrier_key(shuffle_id) await wait_for_state(key, "processing", s) @@ -478,9 +560,72 @@ async def test_closed_worker_during_barrier(c, s, a, b): await close_worker.close() alive_shuffle.block_inputs_done.set() + alive_shuffles = alive_worker.extensions["shuffle"].shuffles - with pytest.raises(RuntimeError): - out = await c.compute(out) + def shuffle_restarted(): + try: + return alive_shuffles[shuffle_id].run_id > alive_shuffle.run_id + except KeyError: + return False + + await async_poll_for( + shuffle_restarted, + timeout=5, + ) + restarted_shuffle = alive_shuffles[shuffle_id] + restarted_shuffle.block_inputs_done.set() + + x = await x + y = await y + assert x == y + + await c.close() + await check_worker_cleanup(close_worker, closed=True) + await check_worker_cleanup(alive_worker) + await check_scheduler_cleanup(s) + + +@mock.patch( + "distributed.shuffle._worker_plugin.DataFrameShuffleRun", + BlockedInputsDoneShuffle, +) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.scheduler.allowed-failures": 0}, +) +async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = c.compute(out.x.size) + shuffle_id = await wait_until_new_shuffle_is_initialized(s) + key = barrier_key(shuffle_id) + await wait_for_state(key, "processing", s) + shuffleA = get_shuffle_run_from_worker(shuffle_id, a) + shuffleB = get_shuffle_run_from_worker(shuffle_id, b) + await shuffleA.in_inputs_done.wait() + await shuffleB.in_inputs_done.wait() + + ts = s.tasks[key] + processing_worker = a if ts.processing_on.address == a.address else b + if processing_worker == a: + close_worker, alive_worker = a, b + alive_shuffle = shuffleB + + else: + close_worker, alive_worker = b, a + alive_shuffle = shuffleA + await close_worker.close() + + with pytest.raises(KilledWorker): + await out + + alive_shuffle.block_inputs_done.set() await c.close() await check_worker_cleanup(close_worker, closed=True) @@ -501,7 +646,7 @@ async def test_closed_other_worker_during_barrier(c, s, a, b): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) shuffle_id = await wait_until_new_shuffle_is_initialized(s) key = barrier_key(shuffle_id) @@ -524,9 +669,24 @@ async def test_closed_other_worker_during_barrier(c, s, a, b): await close_worker.close() alive_shuffle.block_inputs_done.set() + alive_shuffles = alive_worker.extensions["shuffle"].shuffles - with pytest.raises(RuntimeError, match="shuffle_barrier failed"): - out = await c.compute(out) + def shuffle_restarted(): + try: + return alive_shuffles[shuffle_id].run_id > alive_shuffle.run_id + except KeyError: + return False + + await async_poll_for( + shuffle_restarted, + timeout=5, + ) + restarted_shuffle = alive_shuffles[shuffle_id] + restarted_shuffle.block_inputs_done.set() + + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(close_worker, closed=True) @@ -549,20 +709,34 @@ async def test_crashed_other_worker_during_barrier(c, s, a): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) shuffle_id = await wait_until_new_shuffle_is_initialized(s) key = barrier_key(shuffle_id) # Ensure that barrier is not executed on the nanny s.set_restrictions({key: {a.address}}) await wait_for_state(key, "processing", s, interval=0) - + shuffles = a.extensions["shuffle"].shuffles shuffle = get_shuffle_run_from_worker(shuffle_id, a) await shuffle.in_inputs_done.wait() await n.process.process.kill() shuffle.block_inputs_done.set() - with pytest.raises(RuntimeError, match="shuffle"): - out = await c.compute(out) + def shuffle_restarted(): + try: + return shuffles[shuffle_id].run_id > shuffle.run_id + except KeyError: + return False + + await async_poll_for( + shuffle_restarted, + timeout=5, + ) + restarted_shuffle = get_shuffle_run_from_worker(shuffle_id, a) + restarted_shuffle.block_inputs_done.set() + + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(a) @@ -578,12 +752,39 @@ async def test_closed_worker_during_unpack(c, s, a, b): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) await wait_for_tasks_in_state("shuffle-p2p", "memory", 1, b) await b.close() - with pytest.raises(RuntimeError): - out = await c.compute(out) + x = await x + y = await y + assert x == y + + await c.close() + await check_worker_cleanup(a) + await check_worker_cleanup(b, closed=True) + await check_scheduler_cleanup(s) + + +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.scheduler.allowed-failures": 0}, +) +async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = c.compute(out.x.size) + await wait_for_tasks_in_state("shuffle-p2p", "memory", 1, b) + await b.close() + + with pytest.raises(KilledWorker): + await out await c.close() await check_worker_cleanup(a) @@ -602,14 +803,15 @@ async def test_crashed_worker_during_unpack(c, s, a): dtypes={"x": float, "y": float}, freq="10 s", ) + x = await c.compute(df.x.size) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + y = c.compute(out.x.size) + await wait_until_worker_has_tasks("shuffle-p2p", killed_worker_address, 1, s) await n.process.process.kill() - with pytest.raises( - RuntimeError, - ): - out = await c.compute(out) + + y = await y + assert x == y await c.close() await check_worker_cleanup(a) @@ -857,9 +1059,9 @@ async def test_clean_after_forgotten_early(c, s, a, b): await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, a) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) del out - await check_worker_cleanup(a, timeout=2) - await check_worker_cleanup(b, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + await check_worker_cleanup(a) + await check_worker_cleanup(b) + await check_scheduler_cleanup(s) @gen_cluster(client=True) @@ -910,9 +1112,9 @@ async def test_repeat_shuffle_instance(c, s, a, b, wait_until_forgotten): await c.compute(out) - await check_worker_cleanup(a, timeout=2) - await check_worker_cleanup(b, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + await check_worker_cleanup(a) + await check_worker_cleanup(b) + await check_scheduler_cleanup(s) @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @@ -939,9 +1141,9 @@ async def test_repeat_shuffle_operation(c, s, a, b, wait_until_forgotten): await c.compute(dd.shuffle.shuffle(df, "x", shuffle="p2p")) - await check_worker_cleanup(a, timeout=2) - await check_worker_cleanup(b, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + await check_worker_cleanup(a) + await check_worker_cleanup(b) + await check_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1056,7 +1258,7 @@ async def test_new_worker(c, s, a, b): ) shuffled = dd.shuffle.shuffle(df, "x", shuffle="p2p") persisted = shuffled.persist() - while not s.plugins["shuffle"].states: + while not s.plugins["shuffle"].active_shuffles: await asyncio.sleep(0.001) async with Worker(s.address) as w: @@ -1131,12 +1333,12 @@ async def test_delete_some_results(c, s, a, b): while not s.tasks or not any(ts.state == "memory" for ts in s.tasks.values()): await asyncio.sleep(0.01) - x = x.partitions[: x.npartitions // 2].persist() + x = x.partitions[: x.npartitions // 2] + x = await c.compute(x.size) - await c.compute(x.size) - del x await check_worker_cleanup(a) await check_worker_cleanup(b) + del x await check_scheduler_cleanup(s) @@ -1515,9 +1717,9 @@ async def test_deduplicate_stale_transfer(c, s, a, b, wait_until_forgotten): y = await c.compute(df.x.size) assert x == y - await check_worker_cleanup(a, timeout=2) - await check_worker_cleanup(b, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + await check_worker_cleanup(a) + await check_worker_cleanup(b) + await check_scheduler_cleanup(s) class BlockedBarrierShuffleWorkerPlugin(ShuffleWorkerPlugin): @@ -1571,9 +1773,9 @@ async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten): y = await y assert x == y - await check_worker_cleanup(a, timeout=2) - await check_worker_cleanup(b, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + await check_worker_cleanup(a) + await check_worker_cleanup(b) + await check_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1643,9 +1845,29 @@ async def test_shuffle_run_consistency(c, s, a): worker_plugin.block_barrier.set() await out del out + while s.tasks: + await asyncio.sleep(0) + worker_plugin.block_barrier.clear() - await check_worker_cleanup(a, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + out = dd.shuffle.shuffle(df, "y", shuffle="p2p") + out = out.persist() + independent_shuffle_id = await wait_until_new_shuffle_is_initialized(s) + assert shuffle_id != independent_shuffle_id + + independent_shuffle_dict = scheduler_ext.get( + independent_shuffle_id, a.worker_address + ) + + # Check invariant that the new run ID is larger than the previous + # for independent shuffles + assert new_shuffle_dict["run_id"] < independent_shuffle_dict["run_id"] + + worker_plugin.block_barrier.set() + await out + del out + + await check_worker_cleanup(a) + await check_scheduler_cleanup(s) class BlockedShuffleAccessAndFailWorkerPlugin(ShuffleWorkerPlugin): @@ -1748,94 +1970,6 @@ async def test_replace_stale_shuffle(c, s, a, b): await check_scheduler_cleanup(s) -class BlockedRemoveWorkerSchedulerPlugin(SchedulerPlugin): - def __init__(self, scheduler: Scheduler, *args: Any, **kwargs: Any): - self.scheduler = scheduler - super().__init__(*args, **kwargs) - self.in_remove_worker = asyncio.Event() - self.block_remove_worker = asyncio.Event() - self.scheduler.add_plugin(self, name="blocking") - - async def remove_worker(self, *args: Any, **kwargs: Any) -> None: - self.in_remove_worker.set() - await self.block_remove_worker.wait() - - -class BlockedBarrierSchedulerPlugin(ShuffleSchedulerPlugin): - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self.in_barrier = asyncio.Event() - self.block_barrier = asyncio.Event() - - async def barrier(self, *args: Any, **kwargs: Any) -> None: - self.in_barrier.set() - await self.block_barrier.wait() - await super().barrier(*args, **kwargs) - - -@gen_cluster( - client=True, - nthreads=[], - scheduler_kwargs={ - "extensions": { - "blocking": BlockedRemoveWorkerSchedulerPlugin, - "shuffle": BlockedBarrierSchedulerPlugin, - } - }, -) -async def test_closed_worker_returns_before_barrier(c, s): - async with AsyncExitStack() as stack: - workers = [await stack.enter_async_context(Worker(s.address)) for _ in range(2)] - - df = dask.datasets.timeseries( - start="2000-01-01", - end="2000-01-10", - dtypes={"x": float, "y": float}, - freq="10 s", - ) - out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() - shuffle_id = await wait_until_new_shuffle_is_initialized(s) - key = barrier_key(shuffle_id) - await wait_for_state(key, "processing", s) - scheduler_plugin = s.plugins["shuffle"] - await scheduler_plugin.in_barrier.wait() - - flushes = [ - get_shuffle_run_from_worker(shuffle_id, w)._flush_comm() for w in workers - ] - await asyncio.gather(*flushes) - - ts = s.tasks[key] - to_close = None - for worker in workers: - if ts.processing_on.address != worker.address: - to_close = worker - break - assert to_close - closed_port = to_close.port - await to_close.close() - - blocking_plugin = s.plugins["blocking"] - assert blocking_plugin.in_remove_worker.is_set() - - workers.append( - await stack.enter_async_context(Worker(s.address, port=closed_port)) - ) - - scheduler_plugin.block_barrier.set() - - with pytest.raises( - RuntimeError, match=f"shuffle_barrier failed .* {shuffle_id}" - ): - await c.compute(out.x.size) - - blocking_plugin.block_remove_worker.set() - await c.close() - await asyncio.gather(*[check_worker_cleanup(w) for w in workers]) - await check_scheduler_cleanup(s) - - @gen_cluster(client=True) async def test_handle_null_partitions_p2p_shuffling(c, s, *workers): data = [ From 145c13aea1b4214a8bb1378f581104492b1be0c7 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 25 Jul 2023 21:34:11 +0100 Subject: [PATCH 17/21] Update gpuCI `RAPIDS_VER` to `23.10` (#8033) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- continuous_integration/gpuci/axis.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/continuous_integration/gpuci/axis.yaml b/continuous_integration/gpuci/axis.yaml index 97a27084e4..ea1e6464a9 100644 --- a/continuous_integration/gpuci/axis.yaml +++ b/continuous_integration/gpuci/axis.yaml @@ -9,6 +9,6 @@ LINUX_VER: - ubuntu18.04 RAPIDS_VER: -- "23.08" +- "23.10" excludes: From 9eb672840cc76f8718227fc1d2973cc6f233b9cd Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 27 Jul 2023 10:36:57 +0100 Subject: [PATCH 18/21] restore support for yield unsafe Client context managers and deprecate that support (#7987) --- distributed/client.py | 24 ++++++++++++++++++++++-- distributed/tests/test_client.py | 27 +++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index f01eb31190..290feabd01 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1502,7 +1502,17 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): if self._previous_as_current: - _current_client.reset(self._previous_as_current) + try: + _current_client.reset(self._previous_as_current) + except ValueError as e: + if not e.args[0].endswith(" was created in a different Context"): + raise # pragma: nocover + warnings.warn( + "It is deprecated to enter and exit the Client context " + "manager from different tasks", + DeprecationWarning, + stacklevel=2, + ) await self._close( # if we're handling an exception, we assume that it's more # important to deliver that exception than shutdown gracefully. @@ -1512,7 +1522,17 @@ async def __aexit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback): if self._previous_as_current: - _current_client.reset(self._previous_as_current) + try: + _current_client.reset(self._previous_as_current) + except ValueError as e: + if not e.args[0].endswith(" was created in a different Context"): + raise # pragma: nocover + warnings.warn( + "It is deprecated to enter and exit the Client context " + "manager from different threads", + DeprecationWarning, + stacklevel=2, + ) self.close() def __del__(self): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 2ed731a160..fa60de1d88 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1274,6 +1274,33 @@ async def client_2(): await asyncio.gather(client_1(), client_2()) +@gen_cluster(client=False, nthreads=[]) +async def test_context_manager_used_from_different_tasks(s): + c = Client(s.address, asynchronous=True) + await asyncio.create_task(c.__aenter__()) + with pytest.warns( + DeprecationWarning, + match=r"It is deprecated to enter and exit the Client context manager " + "from different tasks", + ): + await asyncio.create_task(c.__aexit__(None, None, None)) + + +def test_context_manager_used_from_different_threads(s, loop): + c = Client(s["address"]) + with ( + concurrent.futures.ThreadPoolExecutor(1) as tp1, + concurrent.futures.ThreadPoolExecutor(1) as tp2, + ): + tp1.submit(c.__enter__).result() + with pytest.warns( + DeprecationWarning, + match=r"It is deprecated to enter and exit the Client context manager " + "from different threads", + ): + tp2.submit(c.__exit__, None, None, None).result() + + def test_global_clients(loop): assert _get_global_client() is None with pytest.raises( From 2751741721d3f2bee636b4281952ed0c9405cfeb Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 28 Jul 2023 11:12:46 +0200 Subject: [PATCH 19/21] Exclude comm handshake from connect timeout (#7698) Co-authored-by: Thomas Grainger --- distributed/comm/core.py | 30 +++-------------- distributed/comm/tests/test_comms.py | 9 +++-- distributed/tests/test_client.py | 16 --------- distributed/tests/test_scheduler.py | 49 +++++++++++++++------------- distributed/tests/test_utils_test.py | 13 ++------ distributed/worker.py | 3 +- 6 files changed, 39 insertions(+), 81 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index fa3b8cb52a..c9c4c880f2 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -7,7 +7,6 @@ import sys import weakref from abc import ABC, abstractmethod -from contextlib import suppress from typing import Any, ClassVar import dask @@ -264,20 +263,8 @@ async def on_connection( ) -> None: local_info = {**comm.handshake_info(), **(handshake_overrides or {})} - timeout = dask.config.get("distributed.comm.timeouts.connect") - timeout = parse_timedelta(timeout, default="seconds") - try: - # Timeout is to ensure that we'll terminate connections eventually. - # Connector side will employ smaller timeouts and we should only - # reach this if the comm is dead anyhow. - await wait_for(comm.write(local_info), timeout=timeout) - handshake = await wait_for(comm.read(), timeout=timeout) - # This would be better, but connections leak if worker is closed quickly - # write, handshake = await asyncio.gather(comm.write(local_info), comm.read()) - except Exception as e: - with suppress(Exception): - await comm.close() - raise CommClosedError(f"Comm {comm!r} closed.") from e + await comm.write(local_info) + handshake = await comm.read() comm.remote_info = handshake comm.remote_info["address"] = comm.peer_address @@ -386,17 +373,8 @@ def time_left(): **comm.handshake_info(), **(handshake_overrides or {}), } - try: - # This would be better, but connections leak if worker is closed quickly - # write, handshake = await asyncio.gather(comm.write(local_info), comm.read()) - handshake = await wait_for(comm.read(), time_left()) - await wait_for(comm.write(local_info), time_left()) - except Exception as exc: - with suppress(Exception): - await comm.close() - raise OSError( - f"Timed out during handshake while connecting to {addr} after {timeout} s" - ) from exc + await comm.write(local_info) + handshake = await comm.read() comm.remote_info = handshake comm.remote_info["address"] = comm._peer_addr diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 7c6bb5476e..b7fc94b96b 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -965,6 +965,7 @@ class UnreliableBackend(tcp.TCPBackend): listener.stop() +@pytest.mark.slow @gen_test() async def test_handshake_slow_comm(tcp, monkeypatch): class SlowComm(tcp.TCP): @@ -999,11 +1000,9 @@ def get_connector(self): import dask - with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): - with pytest.raises( - IOError, match="Timed out during handshake while connecting to" - ): - await connect(listener.contact_address) + # The connect itself is fast. Only the handshake is slow + with dask.config.set({"distributed.comm.timeouts.connect": "500ms"}): + await connect(listener.contact_address) finally: listener.stop() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index fa60de1d88..5a9c5e8014 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6011,22 +6011,6 @@ async def test_client_timeout_2(): assert stop - start < 1 -@gen_test() -async def test_client_active_bad_port(): - import tornado.httpserver - import tornado.web - - application = tornado.web.Application([(r"/", tornado.web.RequestHandler)]) - http_server = tornado.httpserver.HTTPServer(application) - http_server.listen(8080) - with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): - c = Client("127.0.0.1:8080", asynchronous=True) - with pytest.raises((TimeoutError, IOError)): - async with c: - pass - http_server.stop() - - @pytest.mark.parametrize("direct", [True, False]) @gen_cluster(client=True, client_kwargs={"serializers": ["dask", "msgpack"]}) async def test_turn_off_pickle(c, s, a, b, direct): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 4c0728f68d..de33ef542b 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1063,30 +1063,35 @@ async def kill(self, *, timeout, reason=None): @pytest.mark.slow @gen_cluster(client=True, Worker=SlowKillNanny, nthreads=[("", 1)] * 2) async def test_restart_nanny_timeout_exceeded(c, s, a, b): - f = c.submit(div, 1, 0) - fr = c.submit(inc, 1, resources={"FOO": 1}) - await wait(f) - assert s.erred_tasks - assert s.computations - assert s.unrunnable - assert s.tasks - - with pytest.raises( - TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s" - ): - await c.restart(timeout="1s") - assert a.kill_called.is_set() - assert b.kill_called.is_set() + try: + f = c.submit(div, 1, 0) + fr = c.submit(inc, 1, resources={"FOO": 1}) + await wait(f) + assert s.erred_tasks + assert s.computations + assert s.unrunnable + assert s.tasks - assert not s.workers - assert not s.erred_tasks - assert not s.computations - assert not s.unrunnable - assert not s.tasks + with pytest.raises( + TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s" + ): + await c.restart(timeout="1s") + assert a.kill_called.is_set() + assert b.kill_called.is_set() + + assert not s.workers + assert not s.erred_tasks + assert not s.computations + assert not s.unrunnable + assert not s.tasks + + assert not c.futures + assert f.status == "cancelled" + assert fr.status == "cancelled" + finally: + a.kill_proceed.set() + b.kill_proceed.set() - assert not c.futures - assert f.status == "cancelled" - assert fr.status == "cancelled" @gen_cluster(client=True, nthreads=[("", 1)] * 2) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 4438866477..0f8cb02ae0 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -599,16 +599,9 @@ async def test_dump_cluster_state_unresponsive_local_worker(s, a, b, tmp_path): @pytest.mark.slow -@gen_cluster( - client=True, - Worker=Nanny, - config={"distributed.comm.timeouts.connect": "600ms"}, -) +@gen_cluster(client=True, Worker=Nanny) async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmp_path): - clog_fut = asyncio.create_task( - c.run(lambda dask_scheduler: dask_scheduler.stop(), workers=[a.worker_address]) - ) - await asyncio.sleep(0.2) + await c.run(lambda dask_worker: dask_worker.stop(), workers=[a.worker_address]) await dump_cluster_state(s, [a, b], str(tmp_path), "dump") with open(f"{tmp_path}/dump.yaml") as fh: @@ -620,8 +613,6 @@ async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmp_path): "OSError('Timed out trying to connect to" ) - clog_fut.cancel() - # Note: WINDOWS constant doesn't work with `mypy --platform win32` if sys.platform == "win32": diff --git a/distributed/worker.py b/distributed/worker.py index fe52f7960b..391002cd5a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1536,6 +1536,8 @@ async def close( # type: ignore for pc in self.periodic_callbacks.values(): pc.stop() + self.stop() + # Cancel async instructions await BaseWorker.close(self, timeout=timeout) @@ -1638,7 +1640,6 @@ def _close(executor, wait): executor=executor, wait=executor_wait ) # Just run it directly - self.stop() await self.rpc.close() self.status = Status.closed From d6758bdbef9bcd7087a3f4f9e69cd88f101b6b7c Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 28 Jul 2023 12:18:08 +0200 Subject: [PATCH 20/21] Fix linting (#8046) --- distributed/tests/test_scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index de33ef542b..d0703361d3 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1093,7 +1093,6 @@ async def test_restart_nanny_timeout_exceeded(c, s, a, b): b.kill_proceed.set() - @gen_cluster(client=True, nthreads=[("", 1)] * 2) async def test_restart_not_all_workers_return(c, s, a, b): with pytest.raises(TimeoutError, match="Waited for 2 worker"): From 9d9702e0ccc6a9aba8c7332799586298223286e4 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 28 Jul 2023 10:27:54 -0500 Subject: [PATCH 21/21] Use queued tasks in adaptive target (#8037) --- distributed/deploy/tests/test_adaptive.py | 25 ++++++++++++++++++++++- distributed/scheduler.py | 16 ++++++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index fcc15ee0e9..3aae49f346 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -19,7 +19,7 @@ ) from distributed.compatibility import LINUX, MACOS, WINDOWS from distributed.metrics import time -from distributed.utils_test import async_poll_for, gen_test, slowinc +from distributed.utils_test import async_poll_for, gen_cluster, gen_test, slowinc def test_adaptive_local_cluster(loop): @@ -484,3 +484,26 @@ async def test_adaptive_stopped(): pc = instance.periodic_callback await async_poll_for(lambda: not pc.is_running(), timeout=5) + + +@pytest.mark.parametrize("saturation", [1, float("inf")]) +@gen_cluster( + client=True, + nthreads=[], + config={ + "distributed.scheduler.default-task-durations": {"slowinc": 1000}, + }, +) +async def test_scale_up_large_tasks(c, s, saturation): + s.WORKER_SATURATION = saturation + futures = c.map(slowinc, range(10)) + while not s.tasks: + await asyncio.sleep(0.001) + + assert s.adaptive_target() == 10 + + more_futures = c.map(slowinc, range(200)) + while len(s.tasks) != 200: + await asyncio.sleep(0.001) + + assert s.adaptive_target() == 200 diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ca612cd9ba..4649dc205d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -37,6 +37,7 @@ import psutil from sortedcontainers import SortedDict, SortedSet from tlz import ( + concat, first, groupby, merge, @@ -45,6 +46,7 @@ partition, pluck, second, + take, valmap, ) from tornado.ioloop import IOLoop @@ -8061,15 +8063,23 @@ def adaptive_target(self, target_duration=None): target_duration = parse_timedelta(target_duration) # CPU + queued = take(100, concat([self.queued, self.unrunnable])) + queued_occupancy = 0 + for ts in queued: + if ts.prefix.duration_average == -1: + queued_occupancy += self.UNKNOWN_TASK_DURATION + else: + queued_occupancy += ts.prefix.duration_average + + if len(self.queued) + len(self.unrunnable) > 100: + queued_occupancy *= (len(self.queued) + len(self.unrunnable)) / 100 - # TODO consider any user-specified default task durations for queued tasks - queued_occupancy = len(self.queued) * self.UNKNOWN_TASK_DURATION cpu = math.ceil( (self.total_occupancy + queued_occupancy) / target_duration ) # TODO: threads per worker # Avoid a few long tasks from asking for many cores - tasks_ready = len(self.queued) + tasks_ready = len(self.queued) + len(self.unrunnable) for ws in self.workers.values(): tasks_ready += len(ws.processing)