Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Make cached account data/tags/admin types immutable (#16325)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Sep 18, 2023
1 parent 85bfd47 commit c1e244c
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 50 deletions.
1 change: 1 addition & 0 deletions changelog.d/16325.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
14 changes: 7 additions & 7 deletions synapse/app/admin_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import sys
import tempfile
from typing import List, Mapping, Optional
from typing import List, Mapping, Optional, Sequence

from twisted.internet import defer, task

Expand Down Expand Up @@ -57,7 +57,7 @@
from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore
from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
from synapse.types import JsonDict, StateMap
from synapse.types import JsonMapping, StateMap
from synapse.util import SYNAPSE_VERSION
from synapse.util.logcontext import LoggingContext

Expand Down Expand Up @@ -198,15 +198,15 @@ def write_knock(
for event in state.values():
json.dump(event, fp=f)

def write_profile(self, profile: JsonDict) -> None:
def write_profile(self, profile: JsonMapping) -> None:
user_directory = os.path.join(self.base_directory, "user_data")
os.makedirs(user_directory, exist_ok=True)
profile_file = os.path.join(user_directory, "profile")

with open(profile_file, "a") as f:
json.dump(profile, fp=f)

def write_devices(self, devices: List[JsonDict]) -> None:
def write_devices(self, devices: Sequence[JsonMapping]) -> None:
user_directory = os.path.join(self.base_directory, "user_data")
os.makedirs(user_directory, exist_ok=True)
device_file = os.path.join(user_directory, "devices")
Expand All @@ -215,7 +215,7 @@ def write_devices(self, devices: List[JsonDict]) -> None:
with open(device_file, "a") as f:
json.dump(device, fp=f)

def write_connections(self, connections: List[JsonDict]) -> None:
def write_connections(self, connections: Sequence[JsonMapping]) -> None:
user_directory = os.path.join(self.base_directory, "user_data")
os.makedirs(user_directory, exist_ok=True)
connection_file = os.path.join(user_directory, "connections")
Expand All @@ -225,7 +225,7 @@ def write_connections(self, connections: List[JsonDict]) -> None:
json.dump(connection, fp=f)

def write_account_data(
self, file_name: str, account_data: Mapping[str, JsonDict]
self, file_name: str, account_data: Mapping[str, JsonMapping]
) -> None:
account_data_directory = os.path.join(
self.base_directory, "user_data", "account_data"
Expand All @@ -237,7 +237,7 @@ def write_account_data(
with open(account_data_file, "a") as f:
json.dump(account_data, fp=f)

def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None:
file_directory = os.path.join(self.base_directory, "media_ids")
os.makedirs(file_directory, exist_ok=True)
media_id_file = os.path.join(file_directory, media_id)
Expand Down
18 changes: 9 additions & 9 deletions synapse/handlers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

import abc
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set

from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo
from synapse.visibility import filter_events_for_client

if TYPE_CHECKING:
Expand All @@ -35,7 +35,7 @@ def __init__(self, hs: "HomeServer"):
self._state_storage_controller = self._storage_controllers.state
self._msc3866_enabled = hs.config.experimental.msc3866.enabled

async def get_whois(self, user: UserID) -> JsonDict:
async def get_whois(self, user: UserID) -> JsonMapping:
connections = []

sessions = await self._store.get_user_ip_and_agents(user)
Expand All @@ -55,7 +55,7 @@ async def get_whois(self, user: UserID) -> JsonDict:

return ret

async def get_user(self, user: UserID) -> Optional[JsonDict]:
async def get_user(self, user: UserID) -> Optional[JsonMapping]:
"""Function to get user details"""
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
user.to_string()
Expand Down Expand Up @@ -344,7 +344,7 @@ def write_knock(
raise NotImplementedError()

@abc.abstractmethod
def write_profile(self, profile: JsonDict) -> None:
def write_profile(self, profile: JsonMapping) -> None:
"""Write the profile of a user.
Args:
Expand All @@ -353,7 +353,7 @@ def write_profile(self, profile: JsonDict) -> None:
raise NotImplementedError()

@abc.abstractmethod
def write_devices(self, devices: List[JsonDict]) -> None:
def write_devices(self, devices: Sequence[JsonMapping]) -> None:
"""Write the devices of a user.
Args:
Expand All @@ -362,7 +362,7 @@ def write_devices(self, devices: List[JsonDict]) -> None:
raise NotImplementedError()

@abc.abstractmethod
def write_connections(self, connections: List[JsonDict]) -> None:
def write_connections(self, connections: Sequence[JsonMapping]) -> None:
"""Write the connections of a user.
Args:
Expand All @@ -372,7 +372,7 @@ def write_connections(self, connections: List[JsonDict]) -> None:

@abc.abstractmethod
def write_account_data(
self, file_name: str, account_data: Mapping[str, JsonDict]
self, file_name: str, account_data: Mapping[str, JsonMapping]
) -> None:
"""Write the account data of a user.
Expand All @@ -383,7 +383,7 @@ def write_account_data(
raise NotImplementedError()

@abc.abstractmethod
def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None:
"""Write the media's metadata of a user.
Exports only the metadata, as this can be fetched from the database via
read only. In order to access the files, a connection to the correct
Expand Down
27 changes: 16 additions & 11 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
MutableStateMap,
Requester,
RoomStreamToken,
Expand Down Expand Up @@ -1793,19 +1794,23 @@ async def _generate_sync_entry_for_account_data(
)

if push_rules_changed:
global_account_data = dict(global_account_data)
global_account_data[
AccountDataTypes.PUSH_RULES
] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
global_account_data = {
AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user(
sync_config.user
),
**global_account_data,
}
else:
all_global_account_data = await self.store.get_global_account_data_for_user(
user_id
)

global_account_data = dict(all_global_account_data)
global_account_data[
AccountDataTypes.PUSH_RULES
] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
global_account_data = {
AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user(
sync_config.user
),
**all_global_account_data,
}

account_data_for_user = (
await sync_config.filter_collection.filter_global_account_data(
Expand Down Expand Up @@ -1909,7 +1914,7 @@ async def _generate_sync_entry_for_rooms(
blocks_all_rooms
or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data()
):
account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {}
account_data_by_room: Mapping[str, Mapping[str, JsonMapping]] = {}
elif since_token and not sync_result_builder.full_state:
account_data_by_room = (
await self.store.get_updated_room_account_data_for_user(
Expand Down Expand Up @@ -2349,8 +2354,8 @@ async def _generate_room_entry(
sync_result_builder: "SyncResultBuilder",
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Mapping[str, Mapping[str, Any]]],
account_data: Mapping[str, JsonDict],
tags: Optional[Mapping[str, JsonMapping]],
account_data: Mapping[str, JsonMapping],
always_include: bool = False,
) -> None:
"""Populates the `joined` and `archived` section of `sync_result_builder`
Expand Down
8 changes: 4 additions & 4 deletions synapse/rest/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.registration import ExternalIDReuseException
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, JsonMapping, UserID

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -211,7 +211,7 @@ def __init__(self, hs: "HomeServer"):

async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
) -> Tuple[int, JsonMapping]:
await assert_requester_is_admin(self.auth, request)

target_user = UserID.from_string(user_id)
Expand All @@ -226,7 +226,7 @@ async def on_GET(

async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
) -> Tuple[int, JsonMapping]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester)

Expand Down Expand Up @@ -658,7 +658,7 @@ def __init__(self, hs: "HomeServer"):

async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
) -> Tuple[int, JsonMapping]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)

Expand Down
10 changes: 5 additions & 5 deletions synapse/rest/client/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, RoomID
from synapse.types import JsonDict, JsonMapping, RoomID

from ._base import client_patterns

Expand Down Expand Up @@ -95,7 +95,7 @@ async def on_PUT(

async def on_GET(
self, request: SynapseRequest, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
) -> Tuple[int, JsonMapping]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
Expand All @@ -106,7 +106,7 @@ async def on_GET(
and account_data_type == AccountDataTypes.PUSH_RULES
):
account_data: Optional[
JsonDict
JsonMapping
] = await self._push_rules_handler.push_rules_for_user(requester.user)
else:
account_data = await self.store.get_global_account_data_by_type_for_user(
Expand Down Expand Up @@ -236,7 +236,7 @@ async def on_GET(
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]:
) -> Tuple[int, JsonMapping]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
Expand All @@ -253,7 +253,7 @@ async def on_GET(
self._hs.config.experimental.msc4010_push_rules_account_data
and account_data_type == AccountDataTypes.PUSH_RULES
):
account_data: Optional[JsonDict] = {}
account_data: Optional[JsonMapping] = {}
else:
account_data = await self.store.get_account_data_for_room_and_type(
user_id, room_id, account_data_type
Expand Down
14 changes: 7 additions & 7 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_max_account_data_stream_id(self) -> int:
@cached()
async def get_global_account_data_for_user(
self, user_id: str
) -> Mapping[str, JsonDict]:
) -> Mapping[str, JsonMapping]:
"""
Get all the global client account_data for a user.
Expand Down Expand Up @@ -164,7 +164,7 @@ def get_global_account_data_for_user(
@cached()
async def get_room_account_data_for_user(
self, user_id: str
) -> Mapping[str, Mapping[str, JsonDict]]:
) -> Mapping[str, Mapping[str, JsonMapping]]:
"""
Get all of the per-room client account_data for a user.
Expand Down Expand Up @@ -213,7 +213,7 @@ def get_room_account_data_for_user_txn(
@cached(num_args=2, max_entries=5000, tree=True)
async def get_global_account_data_by_type_for_user(
self, user_id: str, data_type: str
) -> Optional[JsonDict]:
) -> Optional[JsonMapping]:
"""
Returns:
The account data.
Expand Down Expand Up @@ -265,7 +265,7 @@ def get_latest_stream_id_for_global_account_data_by_type_for_user_txn(
@cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Mapping[str, JsonDict]:
) -> Mapping[str, JsonMapping]:
"""Get all the client account_data for a user for a room.
Args:
Expand Down Expand Up @@ -296,7 +296,7 @@ def get_account_data_for_room_txn(
@cached(num_args=3, max_entries=5000, tree=True)
async def get_account_data_for_room_and_type(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[JsonDict]:
) -> Optional[JsonMapping]:
"""Get the client account_data of given type for a user for a room.
Args:
Expand Down Expand Up @@ -394,7 +394,7 @@ def get_updated_room_account_data_txn(

async def get_updated_global_account_data_for_user(
self, user_id: str, stream_id: int
) -> Dict[str, JsonDict]:
) -> Mapping[str, JsonMapping]:
"""Get all the global account_data that's changed for a user.
Args:
Expand Down
7 changes: 3 additions & 4 deletions synapse/storage/databases/main/experimental_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, FrozenSet

from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.types import StrCollection
from synapse.util.caches.descriptors import cached

if TYPE_CHECKING:
Expand All @@ -34,7 +33,7 @@ def __init__(
super().__init__(database, db_conn, hs)

@cached()
async def list_enabled_features(self, user_id: str) -> StrCollection:
async def list_enabled_features(self, user_id: str) -> FrozenSet[str]:
"""
Checks to see what features are enabled for a given user
Args:
Expand All @@ -49,7 +48,7 @@ async def list_enabled_features(self, user_id: str) -> StrCollection:
["feature"],
)

return [feature["feature"] for feature in enabled]
return frozenset(feature["feature"] for feature in enabled)

async def set_features_for_user(
self,
Expand Down
Loading

0 comments on commit c1e244c

Please sign in to comment.