Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing types to opentracing #13345

Merged
merged 6 commits into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion changelog.d/13328.misc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Add type hints to `trace` decorator.
Add missing type hints to open tracing module.
1 change: 1 addition & 0 deletions changelog.d/13345.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to open tracing module.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ disallow_untyped_defs = False
[mypy-synapse.http.matrixfederationclient]
disallow_untyped_defs = False

[mypy-synapse.logging.opentracing]
disallow_untyped_defs = False

[mypy-synapse.metrics._reactor_metrics]
disallow_untyped_defs = False
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
Expand Down
2 changes: 1 addition & 1 deletion synapse/federation/transport/server/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ async def new_func(
raise

# update the active opentracing span with the authenticated entity
set_tag("authenticated_entity", origin)
set_tag("authenticated_entity", str(origin))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jaeger automatically casts things to string, but according to the opentracing API these must be simple types.


# if the origin is authenticated and whitelisted, use its span context
# as the parent.
Expand Down
8 changes: 4 additions & 4 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict:
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)

set_tag("device", device)
set_tag("ips", ips)
set_tag("device", str(device))
clokep marked this conversation as resolved.
Show resolved Hide resolved
set_tag("ips", str(ips))

return device

Expand Down Expand Up @@ -170,7 +170,7 @@ async def get_user_ids_changed(
"""

set_tag("user_id", user_id)
set_tag("from_token", from_token)
set_tag("from_token", str(from_token))
now_room_key = self.store.get_room_max_token()

room_ids = await self.store.get_rooms_for_user(user_id)
Expand Down Expand Up @@ -795,7 +795,7 @@ async def incoming_device_list_update(
"""

set_tag("origin", origin)
set_tag("edu_content", edu_content)
set_tag("edu_content", str(edu_content))
user_id = edu_content.pop("user_id")
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
Expand Down
16 changes: 8 additions & 8 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ async def query_devices(
else:
remote_queries[user_id] = device_ids

set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))

