Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type annotations to trace decorator #13328

Merged
merged 11 commits into from
Jul 19, 2022
1 change: 1 addition & 0 deletions changelog.d/13328.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `trace` decorator.
2 changes: 1 addition & 1 deletion synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ async def query_user_devices(
)

async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: int
self, destination: str, content: JsonDict, timeout: Optional[int]
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.

Expand Down
2 changes: 1 addition & 1 deletion synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ async def query_user_devices(
)

async def claim_client_keys(
self, destination: str, query_content: JsonDict, timeout: int
self, destination: str, query_content: JsonDict, timeout: Optional[int]
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.

Expand Down
16 changes: 9 additions & 7 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple

import attr
from canonicaljson import encode_canonical_json
Expand Down Expand Up @@ -92,7 +92,11 @@ def __init__(self, hs: "HomeServer"):

@trace
async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
self,
query_body: JsonDict,
timeout: int,
from_user_id: str,
from_device_id: Optional[str],
) -> JsonDict:
"""Handle a device key query from a client

Expand Down Expand Up @@ -120,9 +124,7 @@ async def query_devices(
the number of in-flight queries at a time.
"""
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
)
device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})

# separate users by domain.
# make a map from domain to user_id to device_ids
Expand Down Expand Up @@ -392,7 +394,7 @@ async def get_cross_signing_keys_from_cache(

@trace
async def query_local_devices(
self, query: Dict[str, Optional[List[str]]]
self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users

Expand Down Expand Up @@ -461,7 +463,7 @@ async def on_federation_query_client_keys(

@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
Expand Down
50 changes: 28 additions & 22 deletions synapse/logging/opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,13 @@ def interesting_function(*args, **kwargs):
return something_usual_and_useful


Operation names can be explicitly set for a function by passing the
operation name to ``trace``
Operation names can be explicitly set for a function by using ``trace_with_opname``:

.. code-block:: python

from synapse.logging.opentracing import trace
from synapse.logging.opentracing import trace_with_opname

@trace(opname="a_better_operation_name")
@trace_with_opname("a_better_operation_name")
def interesting_badly_named_function(*args, **kwargs):
# Does all kinds of cool and expected things
return something_usual_and_useful
Expand Down Expand Up @@ -798,33 +797,31 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
# Tracing decorators


def trace(func=None, opname: Optional[str] = None):
def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator to trace a function.
Sets the operation name to that of the function's or that given
as operation_name. See the module's doc string for usage
examples.
Decorator to trace a function with a custom opname.

See the module's doc string for usage examples.

"""

def decorator(func):
def decorator(func: Callable[P, R]) -> Callable[P, R]:
if opentracing is None:
return func # type: ignore[unreachable]

_opname = opname if opname else func.__name__

if inspect.iscoroutinefunction(func):

@wraps(func)
async def _trace_inner(*args, **kwargs):
with start_active_span(_opname):
return await func(*args, **kwargs)
async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
with start_active_span(opname):
return await func(*args, **kwargs) # type: ignore[misc]
Copy link
Member Author

@clokep clokep Jul 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For completeness this is ignoring:

error: Incompatible types in "await" (actual type "R", expected type "Awaitable[Any]")  [misc]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I think this is fine: within this branch, func is a coroutine so should return an awaitable. It's a shame mypy can't deduce this for itself, given the use of TypeGuard in typeshed.

Oh, but that bit of typeshed I quoted only just landed this morning! So we might find that this ignore is unnecessary in a future mypy release.


else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
@wraps(func)
def _trace_inner(*args, **kwargs):
scope = start_active_span(_opname)
def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
scope = start_active_span(opname)
scope.__enter__()

try:
Expand Down Expand Up @@ -858,12 +855,21 @@ def err_back(result: R) -> R:
scope.__exit__(type(e), None, e.__traceback__)
raise

return _trace_inner
return _trace_inner # type: ignore[return-value]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the error here, out of interest?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synapse/logging/opentracing.py:858: error: Incompatible return value type (got "Callable[P, Coroutine[Any, Any, R]]", expected "Callable[P, R]")  [return-value]

Actually now that I split this up I might be able to do simple overloads with returning Awaitable[R] in a couple situations and remove this ignore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that's possible actually. Any idea if we should do something about this one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As sad as I am to leave a typing problem lying around like this, I think I'd probably leave the ignore as is. The important thing is that we accurately annotate the decorator itself; its internals don't have an effect on the rest of the function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its internals don't have an effect on the rest of the function.

This was kind of my conclusion after messing with it for a bit.


if func:
return decorator(func)
else:
return decorator
return decorator


def trace(func: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to trace a function.

Sets the operation name to that of the function's name.

See the module's doc string for usage examples.
"""

return trace_with_opname(func.__name__)(func)


def tag_args(func: Callable[P, R]) -> Callable[P, R]:
Expand Down
4 changes: 2 additions & 2 deletions synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from synapse.http.server import HttpServer, is_method_cancellable
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.opentracing import trace
from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
Expand Down Expand Up @@ -196,7 +196,7 @@ def make_client(cls, hs: "HomeServer") -> Callable:
"ascii"
)

@trace(opname="outgoing_replication_request")
@trace_with_opname("outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
from synapse.types import JsonDict, StreamToken

from ._base import client_patterns, interactive_auth_handler
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(self, hs: "HomeServer"):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()

@trace(opname="upload_keys")
@trace_with_opname("upload_keys")
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:
Expand Down
13 changes: 8 additions & 5 deletions synapse/rest/client/room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple, cast

from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
Expand Down Expand Up @@ -127,7 +127,7 @@ async def on_PUT(
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
version = parse_string(request, "version")
version = parse_string(request, "version", required=True)

if session_id:
body = {"sessions": {session_id: body}}
Expand Down Expand Up @@ -196,8 +196,11 @@ async def on_GET(
user_id = requester.user.to_string()
version = parse_string(request, "version", required=True)

room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
room_keys = cast(
JsonDict,
await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
),
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
)

# Convert room_keys to the right format to return.
Expand Down Expand Up @@ -240,7 +243,7 @@ async def on_DELETE(

requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
version = parse_string(request, "version")
version = parse_string(request, "version", required=True)

ret = await self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/sendtodevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag, trace
from synapse.logging.opentracing import set_tag, trace_with_opname
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict

Expand All @@ -43,7 +43,7 @@ def __init__(self, hs: "HomeServer"):
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()

@trace(opname="sendToDevice")
@trace_with_opname("sendToDevice")
def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
Expand Down
12 changes: 6 additions & 6 deletions synapse/rest/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder

Expand Down Expand Up @@ -210,7 +210,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
logger.debug("Event formatting complete")
return 200, response_content

@trace(opname="sync.encode_response")
@trace_with_opname("sync.encode_response")
async def encode_response(
self,
time_now: int,
Expand Down Expand Up @@ -315,7 +315,7 @@ def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
]
}

@trace(opname="sync.encode_joined")
@trace_with_opname("sync.encode_joined")
async def encode_joined(
self,
rooms: List[JoinedSyncResult],
Expand All @@ -340,7 +340,7 @@ async def encode_joined(

return joined

@trace(opname="sync.encode_invited")
@trace_with_opname("sync.encode_invited")
async def encode_invited(
self,
rooms: List[InvitedSyncResult],
Expand Down Expand Up @@ -371,7 +371,7 @@ async def encode_invited(

return invited

@trace(opname="sync.encode_knocked")
@trace_with_opname("sync.encode_knocked")
async def encode_knocked(
self,
rooms: List[KnockedSyncResult],
Expand Down Expand Up @@ -420,7 +420,7 @@ async def encode_knocked(

return knocked

@trace(opname="sync.encode_archived")
@trace_with_opname("sync.encode_archived")
async def encode_archived(
self,
rooms: List[ArchivedSyncResult],
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def get_device_stream_token(self) -> int:

@trace
async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, str]]
self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.

Expand Down
Loading