Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve type hints for attrs classes #16276

Merged
merged 6 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16276.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
2 changes: 1 addition & 1 deletion synapse/config/oembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class OEmbedEndpointConfig:
# The API endpoint to fetch.
api_endpoint: str
# The patterns to match.
url_patterns: List[Pattern]
url_patterns: List[Pattern[str]]
clokep marked this conversation as resolved.
Show resolved Hide resolved
# The supported formats.
formats: Optional[List[str]]

Expand Down
8 changes: 3 additions & 5 deletions synapse/storage/controllers/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,13 @@ def try_merge(self, task: "_EventPersistQueueTask") -> bool:


_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask]
_PersistResult = TypeVar("_PersistResult")


@attr.s(auto_attribs=True, slots=True)
class _EventPersistQueueItem:
class _EventPersistQueueItem(Generic[_PersistResult]):
task: _EventPersistQueueTask
deferred: ObservableDeferred
deferred: ObservableDeferred[_PersistResult]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this agrees with Rich prior to d15d241?


parent_opentracing_span_contexts: List = attr.ib(factory=list)
"""A list of opentracing spans waiting for this batch"""
Expand All @@ -168,9 +169,6 @@ class _EventPersistQueueItem:
"""The opentracing span under which the persistence actually happened"""


_PersistResult = TypeVar("_PersistResult")


class _EventPeristenceQueue(Generic[_PersistResult]):
"""Queues up tasks so that they can be processed with only one concurrent
transaction per room.
Expand Down
25 changes: 11 additions & 14 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import inspect
import itertools
import logging
import typing
from contextlib import asynccontextmanager
from typing import (
Any,
Expand All @@ -29,6 +30,7 @@
Collection,
Coroutine,
Dict,
Generator,
Generic,
Hashable,
Iterable,
Expand Down Expand Up @@ -398,7 +400,7 @@ class _LinearizerEntry:
# The number of things executing.
count: int
# Deferreds for the things blocked from executing.
deferreds: collections.OrderedDict
deferreds: typing.OrderedDict["defer.Deferred[None]", Literal[1]]


class Linearizer:
Expand Down Expand Up @@ -717,30 +719,25 @@ def failure_cb(val: Failure) -> None:
return new_d


# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DoneAwaitable: # should be: Generic[R]
class DoneAwaitable(Awaitable[R]):
"""Simple awaitable that returns the provided value."""

value: Any # should be: R
value: R

def __await__(self) -> Any:
return self

def __iter__(self) -> "DoneAwaitable":
return self

def __next__(self) -> None:
raise StopIteration(self.value)
def __await__(self) -> Generator[Any, None, R]:
yield None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it matters, but is the yield less efficient?

I'm not 100% sure I understand removing __iter__ and __next__.

Copy link
Contributor Author

@DMRobertson DMRobertson Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err, being honest it's because I couldn't work out how to annotate the __iter__ and __next__ in a way that mypy was happy with. Can dig up the errors if you'd like.

I think this is equivalent though:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I think it wasn't redundant because we were returning ourself, now __await__ itself is an iterator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I think it wasn't redundant because we were returning ourself, now __await__ itself is an iterator?

This explains why __next__ was needed before my change. But I assert that __iter__ is redundant these days:

$ python3.8 -m asyncio
asyncio REPL 3.8.17 (default, Jun  8 2023, 00:00:00) 
[GCC 13.1.1 20230511 (Red Hat 13.1.1-2)] on linux
Use "await" directly instead of "asyncio.run()".
Type "help", "copyright", "credits" or "license" for more information.
>>> import asyncio
>>> class DummyAwaitable:
...     def __init__(self, value):
...         self.value = value
...     
...     def __await__(self):
...         return self
...     
...     def __next__(self):
...         raise StopIteration(self.value)
... 
>>> 
>>> await DummyAwaitable(1234)
1234

AFAICS you used to need __iter__ so that you could yield from, back in the days when async wasn't a keyword (a = yield from b instead of a = await b.) If I try to use my class in an old-style coroutine:

>>> import asyncio
>>> @asyncio.coroutine
... def f():
...     val = yield from DummyAwaitable(5678)
...     return val
... 
<console>:2: DeprecationWarning: "@coroutine" decorator is deprecated since Python 3.8, use "async def" instead
>>> await f()
Traceback (most recent call last):
  File "/usr/lib64/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib64/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "<console>", line 1, in <module>
  File "<console>", line 3, in f
TypeError: 'DummyAwaitable' object is not iterable

But if we define __iter__:

>>> class OldDummy(DummyAwaitable):
...     def __iter__(self):
...         return self
... 
>>> @asyncio.coroutine
... def g():
...     val = yield from OldDummy(5678)
...     return val
... 
<console>:2: DeprecationWarning: "@coroutine" decorator is deprecated since Python 3.8, use "async def" instead
>>>     
>>> await g()
5678

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Tbh we could probably use asyncio.Future instead of writing our own)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Tbh we could probably use asyncio.Future instead of writing our own)

Alas, Future docs say:

Deprecated since version 3.10: Deprecation warning is emitted if loop is not specified and there is no running event loop.

return self.value
Comment on lines +723 to +730
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't resist trying to simplify this. Needs careful checking though, I find all this confusing.

C.f. a0aef0b



def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
"""Convert a value to an awaitable if not already an awaitable."""
if inspect.isawaitable(value):
assert isinstance(value, Awaitable)
return value

# For some reason mypy doesn't deduce that value is not Awaitable here, even though
# inspect.isawaitable returns a TypeGuard.
assert not isinstance(value, Awaitable)
return DoneAwaitable(value)


Expand Down
10 changes: 4 additions & 6 deletions synapse/util/caches/dictionary_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import enum
import logging
import threading
from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
from typing import Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union

import attr
from typing_extensions import Literal
Expand All @@ -33,10 +33,8 @@
DV = TypeVar("DV")


# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DictionaryEntry: # should be: Generic[DKT, DV].
class DictionaryEntry(Generic[DKT, DV]):
"""Returned when getting an entry from the cache

