Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Additional type hints for client REST servlets (part 5) (#10736)
Browse files Browse the repository at this point in the history
Additionally this enforce type hints on all function signatures inside
of the synapse.rest.client package.
  • Loading branch information
clokep authored Sep 3, 2021
1 parent f58d202 commit ecbfa4f
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 68 deletions.
1 change: 1 addition & 0 deletions changelog.d/10736.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to REST servlets.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

[mypy-synapse.rest.client.*]
disallow_untyped_defs = True

[mypy-pymacaroons.*]
ignore_missing_imports = True

Expand Down
19 changes: 19 additions & 0 deletions synapse/http/servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,25 @@ def parse_string_from_args(
return strings[0]


@overload
def parse_json_value_from_request(request: Request) -> JsonDict:
...


@overload
def parse_json_value_from_request(
request: Request, allow_empty_body: Literal[False]
) -> JsonDict:
...


@overload
def parse_json_value_from_request(
request: Request, allow_empty_body: bool = False
) -> Optional[JsonDict]:
...


def parse_json_value_from_request(
request: Request, allow_empty_body: bool = False
) -> Optional[JsonDict]:
Expand Down
6 changes: 4 additions & 2 deletions synapse/rest/admin/server_notice_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Awaitable, Optional, Tuple

from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError, SynapseError
Expand Down Expand Up @@ -101,7 +101,9 @@ async def on_POST(

return 200, {"event_id": event.event_id}

def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
def on_PUT(
self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
)
11 changes: 7 additions & 4 deletions synapse/rest/client/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
import logging
import re
from typing import Iterable, Pattern
from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast

from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
Expand Down Expand Up @@ -76,7 +76,10 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
)


def interactive_auth_handler(orig):
C = TypeVar("C", bound=Callable[..., Awaitable[Tuple[int, JsonDict]]])


def interactive_auth_handler(orig: C) -> C:
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns an Awaitable (errcode, body) response
Expand All @@ -91,10 +94,10 @@ async def on_POST(self, request):
await self.auth_handler.check_auth
"""

async def wrapped(*args, **kwargs):
async def wrapped(*args: Any, **kwargs: Any) -> Tuple[int, JsonDict]:
try:
return await orig(*args, **kwargs)
except InteractiveAuthIncompleteError as e:
return 401, e.result

return wrapped
return cast(C, wrapped)
10 changes: 7 additions & 3 deletions synapse/rest/client/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging
from functools import wraps
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple

from twisted.web.server import Request

Expand Down Expand Up @@ -43,14 +43,18 @@
logger = logging.getLogger(__name__)


def _validate_group_id(f):
def _validate_group_id(
f: Callable[..., Awaitable[Tuple[int, JsonDict]]]
) -> Callable[..., Awaitable[Tuple[int, JsonDict]]]:
"""Wrapper to validate the form of the group ID.
Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
"""

@wraps(f)
def wrapper(self, request: Request, group_id: str, *args, **kwargs):
def wrapper(
self: RestServlet, request: Request, group_id: str, *args: Any, **kwargs: Any
) -> Awaitable[Tuple[int, JsonDict]]:
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))

Expand Down
Loading

0 comments on commit ecbfa4f

Please sign in to comment.