Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints to REST servlets. #10817

Merged
merged 8 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion changelog.d/10785.misc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Convert the internal `FileInfo` class to attrs and add type hints.
Add missing type hints to REST servlets.
1 change: 1 addition & 0 deletions changelog.d/10817.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to REST servlets.
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

[mypy-synapse.rest.client.*]
[mypy-synapse.rest.*]
disallow_untyped_defs = True

[mypy-synapse.util.batching_queue]
Expand Down
11 changes: 8 additions & 3 deletions synapse/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import JsonResource
from typing import TYPE_CHECKING

from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin
from synapse.rest.client import (
account,
Expand Down Expand Up @@ -57,6 +59,9 @@
voip,
)

if TYPE_CHECKING:
from synapse.server import HomeServer


class ClientRestResource(JsonResource):
"""Matrix Client API REST resource.
Expand All @@ -68,12 +73,12 @@ class ClientRestResource(JsonResource):
* etc
"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)

@staticmethod
def register_servlets(client_resource, hs):
def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
versions.register_servlets(hs, client_resource)

# Deprecated in r0
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/admin/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()

async def on_GET(
self, request: SynapseRequest, user_id, device_id: str
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/admin/server_notice_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, hs: "HomeServer"):
self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)

def register(self, json_resource: HttpServer):
def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def __init__(self, hs: "HomeServer"):
self.nonces: Dict[str, int] = {}
self.hs = hs

def _clear_old_nonces(self):
def _clear_old_nonces(self) -> None:
"""
Clear out old nonces that are older than NONCE_TIMEOUT.
"""
Expand Down
39 changes: 17 additions & 22 deletions synapse/rest/consent/consent_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@
from hashlib import sha256
from http import HTTPStatus
from os import path
from typing import Dict, List
from typing import TYPE_CHECKING, Any, Dict, List

import jinja2
from jinja2 import TemplateNotFound

from twisted.web.server import Request

from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_bytes_from_args, parse_string
from synapse.types import UserID

if TYPE_CHECKING:
from synapse.server import HomeServer

# language to use for the templates. TODO: figure this out from Accept-Language
TEMPLATE_LANGUAGE = "en"

Expand Down Expand Up @@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource):
against the user.
"""

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): homeserver
"""
def __init__(self, hs: "HomeServer"):
super().__init__()

self.hs = hs
Expand Down Expand Up @@ -106,18 +107,14 @@ def __init__(self, hs):

self._hmac_secret = hs.config.form_secret.encode("utf-8")

async def _async_render_GET(self, request):
"""
Args:
request (twisted.web.http.Request):
"""
async def _async_render_GET(self, request: Request) -> None:
version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", default="")
userhmac = None
has_consented = False
public_version = username == ""
if not public_version:
args: Dict[bytes, List[bytes]] = request.args
args: Dict[bytes, List[bytes]] = request.args # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to ignore the type here ooi?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do this in many places -- Twisted has the wrong type for request.args.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, actually it isn't wrong, it's just missing. Without this we get:

synapse/rest/consent/consent_resource.py:118: error: Argument 1 to "parse_bytes_from_args" has incompatible type "Optional[Any]"; expected "Mapping[bytes, Sequence[bytes]]"  [arg-type]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling this might be the 154312th time I've asked that Q....

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling this might be the 154312th time I've asked that Q....

That's OK! I think it might be fixed in the next Twisted. I should double check / fix it.

userhmac_bytes = parse_bytes_from_args(args, "h", required=True)

self._check_hash(username, userhmac_bytes)
Expand Down Expand Up @@ -147,14 +144,10 @@ async def _async_render_GET(self, request):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")

async def _async_render_POST(self, request):
"""
Args:
request (twisted.web.http.Request):
"""
async def _async_render_POST(self, request: Request) -> None:
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
args: Dict[bytes, List[bytes]] = request.args
args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac = parse_bytes_from_args(args, "h", required=True)

self._check_hash(username, userhmac)
Expand All @@ -177,7 +170,9 @@ async def _async_render_POST(self, request):
except TemplateNotFound:
raise NotFoundError("success.html not found")

def _render_template(self, request, template_name, **template_args):
def _render_template(
self, request: Request, template_name: str, **template_args: Any
) -> None:
# get_template checks for ".." so we don't need to worry too much
# about path traversal here.
template_html = self._jinja_env.get_template(
Expand All @@ -186,11 +181,11 @@ def _render_template(self, request, template_name, **template_args):
html = template_html.render(**template_args)
respond_with_html(request, 200, html)

def _check_hash(self, userid, userhmac):
def _check_hash(self, userid: str, userhmac: bytes) -> None:
"""
Args:
userid (unicode):
userhmac (bytes):
userid:
userhmac:

