Skip to content

Commit

Permalink
Merge branch 'main' of github.com:dask/distributed into rootish-many-…
Browse files Browse the repository at this point in the history
…tasks
  • Loading branch information
mrocklin committed Jul 28, 2023
2 parents 8aaf5e5 + 9d9702e commit e3d393b
Show file tree
Hide file tree
Showing 39 changed files with 1,240 additions and 484 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
mv test_report.html test_short_report.html deploy/
- name: Deploy 🚀
uses: JamesIves/github-pages-deploy-action@v4.4.2
uses: JamesIves/github-pages-deploy-action@v4.4.3
with:
branch: gh-pages
folder: deploy
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies:
- pre-commit
- prometheus_client
- psutil
- pyarrow=7
- pyarrow>=7
- pytest
- pytest-cov
- pytest-faulthandler
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies:
- pre-commit
- prometheus_client
- psutil
- pyarrow=7
- pyarrow>=7
- pytest
- pytest-cov
- pytest-faulthandler
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/gpuci/axis.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ LINUX_VER:
- ubuntu18.04

RAPIDS_VER:
- "23.08"
- "23.10"

excludes:
62 changes: 60 additions & 2 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,7 +1502,17 @@ async def __aenter__(self):

async def __aexit__(self, exc_type, exc_value, traceback):
if self._previous_as_current:
_current_client.reset(self._previous_as_current)
try:
_current_client.reset(self._previous_as_current)
except ValueError as e:
if not e.args[0].endswith(" was created in a different Context"):
raise # pragma: nocover
warnings.warn(
"It is deprecated to enter and exit the Client context "
"manager from different tasks",
DeprecationWarning,
stacklevel=2,
)
await self._close(
# if we're handling an exception, we assume that it's more
# important to deliver that exception than shutdown gracefully.
Expand All @@ -1512,7 +1522,17 @@ async def __aexit__(self, exc_type, exc_value, traceback):

def __exit__(self, exc_type, exc_value, traceback):
if self._previous_as_current:
_current_client.reset(self._previous_as_current)
try:
_current_client.reset(self._previous_as_current)
except ValueError as e:
if not e.args[0].endswith(" was created in a different Context"):
raise # pragma: nocover
warnings.warn(
"It is deprecated to enter and exit the Client context "
"manager from different threads",
DeprecationWarning,
stacklevel=2,
)
self.close()

def __del__(self):
Expand Down Expand Up @@ -4840,6 +4860,44 @@ def register_scheduler_plugin(self, plugin, name=None, idempotent=False):
idempotent=idempotent,
)

async def _unregister_scheduler_plugin(self, name):
return await self.scheduler.unregister_scheduler_plugin(name=name)

def unregister_scheduler_plugin(self, name):
"""Unregisters a scheduler plugin
See https://distributed.readthedocs.io/en/latest/plugins.html#scheduler-plugins
Parameters
----------
name : str
Name of the plugin to unregister. See the :meth:`Client.register_scheduler_plugin`
docstring for more information.
Examples
--------
>>> class MyPlugin(SchedulerPlugin):
... def __init__(self, *args, **kwargs):
... pass # the constructor is up to you
... async def start(self, scheduler: Scheduler) -> None:
... pass
... async def before_close(self) -> None:
... pass
... async def close(self) -> None:
... pass
... def restart(self, scheduler: Scheduler) -> None:
... pass
>>> plugin = MyPlugin(1, 2, 3)
>>> client.register_scheduler_plugin(plugin, name='foo')
>>> client.unregister_scheduler_plugin(name='foo')
See Also
--------
register_scheduler_plugin
"""
return self.sync(self._unregister_scheduler_plugin, name=name)

