diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java index adf82005d06..44f149b4b4e 100644 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java @@ -36,6 +36,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.TimeUnit; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.lang.Nullable; @@ -94,38 +95,25 @@ public URI getURI() { @Override @SuppressWarnings("NullAway") protected ClientHttpResponse executeInternal(HttpHeaders headers, @Nullable Body body) throws IOException { - HttpRequest request = buildRequest(headers, body); - CompletableFuture> responsefuture = - this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream()); + CompletableFuture> responseFuture = null; try { + HttpRequest request = buildRequest(headers, body); + responseFuture = this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream()); + if (this.timeout != null) { - CompletableFuture timeoutFuture = new CompletableFuture() - .completeOnTimeout(null, this.timeout.toMillis(), TimeUnit.MILLISECONDS); - timeoutFuture.thenRun(() -> { - if (!responsefuture.cancel(true) && !responsefuture.isCompletedExceptionally()) { - try { - responsefuture.resultNow().body().close(); - } catch (IOException ignored) {} - } - }); - var response = responsefuture.get(); - return new JdkClientHttpResponse(response.statusCode(), response.headers(), new FilterInputStream(response.body()) { - - @Override - public void close() throws IOException { - timeoutFuture.cancel(false); - super.close(); - } - }); - - } else { - var response = responsefuture.get(); - return new JdkClientHttpResponse(response.statusCode(), response.headers(), response.body()); + TimeoutHandler timeoutHandler = new TimeoutHandler(responseFuture, this.timeout); + HttpResponse response = responseFuture.get(); + InputStream inputStream = timeoutHandler.wrapInputStream(response); + return new JdkClientHttpResponse(response, inputStream); + } + else { + HttpResponse response = responseFuture.get(); + return new JdkClientHttpResponse(response, response.body()); } } catch (InterruptedException ex) { Thread.currentThread().interrupt(); - responsefuture.cancel(true); + responseFuture.cancel(true); throw new IOException("Request was interrupted: " + ex.getMessage(), ex); } catch (ExecutionException ex) { @@ -149,7 +137,6 @@ else if (cause instanceof IOException ioEx) { } } - private HttpRequest buildRequest(HttpHeaders headers, @Nullable Body body) { HttpRequest.Builder builder = HttpRequest.newBuilder().uri(this.uri); @@ -225,4 +212,52 @@ public ByteBuffer map(byte[] b, int off, int len) { } } + + /** + * Temporary workaround to use instead of {@link HttpRequest.Builder#timeout(Duration)} + * until JDK-8258397 + * is fixed. Essentially, create a future wiht a timeout handler, and use it + * to close the response. + * @see OpenJDK discussion thread + */ + private static final class TimeoutHandler { + + private final CompletableFuture timeoutFuture; + + private TimeoutHandler(CompletableFuture> future, Duration timeout) { + + this.timeoutFuture = new CompletableFuture() + .completeOnTimeout(null, timeout.toMillis(), TimeUnit.MILLISECONDS); + + this.timeoutFuture.thenRun(() -> { + if (future.cancel(true) || future.isCompletedExceptionally() || !future.isDone()) { + return; + } + try { + future.get().body().close(); + } + catch (Exception ex) { + // ignore + } + }); + + } + + @Nullable + public InputStream wrapInputStream(HttpResponse response) { + InputStream body = response.body(); + if (body == null) { + return body; + } + return new FilterInputStream(body) { + + @Override + public void close() throws IOException { + TimeoutHandler.this.timeoutFuture.cancel(false); + super.close(); + } + }; + } + } + } diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java index a26b5531d9b..6fdeefa0d6c 100644 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.io.InputStream; import java.net.http.HttpClient; +import java.net.http.HttpResponse; import java.util.List; import java.util.Locale; import java.util.Map; @@ -26,6 +27,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatusCode; +import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.MultiValueMap; @@ -40,21 +42,21 @@ */ class JdkClientHttpResponse implements ClientHttpResponse { - private final int statusCode; + private final HttpResponse response; private final HttpHeaders headers; private final InputStream body; - public JdkClientHttpResponse(int statusCode, java.net.http.HttpHeaders headers, InputStream body) { - this.statusCode = statusCode; - this.headers = adaptHeaders(headers); - this.body = body != null ? body : InputStream.nullInputStream(); + public JdkClientHttpResponse(HttpResponse response, @Nullable InputStream body) { + this.response = response; + this.headers = adaptHeaders(response); + this.body = (body != null ? body : InputStream.nullInputStream()); } - private static HttpHeaders adaptHeaders(java.net.http.HttpHeaders headers) { - Map> rawHeaders = headers.map(); + private static HttpHeaders adaptHeaders(HttpResponse response) { + Map> rawHeaders = response.headers().map(); Map> map = new LinkedCaseInsensitiveMap<>(rawHeaders.size(), Locale.ENGLISH); MultiValueMap multiValueMap = CollectionUtils.toMultiValueMap(map); multiValueMap.putAll(rawHeaders); @@ -64,7 +66,7 @@ private static HttpHeaders adaptHeaders(java.net.http.HttpHeaders headers) { @Override public HttpStatusCode getStatusCode() { - return HttpStatusCode.valueOf(statusCode); + return HttpStatusCode.valueOf(this.response.statusCode()); } @Override