Raises:
SynapseError if the hash doesn't match
Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from twisted.web.resource import Resource
from twisted.web.server import Request


class HealthResource(Resource):
Expand All @@ -25,6 +26,6 @@ class HealthResource(Resource):

isLeaf = 1

def render_GET(self, request):
def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", b"text/plain")
return b"OK"
7 changes: 6 additions & 1 deletion synapse/rest/key/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from twisted.web.resource import Resource

from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey

if TYPE_CHECKING:
from synapse.server import HomeServer


class KeyApiV2Resource(Resource):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs))
15 changes: 10 additions & 5 deletions synapse/rest/key/v2/local_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
from typing import TYPE_CHECKING

from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64

from twisted.web.resource import Resource
from twisted.web.server import Request

from synapse.http.server import respond_with_json_bytes
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,18 +63,18 @@ class LocalKey(Resource):

isLeaf = True

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.config = hs.config
self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)

def update_response_body(self, time_now_msec):
def update_response_body(self, time_now_msec: int) -> None:
refresh_interval = self.config.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object())

def response_json_object(self):
def response_json_object(self) -> JsonDict:
verify_keys = {}
for key in self.config.signing_key:
verify_key_bytes = key.verify_key.encode()
Expand All @@ -94,7 +99,7 @@ def response_json_object(self):
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object

def render_GET(self, request):
def render_GET(self, request: Request) -> int:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
Expand Down
30 changes: 21 additions & 9 deletions synapse/rest/key/v2/remote_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,23 @@
# limitations under the License.

import logging
from typing import Dict
from typing import TYPE_CHECKING, Dict

from signedjson.sign import sign_json

from twisted.web.server import Request

from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -85,7 +91,7 @@ class RemoteKey(DirectServeJsonResource):

isLeaf = True

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()

self.fetcher = ServerKeyFetcher(hs)
Expand All @@ -94,7 +100,8 @@ def __init__(self, hs):
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config

async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
assert request.postpath is not None
if len(request.postpath) == 1:
(server,) = request.postpath
query: dict = {server.decode("ascii"): {}}
Expand All @@ -110,14 +117,19 @@ async def _async_render_GET(self, request):

await self.query_keys(request, query, query_remote_on_cache_miss=True)

async def _async_render_POST(self, request):
async def _async_render_POST(self, request: Request) -> None:
content = parse_json_object_from_request(request)

query = content["server_keys"]

await self.query_keys(request, query, query_remote_on_cache_miss=True)

async def query_keys(self, request, query, query_remote_on_cache_miss=False):
async def query_keys(
self,
request: Request,
query: JsonDict,
query_remote_on_cache_miss: bool = False,
) -> None:
logger.info("Handling query for keys %r", query)

store_queries = []
Expand All @@ -142,8 +154,8 @@ async def query_keys(self, request, query, query_remote_on_cache_miss=False):

# Note that the value is unused.
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]
for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in key_results]

if not results and key_id is not None:
cache_misses.setdefault(server_name, {})[key_id] = 0
Expand Down Expand Up @@ -230,6 +242,6 @@ async def query_keys(self, request, query, query_remote_on_cache_miss=False):

signed_keys.append(key_json)

results = {"server_keys": signed_keys}
response = {"server_keys": signed_keys}

respond_with_json(request, 200, results, canonical_json=True)
respond_with_json(request, 200, response, canonical_json=True)
Loading