def register_worker_callbacks(self, setup=None):
"""
Registers a setup callback function for all current and future workers.
Expand Down
30 changes: 4 additions & 26 deletions distributed/comm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys
import weakref
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import Any, ClassVar

import dask
Expand Down Expand Up @@ -264,20 +263,8 @@ async def on_connection(
) -> None:
local_info = {**comm.handshake_info(), **(handshake_overrides or {})}

timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, default="seconds")
try:
# Timeout is to ensure that we'll terminate connections eventually.
# Connector side will employ smaller timeouts and we should only
# reach this if the comm is dead anyhow.
await wait_for(comm.write(local_info), timeout=timeout)
handshake = await wait_for(comm.read(), timeout=timeout)
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
except Exception as e:
with suppress(Exception):
await comm.close()
raise CommClosedError(f"Comm {comm!r} closed.") from e
await comm.write(local_info)
handshake = await comm.read()

comm.remote_info = handshake
comm.remote_info["address"] = comm.peer_address
Expand Down Expand Up @@ -386,17 +373,8 @@ def time_left():
**comm.handshake_info(),
**(handshake_overrides or {}),
}
try:
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
handshake = await wait_for(comm.read(), time_left())
await wait_for(comm.write(local_info), time_left())
except Exception as exc:
with suppress(Exception):
await comm.close()
raise OSError(
f"Timed out during handshake while connecting to {addr} after {timeout} s"
) from exc
await comm.write(local_info)
handshake = await comm.read()

comm.remote_info = handshake
comm.remote_info["address"] = comm._peer_addr
Expand Down
9 changes: 4 additions & 5 deletions distributed/comm/tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,7 @@ class UnreliableBackend(tcp.TCPBackend):
listener.stop()


@pytest.mark.slow
@gen_test()
async def test_handshake_slow_comm(tcp, monkeypatch):
class SlowComm(tcp.TCP):
Expand Down Expand Up @@ -999,11 +1000,9 @@ def get_connector(self):

import dask

with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}):
with pytest.raises(
IOError, match="Timed out during handshake while connecting to"
):
await connect(listener.contact_address)
# The connect itself is fast. Only the handshake is slow
with dask.config.set({"distributed.comm.timeouts.connect": "500ms"}):
await connect(listener.contact_address)
finally:
listener.stop()

Expand Down
25 changes: 24 additions & 1 deletion distributed/deploy/tests/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from distributed.compatibility import LINUX, MACOS, WINDOWS
from distributed.metrics import time
from distributed.utils_test import async_poll_for, gen_test, slowinc
from distributed.utils_test import async_poll_for, gen_cluster, gen_test, slowinc


def test_adaptive_local_cluster(loop):
Expand Down Expand Up @@ -484,3 +484,26 @@ async def test_adaptive_stopped():
pc = instance.periodic_callback

await async_poll_for(lambda: not pc.is_running(), timeout=5)


@pytest.mark.parametrize("saturation", [1, float("inf")])
@gen_cluster(
client=True,
nthreads=[],
config={
"distributed.scheduler.default-task-durations": {"slowinc": 1000},
},
)
async def test_scale_up_large_tasks(c, s, saturation):
s.WORKER_SATURATION = saturation
futures = c.map(slowinc, range(10))
while not s.tasks:
await asyncio.sleep(0.001)

assert s.adaptive_target() == 10

more_futures = c.map(slowinc, range(200))
while len(s.tasks) != 200:
await asyncio.sleep(0.001)

assert s.adaptive_target() == 200
11 changes: 9 additions & 2 deletions distributed/diagnostics/graph_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid

from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.utils import TupleComparable


class GraphLayout(SchedulerPlugin):
Expand Down Expand Up @@ -48,7 +49,9 @@ def __init__(self, scheduler):
def update_graph(
self, scheduler, *, dependencies=None, priority=None, tasks=None, **kwargs
):
stack = sorted(tasks, key=lambda k: priority.get(k, 0), reverse=True)
stack = sorted(
tasks, key=lambda k: TupleComparable(priority.get(k, 0)), reverse=True
)
while stack:
key = stack.pop()
if key in self.x or key not in scheduler.tasks:
Expand All @@ -58,7 +61,11 @@ def update_graph(
if not all(dep in self.y for dep in deps):
stack.append(key)
stack.extend(
sorted(deps, key=lambda k: priority.get(k, 0), reverse=True)
sorted(
deps,
key=lambda k: TupleComparable(priority.get(k, 0)),
reverse=True,
)
)
continue
else:
Expand Down
5 changes: 4 additions & 1 deletion distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def transition(
start: TaskStateState,
finish: TaskStateState,
*args: Any,
stimulus_id: str,
**kwargs: Any,
) -> None:
"""Run whenever a task changes state
Expand All @@ -143,6 +144,8 @@ def transition(
One of released, waiting, processing, memory, error.
finish : string
Final state of the transition.
stimulus_id: string
ID of stimulus causing the transition.
*args, **kwargs :
More options passed when transitioning
This may include worker ID, compute time, etc.
Expand All @@ -164,7 +167,7 @@ def add_worker(self, scheduler: Scheduler, worker: str) -> None | Awaitable[None
"""

def remove_worker(
self, scheduler: Scheduler, worker: str
self, scheduler: Scheduler, worker: str, *, stimulus_id: str, **kwargs: Any
) -> None | Awaitable[None]:
"""Run when a worker leaves the cluster
Expand Down
8 changes: 6 additions & 2 deletions distributed/diagnostics/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ async def setup(self):
logger.debug("Set up Progress keys")

for k in errors:
self.transition(k, None, "erred", exception=True)
self.transition(
k, None, "erred", stimulus_id="progress-setup", exception=True
)

def transition(self, key, start, finish, *args, **kwargs):
if key in self.keys and start == "processing" and finish == "memory":
Expand Down Expand Up @@ -240,7 +242,9 @@ def group_key(k):
self.keys[k] = set()

for k in errors:
self.transition(k, None, "erred", exception=True)
self.transition(
k, None, "erred", stimulus_id="multiprogress-setup", exception=True
)
logger.debug("Set up Progress keys")

def transition(self, key, start, finish, *args, **kwargs):
Expand Down
11 changes: 11 additions & 0 deletions distributed/diagnostics/tests/test_graph_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,14 @@ async def test_unique_positions(c, s, a, b):

y_positions = [(gl.x[k], gl.y[k]) for k in gl.x]
assert len(y_positions) == len(set(y_positions))


@gen_cluster(client=True)
async def test_layout_scatter(c, s, a, b):
gl = GraphLayout(s)
s.add_plugin(gl)

data = await c.scatter([1, 2, 3], broadcast=True)
futures = [c.submit(sum, data) for _ in range(5)]
await wait(futures)
assert len(gl.state_updates) > 0
Loading

0 comments on commit e3d393b

Please sign in to comment.