If `full` is true then `known_absent` will be the empty set.
Expand All @@ -50,8 +48,8 @@ class DictionaryEntry: # should be: Generic[DKT, DV].
"""

full: bool
known_absent: Set[Any] # should be: Set[DKT]
value: Dict[Any, Any] # should be: Dict[DKT, DV]
known_absent: Set[DKT]
value: Dict[DKT, DV]

def __len__(self) -> int:
return len(self.value)
Expand Down
20 changes: 12 additions & 8 deletions synapse/util/caches/expiringcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from collections import OrderedDict
from typing import Any, Generic, Optional, TypeVar, Union, overload
from typing import Any, Generic, Iterable, Optional, TypeVar, Union, overload

import attr
from typing_extensions import Literal
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get

self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict()
self._cache: OrderedDict[KT, _CacheEntry[VT]] = OrderedDict()

self.iterable = iterable

Expand All @@ -100,7 +100,10 @@ def evict(self) -> None:
while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False)
if self.iterable:
self.metrics.inc_evictions(EvictionReason.size, len(value.value))
# type-ignore, here and below: if self.iterable is true, then the value
# type VT should be Sized (i.e. have a __len__ method). We don't enforce
# this via the type system at present.
clokep marked this conversation as resolved.
Show resolved Hide resolved
self.metrics.inc_evictions(EvictionReason.size, len(value.value)) # type: ignore[arg-type]
else:
self.metrics.inc_evictions(EvictionReason.size)

Expand Down Expand Up @@ -134,7 +137,7 @@ def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
return default

if self.iterable:
self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value))
self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) # type: ignore[arg-type]
else:
self.metrics.inc_evictions(EvictionReason.invalidation)

Expand Down Expand Up @@ -182,7 +185,7 @@ async def _prune_cache(self) -> None:
for k in keys_to_delete:
value = self._cache.pop(k)
if self.iterable:
self.metrics.inc_evictions(EvictionReason.time, len(value.value))
self.metrics.inc_evictions(EvictionReason.time, len(value.value)) # type: ignore[arg-type]
else:
self.metrics.inc_evictions(EvictionReason.time)

Expand All @@ -195,7 +198,8 @@ async def _prune_cache(self) -> None:

def __len__(self) -> int:
if self.iterable:
return sum(len(entry.value) for entry in self._cache.values())
g: Iterable[int] = (len(entry.value) for entry in self._cache.values()) # type: ignore[arg-type]
return sum(g)
else:
return len(self._cache)

Expand All @@ -218,6 +222,6 @@ def set_cache_factor(self, factor: float) -> bool:


@attr.s(slots=True, auto_attribs=True)
class _CacheEntry:
class _CacheEntry(Generic[VT]):
time: int
value: Any
value: VT
10 changes: 5 additions & 5 deletions synapse/util/caches/ttlcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):

def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
# map from key to _CacheEntry
self._data: Dict[KT, _CacheEntry] = {}
self._data: Dict[KT, _CacheEntry[KT, VT]] = {}

# the _CacheEntries, sorted by expiry time
self._expiry_list: SortedList[_CacheEntry] = SortedList()
self._expiry_list: SortedList[_CacheEntry[KT, VT]] = SortedList()

self._timer = timer

Expand Down Expand Up @@ -160,11 +160,11 @@ def expire(self) -> None:


@attr.s(frozen=True, slots=True, auto_attribs=True)
class _CacheEntry: # Should be Generic[KT, VT]. See python-attrs/attrs#313
class _CacheEntry(Generic[KT, VT]):
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""TTLCache entry"""

# expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time: float
ttl: float
key: Any # should be KT
value: Any # should be VT
key: KT
value: VT
Loading