Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle async requests in srping mvc library instrumentation #10868

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,16 @@
import io.opentelemetry.instrumentation.api.instrumenter.Instrumenter;
import io.opentelemetry.instrumentation.api.semconv.http.HttpServerRoute;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.servlet.AsyncContext;
import javax.servlet.AsyncEvent;
import javax.servlet.AsyncListener;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.springframework.core.Ordered;
import org.springframework.web.filter.OncePerRequestFilter;
Expand Down Expand Up @@ -53,17 +60,21 @@ public void doFilterInternal(
}

Context context = instrumenter.start(parentContext, request);
AsyncAwareHttpServletRequest asyncAwareRequest =
new AsyncAwareHttpServletRequest(request, response, context);
Throwable error = null;
try (Scope ignored = context.makeCurrent()) {
filterChain.doFilter(request, response);
filterChain.doFilter(asyncAwareRequest, response);
} catch (Throwable t) {
error = t;
throw t;
} finally {
if (httpRouteSupport.hasMappings()) {
HttpServerRoute.update(context, CONTROLLER, httpRouteSupport::getHttpRoute, request);
}
instrumenter.end(context, request, response, error);
if (error != null || asyncAwareRequest.isNotAsync()) {
instrumenter.end(context, request, response, error);
}
}
}

Expand All @@ -75,4 +86,88 @@ public int getOrder() {
// Run after all HIGHEST_PRECEDENCE items
return Ordered.HIGHEST_PRECEDENCE + 1;
}

private class AsyncAwareHttpServletRequest extends HttpServletRequestWrapper {
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Context context;
private final AtomicBoolean listenerAttached = new AtomicBoolean();

AsyncAwareHttpServletRequest(
HttpServletRequest request, HttpServletResponse response, Context context) {
super(request);
this.request = request;
this.response = response;
this.context = context;
}

@Override
public AsyncContext startAsync() {
AsyncContext asyncContext = super.startAsync();
attachListener(asyncContext);
return asyncContext;
}

@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) {
AsyncContext asyncContext = super.startAsync(servletRequest, servletResponse);
attachListener(asyncContext);
return asyncContext;
}

private void attachListener(AsyncContext asyncContext) {
if (!listenerAttached.compareAndSet(false, true)) {
return;
}

asyncContext.addListener(
new AsyncRequestCompletionListener(request, response, context), request, response);
}

boolean isNotAsync() {
return !listenerAttached.get();
}
}

private class AsyncRequestCompletionListener implements AsyncListener {
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Context context;
private final AtomicBoolean responseHandled = new AtomicBoolean();

AsyncRequestCompletionListener(
HttpServletRequest request, HttpServletResponse response, Context context) {
this.request = request;
this.response = response;
this.context = context;
}

@Override
public void onComplete(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, null);
}
}

@Override
public void onTimeout(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, null);
}
}

@Override
public void onError(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, asyncEvent.getThrowable());
}
}