# First get local devices.
# A map of destination -> failure response.
Expand Down Expand Up @@ -343,7 +343,7 @@ async def _query_devices_for_destination(
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
set_tag("reason", str(failure))

return

Expand Down Expand Up @@ -405,7 +405,7 @@ async def query_local_devices(
Returns:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
set_tag("local_query", str(query))
local_query: List[Tuple[str, Optional[str]]] = []

result_dict: Dict[str, Dict[str, dict]] = {}
Expand Down Expand Up @@ -477,8 +477,8 @@ async def claim_one_time_keys(
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys

set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))

results = await self.store.claim_e2e_one_time_keys(local_query)

Expand Down Expand Up @@ -508,7 +508,7 @@ async def claim_client_keys(destination: str) -> None:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
set_tag("reason", str(failure))

await make_deferred_yieldable(
defer.gatherResults(
Expand Down Expand Up @@ -611,7 +611,7 @@ async def upload_keys_for_user(

result = await self.store.count_e2e_one_time_keys(user_id, device_id)

set_tag("one_time_key_counts", result)
set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}

async def _upload_one_time_keys_for_user(
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional, cast

from typing_extensions import Literal

Expand Down Expand Up @@ -97,7 +97,7 @@ async def get_room_keys(
user_id, version, room_id, session_id
)

log_kv(results)
log_kv(cast(JsonDict, results))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit of a shame we need the noise here. We could maybe change our log_kv wrapper to accept a Mapping[str, Any] aka JsonMapping instead of JsonDict... but it's probably not worth it!

(Fine as it is, I'm just lamenting in passing)

return results

@trace
Expand Down
44 changes: 35 additions & 9 deletions synapse/logging/opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
Type,
TypeVar,
Union,
cast,
overload,
)

import attr
Expand Down Expand Up @@ -328,6 +330,7 @@ class _Sentinel(enum.Enum):

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")


def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
Expand All @@ -343,22 +346,43 @@ def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
return _only_if_tracing_inner


def ensure_active_span(message: str, ret=None):
@overload
def ensure_active_span(
message: str,
) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
...


@overload
def ensure_active_span(
message: str, ret: T
) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
...


def ensure_active_span(
message: str, ret: Optional[T] = None
) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
"""Executes the operation only if opentracing is enabled and there is an active span.
If there is no active span it logs message at the error level.

Args:
message: Message which fills in "There was no active span when trying to %s"
in the error log if there is no active span and opentracing is enabled.
ret (object): return value if opentracing is None or there is no active span.
ret: return value if opentracing is None or there is no active span.

Returns (object): The result of the func or ret if opentracing is disabled or there
Returns:
The result of the func, falling back to ret if opentracing is disabled or there
was no active span.
"""

def ensure_active_span_inner_1(func):
def ensure_active_span_inner_1(
func: Callable[P, R]
) -> Callable[P, Union[Optional[T], R]]:
@wraps(func)
def ensure_active_span_inner_2(*args, **kwargs):
def ensure_active_span_inner_2(
*args: P.args, **kwargs: P.kwargs
) -> Union[Optional[T], R]:
if not opentracing:
return ret

Expand Down Expand Up @@ -464,7 +488,7 @@ def start_active_span(
finish_on_close: bool = True,
*,
tracer: Optional["opentracing.Tracer"] = None,
):
) -> "opentracing.Scope":
"""Starts an active opentracing span.

Records the start time for the span, and sets it as the "active span" in the
Expand Down Expand Up @@ -502,7 +526,7 @@ def start_active_span_follows_from(
*,
inherit_force_tracing: bool = False,
tracer: Optional["opentracing.Tracer"] = None,
):
) -> "opentracing.Scope":
"""Starts an active opentracing span, with additional references to previous spans

Args:
Expand Down Expand Up @@ -717,7 +741,9 @@ def inject_response_headers(response_headers: Headers) -> None:
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")


@ensure_active_span("get the active span context as a dict", ret={})
@ensure_active_span(
"get the active span context as a dict", ret=cast(Dict[str, str], {})
)
def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
"""
Gets a span context as a dict. This can be used instead of manually
Expand Down Expand Up @@ -886,7 +912,7 @@ def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
for i, arg in enumerate(argspec.args[1:]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't your change, but I don't really understand the offset of 1 here. I'm guessing that we want to ignore the first arg to a function call because that arg is a self or cls. But if that's the case, shouldn't we pass start=1 to enumerate or else use args[i+1])?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, I was a bit confused by it too but decided not to touch it.

set_tag("ARG_" + arg, args[i]) # type: ignore[index]
set_tag("args", args[len(argspec.args) :]) # type: ignore[index]
set_tag("kwargs", kwargs)
set_tag("kwargs", str(kwargs))
return func(*args, **kwargs)

return _tag_args_inner
Expand Down
2 changes: 1 addition & 1 deletion synapse/metrics/background_process_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ async def run() -> Optional[R]:
f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
)
else:
ctx = nullcontext()
ctx = nullcontext() # type: ignore[assignment]
with ctx:
return await func(*args, **kwargs)
except Exception:
Expand Down
4 changes: 3 additions & 1 deletion synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

# We want to enforce they do pass us one, but we ignore it and return
# changes after the "to" as well as before.
set_tag("to", parse_string(request, "to"))
#
# XXX This does not enforce that "to" is passed.
set_tag("to", str(parse_string(request, "to")))

from_token = await StreamToken.from_string(self.store, from_token_string)

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ async def delete_messages_for_device(
(user_id, device_id), None
)

set_tag("last_deleted_stream_id", last_deleted_stream_id)
set_tag("last_deleted_stream_id", str(last_deleted_stream_id))

if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,8 +706,8 @@ async def get_user_devices_from_cache(
else:
results[user_id] = await self.get_cached_devices_for_user(user_id)

set_tag("in_cache", results)
set_tag("not_in_cache", user_ids_not_in_cache)
set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))

return user_ids_not_in_cache, results

Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def get_e2e_device_keys_for_cs_api(
key data. The key data will be a dict in the same format as the
DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
"""
set_tag("query_list", query_list)
set_tag("query_list", str(query_list))
if not query_list:
return {}

Expand Down Expand Up @@ -418,7 +418,7 @@ async def add_e2e_one_time_keys(
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", new_keys)
set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
Expand Down Expand Up @@ -1161,7 +1161,7 @@ def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
set_tag("device_keys", str(device_keys))

old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
Expand Down
29 changes: 20 additions & 9 deletions tests/logging/test_opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactorClock

Expand Down Expand Up @@ -40,6 +42,14 @@


class LogContextScopeManagerTestCase(TestCase):
"""
Test logging contexts and active opentracing spans.

There's casts throughout this from generic opentracing objects (e.g.
opentracing.Span) to the ones specific to Jaeger since they have additional
properties that these tests depend on. This is safe since the only supported
opentracing backend is Jaeger.
"""
if LogContextScopeManager is None:
skip = "Requires opentracing" # type: ignore[unreachable]
if jaeger_client is None:
Expand Down Expand Up @@ -69,7 +79,7 @@ def test_start_active_span(self) -> None:

# start_active_span should start and activate a span.
scope = start_active_span("span", tracer=self._tracer)
span = scope.span
span = cast(jaeger_client.Span, scope.span)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this file we reference a bunch of properties which are specific to jaeger's implementation of Span so we cast from the generic opentracing.Span to jaeger_client.Span.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before I saw your comment, I drafted:

Here and below: are the casts justified because we know we're in a logging context? I.e. so we don't have to worry about the sentinel span?

Your explanation shows my guess wasn't correct. Might it be worth a quick comment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully edab4ec makes sense?

self.assertEqual(self._tracer.active_span, span)
self.assertIsNotNone(span.start_time)

Expand All @@ -91,6 +101,7 @@ def test_nested_spans(self) -> None:
with LoggingContext("root context"):
with start_active_span("root span", tracer=self._tracer) as root_scope:
self.assertEqual(self._tracer.active_span, root_scope.span)
root_context = cast(jaeger_client.SpanContext, root_scope.span.context)

scope1 = start_active_span(
"child1",
Expand All @@ -99,27 +110,27 @@ def test_nested_spans(self) -> None:
self.assertEqual(
self._tracer.active_span, scope1.span, "child1 was not activated"
)
self.assertEqual(
scope1.span.context.parent_id, root_scope.span.context.span_id
)
context1 = cast(jaeger_client.SpanContext, scope1.span.context)
self.assertEqual(context1.parent_id, root_context.span_id)

scope2 = start_active_span_follows_from(
"child2",
contexts=(scope1,),
tracer=self._tracer,
)
self.assertEqual(self._tracer.active_span, scope2.span)
self.assertEqual(
scope2.span.context.parent_id, scope1.span.context.span_id
)
context2 = cast(jaeger_client.SpanContext, scope2.span.context)
self.assertEqual(context2.parent_id, context1.span_id)

with scope1, scope2:
pass

# the root scope should be restored
self.assertEqual(self._tracer.active_span, root_scope.span)
self.assertIsNotNone(scope2.span.end_time)
self.assertIsNotNone(scope1.span.end_time)
span2 = cast(jaeger_client.Span, scope2.span)
span1 = cast(jaeger_client.Span, scope1.span)
self.assertIsNotNone(span2.end_time)
self.assertIsNotNone(span1.end_time)

self.assertIsNone(self._tracer.active_span)

Expand Down