From 6e556962ae4dd66f26796162ad626b1e12e239fe Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 15 Aug 2024 14:38:40 +0200 Subject: [PATCH] more fixes --- distributed/actor.py | 2 +- distributed/client.py | 45 +++++++++---------- distributed/queues.py | 3 +- distributed/scheduler.py | 8 ++-- distributed/tests/test_client.py | 75 -------------------------------- distributed/tests/test_spans.py | 25 +---------- distributed/variable.py | 13 ++++-- 7 files changed, 37 insertions(+), 134 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index 1fdbf5dae40..27d6071e0b0 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -77,7 +77,7 @@ def _try_bind_worker_client(self): if not self._client: try: self._client = get_client() - self._future = Future(self._key, inform=False) + self._future = Future(self._key, self._client, inform=False) # ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable. except ValueError: self._client = None diff --git a/distributed/client.py b/distributed/client.py index 2ece27802a8..dd8e386c914 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -297,12 +297,14 @@ class Future(WrappedKey): # Make sure this stays unique even across multiple processes or hosts _uid = uuid.uuid4().hex - def __init__(self, key, client=None, inform=True, state=None, _id=None): + def __init__(self, key, client=None, inform=None, state=None, _id=None): self.key = key self._cleared = False self._client = client self._id = _id or (Future._uid, next(Future._counter)) self._input_state = state + if inform: + raise RuntimeError("Futures should not be informed") self._inform = inform self._state = None self._bind_late() @@ -312,13 +314,11 @@ def client(self): self._bind_late() return self._client + def bind_client(self, client): + self._client = client + self._bind_late() + def _bind_late(self): - if not self._client: - try: - client = get_client() - except ValueError: - client = None - self._client = client if self._client and not self._state: self._client._inc_ref(self.key) self._generation = self._client.generation @@ -328,15 +328,6 @@ def _bind_late(self): else: self._state = self._client.futures[self.key] = FutureState(self.key) - if self._inform: - self._client._send_to_scheduler( - { - "op": "client-desires-keys", - "keys": [self.key], - "client": self._client.id, - } - ) - if self._input_state is not None: try: handler = self._client._state_handlers[self._input_state] @@ -588,13 +579,8 @@ def release(self): except TypeError: # pragma: no cover pass # Shutting down, add_callback may be None - @staticmethod - def make_future(key, id): - # Can't use kwargs in pickle __reduce__ methods - return Future(key=key, _id=id) - def __reduce__(self) -> str | tuple[Any, ...]: - return Future.make_future, (self.key, self._id) + return Future, (self.key,) def __dask_tokenize__(self): return (type(self).__name__, self.key, self._id) @@ -2969,12 +2955,14 @@ def list_datasets(self, **kwargs): async def _get_dataset(self, name, default=no_default): with self.as_current(): out = await self.scheduler.publish_get(name=name, client=self.id) - if out is None: if default is no_default: raise KeyError(f"Dataset '{name}' not found") else: return default + for fut in futures_of(out["data"]): + fut.bind_client(self) + self._inform_scheduler_of_futures() return out["data"] def get_dataset(self, name, default=no_default, **kwargs): @@ -3300,6 +3288,14 @@ def _get_computation_code( return tuple(reversed(code)) + def _inform_scheduler_of_futures(self): + self._send_to_scheduler( + { + "op": "client-desires-keys", + "keys": list(self.refcount), + } + ) + def _graph_to_futures( self, dsk, @@ -6092,7 +6088,7 @@ def futures_of(o, client=None): stack.extend(x.values()) elif type(x) is SubgraphCallable: stack.extend(x.dsk.values()) - elif isinstance(x, Future): + elif isinstance(x, WrappedKey): if x not in seen: seen.add(x) futures.append(x) @@ -6146,6 +6142,7 @@ def fire_and_forget(obj): "op": "client-desires-keys", "keys": [future.key], "client": "fire-and-forget", + "where": "fire-and-forget", } ) diff --git a/distributed/queues.py b/distributed/queues.py index 45f4f939613..7a313613ed0 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -82,6 +82,7 @@ async def put(self, name=None, key=None, data=None, client=None, timeout=None): await wait_for(self.queues[name].put(record), timeout=deadline.remaining) def future_release(self, name=None, key=None, client=None): + self.scheduler.client_desires_keys(keys=[key], client=client) self.future_refcount[name, key] -= 1 if self.future_refcount[name, key] == 0: self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name) @@ -271,7 +272,7 @@ async def _get(self, timeout=None, batch=False): def process(d): if d["type"] == "Future": - value = Future(d["value"], self.client, inform=True, state=d["state"]) + value = Future(d["value"], self.client, state=d["state"]) if d["state"] == "erred": value._state.set_error(d["exception"], d["traceback"]) self.client._send_to_scheduler( diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a9b28fc8f61..d22e8dee117 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -670,9 +670,7 @@ def clean(self) -> WorkerState: ) ws._occupancy_cache = self.occupancy - ws.executing = { - ts.key: duration for ts, duration in self.executing.items() # type: ignore - } + ws.executing = {ts.key: duration for ts, duration in self.executing.items()} return ws def __repr__(self) -> str: @@ -5591,7 +5589,7 @@ def client_desires_keys(self, keys: Collection[Key], client: str) -> None: for k in keys: ts = self.tasks.get(k) if ts is None: - warnings.warn(f"Client {client} desires key {k!r} but key is unknown.") + warnings.warn(f"Client desires key {k!r} but key is unknown.") continue if ts.who_wants is None: ts.who_wants = set() @@ -9339,7 +9337,7 @@ def transition( def _materialize_graph( graph: HighLevelGraph, global_annotations: dict[str, Any] ) -> tuple[dict[Key, T_runspec], dict[Key, set[Key]], dict[str, dict[Key, Any]]]: - dsk = ensure_dict(graph) + dsk: dict = ensure_dict(graph) for k in dsk: validate_key(k) annotations_by_type: defaultdict[str, dict[Key, Any]] = defaultdict(dict) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index f7a137e9eb8..c2374b4c3b6 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -102,7 +102,6 @@ dec, div, double, - ensure_no_new_clients, gen_cluster, gen_test, get_cert, @@ -3036,14 +3035,6 @@ async def test_rebalance_unprepared(c, s, a, b): s.validate_state() -@gen_cluster(client=True, config=NO_AMM) -async def test_rebalance_raises_on_explicit_missing_data(c, s, a, b): - """rebalance() raises KeyError if explicitly listed futures disappear""" - f = Future("x", client=c, state="memory") - with pytest.raises(KeyError, match="Could not rebalance keys:"): - await c.rebalance(futures=[f]) - - @gen_cluster(client=True) async def test_receive_lost_key(c, s, a, b): x = c.submit(inc, 1, workers=[a.address]) @@ -4150,51 +4141,6 @@ async def test_scatter_compute_store_lose_processing(c, s, a, b): assert z.status == "cancelled" -@gen_cluster() -async def test_serialize_future(s, a, b): - async with ( - Client(s.address, asynchronous=True) as c1, - Client(s.address, asynchronous=True) as c2, - ): - future = c1.submit(lambda: 1) - result = await future - - for ci in (c1, c2): - with ensure_no_new_clients(): - with ci.as_current(): - future2 = pickle.loads(pickle.dumps(future)) - assert future2.client is ci - assert future2.key in ci.futures - result2 = await future2 - assert result == result2 - with temp_default_client(ci): - future2 = pickle.loads(pickle.dumps(future)) - - -@gen_cluster() -async def test_serialize_future_without_client(s, a, b): - # Do not use a ctx manager to avoid having this being set as a current and/or default client - c1 = await Client(s.address, asynchronous=True, set_as_default=False) - try: - with ensure_no_new_clients(): - - def do_stuff(): - return 1 - - future = c1.submit(do_stuff) - pickled = pickle.dumps(future) - unpickled_fut = pickle.loads(pickled) - - with pytest.raises(RuntimeError): - await unpickled_fut - - with c1.as_current(): - unpickled_fut_ctx = pickle.loads(pickled) - assert await unpickled_fut_ctx == 1 - finally: - await c1.close() - - @gen_cluster() async def test_temp_default_client(s, a, b): async with ( @@ -5836,27 +5782,6 @@ async def test_client_with_name(s, a, b): assert "foo" in text -@gen_cluster(client=True) -async def test_future_defaults_to_default_client(c, s, a, b): - x = c.submit(inc, 1) - await wait(x) - - future = Future(x.key) - assert future.client is c - - -@gen_cluster(client=True) -async def test_future_auto_inform(c, s, a, b): - x = c.submit(inc, 1) - await wait(x) - - async with Client(s.address, asynchronous=True) as client: - future = Future(x.key, client) - - while future.status != "finished": - await asyncio.sleep(0.01) - - def test_client_async_before_loop_starts(cleanup): with pytest.raises( RuntimeError, diff --git a/distributed/tests/test_spans.py b/distributed/tests/test_spans.py index c269411bfba..63db4ccf4ce 100644 --- a/distributed/tests/test_spans.py +++ b/distributed/tests/test_spans.py @@ -8,7 +8,7 @@ from dask import delayed import distributed -from distributed import Client, Event, Future, Worker, span, wait +from distributed import Client, Event, Worker, span, wait from distributed.diagnostics.plugin import SchedulerPlugin from distributed.utils_test import ( NoSchedulerDelayWorker, @@ -386,25 +386,6 @@ def test_no_tags(): pass -@gen_cluster(client=True) -async def test_client_desires_keys_creates_tg(c, s, a, b): - """A TaskGroup object is created by client_desires_keys, and - only later gains runnable tasks - - See also - -------- - test_spans.py::test_client_desires_keys_creates_ts - test_spans.py::test_scatter_creates_ts - test_spans.py::test_scatter_creates_tg - """ - x0 = Future(key="x-0") - await wait_for_state("x-0", "released", s) - assert s.tasks["x-0"].group.span_id is None - x1 = c.submit(inc, 1, key="x-1") - assert await x1 == 2 - assert s.tasks["x-0"].group.span_id is not None - - @gen_cluster(client=True) async def test_scatter_creates_ts(c, s, a, b): """A TaskState object is created by scatter, and only later becomes runnable @@ -412,8 +393,6 @@ async def test_scatter_creates_ts(c, s, a, b): See also -------- test_scheduler.py::test_scatter_creates_ts - test_spans.py::test_client_desires_keys_creates_ts - test_spans.py::test_client_desires_keys_creates_tg test_spans.py::test_scatter_creates_tg """ x1 = (await c.scatter({"x": 1}, workers=[a.address]))["x"] @@ -433,8 +412,6 @@ async def test_scatter_creates_tg(c, s, a, b): See also -------- - test_spans.py::test_client_desires_keys_creates_ts - test_spans.py::test_client_desires_keys_creates_tg test_spans.py::test_scatter_creates_ts """ x0 = (await c.scatter({"x-0": 1}))["x-0"] diff --git a/distributed/variable.py b/distributed/variable.py index cc2d4abfd35..f02840d5034 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -39,7 +39,9 @@ def __init__(self, scheduler): {"variable_set": self.set, "variable_get": self.get} ) - self.scheduler.stream_handlers["variable-future-release"] = self.future_release + self.scheduler.stream_handlers[ + "variable-future-received-confirm" + ] = self.future_received_confirm self.scheduler.stream_handlers["variable_delete"] = self.delete async def set(self, name=None, key=None, data=None, client=None, timeout=None): @@ -73,7 +75,10 @@ async def release(self, key, name): self.scheduler.client_releases_keys(keys=[key], client="variable-%s" % name) del self.waiting[key, name] - async def future_release(self, name=None, key=None, token=None, client=None): + async def future_received_confirm( + self, name=None, key=None, token=None, client=None + ): + self.scheduler.client_desires_keys([key], client) self.waiting[key, name].remove(token) if not self.waiting[key, name]: async with self.waiting_conditions[name]: @@ -213,12 +218,12 @@ async def _get(self, timeout=None): timeout=timeout, name=self.name, client=self.client.id ) if d["type"] == "Future": - value = Future(d["value"], self.client, inform=True, state=d["state"]) + value = Future(d["value"], self.client, inform=False, state=d["state"]) if d["state"] == "erred": value._state.set_error(d["exception"], d["traceback"]) self.client._send_to_scheduler( { - "op": "variable-future-release", + "op": "variable-future-received-confirm", "name": self.name, "key": d["value"], "token": d["token"],