@Override
public void onStartAsync(AsyncEvent asyncEvent) {
asyncEvent
.getAsyncContext()
.addListener(this, asyncEvent.getSuppliedRequest(), asyncEvent.getSuppliedResponse());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
import static java.util.Collections.singletonList;

import io.opentelemetry.api.GlobalOpenTelemetry;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.testing.junit.http.AbstractHttpServerTest;
import io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import javax.servlet.Filter;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
Expand All @@ -37,6 +41,7 @@

@SpringBootApplication
class TestWebSpringBootApp {
static final ServerEndpoint ASYNC_ENDPOINT = new ServerEndpoint("ASYNC", "async", 200, "success");

static ConfigurableApplicationContext start(int port, String contextPath) {
Properties props = new Properties();
Expand Down Expand Up @@ -122,6 +127,26 @@ String indexed_child(@RequestParam("id") String id) {
});
}

@RequestMapping("/async")
@ResponseBody
CompletableFuture<String> async() {
Context context = Context.current();
return CompletableFuture.supplyAsync(
() -> {
// Sleep a bit so that the future completes after the controller method. This helps to
// verify whether request ends after the future has completed not after when the
// controller method has completed.
try {
Thread.sleep(100);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
try (Scope ignored = context.makeCurrent()) {
return controller(ASYNC_ENDPOINT, ASYNC_ENDPOINT::getBody);
}
});
}

@ExceptionHandler
ResponseEntity<String> handleException(Throwable throwable) {
return new ResponseEntity<>(throwable.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@

package io.opentelemetry.instrumentation.spring.webmvc.v5_3;

import static org.assertj.core.api.Assertions.assertThat;

import io.opentelemetry.instrumentation.api.internal.HttpConstants;
import io.opentelemetry.instrumentation.testing.junit.InstrumentationExtension;
import io.opentelemetry.instrumentation.testing.junit.http.AbstractHttpServerTest;
import io.opentelemetry.instrumentation.testing.junit.http.HttpServerInstrumentationExtension;
import io.opentelemetry.instrumentation.testing.junit.http.HttpServerTestOptions;
import io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint;
import io.opentelemetry.testing.internal.armeria.common.AggregatedHttpRequest;
import io.opentelemetry.testing.internal.armeria.common.AggregatedHttpResponse;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.springframework.context.ConfigurableApplicationContext;

Expand Down Expand Up @@ -49,4 +54,18 @@ protected void configure(HttpServerTestOptions options) {
return expectedHttpRoute(endpoint, method);
});
}

@Test
void asyncRequest() {
ServerEndpoint endpoint = TestWebSpringBootApp.ASYNC_ENDPOINT;
String method = "GET";
AggregatedHttpRequest request = request(endpoint, method);
AggregatedHttpResponse response = client.execute(request).aggregate().join();

assertThat(response.status().code()).isEqualTo(endpoint.getStatus());
assertThat(response.contentUtf8()).isEqualTo(endpoint.getBody());

String spanId = assertResponseHasCustomizedHeaders(response, endpoint, null);
assertTheTraces(1, null, null, spanId, method, endpoint);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.api.instrumenter.Instrumenter;
import io.opentelemetry.instrumentation.api.semconv.http.HttpServerRoute;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncEvent;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.springframework.core.Ordered;
import org.springframework.web.filter.OncePerRequestFilter;

Expand Down Expand Up @@ -53,17 +60,21 @@ public void doFilterInternal(
}

Context context = instrumenter.start(parentContext, request);
AsyncAwareHttpServletRequest asyncAwareRequest =
new AsyncAwareHttpServletRequest(request, response, context);
Throwable error = null;
try (Scope ignored = context.makeCurrent()) {
filterChain.doFilter(request, response);
filterChain.doFilter(asyncAwareRequest, response);
} catch (Throwable t) {
error = t;
throw t;
} finally {
if (httpRouteSupport.hasMappings()) {
HttpServerRoute.update(context, CONTROLLER, httpRouteSupport::getHttpRoute, request);
}
instrumenter.end(context, request, response, error);
if (error != null || asyncAwareRequest.isNotAsync()) {
instrumenter.end(context, request, response, error);
}
}
}

Expand All @@ -75,4 +86,88 @@ public int getOrder() {
// Run after all HIGHEST_PRECEDENCE items
return Ordered.HIGHEST_PRECEDENCE + 1;
}

private class AsyncAwareHttpServletRequest extends HttpServletRequestWrapper {
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Context context;
private final AtomicBoolean listenerAttached = new AtomicBoolean();

AsyncAwareHttpServletRequest(
HttpServletRequest request, HttpServletResponse response, Context context) {
super(request);
this.request = request;
this.response = response;
this.context = context;
}

@Override
public AsyncContext startAsync() {
AsyncContext asyncContext = super.startAsync();
attachListener(asyncContext);
return asyncContext;
}

@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) {
AsyncContext asyncContext = super.startAsync(servletRequest, servletResponse);
attachListener(asyncContext);
return asyncContext;
}

private void attachListener(AsyncContext asyncContext) {
if (!listenerAttached.compareAndSet(false, true)) {
return;
}

asyncContext.addListener(
new AsyncRequestCompletionListener(request, response, context), request, response);
}

boolean isNotAsync() {
return !listenerAttached.get();
}
}

private class AsyncRequestCompletionListener implements AsyncListener {
private final HttpServletRequest request;
private final HttpServletResponse response;
private final Context context;
private final AtomicBoolean responseHandled = new AtomicBoolean();

AsyncRequestCompletionListener(
HttpServletRequest request, HttpServletResponse response, Context context) {
this.request = request;
this.response = response;
this.context = context;
}

@Override
public void onComplete(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, null);
}
}

@Override
public void onTimeout(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, null);
}
}

@Override
public void onError(AsyncEvent asyncEvent) {
if (responseHandled.compareAndSet(false, true)) {
instrumenter.end(context, request, response, asyncEvent.getThrowable());
}
}

@Override
public void onStartAsync(AsyncEvent asyncEvent) {
asyncEvent
.getAsyncContext()
.addListener(this, asyncEvent.getSuppliedRequest(), asyncEvent.getSuppliedResponse());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
import static java.util.Collections.singletonList;

import io.opentelemetry.api.GlobalOpenTelemetry;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.testing.junit.http.AbstractHttpServerTest;
import io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint;
import jakarta.servlet.Filter;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.ConfigurableApplicationContext;
Expand All @@ -37,6 +41,7 @@

@SpringBootApplication
class TestWebSpringBootApp {
static final ServerEndpoint ASYNC_ENDPOINT = new ServerEndpoint("ASYNC", "async", 200, "success");

static ConfigurableApplicationContext start(int port, String contextPath) {
Properties props = new Properties();
Expand Down Expand Up @@ -122,6 +127,26 @@ String indexed_child(@RequestParam("id") String id) {
});
}

@RequestMapping("/async")
@ResponseBody
CompletableFuture<String> async() {
Context context = Context.current();
return CompletableFuture.supplyAsync(
() -> {
// Sleep a bit so that the future completes after the controller method. This helps to
// verify whether request ends after the future has completed not after when the
// controller method has completed.
try {
Thread.sleep(100);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
try (Scope ignored = context.makeCurrent()) {
return controller(ASYNC_ENDPOINT, ASYNC_ENDPOINT::getBody);
}
});
}

@ExceptionHandler
ResponseEntity<String> handleException(Throwable throwable) {
return new ResponseEntity<>(throwable.getMessage(), HttpStatus.INTERNAL_SERVER_ERROR);
Expand Down
Loading
Loading