Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve type hints for attrs classes #16276

Merged
merged 6 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions synapse/storage/controllers/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,13 @@ def try_merge(self, task: "_EventPersistQueueTask") -> bool:


_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask]
_PersistResult = TypeVar("_PersistResult")


@attr.s(auto_attribs=True, slots=True)
class _EventPersistQueueItem:
class _EventPersistQueueItem(Generic[_PersistResult]):
task: _EventPersistQueueTask
deferred: ObservableDeferred
deferred: ObservableDeferred[_PersistResult]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this agrees with Rich prior to d15d241?


parent_opentracing_span_contexts: List = attr.ib(factory=list)
"""A list of opentracing spans waiting for this batch"""
Expand All @@ -168,9 +169,6 @@ class _EventPersistQueueItem:
"""The opentracing span under which the persistence actually happened"""


_PersistResult = TypeVar("_PersistResult")


class _EventPeristenceQueue(Generic[_PersistResult]):
"""Queues up tasks so that they can be processed with only one concurrent
transaction per room.
Expand Down
2 changes: 1 addition & 1 deletion synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import inspect
import itertools
import logging
import typing
from contextlib import asynccontextmanager
from typing import (
Any,
Expand All @@ -43,7 +44,6 @@
)

import attr
import typing
from typing_extensions import Concatenate, Literal, ParamSpec

from twisted.internet import defer
Expand Down
10 changes: 4 additions & 6 deletions synapse/util/caches/dictionary_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import enum
import logging
import threading
from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
from typing import Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union

import attr
from typing_extensions import Literal
Expand All @@ -33,10 +33,8 @@
DV = TypeVar("DV")


# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DictionaryEntry: # should be: Generic[DKT, DV].
class DictionaryEntry(Generic[DKT, DV]):
"""Returned when getting an entry from the cache

If `full` is true then `known_absent` will be the empty set.
Expand All @@ -50,8 +48,8 @@ class DictionaryEntry: # should be: Generic[DKT, DV].
"""

full: bool
known_absent: Set[Any] # should be: Set[DKT]
value: Dict[Any, Any] # should be: Dict[DKT, DV]
known_absent: Set[DKT]
value: Dict[DKT, DV]

def __len__(self) -> int:
return len(self.value)
Expand Down
6 changes: 3 additions & 3 deletions synapse/util/caches/expiringcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get

self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict()
self._cache: OrderedDict[KT, _CacheEntry[VT]] = OrderedDict()

self.iterable = iterable

Expand Down Expand Up @@ -218,6 +218,6 @@ def set_cache_factor(self, factor: float) -> bool:


@attr.s(slots=True, auto_attribs=True)
class _CacheEntry:
class _CacheEntry(Generic[VT]):
time: int
value: Any
value: VT
10 changes: 5 additions & 5 deletions synapse/util/caches/ttlcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):

def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
# map from key to _CacheEntry
self._data: Dict[KT, _CacheEntry] = {}
self._data: Dict[KT, _CacheEntry[KT, VT]] = {}

# the _CacheEntries, sorted by expiry time
self._expiry_list: SortedList[_CacheEntry] = SortedList()
self._expiry_list: SortedList[_CacheEntry[KT, VT]] = SortedList()

self._timer = timer

Expand Down Expand Up @@ -160,11 +160,11 @@ def expire(self) -> None:


@attr.s(frozen=True, slots=True, auto_attribs=True)
class _CacheEntry: # Should be Generic[KT, VT]. See python-attrs/attrs#313
class _CacheEntry(Generic[KT, VT]):
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""TTLCache entry"""

# expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time: float
ttl: float
key: Any # should be KT
value: Any # should be VT
key: KT
value: VT