Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Generics for ObservableDeferred #10491

Merged
merged 3 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10491.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type annotations for `ObservableDeferred`.
5 changes: 3 additions & 2 deletions synapse/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ def __init__(
self.last_notified_token = current_token
self.last_notified_ms = time_now_ms

with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
self.notify_deferred: ObservableDeferred[StreamToken] = ObservableDeferred(
defer.Deferred()
)

def notify(
self,
Expand Down
4 changes: 3 additions & 1 deletion synapse/storage/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ async def add_to_queue(
end_item = queue[-1]
else:
# need to make a new queue item
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
deferred = ObservableDeferred[_PersistResult](

Is that a thing we can do now?

Copy link
Contributor

@ShadowJonathan ShadowJonathan Jul 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First instinct tells me No (though not sure)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we can't, at least under Python 3.6. It produces an infinite recursion:

  File "/home/rav/work/synapse/synapse/storage/persist_events.py", line 174, in add_to_queue
    defer.Deferred(), consumeErrors=True
  File "/home/rav/.pyenv/versions/3.6.3/lib/python3.6/typing.py", line 1227, in __new__
    return _generic_new(cls.__next_in_mro__, cls, *args, **kwds)
  File "/home/rav/.pyenv/versions/3.6.3/lib/python3.6/typing.py", line 1193, in _generic_new
    obj.__orig_class__ = cls
  File "/home/rav/work/synapse/synapse/util/async_helpers.py", line 158, in __setattr__
    setattr(self._deferred, name, value)
  File "/home/rav/work/synapse/synapse/util/async_helpers.py", line 155, in __getattr__
    return getattr(self._deferred, name)
  File "/home/rav/work/synapse/synapse/util/async_helpers.py", line 155, in __getattr__
    return getattr(self._deferred, name)

No doubt this is fixable (probably by removing the getattr/setattr magic which tries to make ObservableDeferred a proxy for the original Deferred, for no good reason), but I can't be doing with that right now.

defer.Deferred(), consumeErrors=True
)

end_item = _EventPersistQueueItem(
events_and_contexts=[],
Expand Down
14 changes: 8 additions & 6 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Awaitable,
Callable,
Dict,
Generic,
Hashable,
Iterable,
List,
Expand All @@ -39,6 +40,7 @@
from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
from twisted.python.failure import Failure

from synapse.logging.context import (
PreserveLoggingContext,
Expand All @@ -52,7 +54,7 @@
_T = TypeVar("_T")


class ObservableDeferred:
class ObservableDeferred(Generic[_T]):
"""Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original
deferred.
Expand All @@ -70,7 +72,7 @@ class ObservableDeferred:

__slots__ = ["_deferred", "_observers", "_result"]

def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set())
Expand Down Expand Up @@ -115,15 +117,15 @@ def errback(f):

deferred.addCallbacks(callback, errback)

def observe(self) -> defer.Deferred:
def observe(self) -> "defer.Deferred[_T]":
"""Observe the underlying deferred.

This returns a brand new deferred that is resolved when the underlying
deferred is resolved. Interacting with the returned deferred does not
effect the underlying deferred.
"""
if not self._result:
d: "defer.Deferred[Any]" = defer.Deferred()
d: "defer.Deferred[_T]" = defer.Deferred()

def remove(r):
self._observers.discard(d)
Expand All @@ -137,7 +139,7 @@ def remove(r):
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)

def observers(self) -> List[defer.Deferred]:
def observers(self) -> "List[defer.Deferred[_T]]":
return self._observers

def has_called(self) -> bool:
Expand All @@ -146,7 +148,7 @@ def has_called(self) -> bool:
def has_succeeded(self) -> bool:
return self._result is not None and self._result[0] is True

def get_result(self) -> Any:
def get_result(self) -> Union[_T, Failure]:
return self._result[1]

def __getattr__(self, name: str) -> Any:
Expand Down