Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints to replication.http. #11856

Merged
merged 4 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11856.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to replication code.
2 changes: 1 addition & 1 deletion synapse/replication/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, hs: "HomeServer"):
super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)

def register_servlets(self, hs: "HomeServer"):
def register_servlets(self, hs: "HomeServer") -> None:
send_event.register_servlets(hs, self)
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)
Expand Down
31 changes: 20 additions & 11 deletions synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
import abc
import logging
import re
import urllib
import urllib.parse
from inspect import signature
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple

from prometheus_client import Counter, Gauge

from twisted.web.server import Request

from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer
from synapse.logging import opentracing
from synapse.logging.opentracing import trace
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string

Expand Down Expand Up @@ -113,10 +117,12 @@ def __init__(self, hs: "HomeServer"):
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret

def _check_auth(self, request) -> None:
def _check_auth(self, request: Request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")

if not auth_headers:
raise RuntimeError("Missing Authorization header.")
if len(auth_headers) > 1:
raise RuntimeError("Too many Authorization headers.")
parts = auth_headers[0].split(b" ")
Expand All @@ -129,7 +135,7 @@ def _check_auth(self, request) -> None:
raise RuntimeError("Invalid Authorization header.")

@abc.abstractmethod
async def _serialize_payload(**kwargs):
async def _serialize_payload(**kwargs) -> JsonDict:
"""Static method that is called when creating a request.

Concrete implementations should have explicit parameters (rather than
Expand All @@ -144,19 +150,20 @@ async def _serialize_payload(**kwargs):
return {}

@abc.abstractmethod
async def _handle_request(self, request, **kwargs):
async def _handle_request(
self, request: Request, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Handle incoming request.

This is called with the request object and PATH_ARGS.

Returns:
tuple[int, dict]: HTTP status code and a JSON serialisable dict
to be used as response body of request.
HTTP status code and a JSON serialisable dict to be used as response
body of request.
"""
pass

@classmethod
def make_client(cls, hs: "HomeServer"):
def make_client(cls, hs: "HomeServer") -> Callable:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""Create a client that makes requests.

Returns a callable that accepts the same parameters as
Expand All @@ -182,7 +189,7 @@ def make_client(cls, hs: "HomeServer"):
)

@trace(opname="outgoing_replication_request")
async def send_request(*, instance_name="master", **kwargs):
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
Expand Down Expand Up @@ -268,7 +275,7 @@ async def send_request(*, instance_name="master", **kwargs):

return send_request

def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
"""Called by the server to register this as a handler to the
appropriate path.
"""
Expand All @@ -289,7 +296,9 @@ def register(self, http_server):
self.__class__.__name__,
)

async def _check_auth_and_handle(self, request, **kwargs):
async def _check_auth_and_handle(
self, request: Request, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
Expand Down
38 changes: 28 additions & 10 deletions synapse/replication/http/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple

from twisted.web.server import Request

from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -48,14 +52,18 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()

@staticmethod
async def _serialize_payload(user_id, account_data_type, content):
async def _serialize_payload( # type: ignore[override]
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
user_id: str, account_data_type: str, content: JsonDict
) -> JsonDict:
payload = {
"content": content,
}

return payload

async def _handle_request(self, request, user_id, account_data_type):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_account_data_for_user(
Expand Down Expand Up @@ -89,14 +97,18 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()

@staticmethod
async def _serialize_payload(user_id, room_id, account_data_type, content):
async def _serialize_payload( # type: ignore[override]
user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> JsonDict:
payload = {
"content": content,
}

return payload

async def _handle_request(self, request, user_id, room_id, account_data_type):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_account_data_to_room(
Expand Down Expand Up @@ -130,14 +142,18 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()

@staticmethod
async def _serialize_payload(user_id, room_id, tag, content):
async def _serialize_payload( # type: ignore[override]
user_id: str, room_id: str, tag: str, content: JsonDict
) -> JsonDict:
payload = {
"content": content,
}

return payload

async def _handle_request(self, request, user_id, room_id, tag):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_tag_to_room(
Expand Down Expand Up @@ -173,11 +189,13 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()

@staticmethod
async def _serialize_payload(user_id, room_id, tag):
async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override]

return {}

async def _handle_request(self, request, user_id, room_id, tag):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_tag_from_room(
user_id,
room_id,
Expand All @@ -187,7 +205,7 @@ async def _handle_request(self, request, user_id, room_id, tag):
return 200, {"max_stream_id": max_stream_id}


def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserAccountDataRestServlet(hs).register(http_server)
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server)
Expand Down
14 changes: 10 additions & 4 deletions synapse/replication/http/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple

from twisted.web.server import Request

from synapse.http.server import HttpServer
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -63,14 +67,16 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()

@staticmethod
async def _serialize_payload(user_id):
async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
return {}

async def _handle_request(self, request, user_id):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
user_devices = await self.device_list_updater.user_device_resync(user_id)

return 200, user_devices


def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
Loading