Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure client_desires_keys does not corrupt Scheduler state #8827

Merged
merged 10 commits into from
Aug 20, 2024
2 changes: 1 addition & 1 deletion distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
if not self._client:
try:
self._client = get_client()
self._future = Future(self._key, inform=False)
self._future = Future(self._key, self._client)

Check warning on line 80 in distributed/actor.py

View check run for this annotation

Codecov / codecov/patch

distributed/actor.py#L80

Added line #L80 was not covered by tests
# ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable.
except ValueError:
self._client = None
Expand Down
54 changes: 25 additions & 29 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@
result = "\n".join([result, self.msg])
return result

def __reduce__(self):
return self.__class__, (self.key, self.reason, self.msg)


class FuturesCancelledError(CancelledError):
error_groups: list[CancelledFuturesGroup]
Expand Down Expand Up @@ -297,13 +300,12 @@
# Make sure this stays unique even across multiple processes or hosts
_uid = uuid.uuid4().hex

def __init__(self, key, client=None, inform=True, state=None, _id=None):
def __init__(self, key, client=None, state=None, _id=None):
self.key = key
self._cleared = False
self._client = client
self._id = _id or (Future._uid, next(Future._counter))
self._input_state = state
self._inform = inform
self._state = None
self._bind_late()

Expand All @@ -312,13 +314,11 @@
self._bind_late()
return self._client

def bind_client(self, client):
self._client = client
self._bind_late()

Check warning on line 319 in distributed/client.py

View check run for this annotation

Codecov / codecov/patch

distributed/client.py#L318-L319

Added lines #L318 - L319 were not covered by tests

def _bind_late(self):
if not self._client:
try:
client = get_client()
except ValueError:
client = None
self._client = client
if self._client and not self._state:
self._client._inc_ref(self.key)
self._generation = self._client.generation
Expand All @@ -328,15 +328,6 @@
else:
self._state = self._client.futures[self.key] = FutureState(self.key)

if self._inform:
self._client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [self.key],
"client": self._client.id,
}
)

if self._input_state is not None:
try:
handler = self._client._state_handlers[self._input_state]
Expand Down Expand Up @@ -588,13 +579,8 @@
except TypeError: # pragma: no cover
pass # Shutting down, add_callback may be None

@staticmethod
def make_future(key, id):
# Can't use kwargs in pickle __reduce__ methods
return Future(key=key, _id=id)

def __reduce__(self) -> str | tuple[Any, ...]:
return Future.make_future, (self.key, self._id)
return Future, (self.key,)

def __dask_tokenize__(self):
return (type(self).__name__, self.key, self._id)
Expand Down Expand Up @@ -2161,7 +2147,7 @@

with self._refcount_lock:
if key in self.futures:
return Future(key, self, inform=False)
return Future(key, self)

if allow_other_workers and workers is None:
raise ValueError("Only use allow_other_workers= if using workers=")
Expand Down Expand Up @@ -2661,7 +2647,7 @@
timeout=timeout,
)

out = {k: Future(k, self, inform=False) for k in data}
out = {k: Future(k, self) for k in data}
for key, typ in types.items():
self.futures[key].finish(type=typ)

Expand Down Expand Up @@ -2969,12 +2955,14 @@
async def _get_dataset(self, name, default=no_default):
with self.as_current():
out = await self.scheduler.publish_get(name=name, client=self.id)

if out is None:
if default is no_default:
raise KeyError(f"Dataset '{name}' not found")
else:
return default
for fut in futures_of(out["data"]):
fut.bind_client(self)
self._inform_scheduler_of_futures()

Check warning on line 2965 in distributed/client.py

View check run for this annotation

Codecov / codecov/patch

distributed/client.py#L2963-L2965

Added lines #L2963 - L2965 were not covered by tests
return out["data"]

def get_dataset(self, name, default=no_default, **kwargs):
Expand Down Expand Up @@ -3300,6 +3288,14 @@

return tuple(reversed(code))

def _inform_scheduler_of_futures(self):
self._send_to_scheduler(

Check warning on line 3292 in distributed/client.py

View check run for this annotation

Codecov / codecov/patch

distributed/client.py#L3292

Added line #L3292 was not covered by tests
{
"op": "client-desires-keys",
"keys": list(self.refcount),
}
)

