Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Trace functions which return Awaitable #15650

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15650.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for tracing functions which return `Awaitable`s.
37 changes: 26 additions & 11 deletions synapse/logging/opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Collection,
ContextManager,
Expand Down Expand Up @@ -903,6 +904,7 @@ def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) ->
"""

if inspect.iscoroutinefunction(func):
# For this branch, we handle async functions like `async def func() -> RInner`.
# In this branch, R = Awaitable[RInner], for some other type RInner
@wraps(func)
async def _wrapper(
Expand All @@ -914,36 +916,49 @@ async def _wrapper(
return await func(*args, **kwargs) # type: ignore[misc]

else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
# The other case here handles sync functions including those decorated with
# `@defer.inlineCallbacks` or that return a `Deferred` or other `Awaitable`.
@wraps(func)
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
scope = wrapping_logic(func, *args, **kwargs)
scope.__enter__()

try:
result = func(*args, **kwargs)

if isinstance(result, defer.Deferred):

def call_back(result: R) -> R:
scope.__exit__(None, None, None)
return result

def err_back(result: R) -> R:
# TODO: Pass the error details into `scope.__exit__(...)` for
# consistency with the other paths.
scope.__exit__(None, None, None)
return result

result.addCallbacks(call_back, err_back)

elif inspect.isawaitable(result):

async def wrap_awaitable() -> Any:
try:
assert isinstance(result, Awaitable)
awaited_result = await result
scope.__exit__(None, None, None)
return awaited_result
except Exception as e:
scope.__exit__(type(e), None, e.__traceback__)
raise

# The original method returned an awaitable, eg. a coroutine, so we
# create another awaitable wrapping it that calls
# `scope.__exit__(...)`.
return wrap_awaitable()
else:
if inspect.isawaitable(result):
logger.error(
"@trace may not have wrapped %s correctly! "
"The function is not async but returned a %s.",
func.__qualname__,
type(result).__name__,
)

# Just a simple sync function so we can just exit the scope and
# return the result without any fuss.
scope.__exit__(None, None, None)

return result
Expand Down
43 changes: 32 additions & 11 deletions tests/logging/test_opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast
from typing import Awaitable, cast

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactorClock
Expand Down Expand Up @@ -227,8 +227,6 @@ def test_trace_decorator_deferred(self) -> None:
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with functions that return deferreds
"""
reactor = MemoryReactorClock()

with LoggingContext("root context"):

@trace_with_opname("fixture_deferred_func", tracer=self._tracer)
Expand All @@ -240,9 +238,6 @@ def fixture_deferred_func() -> "defer.Deferred[str]":

result_d1 = fixture_deferred_func()

# let the tasks complete
reactor.pump((2,) * 8)

self.assertEqual(self.successResultOf(result_d1), "foo")

# the span should have been reported
Expand All @@ -256,8 +251,6 @@ def test_trace_decorator_async(self) -> None:
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with async functions
"""
reactor = MemoryReactorClock()

with LoggingContext("root context"):

@trace_with_opname("fixture_async_func", tracer=self._tracer)
Expand All @@ -267,13 +260,41 @@ async def fixture_async_func() -> str:

d1 = defer.ensureDeferred(fixture_async_func())

# let the tasks complete
reactor.pump((2,) * 8)

self.assertEqual(self.successResultOf(d1), "foo")

# the span should have been reported
self.assertEqual(
[span.operation_name for span in self._reporter.get_spans()],
["fixture_async_func"],
)

def test_trace_decorator_awaitable_return(self) -> None:
"""
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with functions that return an awaitable (e.g. a coroutine)
"""
with LoggingContext("root context"):
# Something we can return without `await` to get a coroutine
async def fixture_async_func() -> str:
return "foo"

# The actual kind of function we want to test that returns an awaitable
@trace_with_opname("fixture_awaitable_return_func", tracer=self._tracer)
@tag_args
def fixture_awaitable_return_func() -> Awaitable[str]:
return fixture_async_func()

# Something we can run with `defer.ensureDeferred(runner())` and pump the
# whole async tasks through to completion.
async def runner() -> str:
return await fixture_awaitable_return_func()

d1 = defer.ensureDeferred(runner())
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved

self.assertEqual(self.successResultOf(d1), "foo")

# the span should have been reported
self.assertEqual(
[span.operation_name for span in self._reporter.get_spans()],
["fixture_awaitable_return_func"],
)