Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Aug 15, 2024
1 parent 883d1de commit 6e55696
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 134 deletions.
2 changes: 1 addition & 1 deletion distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _try_bind_worker_client(self):
if not self._client:
try:
self._client = get_client()
self._future = Future(self._key, inform=False)
self._future = Future(self._key, self._client, inform=False)
# ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable.
except ValueError:
self._client = None
Expand Down
45 changes: 21 additions & 24 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,14 @@ class Future(WrappedKey):
# Make sure this stays unique even across multiple processes or hosts
_uid = uuid.uuid4().hex

def __init__(self, key, client=None, inform=True, state=None, _id=None):
def __init__(self, key, client=None, inform=None, state=None, _id=None):
self.key = key
self._cleared = False
self._client = client
self._id = _id or (Future._uid, next(Future._counter))
self._input_state = state
if inform:
raise RuntimeError("Futures should not be informed")
self._inform = inform
self._state = None
self._bind_late()
Expand All @@ -312,13 +314,11 @@ def client(self):
self._bind_late()
return self._client

def bind_client(self, client):
self._client = client
self._bind_late()

def _bind_late(self):
if not self._client:
try:
client = get_client()
except ValueError:
client = None
self._client = client
if self._client and not self._state:
self._client._inc_ref(self.key)
self._generation = self._client.generation
Expand All @@ -328,15 +328,6 @@ def _bind_late(self):
else:
self._state = self._client.futures[self.key] = FutureState(self.key)

if self._inform:
self._client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [self.key],
"client": self._client.id,
}
)

if self._input_state is not None:
try:
handler = self._client._state_handlers[self._input_state]
Expand Down Expand Up @@ -588,13 +579,8 @@ def release(self):
except TypeError: # pragma: no cover
pass # Shutting down, add_callback may be None

@staticmethod
def make_future(key, id):
# Can't use kwargs in pickle __reduce__ methods
return Future(key=key, _id=id)

def __reduce__(self) -> str | tuple[Any, ...]:
return Future.make_future, (self.key, self._id)
return Future, (self.key,)

def __dask_tokenize__(self):
return (type(self).__name__, self.key, self._id)
Expand Down Expand Up @@ -2969,12 +2955,14 @@ def list_datasets(self, **kwargs):
async def _get_dataset(self, name, default=no_default):
with self.as_current():
out = await self.scheduler.publish_get(name=name, client=self.id)

if out is None:
if default is no_default:
raise KeyError(f"Dataset '{name}' not found")
else:
return default
for fut in futures_of(out["data"]):
fut.bind_client(self)
self._inform_scheduler_of_futures()
return out["data"]

def get_dataset(self, name, default=no_default, **kwargs):
Expand Down Expand Up @@ -3300,6 +3288,14 @@ def _get_computation_code(

return tuple(reversed(code))

def _inform_scheduler_of_futures(self):
self._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": list(self.refcount),
}
)

def _graph_to_futures(
self,
dsk,
Expand Down Expand Up @@ -6092,7 +6088,7 @@ def futures_of(o, client=None):
stack.extend(x.values())
elif type(x) is SubgraphCallable:
stack.extend(x.dsk.values())
elif isinstance(x, Future):
elif isinstance(x, WrappedKey):
if x not in seen:
seen.add(x)
futures.append(x)
Expand Down Expand Up @@ -6146,6 +6142,7 @@ def fire_and_forget(obj):
"op": "client-desires-keys",
"keys": [future.key],
"client": "fire-and-forget",
"where": "fire-and-forget",
}
)

Expand Down
3 changes: 2 additions & 1 deletion distributed/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ async def put(self, name=None, key=None, data=None, client=None, timeout=None):
await wait_for(self.queues[name].put(record), timeout=deadline.remaining)

def future_release(self, name=None, key=None, client=None):
self.scheduler.client_desires_keys(keys=[key], client=client)
self.future_refcount[name, key] -= 1
if self.future_refcount[name, key] == 0:
self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name)
Expand Down Expand Up @@ -271,7 +272,7 @@ async def _get(self, timeout=None, batch=False):