def _graph_to_futures(
self,
dsk,
Expand Down Expand Up @@ -3348,7 +3344,7 @@
validate_key(key)

# Create futures before sending graph (helps avoid contention)
futures = {key: Future(key, self, inform=False) for key in keyset}
futures = {key: Future(key, self) for key in keyset}
# Circular import
from distributed.protocol import serialize
from distributed.protocol.serialize import ToPickle
Expand Down Expand Up @@ -3507,7 +3503,7 @@
if not changed:
changed = True
dsk = ensure_dict(dsk)
dsk[key] = Future(key, self, inform=False)
dsk[key] = Future(key, self)

if changed:
dsk, _ = dask.optimization.cull(dsk, keys)
Expand Down Expand Up @@ -6092,7 +6088,7 @@
stack.extend(x.values())
elif type(x) is SubgraphCallable:
stack.extend(x.dsk.values())
elif isinstance(x, Future):
elif isinstance(x, WrappedKey):
if x not in seen:
seen.add(x)
futures.append(x)
Expand Down
13 changes: 10 additions & 3 deletions distributed/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dask.utils import parse_timedelta

from distributed.client import Future
from distributed.utils import wait_for
from distributed.utils import Deadline, wait_for
from distributed.worker import get_client

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,15 +67,22 @@
self.scheduler.client_releases_keys(keys=keys, client="queue-%s" % name)

async def put(self, name=None, key=None, data=None, client=None, timeout=None):
deadline = Deadline.after(timeout)

Check warning on line 70 in distributed/queues.py

View check run for this annotation

Codecov / codecov/patch

distributed/queues.py#L70

Added line #L70 was not covered by tests
if key is not None:
while key not in self.scheduler.tasks:
await asyncio.sleep(0.01)
if deadline.expired:
raise TimeoutError(f"Task {key} unknown to scheduler.")

Check warning on line 75 in distributed/queues.py

View check run for this annotation

Codecov / codecov/patch

distributed/queues.py#L72-L75

Added lines #L72 - L75 were not covered by tests

record = {"type": "Future", "value": key}
self.future_refcount[name, key] += 1
self.scheduler.client_desires_keys(keys=[key], client="queue-%s" % name)
else:
record = {"type": "msgpack", "value": data}
await wait_for(self.queues[name].put(record), timeout=timeout)
await wait_for(self.queues[name].put(record), timeout=deadline.remaining)

Check warning on line 82 in distributed/queues.py

View check run for this annotation

Codecov / codecov/patch

distributed/queues.py#L82

Added line #L82 was not covered by tests

def future_release(self, name=None, key=None, client=None):
self.scheduler.client_desires_keys(keys=[key], client=client)

Check warning on line 85 in distributed/queues.py

View check run for this annotation

Codecov / codecov/patch

distributed/queues.py#L85

Added line #L85 was not covered by tests
self.future_refcount[name, key] -= 1
if self.future_refcount[name, key] == 0:
self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name)
Expand Down Expand Up @@ -265,7 +272,7 @@

def process(d):
if d["type"] == "Future":
value = Future(d["value"], self.client, inform=True, state=d["state"])
value = Future(d["value"], self.client, state=d["state"])

Check warning on line 275 in distributed/queues.py

View check run for this annotation

Codecov / codecov/patch

distributed/queues.py#L275

Added line #L275 was not covered by tests
if d["state"] == "erred":
value._state.set_error(d["exception"], d["traceback"])
self.client._send_to_scheduler(
Expand Down
12 changes: 5 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,7 @@ def clean(self) -> WorkerState:
)
ws._occupancy_cache = self.occupancy

ws.executing = {
ts.key: duration for ts, duration in self.executing.items() # type: ignore
}
ws.executing = {ts.key: duration for ts, duration in self.executing.items()} # type: ignore
return ws

def __repr__(self) -> str:
Expand Down Expand Up @@ -4634,7 +4632,7 @@ def _match_graph_with_tasks(
): # bad key
lost_keys.add(k)
logger.info("User asked for computation on lost data, %s", k)
del dsk[k]
dsk.pop(k, None)
del dependencies[k]
if k in keys:
keys.remove(k)
Expand Down Expand Up @@ -5595,8 +5593,8 @@ def client_desires_keys(self, keys: Collection[Key], client: str) -> None:
for k in keys:
ts = self.tasks.get(k)
if ts is None:
# For publish, queues etc.
ts = self.new_task(k, None, "released")
warnings.warn(f"Client desires key {k!r} but key is unknown.")
continue
if ts.who_wants is None:
ts.who_wants = set()
ts.who_wants.add(cs)
Expand Down Expand Up @@ -9345,7 +9343,7 @@ def transition(
def _materialize_graph(
graph: HighLevelGraph, global_annotations: dict[str, Any], validate: bool
) -> tuple[dict[Key, T_runspec], dict[Key, set[Key]], dict[str, dict[Key, Any]]]:
dsk = ensure_dict(graph)
dsk: dict = ensure_dict(graph)
if validate:
for k in dsk:
validate_key(k)
Expand Down
Loading
Loading