Skip to content

Commit

Permalink
Support http read timeouts for transport-nio (elastic#41466)
Browse files Browse the repository at this point in the history
This is related to elastic#27260. Currently there is a setting
http.read_timeout that allows users to define a read timeout for the
http transport. This commit implements support for this functionality
with the transport-nio plugin. The behavior here is that a repeating
task will be scheduled for the interval defined. If there have been
no requests received since the last run and there are no inflight
requests, the channel will be closed.
  • Loading branch information
Tim-Brooks committed May 1, 2019
1 parent c86f797 commit 65edd78
Show file tree
Hide file tree
Showing 17 changed files with 263 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ public WriteOperation createWriteOperation(SocketChannelContext context, Object
return new FlushReadyWrite(context, (ByteBuffer[]) message, listener);
}

@Override
public void channelRegistered() {}

@Override
public List<FlushOperation> writeToBytes(WriteOperation writeOperation) {
assert writeOperation instanceof FlushReadyWrite : "Write operation must be flush ready";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
*/
public interface ReadWriteHandler {

/**
* This method is called when the channel is registered with its selector.
*/
void channelRegistered();

/**
* This method is called when a message is queued with a channel. It can be called from any thread.
* This method should validate that the message is a valid type and return a write operation object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ protected FlushOperation getPendingFlush() {
@Override
protected void register() throws IOException {
super.register();
readWriteHandler.channelRegistered();
if (allowChannelPredicate.test(channel) == false) {
closeNow = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public Runnable scheduleAtRelativeTime(Runnable task, long relativeNanos) {
return delayedTask;
}

Runnable pollTask(long relativeNanos) {
public Runnable pollTask(long relativeNanos) {
DelayedTask task;
while ((task = tasks.peek()) != null) {
if (relativeNanos - task.deadline >= 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.network.CloseableChannel;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Setting.Property;
Expand All @@ -59,6 +58,7 @@
import org.elasticsearch.http.AbstractHttpServerTransport;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpReadTimeoutException;
import org.elasticsearch.http.HttpServerChannel;
import org.elasticsearch.http.netty4.cors.Netty4CorsConfig;
import org.elasticsearch.http.netty4.cors.Netty4CorsConfigBuilder;
Expand Down Expand Up @@ -289,12 +289,9 @@ protected void stopInternal() {
}

@Override
protected void onException(HttpChannel channel, Exception cause) {
public void onException(HttpChannel channel, Exception cause) {
if (cause instanceof ReadTimeoutException) {
if (logger.isTraceEnabled()) {
logger.trace("Http read timeout {}", channel);
}
CloseableChannel.closeChannel(channel);
super.onException(channel, new HttpReadTimeoutException(readTimeoutMillis, cause));
} else {
super.onException(channel, cause);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.PatternSyntaxException;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -346,7 +346,7 @@ public void dispatchBadRequest(final RestRequest request,
transport.start();
final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses());

AtomicBoolean channelClosed = new AtomicBoolean(false);
CountDownLatch channelClosedLatch = new CountDownLatch(1);

Bootstrap clientBootstrap = new Bootstrap().channel(NioSocketChannel.class).handler(new ChannelInitializer<SocketChannel>() {

Expand All @@ -357,9 +357,9 @@ protected void initChannel(SocketChannel ch) {
}
}).group(group);
ChannelFuture connect = clientBootstrap.connect(remoteAddress.address());
connect.channel().closeFuture().addListener(future -> channelClosed.set(true));
connect.channel().closeFuture().addListener(future -> channelClosedLatch.countDown());

assertBusy(() -> assertTrue("Channel should be closed due to read timeout", channelClosed.get()), 5, TimeUnit.SECONDS);
assertTrue("Channel should be closed due to read timeout", channelClosedLatch.await(1, TimeUnit.MINUTES));

} finally {
group.shutdownGracefully().await();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,45 @@
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.http.HttpReadTimeoutException;
import org.elasticsearch.http.nio.cors.NioCorsConfig;
import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.ReadWriteHandler;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.TaskScheduler;
import org.elasticsearch.nio.WriteOperation;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.LongSupplier;

public class HttpReadWriteHandler implements ReadWriteHandler {

private final NettyAdaptor adaptor;
private final NioHttpChannel nioHttpChannel;
private final NioHttpServerTransport transport;
private final TaskScheduler taskScheduler;
private final LongSupplier nanoClock;
private final long readTimeoutNanos;
private boolean channelRegistered = false;
private boolean requestSinceReadTimeoutTrigger = false;
private int inFlightRequests = 0;

public HttpReadWriteHandler(NioHttpChannel nioHttpChannel, NioHttpServerTransport transport, HttpHandlingSettings settings,
NioCorsConfig corsConfig) {
NioCorsConfig corsConfig, TaskScheduler taskScheduler, LongSupplier nanoClock) {
this.nioHttpChannel = nioHttpChannel;
this.transport = transport;
this.taskScheduler = taskScheduler;
this.nanoClock = nanoClock;
this.readTimeoutNanos = TimeUnit.MILLISECONDS.toNanos(settings.getReadTimeoutMillis());

List<ChannelHandler> handlers = new ArrayList<>(5);
HttpRequestDecoder decoder = new HttpRequestDecoder(settings.getMaxInitialLineLength(), settings.getMaxHeaderSize(),
Expand All @@ -77,10 +91,21 @@ public HttpReadWriteHandler(NioHttpChannel nioHttpChannel, NioHttpServerTranspor
}

@Override
public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException {
public void channelRegistered() {
channelRegistered = true;
if (readTimeoutNanos > 0) {
scheduleReadTimeout();
}
}

@Override
public int consumeReads(InboundChannelBuffer channelBuffer) {
assert channelRegistered : "channelRegistered should have been called";
int bytesConsumed = adaptor.read(channelBuffer.sliceAndRetainPagesTo(channelBuffer.getIndex()));
Object message;
while ((message = adaptor.pollInboundMessage()) != null) {
++inFlightRequests;
requestSinceReadTimeoutTrigger = true;
handleRequest(message);
}

Expand All @@ -96,6 +121,11 @@ public WriteOperation createWriteOperation(SocketChannelContext context, Object

@Override
public List<FlushOperation> writeToBytes(WriteOperation writeOperation) {
assert writeOperation.getObject() instanceof NioHttpResponse : "This channel only supports messages that are of type: "
+ NioHttpResponse.class + ". Found type: " + writeOperation.getObject().getClass() + ".";
assert channelRegistered : "channelRegistered should have been called";
--inFlightRequests;
assert inFlightRequests >= 0 : "Inflight requests should never drop below zero, found: " + inFlightRequests;
adaptor.write(writeOperation);
return pollFlushOperations();
}
Expand Down Expand Up @@ -152,4 +182,17 @@ private void handleRequest(Object msg) {
request.release();
}
}

private void maybeReadTimeout() {
if (requestSinceReadTimeoutTrigger == false && inFlightRequests == 0) {
transport.onException(nioHttpChannel, new HttpReadTimeoutException(TimeValue.nsecToMSec(readTimeoutNanos)));
} else {
requestSinceReadTimeoutTrigger = false;
scheduleReadTimeout();
}
}

private void scheduleReadTimeout() {
taskScheduler.scheduleAtRelativeTime(this::maybeReadTimeout, nanoClock.getAsLong() + readTimeoutNanos);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel)
return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
};
HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this,
handlingSettings, corsConfig);
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis);
Consumer<Exception> exceptionHandler = (e) -> onException(httpChannel, e);
SocketChannelContext context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpReadWritePipeline,
new InboundChannelBuffer(pageSupplier));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;

import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpReadTimeoutException;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.http.HttpTransportSettings;
Expand All @@ -48,6 +49,7 @@
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.TaskScheduler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
Expand All @@ -56,26 +58,23 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiConsumer;

import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION_LEVEL;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_DETAILED_ERRORS_ENABLED;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CHUNK_SIZE;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_RESET_COOKIES;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_PIPELINING_MAX_EVENTS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_READ_TIMEOUT;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand All @@ -84,31 +83,24 @@
public class HttpReadWriteHandlerTests extends ESTestCase {

private HttpReadWriteHandler handler;
private NioHttpChannel nioHttpChannel;
private NioHttpChannel channel;
private NioHttpServerTransport transport;
private TaskScheduler taskScheduler;

private final RequestEncoder requestEncoder = new RequestEncoder();
private final ResponseDecoder responseDecoder = new ResponseDecoder();

@Before
public void setMocks() {
transport = mock(NioHttpServerTransport.class);
Settings settings = Settings.EMPTY;
ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.getDefault(settings);
ByteSizeValue maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.getDefault(settings);
ByteSizeValue maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.getDefault(settings);
HttpHandlingSettings httpHandlingSettings = new HttpHandlingSettings(1024,
Math.toIntExact(maxChunkSize.getBytes()),
Math.toIntExact(maxHeaderSize.getBytes()),
Math.toIntExact(maxInitialLineLength.getBytes()),
SETTING_HTTP_RESET_COOKIES.getDefault(settings),
SETTING_HTTP_COMPRESSION.getDefault(settings),
SETTING_HTTP_COMPRESSION_LEVEL.getDefault(settings),
SETTING_HTTP_DETAILED_ERRORS_ENABLED.getDefault(settings),
SETTING_PIPELINING_MAX_EVENTS.getDefault(settings),
SETTING_CORS_ENABLED.getDefault(settings));
nioHttpChannel = mock(NioHttpChannel.class);
handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, NioCorsConfigBuilder.forAnyOrigin().build());
Settings settings = Settings.builder().put(SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(), new ByteSizeValue(1024)).build();
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
channel = mock(NioHttpChannel.class);
taskScheduler = mock(TaskScheduler.class);

NioCorsConfig corsConfig = NioCorsConfigBuilder.forAnyOrigin().build();
handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, System::nanoTime);
handler.channelRegistered();
}

public void testSuccessfulDecodeHttpRequest() throws IOException {
Expand Down Expand Up @@ -188,7 +180,7 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t
flushOperation.getListener().accept(null, null);
// Since we have keep-alive set to false, we should close the channel after the response has been
// flushed
verify(nioHttpChannel).close();
verify(channel).close();
} finally {
response.release();
}
Expand Down Expand Up @@ -335,10 +327,59 @@ public void testThatAnyOriginWorks() throws IOException {
}
}

private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException {
@SuppressWarnings("unchecked")
public void testReadTimeout() throws IOException {
TimeValue timeValue = TimeValue.timeValueMillis(500);
Settings settings = Settings.builder().put(SETTING_HTTP_READ_TIMEOUT.getKey(), timeValue).build();
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig nioCorsConfig = NioHttpServerTransport.buildCorsConfig(settings);
HttpReadWriteHandler handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, nioCorsConfig);
DefaultFullHttpRequest nettyRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
NioHttpRequest nioHttpRequest = new NioHttpRequest(nettyRequest, 0);
NioHttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY);
httpResponse.addHeader(HttpHeaderNames.CONTENT_LENGTH.toString(), "0");

NioCorsConfig corsConfig = NioCorsConfigBuilder.forAnyOrigin().build();
TaskScheduler taskScheduler = new TaskScheduler();

Iterator<Integer> timeValues = Arrays.asList(0, 2, 4, 6, 8).iterator();
handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, timeValues::next);
handler.channelRegistered();

prepareHandlerForResponse(handler);
SocketChannelContext context = mock(SocketChannelContext.class);
HttpWriteOperation writeOperation = new HttpWriteOperation(context, httpResponse, mock(BiConsumer.class));
handler.writeToBytes(writeOperation);

taskScheduler.pollTask(timeValue.getNanos() + 1).run();
// There was a read. Do not close.
verify(transport, times(0)).onException(eq(channel), any(HttpReadTimeoutException.class));

prepareHandlerForResponse(handler);
prepareHandlerForResponse(handler);

taskScheduler.pollTask(timeValue.getNanos() + 3).run();
// There was a read. Do not close.
verify(transport, times(0)).onException(eq(channel), any(HttpReadTimeoutException.class));

handler.writeToBytes(writeOperation);

taskScheduler.pollTask(timeValue.getNanos() + 5).run();
// There has not been a read, however there is still an inflight request. Do not close.
verify(transport, times(0)).onException(eq(channel), any(HttpReadTimeoutException.class));

handler.writeToBytes(writeOperation);

taskScheduler.pollTask(timeValue.getNanos() + 7).run();
// No reads and no inflight requests, close
verify(transport, times(1)).onException(eq(channel), any(HttpReadTimeoutException.class));
assertNull(taskScheduler.pollTask(timeValue.getNanos() + 9));
}

private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException {
HttpHandlingSettings httpSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings);
HttpReadWriteHandler handler = new HttpReadWriteHandler(channel, transport, httpSettings, corsConfig, taskScheduler,
System::nanoTime);
handler.channelRegistered();
prepareHandlerForResponse(handler);
DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {
Expand All @@ -360,7 +401,7 @@ private FullHttpResponse executeCorsRequest(final Settings settings, final Strin



private NioHttpRequest prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOException {
private void prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOException {
HttpMethod method = randomBoolean() ? HttpMethod.GET : HttpMethod.HEAD;
HttpVersion version = randomBoolean() ? HttpVersion.HTTP_1_0 : HttpVersion.HTTP_1_1;
String uri = "http://localhost:9090/" + randomAlphaOfLength(8);
Expand All @@ -385,7 +426,6 @@ private NioHttpRequest prepareHandlerForResponse(HttpReadWriteHandler handler) t
assertEquals(HttpRequest.HttpVersion.HTTP_1_0, nioHttpRequest.protocolVersion());
}
assertEquals(nioHttpRequest.uri(), uri);
return nioHttpRequest;
}

private InboundChannelBuffer toChannelBuffer(ByteBuf buf) {
Expand Down
Loading

0 comments on commit 65edd78

Please sign in to comment.