def process(d):
if d["type"] == "Future":
value = Future(d["value"], self.client, inform=True, state=d["state"])
value = Future(d["value"], self.client, state=d["state"])
if d["state"] == "erred":
value._state.set_error(d["exception"], d["traceback"])
self.client._send_to_scheduler(
Expand Down
8 changes: 3 additions & 5 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,7 @@ def clean(self) -> WorkerState:
)
ws._occupancy_cache = self.occupancy

ws.executing = {
ts.key: duration for ts, duration in self.executing.items() # type: ignore
}
ws.executing = {ts.key: duration for ts, duration in self.executing.items()}
return ws

def __repr__(self) -> str:
Expand Down Expand Up @@ -5591,7 +5589,7 @@ def client_desires_keys(self, keys: Collection[Key], client: str) -> None:
for k in keys:
ts = self.tasks.get(k)
if ts is None:
warnings.warn(f"Client {client} desires key {k!r} but key is unknown.")
warnings.warn(f"Client desires key {k!r} but key is unknown.")
continue
if ts.who_wants is None:
ts.who_wants = set()
Expand Down Expand Up @@ -9339,7 +9337,7 @@ def transition(
def _materialize_graph(
graph: HighLevelGraph, global_annotations: dict[str, Any]
) -> tuple[dict[Key, T_runspec], dict[Key, set[Key]], dict[str, dict[Key, Any]]]:
dsk = ensure_dict(graph)
dsk: dict = ensure_dict(graph)
for k in dsk:
validate_key(k)
annotations_by_type: defaultdict[str, dict[Key, Any]] = defaultdict(dict)
Expand Down
75 changes: 0 additions & 75 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@
dec,
div,
double,
ensure_no_new_clients,
gen_cluster,
gen_test,
get_cert,
Expand Down Expand Up @@ -3036,14 +3035,6 @@ async def test_rebalance_unprepared(c, s, a, b):
s.validate_state()


@gen_cluster(client=True, config=NO_AMM)
async def test_rebalance_raises_on_explicit_missing_data(c, s, a, b):
"""rebalance() raises KeyError if explicitly listed futures disappear"""
f = Future("x", client=c, state="memory")
with pytest.raises(KeyError, match="Could not rebalance keys:"):
await c.rebalance(futures=[f])


@gen_cluster(client=True)
async def test_receive_lost_key(c, s, a, b):
x = c.submit(inc, 1, workers=[a.address])
Expand Down Expand Up @@ -4150,51 +4141,6 @@ async def test_scatter_compute_store_lose_processing(c, s, a, b):
assert z.status == "cancelled"


@gen_cluster()
async def test_serialize_future(s, a, b):
async with (
Client(s.address, asynchronous=True) as c1,
Client(s.address, asynchronous=True) as c2,
):
future = c1.submit(lambda: 1)
result = await future

for ci in (c1, c2):
with ensure_no_new_clients():
with ci.as_current():
future2 = pickle.loads(pickle.dumps(future))
assert future2.client is ci
assert future2.key in ci.futures
result2 = await future2
assert result == result2
with temp_default_client(ci):
future2 = pickle.loads(pickle.dumps(future))


@gen_cluster()
async def test_serialize_future_without_client(s, a, b):
# Do not use a ctx manager to avoid having this being set as a current and/or default client
c1 = await Client(s.address, asynchronous=True, set_as_default=False)
try:
with ensure_no_new_clients():

def do_stuff():
return 1

future = c1.submit(do_stuff)
pickled = pickle.dumps(future)
unpickled_fut = pickle.loads(pickled)

with pytest.raises(RuntimeError):
await unpickled_fut

with c1.as_current():
unpickled_fut_ctx = pickle.loads(pickled)
assert await unpickled_fut_ctx == 1
finally:
await c1.close()


@gen_cluster()
async def test_temp_default_client(s, a, b):
async with (
Expand Down Expand Up @@ -5836,27 +5782,6 @@ async def test_client_with_name(s, a, b):
assert "foo" in text


@gen_cluster(client=True)
async def test_future_defaults_to_default_client(c, s, a, b):
x = c.submit(inc, 1)
await wait(x)

future = Future(x.key)
assert future.client is c


@gen_cluster(client=True)
async def test_future_auto_inform(c, s, a, b):
x = c.submit(inc, 1)
await wait(x)

async with Client(s.address, asynchronous=True) as client:
future = Future(x.key, client)

while future.status != "finished":
await asyncio.sleep(0.01)


def test_client_async_before_loop_starts(cleanup):
with pytest.raises(
RuntimeError,
Expand Down
25 changes: 1 addition & 24 deletions distributed/tests/test_spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dask import delayed

import distributed
from distributed import Client, Event, Future, Worker, span, wait
from distributed import Client, Event, Worker, span, wait
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.utils_test import (
NoSchedulerDelayWorker,
Expand Down Expand Up @@ -386,34 +386,13 @@ def test_no_tags():
pass


@gen_cluster(client=True)
async def test_client_desires_keys_creates_tg(c, s, a, b):
"""A TaskGroup object is created by client_desires_keys, and
only later gains runnable tasks
See also
--------
test_spans.py::test_client_desires_keys_creates_ts
test_spans.py::test_scatter_creates_ts
test_spans.py::test_scatter_creates_tg
"""
x0 = Future(key="x-0")
await wait_for_state("x-0", "released", s)
assert s.tasks["x-0"].group.span_id is None
x1 = c.submit(inc, 1, key="x-1")
assert await x1 == 2
assert s.tasks["x-0"].group.span_id is not None


@gen_cluster(client=True)
async def test_scatter_creates_ts(c, s, a, b):
"""A TaskState object is created by scatter, and only later becomes runnable
See also
--------
test_scheduler.py::test_scatter_creates_ts
test_spans.py::test_client_desires_keys_creates_ts
test_spans.py::test_client_desires_keys_creates_tg
test_spans.py::test_scatter_creates_tg
"""
x1 = (await c.scatter({"x": 1}, workers=[a.address]))["x"]
Expand All @@ -433,8 +412,6 @@ async def test_scatter_creates_tg(c, s, a, b):
See also
--------
test_spans.py::test_client_desires_keys_creates_ts
test_spans.py::test_client_desires_keys_creates_tg
test_spans.py::test_scatter_creates_ts
"""
x0 = (await c.scatter({"x-0": 1}))["x-0"]
Expand Down
13 changes: 9 additions & 4 deletions distributed/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def __init__(self, scheduler):
{"variable_set": self.set, "variable_get": self.get}
)

self.scheduler.stream_handlers["variable-future-release"] = self.future_release
self.scheduler.stream_handlers[
"variable-future-received-confirm"
] = self.future_received_confirm
self.scheduler.stream_handlers["variable_delete"] = self.delete

async def set(self, name=None, key=None, data=None, client=None, timeout=None):
Expand Down Expand Up @@ -73,7 +75,10 @@ async def release(self, key, name):
self.scheduler.client_releases_keys(keys=[key], client="variable-%s" % name)
del self.waiting[key, name]

async def future_release(self, name=None, key=None, token=None, client=None):
async def future_received_confirm(
self, name=None, key=None, token=None, client=None
):
self.scheduler.client_desires_keys([key], client)
self.waiting[key, name].remove(token)
if not self.waiting[key, name]:
async with self.waiting_conditions[name]:
Expand Down Expand Up @@ -213,12 +218,12 @@ async def _get(self, timeout=None):
timeout=timeout, name=self.name, client=self.client.id
)
if d["type"] == "Future":
value = Future(d["value"], self.client, inform=True, state=d["state"])
value = Future(d["value"], self.client, inform=False, state=d["state"])
if d["state"] == "erred":
value._state.set_error(d["exception"], d["traceback"])
self.client._send_to_scheduler(
{
"op": "variable-future-release",
"op": "variable-future-received-confirm",
"name": self.name,
"key": d["value"],
"token": d["token"],
Expand Down

0 comments on commit 6e55696

Please sign in to comment.