From 636220b642cb6649e9616570d6f972f7cfd2a835 Mon Sep 17 00:00:00 2001 From: Dmytro Nosan Date: Wed, 29 May 2019 18:50:52 +0300 Subject: [PATCH] Update `RestTemplateBuilder` and `TestRestTemplate` to use a custom request factory to add authentication headers. Prior to this commit, the `RestTemplateBuilder` and `TestRestTemplate` used the `BasicAuthenticationInterceptor` interceptor to add headers. Unfortunately, adding any interceptor causes the entire message body to be read into a byte array. This causes an `OutOfMemoryError` whenever a large file is uploaded. Closes gh-15078 --- .../test/web/client/TestRestTemplate.java | 49 ++++++----- .../web/client/TestRestTemplateTests.java | 39 ++++----- .../boot/web/client/BasicAuthentication.java | 84 +++++++++++++++++++ ...uthenticationClientHttpRequestFactory.java | 67 +++++++++++++++ .../boot/web/client/RestTemplateBuilder.java | 56 +++++++++---- ...ticationClientHttpRequestFactoryTests.java | 81 ++++++++++++++++++ .../web/client/RestTemplateBuilderTests.java | 25 +++--- 7 files changed, 329 insertions(+), 72 deletions(-) create mode 100644 spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java create mode 100644 spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java create mode 100644 spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java index 4adea0980d2c..69a893fde2d5 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java @@ -19,11 +19,8 @@ import java.io.IOException; import java.lang.reflect.Field; import java.net.URI; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Supplier; @@ -41,6 +38,8 @@ import org.springframework.beans.BeanInstantiationException; import org.springframework.beans.BeanUtils; +import org.springframework.boot.web.client.BasicAuthentication; +import org.springframework.boot.web.client.BasicAuthenticationClientHttpRequestFactory; import org.springframework.boot.web.client.ClientHttpRequestFactorySupplier; import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.boot.web.client.RootUriTemplateHandler; @@ -50,12 +49,11 @@ import org.springframework.http.HttpMethod; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; +import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequestFactory; -import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.InterceptingClientHttpRequestFactory; -import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; import org.springframework.web.client.DefaultResponseErrorHandler; @@ -86,6 +84,7 @@ * @author Phillip Webb * @author Andy Wilkinson * @author Kristine Jetzke + * @author Dmytro Nosan * @since 1.4.0 */ public class TestRestTemplate { @@ -154,31 +153,37 @@ private TestRestTemplate(RestTemplate restTemplate, String username, String pass private Class getRequestFactoryClass( RestTemplate restTemplate) { + return getRequestFactory(restTemplate).getClass(); + } + + private ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) { ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); - if (InterceptingClientHttpRequestFactory.class - .isAssignableFrom(requestFactory.getClass())) { - Field requestFactoryField = ReflectionUtils.findField(RestTemplate.class, - "requestFactory"); - ReflectionUtils.makeAccessible(requestFactoryField); - requestFactory = (ClientHttpRequestFactory) ReflectionUtils - .getField(requestFactoryField, restTemplate); + while (requestFactory instanceof InterceptingClientHttpRequestFactory + || requestFactory instanceof BasicAuthenticationClientHttpRequestFactory) { + requestFactory = unwrapRequestFactoryIfNecessary(requestFactory); + } + return requestFactory; + } + + private ClientHttpRequestFactory unwrapRequestFactoryIfNecessary( + ClientHttpRequestFactory requestFactory) { + if (!(requestFactory instanceof AbstractClientHttpRequestFactoryWrapper)) { + return requestFactory; } - return requestFactory.getClass(); + Field field = ReflectionUtils.findField( + AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); + ReflectionUtils.makeAccessible(field); + return (ClientHttpRequestFactory) ReflectionUtils.getField(field, requestFactory); } private void addAuthentication(RestTemplate restTemplate, String username, String password) { - if (username == null) { + if (username == null || password == null) { return; } - List interceptors = restTemplate.getInterceptors(); - if (interceptors == null) { - interceptors = Collections.emptyList(); - } - interceptors = new ArrayList<>(interceptors); - interceptors.removeIf(BasicAuthenticationInterceptor.class::isInstance); - interceptors.add(new BasicAuthenticationInterceptor(username, password)); - restTemplate.setInterceptors(interceptors); + ClientHttpRequestFactory requestFactory = getRequestFactory(restTemplate); + restTemplate.setRequestFactory(new BasicAuthenticationClientHttpRequestFactory( + new BasicAuthentication(username, password), requestFactory)); } /** diff --git a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java index efc1efeece13..ec6d88f62d8f 100644 --- a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java +++ b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java @@ -20,13 +20,13 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.net.URI; -import java.util.List; import org.apache.http.client.config.RequestConfig; import org.junit.jupiter.api.Test; import org.springframework.boot.test.web.client.TestRestTemplate.CustomHttpComponentsClientHttpRequestFactory; import org.springframework.boot.test.web.client.TestRestTemplate.HttpClientOption; +import org.springframework.boot.web.client.BasicAuthenticationClientHttpRequestFactory; import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpEntity; @@ -35,12 +35,9 @@ import org.springframework.http.RequestEntity; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; -import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; -import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; -import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.mock.env.MockEnvironment; import org.springframework.mock.http.client.MockClientHttpRequest; import org.springframework.mock.http.client.MockClientHttpResponse; @@ -150,7 +147,7 @@ public void getRootUriRootUriNotSet() { public void authenticated() { assertThat(new TestRestTemplate("user", "password").getRestTemplate() .getRequestFactory()) - .isInstanceOf(InterceptingClientHttpRequestFactory.class); + .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); } @Test @@ -235,16 +232,15 @@ public void withBasicAuthAddsBasicAuthInterceptorWhenNotAlreadyPresent() { .containsExactlyElementsOf( originalTemplate.getRestTemplate().getMessageConverters()); assertThat(basicAuthTemplate.getRestTemplate().getRequestFactory()) - .isInstanceOf(InterceptingClientHttpRequestFactory.class); + .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); assertThat(ReflectionTestUtils.getField( basicAuthTemplate.getRestTemplate().getRequestFactory(), "requestFactory")) .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); assertThat(basicAuthTemplate.getRestTemplate().getUriTemplateHandler()) .isSameAs(originalTemplate.getRestTemplate().getUriTemplateHandler()); - assertThat(basicAuthTemplate.getRestTemplate().getInterceptors()).hasSize(1); - assertBasicAuthorizationInterceptorCredentials(basicAuthTemplate, "user", - "password"); + assertThat(basicAuthTemplate.getRestTemplate().getInterceptors()).isEmpty(); + assertBasicAuthorizationCredentials(basicAuthTemplate, "user", "password"); } @Test @@ -256,14 +252,14 @@ public void withBasicAuthReplacesBasicAuthInterceptorWhenAlreadyPresent() { .containsExactlyElementsOf( original.getRestTemplate().getMessageConverters()); assertThat(basicAuth.getRestTemplate().getRequestFactory()) - .isInstanceOf(InterceptingClientHttpRequestFactory.class); + .isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class); assertThat(ReflectionTestUtils.getField( basicAuth.getRestTemplate().getRequestFactory(), "requestFactory")) .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); assertThat(basicAuth.getRestTemplate().getUriTemplateHandler()) .isSameAs(original.getRestTemplate().getUriTemplateHandler()); - assertThat(basicAuth.getRestTemplate().getInterceptors()).hasSize(1); - assertBasicAuthorizationInterceptorCredentials(basicAuth, "user", "password"); + assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty(); + assertBasicAuthorizationCredentials(basicAuth, "user", "password"); } @Test @@ -394,17 +390,14 @@ private void verifyRelativeUriHandling(TestRestTemplateCallback callback) verify(requestFactory).createRequest(eq(absoluteUri), any(HttpMethod.class)); } - private void assertBasicAuthorizationInterceptorCredentials( - TestRestTemplate testRestTemplate, String username, String password) { - @SuppressWarnings("unchecked") - List requestFactoryInterceptors = (List) ReflectionTestUtils - .getField(testRestTemplate.getRestTemplate().getRequestFactory(), - "interceptors"); - assertThat(requestFactoryInterceptors).hasSize(1); - ClientHttpRequestInterceptor interceptor = requestFactoryInterceptors.get(0); - assertThat(interceptor).isInstanceOf(BasicAuthenticationInterceptor.class); - assertThat(interceptor).hasFieldOrPropertyWithValue("username", username); - assertThat(interceptor).hasFieldOrPropertyWithValue("password", password); + private void assertBasicAuthorizationCredentials(TestRestTemplate testRestTemplate, + String username, String password) { + ClientHttpRequestFactory requestFactory = testRestTemplate.getRestTemplate() + .getRequestFactory(); + Object authentication = ReflectionTestUtils.getField(requestFactory, + "authentication"); + assertThat(authentication).hasFieldOrPropertyWithValue("username", username); + assertThat(authentication).hasFieldOrPropertyWithValue("password", password); } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java new file mode 100644 index 000000000000..40c169cbc54c --- /dev/null +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthentication.java @@ -0,0 +1,84 @@ +/* + * Copyright 2012-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.nio.charset.Charset; + +import org.springframework.util.Assert; + +/** + * Basic authentication properties. + * + * @author Dmytro Nosan + * @since 2.2.0 + */ +public class BasicAuthentication { + + private final String username; + + private final String password; + + private final Charset charset; + + /** + * Create a new {@link BasicAuthentication}. + * @param username the username to use + * @param password the password to use + */ + public BasicAuthentication(String username, String password) { + this(username, password, null); + } + + /** + * Create a new {@link BasicAuthentication}. + * @param username the username to use + * @param password the password to use + * @param charset the charset to use + */ + public BasicAuthentication(String username, String password, Charset charset) { + Assert.notNull(username, "Username must not be null"); + Assert.notNull(password, "Password must not be null"); + this.username = username; + this.password = password; + this.charset = charset; + } + + /** + * The username to use. + * @return the username, never {@code null}. + */ + public String getUsername() { + return this.username; + } + + /** + * The password to use. + * @return the password, never {@code null}. + */ + public String getPassword() { + return this.password; + } + + /** + * The charset to use. + * @return the charset, or {@code null}. + */ + public Charset getCharset() { + return this.charset; + } + +} diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java new file mode 100644 index 000000000000..dffe789e01a0 --- /dev/null +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactory.java @@ -0,0 +1,67 @@ +/* + * Copyright 2012-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.io.IOException; +import java.net.URI; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.util.Assert; + +/** + * {@link ClientHttpRequestFactory} to apply a given HTTP Basic Authentication + * username/password pair, unless a custom Authorization header has been set before. + * + * @author Dmytro Nosan + * @since 2.2.0 + */ +public class BasicAuthenticationClientHttpRequestFactory + extends AbstractClientHttpRequestFactoryWrapper { + + private final BasicAuthentication authentication; + + /** + * Create a new {@link BasicAuthenticationClientHttpRequestFactory} which adds + * {@link HttpHeaders#AUTHORIZATION} header for the given authentication. + * @param authentication the authentication to use + * @param clientHttpRequestFactory the factory to use + */ + public BasicAuthenticationClientHttpRequestFactory(BasicAuthentication authentication, + ClientHttpRequestFactory clientHttpRequestFactory) { + super(clientHttpRequestFactory); + Assert.notNull(authentication, "Authentication must not be null"); + this.authentication = authentication; + } + + @Override + protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, + ClientHttpRequestFactory requestFactory) throws IOException { + BasicAuthentication authentication = this.authentication; + ClientHttpRequest request = requestFactory.createRequest(uri, httpMethod); + HttpHeaders headers = request.getHeaders(); + if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) { + headers.setBasicAuth(authentication.getUsername(), + authentication.getPassword(), authentication.getCharset()); + } + return request; + } + +} diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java index 4a5a09820e21..ffc10e69e069 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java @@ -33,6 +33,7 @@ import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.util.Assert; @@ -58,6 +59,7 @@ * @author Phillip Webb * @author Andy Wilkinson * @author Brian Clozel + * @author Dmytro Nosan * @since 1.4.0 */ public class RestTemplateBuilder { @@ -74,7 +76,7 @@ public class RestTemplateBuilder { private final ResponseErrorHandler errorHandler; - private final BasicAuthenticationInterceptor basicAuthentication; + private final BasicAuthentication basicAuthentication; private final Set restTemplateCustomizers; @@ -106,7 +108,7 @@ private RestTemplateBuilder(boolean detectRequestFactory, String rootUri, Set> messageConverters, Supplier requestFactorySupplier, UriTemplateHandler uriTemplateHandler, ResponseErrorHandler errorHandler, - BasicAuthenticationInterceptor basicAuthentication, + BasicAuthentication basicAuthentication, Set restTemplateCustomizers, RequestFactoryCustomizer requestFactoryCustomizer, Set interceptors) { @@ -379,10 +381,21 @@ public RestTemplateBuilder errorHandler(ResponseErrorHandler errorHandler) { * @since 2.1.0 */ public RestTemplateBuilder basicAuthentication(String username, String password) { + return basicAuthentication(new BasicAuthentication(username, password)); + } + + /** + * Add HTTP basic authentication to requests. See + * {@link BasicAuthenticationInterceptor} for details. + * @param basicAuthentication the authentication + * @return a new builder instance + * @since 2.2.0 + */ + public RestTemplateBuilder basicAuthentication( + BasicAuthentication basicAuthentication) { return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, this.messageConverters, this.requestFactorySupplier, - this.uriTemplateHandler, this.errorHandler, - new BasicAuthenticationInterceptor(username, password), + this.uriTemplateHandler, this.errorHandler, basicAuthentication, this.restTemplateCustomizers, this.requestFactoryCustomizer, this.interceptors); } @@ -534,7 +547,7 @@ public T configure(T restTemplate) { RootUriTemplateHandler.addTo(restTemplate, this.rootUri); } if (this.basicAuthentication != null) { - restTemplate.getInterceptors().add(this.basicAuthentication); + configureBasicAuthentication(restTemplate); } restTemplate.getInterceptors().addAll(this.interceptors); if (!CollectionUtils.isEmpty(this.restTemplateCustomizers)) { @@ -561,6 +574,27 @@ else if (this.detectRequestFactory) { } } + private void configureBasicAuthentication(RestTemplate restTemplate) { + ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); + while (requestFactory instanceof InterceptingClientHttpRequestFactory + || requestFactory instanceof BasicAuthenticationClientHttpRequestFactory) { + requestFactory = unwrapRequestFactory(requestFactory); + } + restTemplate.setRequestFactory(new BasicAuthenticationClientHttpRequestFactory( + this.basicAuthentication, requestFactory)); + } + + private static ClientHttpRequestFactory unwrapRequestFactory( + ClientHttpRequestFactory requestFactory) { + if (!(requestFactory instanceof AbstractClientHttpRequestFactoryWrapper)) { + return requestFactory; + } + Field field = ReflectionUtils.findField( + AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); + ReflectionUtils.makeAccessible(field); + return (ClientHttpRequestFactory) ReflectionUtils.getField(field, requestFactory); + } + private Set append(Set set, Collection additions) { Set result = new LinkedHashSet<>((set != null) ? set : Collections.emptySet()); result.addAll(additions); @@ -607,18 +641,10 @@ public void accept(ClientHttpRequestFactory requestFactory) { private ClientHttpRequestFactory unwrapRequestFactoryIfNecessary( ClientHttpRequestFactory requestFactory) { - if (!(requestFactory instanceof AbstractClientHttpRequestFactoryWrapper)) { - return requestFactory; - } ClientHttpRequestFactory unwrappedRequestFactory = requestFactory; - Field field = ReflectionUtils.findField( - AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); - ReflectionUtils.makeAccessible(field); - do { - unwrappedRequestFactory = (ClientHttpRequestFactory) ReflectionUtils - .getField(field, unwrappedRequestFactory); + while (unwrappedRequestFactory instanceof AbstractClientHttpRequestFactoryWrapper) { + unwrappedRequestFactory = unwrapRequestFactory(unwrappedRequestFactory); } - while (unwrappedRequestFactory instanceof AbstractClientHttpRequestFactoryWrapper); return unwrappedRequestFactory; } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java new file mode 100644 index 000000000000..ba7f10cf16ef --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/BasicAuthenticationClientHttpRequestFactoryTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2012-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.io.IOException; +import java.net.URI; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; + +/** + * Tests for {@link BasicAuthenticationClientHttpRequestFactory}. + * + * @author Dmytro Nosan + */ +public class BasicAuthenticationClientHttpRequestFactoryTests { + + private final HttpHeaders httpHeaders = new HttpHeaders(); + + private final BasicAuthentication authentication = new BasicAuthentication("spring", + "boot"); + + private ClientHttpRequestFactory requestFactory; + + @Before + public void setUp() throws IOException { + ClientHttpRequestFactory requestFactory = Mockito + .mock(ClientHttpRequestFactory.class); + ClientHttpRequest request = Mockito.mock(ClientHttpRequest.class); + Mockito.when(requestFactory.createRequest(any(), any())).thenReturn(request); + Mockito.when(request.getHeaders()).thenReturn(this.httpHeaders); + this.requestFactory = new BasicAuthenticationClientHttpRequestFactory( + this.authentication, requestFactory); + } + + @Test + public void shouldAddAuthorizationHeader() throws IOException { + ClientHttpRequest request = createRequest(); + assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)) + .containsExactly("Basic c3ByaW5nOmJvb3Q="); + } + + @Test + public void shouldNotAddAuthorizationHeaderAlreadyContainsHeader() + throws IOException { + this.httpHeaders.setBasicAuth("boot", "spring"); + ClientHttpRequest request = createRequest(); + assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)) + .doesNotContain("Basic c3ByaW5nOmJvb3Q="); + + } + + private ClientHttpRequest createRequest() throws IOException { + return this.requestFactory.createRequest(URI.create("http://localhost:8080"), + HttpMethod.POST); + } + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java index f2070c7f9a29..0dc4ac24eab6 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java @@ -16,6 +16,7 @@ package org.springframework.boot.web.client; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Collections; import java.util.Set; @@ -35,7 +36,6 @@ import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; -import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.ResourceHttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; @@ -324,12 +324,13 @@ public void errorHandlerShouldApply() { @Test public void basicAuthenticationShouldApply() { - RestTemplate template = this.builder.basicAuthentication("spring", "boot") + BasicAuthentication basicAuthentication = new BasicAuthentication("spring", + "boot", StandardCharsets.UTF_8); + RestTemplate template = this.builder.basicAuthentication(basicAuthentication) .build(); - ClientHttpRequestInterceptor interceptor = template.getInterceptors().get(0); - assertThat(interceptor).isInstanceOf(BasicAuthenticationInterceptor.class); - assertThat(interceptor).extracting("username").containsExactly("spring"); - assertThat(interceptor).extracting("password").containsExactly("boot"); + ClientHttpRequestFactory requestFactory = template.getRequestFactory(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("authentication", + basicAuthentication); } @Test @@ -406,19 +407,19 @@ public void customizerShouldBeAppliedAtTheEnd() { .messageConverters(this.messageConverter).rootUri("http://localhost:8080") .errorHandler(errorHandler).basicAuthentication("spring", "boot") .requestFactory(() -> requestFactory).customizers((restTemplate) -> { - assertThat(restTemplate.getInterceptors()).hasSize(2) - .contains(this.interceptor).anyMatch( - (ic) -> ic instanceof BasicAuthenticationInterceptor); + assertThat(restTemplate.getInterceptors()).hasSize(1); assertThat(restTemplate.getMessageConverters()) .contains(this.messageConverter); assertThat(restTemplate.getUriTemplateHandler()) .isInstanceOf(RootUriTemplateHandler.class); assertThat(restTemplate.getErrorHandler()).isEqualTo(errorHandler); - ClientHttpRequestFactory actualRequestFactory = restTemplate + ClientHttpRequestFactory interceptingRequestFactory = restTemplate .getRequestFactory(); - assertThat(actualRequestFactory) + assertThat(interceptingRequestFactory) .isInstanceOf(InterceptingClientHttpRequestFactory.class); - assertThat(actualRequestFactory).hasFieldOrPropertyWithValue( + Object basicAuthRequestFactory = ReflectionTestUtils + .getField(interceptingRequestFactory, "requestFactory"); + assertThat(basicAuthRequestFactory).hasFieldOrPropertyWithValue( "requestFactory", requestFactory); }).build(); }