diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java index 000e871e92781..ef1e188a22e0a 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java @@ -20,25 +20,13 @@ package org.elasticsearch.nio; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; -import java.util.LinkedList; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BiConsumer; import java.util.function.Consumer; public class BytesChannelContext extends SocketChannelContext { - private final ReadConsumer readConsumer; - private final InboundChannelBuffer channelBuffer; - private final LinkedList queued = new LinkedList<>(); - private final AtomicBoolean isClosing = new AtomicBoolean(false); - public BytesChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler, - ReadConsumer readConsumer, InboundChannelBuffer channelBuffer) { - super(channel, selector, exceptionHandler); - this.readConsumer = readConsumer; - this.channelBuffer = channelBuffer; + ReadWriteHandler handler, InboundChannelBuffer channelBuffer) { + super(channel, selector, exceptionHandler, handler, channelBuffer); } @Override @@ -56,55 +44,30 @@ public int read() throws IOException { channelBuffer.incrementIndex(bytesRead); - int bytesConsumed = Integer.MAX_VALUE; - while (bytesConsumed > 0 && channelBuffer.getIndex() > 0) { - bytesConsumed = readConsumer.consumeReads(channelBuffer); - channelBuffer.release(bytesConsumed); - } + handleReadBytes(); return bytesRead; } - @Override - public void sendMessage(ByteBuffer[] buffers, BiConsumer listener) { - if (isClosing.get()) { - listener.accept(null, new ClosedChannelException()); - return; - } - - BytesWriteOperation writeOperation = new BytesWriteOperation(this, buffers, listener); - SocketSelector selector = getSelector(); - if (selector.isOnCurrentThread() == false) { - selector.queueWrite(writeOperation); - return; - } - - selector.queueWriteInChannelBuffer(writeOperation); - } - - @Override - public void queueWriteOperation(WriteOperation writeOperation) { - getSelector().assertOnSelectorThread(); - queued.add((BytesWriteOperation) writeOperation); - } - @Override public void flushChannel() throws IOException { getSelector().assertOnSelectorThread(); - int ops = queued.size(); - if (ops == 1) { - singleFlush(queued.pop()); - } else if (ops > 1) { - multiFlush(); + boolean lastOpCompleted = true; + FlushOperation flushOperation; + while (lastOpCompleted && (flushOperation = getPendingFlush()) != null) { + try { + if (singleFlush(flushOperation)) { + currentFlushOperationComplete(); + } else { + lastOpCompleted = false; + } + } catch (IOException e) { + currentFlushOperationFailed(e); + throw e; + } } } - @Override - public boolean hasQueuedWriteOps() { - getSelector().assertOnSelectorThread(); - return queued.isEmpty() == false; - } - @Override public void closeChannel() { if (isClosing.compareAndSet(false, true)) { @@ -117,51 +80,12 @@ public boolean selectorShouldClose() { return isPeerClosed() || hasIOException() || isClosing.get(); } - @Override - public void closeFromSelector() throws IOException { - getSelector().assertOnSelectorThread(); - if (channel.isOpen()) { - IOException channelCloseException = null; - try { - super.closeFromSelector(); - } catch (IOException e) { - channelCloseException = e; - } - // Set to true in order to reject new writes before queuing with selector - isClosing.set(true); - channelBuffer.close(); - for (BytesWriteOperation op : queued) { - getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); - } - queued.clear(); - if (channelCloseException != null) { - throw channelCloseException; - } - } - } - - private void singleFlush(BytesWriteOperation headOp) throws IOException { - try { - int written = flushToChannel(headOp.getBuffersToWrite()); - headOp.incrementIndex(written); - } catch (IOException e) { - getSelector().executeFailedListener(headOp.getListener(), e); - throw e; - } - - if (headOp.isFullyFlushed()) { - getSelector().executeListener(headOp.getListener(), null); - } else { - queued.push(headOp); - } - } - - private void multiFlush() throws IOException { - boolean lastOpCompleted = true; - while (lastOpCompleted && queued.isEmpty() == false) { - BytesWriteOperation op = queued.pop(); - singleFlush(op); - lastOpCompleted = op.isFullyFlushed(); - } + /** + * Returns a boolean indicating if the operation was fully flushed. + */ + private boolean singleFlush(FlushOperation flushOperation) throws IOException { + int written = flushToChannel(flushOperation.getBuffersToWrite()); + flushOperation.incrementIndex(written); + return flushOperation.isFullyFlushed(); } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java new file mode 100644 index 0000000000000..ba379e2873210 --- /dev/null +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import java.util.function.BiConsumer; + +public abstract class BytesWriteHandler implements ReadWriteHandler { + + private static final List EMPTY_LIST = Collections.emptyList(); + + public WriteOperation createWriteOperation(SocketChannelContext context, Object message, BiConsumer listener) { + assert message instanceof ByteBuffer[] : "This channel only supports messages that are of type: " + ByteBuffer[].class + + ". Found type: " + message.getClass() + "."; + return new FlushReadyWrite(context, (ByteBuffer[]) message, listener); + } + + public List writeToBytes(WriteOperation writeOperation) { + assert writeOperation instanceof FlushReadyWrite : "Write operation must be flush ready"; + return Collections.singletonList((FlushReadyWrite) writeOperation); + } + + public List pollFlushOperations() { + return EMPTY_LIST; + } + + public void close() {} +} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/FlushOperation.java similarity index 86% rename from libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java rename to libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/FlushOperation.java index 37c6e49727634..3102c972a6795 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/FlushOperation.java @@ -23,17 +23,15 @@ import java.util.Arrays; import java.util.function.BiConsumer; -public class BytesWriteOperation implements WriteOperation { +public class FlushOperation { - private final SocketChannelContext channelContext; private final BiConsumer listener; private final ByteBuffer[] buffers; private final int[] offsets; private final int length; private int internalIndex; - public BytesWriteOperation(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer listener) { - this.channelContext = channelContext; + public FlushOperation(ByteBuffer[] buffers, BiConsumer listener) { this.listener = listener; this.buffers = buffers; this.offsets = new int[buffers.length]; @@ -46,16 +44,10 @@ public BytesWriteOperation(SocketChannelContext channelContext, ByteBuffer[] buf length = offset; } - @Override public BiConsumer getListener() { return listener; } - @Override - public SocketChannelContext getChannel() { - return channelContext; - } - public boolean isFullyFlushed() { assert length >= internalIndex : "Should never have an index that is greater than the length [length=" + length + ", index=" + internalIndex + "]"; @@ -84,5 +76,4 @@ public ByteBuffer[] getBuffersToWrite() { return postIndexBuffers; } - } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/FlushReadyWrite.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/FlushReadyWrite.java new file mode 100644 index 0000000000000..65bc8f17aaf4b --- /dev/null +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/FlushReadyWrite.java @@ -0,0 +1,45 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import java.nio.ByteBuffer; +import java.util.function.BiConsumer; + +public class FlushReadyWrite extends FlushOperation implements WriteOperation { + + private final SocketChannelContext channelContext; + private final ByteBuffer[] buffers; + + FlushReadyWrite(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer listener) { + super(buffers, listener); + this.channelContext = channelContext; + this.buffers = buffers; + } + + @Override + public SocketChannelContext getChannel() { + return channelContext; + } + + @Override + public ByteBuffer[] getObject() { + return buffers; + } +} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ReadWriteHandler.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ReadWriteHandler.java new file mode 100644 index 0000000000000..f0637ea265280 --- /dev/null +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ReadWriteHandler.java @@ -0,0 +1,71 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import java.io.IOException; +import java.util.List; +import java.util.function.BiConsumer; + +/** + * Implements the application specific logic for handling inbound and outbound messages for a channel. + */ +public interface ReadWriteHandler { + + /** + * This method is called when a message is queued with a channel. It can be called from any thread. + * This method should validate that the message is a valid type and return a write operation object + * to be queued with the channel + * + * @param context the channel context + * @param message the message + * @param listener the listener to be called when the message is sent + * @return the write operation to be queued + */ + WriteOperation createWriteOperation(SocketChannelContext context, Object message, BiConsumer listener); + + /** + * This method is called on the event loop thread. It should serialize a write operation object to bytes + * that can be flushed to the raw nio channel. + * + * @param writeOperation to be converted to bytes + * @return the operations to flush the bytes to the channel + */ + List writeToBytes(WriteOperation writeOperation); + + /** + * Returns any flush operations that are ready to flush. This exists as a way to check if any flush + * operations were produced during a read call. + * + * @return flush operations + */ + List pollFlushOperations(); + + /** + * This method handles bytes that have been read from the network. It should return the number of bytes + * consumed so that they can be released. + * + * @param channelBuffer of bytes read from the network + * @return the number of bytes consumed + * @throws IOException if an exception occurs + */ + int consumeReads(InboundChannelBuffer channelBuffer) throws IOException; + + void close() throws IOException; +} diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java index 3bf47a98e0267..f2d299a9d328a 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java @@ -19,10 +19,16 @@ package org.elasticsearch.nio; +import org.elasticsearch.nio.utils.ExceptionsHelper; + import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; import java.nio.channels.SocketChannel; +import java.util.ArrayList; +import java.util.LinkedList; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -33,21 +39,28 @@ * close behavior is required, it should be implemented in this context. * * The only methods of the context that should ever be called from a non-selector thread are - * {@link #closeChannel()} and {@link #sendMessage(ByteBuffer[], BiConsumer)}. + * {@link #closeChannel()} and {@link #sendMessage(Object, BiConsumer)}. */ public abstract class SocketChannelContext extends ChannelContext { protected final NioSocketChannel channel; + protected final InboundChannelBuffer channelBuffer; + protected final AtomicBoolean isClosing = new AtomicBoolean(false); + private final ReadWriteHandler readWriteHandler; private final SocketSelector selector; private final CompletableFuture connectContext = new CompletableFuture<>(); + private final LinkedList pendingFlushes = new LinkedList<>(); private boolean ioException; private boolean peerClosed; private Exception connectException; - protected SocketChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler) { + protected SocketChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler, + ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) { super(channel.getRawChannel(), exceptionHandler); this.selector = selector; this.channel = channel; + this.readWriteHandler = readWriteHandler; + this.channelBuffer = channelBuffer; } @Override @@ -108,15 +121,94 @@ public boolean connect() throws IOException { return isConnected; } - public abstract int read() throws IOException; + public void sendMessage(Object message, BiConsumer listener) { + if (isClosing.get()) { + listener.accept(null, new ClosedChannelException()); + return; + } - public abstract void sendMessage(ByteBuffer[] buffers, BiConsumer listener); + WriteOperation writeOperation = readWriteHandler.createWriteOperation(this, message, listener); + + SocketSelector selector = getSelector(); + if (selector.isOnCurrentThread() == false) { + selector.queueWrite(writeOperation); + return; + } + + selector.queueWriteInChannelBuffer(writeOperation); + } + + public void queueWriteOperation(WriteOperation writeOperation) { + getSelector().assertOnSelectorThread(); + pendingFlushes.addAll(readWriteHandler.writeToBytes(writeOperation)); + } - public abstract void queueWriteOperation(WriteOperation writeOperation); + public abstract int read() throws IOException; public abstract void flushChannel() throws IOException; - public abstract boolean hasQueuedWriteOps(); + protected void currentFlushOperationFailed(IOException e) { + FlushOperation flushOperation = pendingFlushes.pollFirst(); + getSelector().executeFailedListener(flushOperation.getListener(), e); + } + + protected void currentFlushOperationComplete() { + FlushOperation flushOperation = pendingFlushes.pollFirst(); + getSelector().executeListener(flushOperation.getListener(), null); + } + + protected FlushOperation getPendingFlush() { + return pendingFlushes.peekFirst(); + } + + @Override + public void closeFromSelector() throws IOException { + getSelector().assertOnSelectorThread(); + if (channel.isOpen()) { + ArrayList closingExceptions = new ArrayList<>(3); + try { + super.closeFromSelector(); + } catch (IOException e) { + closingExceptions.add(e); + } + // Set to true in order to reject new writes before queuing with selector + isClosing.set(true); + + // Poll for new flush operations to close + pendingFlushes.addAll(readWriteHandler.pollFlushOperations()); + FlushOperation flushOperation; + while ((flushOperation = pendingFlushes.pollFirst()) != null) { + selector.executeFailedListener(flushOperation.getListener(), new ClosedChannelException()); + } + + try { + readWriteHandler.close(); + } catch (IOException e) { + closingExceptions.add(e); + } + channelBuffer.close(); + + if (closingExceptions.isEmpty() == false) { + ExceptionsHelper.rethrowAndSuppress(closingExceptions); + } + } + } + + protected void handleReadBytes() throws IOException { + int bytesConsumed = Integer.MAX_VALUE; + while (bytesConsumed > 0 && channelBuffer.getIndex() > 0) { + bytesConsumed = readWriteHandler.consumeReads(channelBuffer); + channelBuffer.release(bytesConsumed); + } + + // Some protocols might produce messages to flush during a read operation. + pendingFlushes.addAll(readWriteHandler.pollFlushOperations()); + } + + public boolean readyForFlush() { + getSelector().assertOnSelectorThread(); + return pendingFlushes.isEmpty() == false; + } /** * This method indicates if a selector should close this channel. @@ -178,9 +270,4 @@ protected int flushToChannel(ByteBuffer[] buffers) throws IOException { throw e; } } - - @FunctionalInterface - public interface ReadConsumer { - int consumeReads(InboundChannelBuffer channelBuffer) throws IOException; - } } diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java index b1f738647619b..cacee47e96196 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java @@ -48,7 +48,7 @@ protected void handleRegistration(SocketChannelContext context) throws IOExcepti context.register(); SelectionKey selectionKey = context.getSelectionKey(); selectionKey.attach(context); - if (context.hasQueuedWriteOps()) { + if (context.readyForFlush()) { SelectionKeyUtils.setConnectReadAndWriteInterested(selectionKey); } else { SelectionKeyUtils.setConnectAndReadInterested(selectionKey); @@ -150,7 +150,7 @@ protected void postHandling(SocketChannelContext context) { } else { SelectionKey selectionKey = context.getSelectionKey(); boolean currentlyWriteInterested = SelectionKeyUtils.isWriteInterested(selectionKey); - boolean pendingWrites = context.hasQueuedWriteOps(); + boolean pendingWrites = context.readyForFlush(); if (currentlyWriteInterested == false && pendingWrites) { SelectionKeyUtils.setWriteInterested(selectionKey); } else if (currentlyWriteInterested && pendingWrites == false) { diff --git a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java index 665b9f7759e11..25de6ab7326f3 100644 --- a/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java +++ b/libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - package org.elasticsearch.nio; import java.util.function.BiConsumer; @@ -24,11 +23,14 @@ /** * This is a basic write operation that can be queued with a channel. The only requirements of a write * operation is that is has a listener and a reference to its channel. The actual conversion of the write - * operation implementation to bytes will be performed by the {@link SocketChannelContext}. + * operation implementation to bytes will be performed by the {@link ReadWriteHandler}. */ public interface WriteOperation { BiConsumer getListener(); SocketChannelContext getChannel(); + + Object getObject(); + } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java index d9de0ab1361c3..addfcdedbf99f 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java @@ -19,23 +19,19 @@ package org.elasticsearch.nio; +import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.test.ESTestCase; import org.junit.Before; -import org.mockito.ArgumentCaptor; import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; import java.nio.channels.SocketChannel; import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Supplier; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; -import static org.mockito.Matchers.isNull; -import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -43,20 +39,19 @@ public class BytesChannelContextTests extends ESTestCase { - private SocketChannelContext.ReadConsumer readConsumer; + private CheckedFunction readConsumer; private NioSocketChannel channel; private SocketChannel rawChannel; private BytesChannelContext context; private InboundChannelBuffer channelBuffer; private SocketSelector selector; - private Consumer exceptionHandler; private BiConsumer listener; private int messageLength; @Before @SuppressWarnings("unchecked") public void init() { - readConsumer = mock(SocketChannelContext.ReadConsumer.class); + readConsumer = mock(CheckedFunction.class); messageLength = randomInt(96) + 20; selector = mock(SocketSelector.class); @@ -64,9 +59,9 @@ public void init() { channel = mock(NioSocketChannel.class); rawChannel = mock(SocketChannel.class); channelBuffer = InboundChannelBuffer.allocatingInstance(); - exceptionHandler = mock(Consumer.class); + TestReadWriteHandler handler = new TestReadWriteHandler(readConsumer); when(channel.getRawChannel()).thenReturn(rawChannel); - context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, channelBuffer); + context = new BytesChannelContext(channel, selector, mock(Consumer.class), handler, channelBuffer); when(selector.isOnCurrentThread()).thenReturn(true); } @@ -80,13 +75,13 @@ public void testSuccessfulRead() throws IOException { return bytes.length; }); - when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0); + when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, 0); assertEquals(messageLength, context.read()); assertEquals(0, channelBuffer.getIndex()); assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity()); - verify(readConsumer, times(1)).consumeReads(channelBuffer); + verify(readConsumer, times(1)).apply(channelBuffer); } public void testMultipleReadsConsumed() throws IOException { @@ -98,13 +93,13 @@ public void testMultipleReadsConsumed() throws IOException { return bytes.length; }); - when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0); + when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, messageLength, 0); assertEquals(bytes.length, context.read()); assertEquals(0, channelBuffer.getIndex()); assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity()); - verify(readConsumer, times(2)).consumeReads(channelBuffer); + verify(readConsumer, times(2)).apply(channelBuffer); } public void testPartialRead() throws IOException { @@ -117,20 +112,20 @@ public void testPartialRead() throws IOException { }); - when(readConsumer.consumeReads(channelBuffer)).thenReturn(0); + when(readConsumer.apply(channelBuffer)).thenReturn(0); assertEquals(messageLength, context.read()); assertEquals(bytes.length, channelBuffer.getIndex()); - verify(readConsumer, times(1)).consumeReads(channelBuffer); + verify(readConsumer, times(1)).apply(channelBuffer); - when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength * 2, 0); + when(readConsumer.apply(channelBuffer)).thenReturn(messageLength * 2, 0); assertEquals(messageLength, context.read()); assertEquals(0, channelBuffer.getIndex()); assertEquals(BigArrays.BYTE_PAGE_SIZE - (bytes.length * 2), channelBuffer.getCapacity()); - verify(readConsumer, times(2)).consumeReads(channelBuffer); + verify(readConsumer, times(2)).apply(channelBuffer); } public void testReadThrowsIOException() throws IOException { @@ -157,186 +152,100 @@ public void testReadLessThanZeroMeansReadyForClose() throws IOException { assertTrue(context.selectorShouldClose()); } - @SuppressWarnings("unchecked") - public void testCloseClosesChannelBuffer() throws IOException { - try (SocketChannel realChannel = SocketChannel.open()) { - when(channel.getRawChannel()).thenReturn(realChannel); - context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, channelBuffer); - - when(channel.isOpen()).thenReturn(true); - Runnable closer = mock(Runnable.class); - Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer); - InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); - buffer.ensureCapacity(1); - BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, buffer); - context.closeFromSelector(); - verify(closer).run(); - } - } - - public void testWriteFailsIfClosing() { - context.closeChannel(); - - ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; - context.sendMessage(buffers, listener); - - verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class)); - } - - public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception { - ArgumentCaptor writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class); - - when(selector.isOnCurrentThread()).thenReturn(false); - - ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; - context.sendMessage(buffers, listener); - - verify(selector).queueWrite(writeOpCaptor.capture()); - BytesWriteOperation writeOp = writeOpCaptor.getValue(); - - assertSame(listener, writeOp.getListener()); - assertSame(context, writeOp.getChannel()); - assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); - } - - public void testSendMessageFromSameThreadIsQueuedInChannel() { - ArgumentCaptor writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class); - - ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; - context.sendMessage(buffers, listener); - - verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture()); - BytesWriteOperation writeOp = writeOpCaptor.getValue(); - - assertSame(listener, writeOp.getListener()); - assertSame(context, writeOp.getChannel()); - assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); - } - - public void testWriteIsQueuedInChannel() { - assertFalse(context.hasQueuedWriteOps()); - - ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); - - assertTrue(context.hasQueuedWriteOps()); - } - - @SuppressWarnings("unchecked") - public void testWriteOpsClearedOnClose() throws Exception { - try (SocketChannel realChannel = SocketChannel.open()) { - when(channel.getRawChannel()).thenReturn(realChannel); - context = new BytesChannelContext(channel, selector, exceptionHandler, readConsumer, channelBuffer); - - assertFalse(context.hasQueuedWriteOps()); - - ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); - - assertTrue(context.hasQueuedWriteOps()); - - when(channel.isOpen()).thenReturn(true); - context.closeFromSelector(); - - verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class)); - - assertFalse(context.hasQueuedWriteOps()); - } - } - + @SuppressWarnings("varargs") public void testQueuedWriteIsFlushedInFlushCall() throws Exception { - assertFalse(context.hasQueuedWriteOps()); + assertFalse(context.readyForFlush()); ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOperation = mock(BytesWriteOperation.class); - context.queueWriteOperation(writeOperation); - assertTrue(context.hasQueuedWriteOps()); + FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); + context.queueWriteOperation(flushOperation); + + assertTrue(context.readyForFlush()); - when(writeOperation.getBuffersToWrite()).thenReturn(buffers); - when(writeOperation.isFullyFlushed()).thenReturn(true); - when(writeOperation.getListener()).thenReturn(listener); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.isFullyFlushed()).thenReturn(true); + when(flushOperation.getListener()).thenReturn(listener); context.flushChannel(); verify(rawChannel).write(buffers, 0, buffers.length); verify(selector).executeListener(listener, null); - assertFalse(context.hasQueuedWriteOps()); + assertFalse(context.readyForFlush()); } public void testPartialFlush() throws IOException { - assertFalse(context.hasQueuedWriteOps()); - - BytesWriteOperation writeOperation = mock(BytesWriteOperation.class); - context.queueWriteOperation(writeOperation); + assertFalse(context.readyForFlush()); + FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); + context.queueWriteOperation(flushOperation); + assertTrue(context.readyForFlush()); - assertTrue(context.hasQueuedWriteOps()); - - when(writeOperation.isFullyFlushed()).thenReturn(false); - when(writeOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); + when(flushOperation.isFullyFlushed()).thenReturn(false); + when(flushOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); context.flushChannel(); verify(listener, times(0)).accept(null, null); - assertTrue(context.hasQueuedWriteOps()); + assertTrue(context.readyForFlush()); } @SuppressWarnings("unchecked") public void testMultipleWritesPartialFlushes() throws IOException { - assertFalse(context.hasQueuedWriteOps()); + assertFalse(context.readyForFlush()); BiConsumer listener2 = mock(BiConsumer.class); - BytesWriteOperation writeOperation1 = mock(BytesWriteOperation.class); - BytesWriteOperation writeOperation2 = mock(BytesWriteOperation.class); - when(writeOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); - when(writeOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); - when(writeOperation1.getListener()).thenReturn(listener); - when(writeOperation2.getListener()).thenReturn(listener2); - context.queueWriteOperation(writeOperation1); - context.queueWriteOperation(writeOperation2); - - assertTrue(context.hasQueuedWriteOps()); - - when(writeOperation1.isFullyFlushed()).thenReturn(true); - when(writeOperation2.isFullyFlushed()).thenReturn(false); + FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class); + FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class); + when(flushOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); + when(flushOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[0]); + when(flushOperation1.getListener()).thenReturn(listener); + when(flushOperation2.getListener()).thenReturn(listener2); + + context.queueWriteOperation(flushOperation1); + context.queueWriteOperation(flushOperation2); + + assertTrue(context.readyForFlush()); + + when(flushOperation1.isFullyFlushed()).thenReturn(true); + when(flushOperation2.isFullyFlushed()).thenReturn(false); context.flushChannel(); verify(selector).executeListener(listener, null); verify(listener2, times(0)).accept(null, null); - assertTrue(context.hasQueuedWriteOps()); + assertTrue(context.readyForFlush()); - when(writeOperation2.isFullyFlushed()).thenReturn(true); + when(flushOperation2.isFullyFlushed()).thenReturn(true); context.flushChannel(); verify(selector).executeListener(listener2, null); - assertFalse(context.hasQueuedWriteOps()); + assertFalse(context.readyForFlush()); } public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { - assertFalse(context.hasQueuedWriteOps()); + assertFalse(context.readyForFlush()); ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOperation = mock(BytesWriteOperation.class); - context.queueWriteOperation(writeOperation); + FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); + context.queueWriteOperation(flushOperation); - assertTrue(context.hasQueuedWriteOps()); + assertTrue(context.readyForFlush()); IOException exception = new IOException(); - when(writeOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); when(rawChannel.write(buffers, 0, buffers.length)).thenThrow(exception); - when(writeOperation.getListener()).thenReturn(listener); + when(flushOperation.getListener()).thenReturn(listener); expectThrows(IOException.class, () -> context.flushChannel()); verify(selector).executeFailedListener(listener, exception); - assertFalse(context.hasQueuedWriteOps()); + assertFalse(context.readyForFlush()); } public void testWriteIOExceptionMeansChannelReadyToClose() throws IOException { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOperation = mock(BytesWriteOperation.class); - context.queueWriteOperation(writeOperation); + FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); + context.queueWriteOperation(flushOperation); IOException exception = new IOException(); - when(writeOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); when(rawChannel.write(buffers, 0, buffers.length)).thenThrow(exception); assertFalse(context.selectorShouldClose()); @@ -344,7 +253,7 @@ public void testWriteIOExceptionMeansChannelReadyToClose() throws IOException { assertTrue(context.selectorShouldClose()); } - public void initiateCloseSchedulesCloseWithSelector() { + public void testInitiateCloseSchedulesCloseWithSelector() { context.closeChannel(); verify(selector).queueChannelClose(channel); } @@ -356,4 +265,18 @@ private static byte[] createMessage(int length) { } return bytes; } + + private static class TestReadWriteHandler extends BytesWriteHandler { + + private final CheckedFunction fn; + + private TestReadWriteHandler(CheckedFunction fn) { + this.fn = fn; + } + + @Override + public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException { + return fn.apply(channelBuffer); + } + } } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteOperationTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/FlushOperationTests.java similarity index 87% rename from libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteOperationTests.java rename to libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/FlushOperationTests.java index 05afc80a49086..a244de51f3591 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteOperationTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/FlushOperationTests.java @@ -29,22 +29,19 @@ import static org.mockito.Mockito.mock; -public class BytesWriteOperationTests extends ESTestCase { +public class FlushOperationTests extends ESTestCase { - private SocketChannelContext channelContext; private BiConsumer listener; @Before @SuppressWarnings("unchecked") public void setFields() { - channelContext = mock(SocketChannelContext.class); listener = mock(BiConsumer.class); - } public void testFullyFlushedMarker() { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOp = new BytesWriteOperation(channelContext, buffers, listener); + FlushOperation writeOp = new FlushOperation(buffers, listener); writeOp.incrementIndex(10); @@ -53,7 +50,7 @@ public void testFullyFlushedMarker() { public void testPartiallyFlushedMarker() { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOp = new BytesWriteOperation(channelContext, buffers, listener); + FlushOperation writeOp = new FlushOperation(buffers, listener); writeOp.incrementIndex(5); @@ -62,7 +59,7 @@ public void testPartiallyFlushedMarker() { public void testMultipleFlushesWithCompositeBuffer() throws IOException { ByteBuffer[] buffers = {ByteBuffer.allocate(10), ByteBuffer.allocate(15), ByteBuffer.allocate(3)}; - BytesWriteOperation writeOp = new BytesWriteOperation(channelContext, buffers, listener); + FlushOperation writeOp = new FlushOperation(buffers, listener); ArgumentCaptor buffersCaptor = ArgumentCaptor.forClass(ByteBuffer[].class); diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java index 17e6b7acba283..d6787f7cc1534 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java @@ -21,18 +21,27 @@ import org.elasticsearch.test.ESTestCase; import org.junit.Before; +import org.mockito.ArgumentCaptor; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; import java.nio.channels.SocketChannel; +import java.util.Arrays; +import java.util.Collections; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.function.Supplier; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.isNull; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class SocketChannelContextTests extends ESTestCase { @@ -41,6 +50,9 @@ public class SocketChannelContextTests extends ESTestCase { private TestSocketChannelContext context; private Consumer exceptionHandler; private NioSocketChannel channel; + private BiConsumer listener; + private SocketSelector selector; + private ReadWriteHandler readWriteHandler; @SuppressWarnings("unchecked") @Before @@ -49,9 +61,15 @@ public void setup() throws Exception { rawChannel = mock(SocketChannel.class); channel = mock(NioSocketChannel.class); + listener = mock(BiConsumer.class); when(channel.getRawChannel()).thenReturn(rawChannel); exceptionHandler = mock(Consumer.class); - context = new TestSocketChannelContext(channel, mock(SocketSelector.class), exceptionHandler); + selector = mock(SocketSelector.class); + readWriteHandler = mock(ReadWriteHandler.class); + InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); + context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); + + when(selector.isOnCurrentThread()).thenReturn(true); } public void testIOExceptionSetIfEncountered() throws IOException { @@ -119,10 +137,147 @@ public void testConnectFails() throws IOException { assertSame(ioException, exception.get()); } + public void testWriteFailsIfClosing() { + context.closeChannel(); + + ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; + context.sendMessage(buffers, listener); + + verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class)); + } + + public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception { + ArgumentCaptor writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class); + + when(selector.isOnCurrentThread()).thenReturn(false); + + ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; + WriteOperation writeOperation = mock(WriteOperation.class); + when(readWriteHandler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation); + context.sendMessage(buffers, listener); + + verify(selector).queueWrite(writeOpCaptor.capture()); + WriteOperation writeOp = writeOpCaptor.getValue(); + + assertSame(writeOperation, writeOp); + } + + public void testSendMessageFromSameThreadIsQueuedInChannel() { + ArgumentCaptor writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class); + + ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; + WriteOperation writeOperation = mock(WriteOperation.class); + when(readWriteHandler.createWriteOperation(context, buffers, listener)).thenReturn(writeOperation); + context.sendMessage(buffers, listener); + + verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture()); + WriteOperation writeOp = writeOpCaptor.getValue(); + + assertSame(writeOperation, writeOp); + } + + public void testWriteIsQueuedInChannel() { + assertFalse(context.readyForFlush()); + + ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; + FlushReadyWrite writeOperation = new FlushReadyWrite(context, buffer, listener); + when(readWriteHandler.writeToBytes(writeOperation)).thenReturn(Collections.singletonList(writeOperation)); + context.queueWriteOperation(writeOperation); + + verify(readWriteHandler).writeToBytes(writeOperation); + assertTrue(context.readyForFlush()); + } + + public void testHandleReadBytesWillCheckForNewFlushOperations() throws IOException { + assertFalse(context.readyForFlush()); + when(readWriteHandler.pollFlushOperations()).thenReturn(Collections.singletonList(mock(FlushOperation.class))); + context.handleReadBytes(); + assertTrue(context.readyForFlush()); + } + + @SuppressWarnings({"unchecked", "varargs"}) + public void testFlushOpsClearedOnClose() throws Exception { + try (SocketChannel realChannel = SocketChannel.open()) { + when(channel.getRawChannel()).thenReturn(realChannel); + InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); + context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); + + assertFalse(context.readyForFlush()); + + ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; + WriteOperation writeOperation = mock(WriteOperation.class); + BiConsumer listener2 = mock(BiConsumer.class); + when(readWriteHandler.writeToBytes(writeOperation)).thenReturn(Arrays.asList(new FlushOperation(buffer, listener), + new FlushOperation(buffer, listener2))); + context.queueWriteOperation(writeOperation); + + assertTrue(context.readyForFlush()); + + when(channel.isOpen()).thenReturn(true); + context.closeFromSelector(); + + verify(selector, times(1)).executeFailedListener(same(listener), any(ClosedChannelException.class)); + verify(selector, times(1)).executeFailedListener(same(listener2), any(ClosedChannelException.class)); + + assertFalse(context.readyForFlush()); + } + } + + @SuppressWarnings({"unchecked", "varargs"}) + public void testWillPollForFlushOpsToClose() throws Exception { + try (SocketChannel realChannel = SocketChannel.open()) { + when(channel.getRawChannel()).thenReturn(realChannel); + InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); + context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); + + + ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; + BiConsumer listener2 = mock(BiConsumer.class); + + assertFalse(context.readyForFlush()); + when(channel.isOpen()).thenReturn(true); + when(readWriteHandler.pollFlushOperations()).thenReturn(Arrays.asList(new FlushOperation(buffer, listener), + new FlushOperation(buffer, listener2))); + context.closeFromSelector(); + + verify(selector, times(1)).executeFailedListener(same(listener), any(ClosedChannelException.class)); + verify(selector, times(1)).executeFailedListener(same(listener2), any(ClosedChannelException.class)); + + assertFalse(context.readyForFlush()); + } + } + + public void testCloseClosesWriteProducer() throws IOException { + try (SocketChannel realChannel = SocketChannel.open()) { + when(channel.getRawChannel()).thenReturn(realChannel); + when(channel.isOpen()).thenReturn(true); + InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance(); + BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer); + context.closeFromSelector(); + verify(readWriteHandler).close(); + } + } + + @SuppressWarnings("unchecked") + public void testCloseClosesChannelBuffer() throws IOException { + try (SocketChannel realChannel = SocketChannel.open()) { + when(channel.getRawChannel()).thenReturn(realChannel); + when(channel.isOpen()).thenReturn(true); + Runnable closer = mock(Runnable.class); + Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer); + InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); + buffer.ensureCapacity(1); + TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer); + context.closeFromSelector(); + verify(closer).run(); + } + } + private static class TestSocketChannelContext extends SocketChannelContext { - private TestSocketChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler) { - super(channel, selector, exceptionHandler); + private TestSocketChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler, + ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) { + super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); } @Override @@ -135,16 +290,6 @@ public int read() throws IOException { } } - @Override - public void sendMessage(ByteBuffer[] buffers, BiConsumer listener) { - - } - - @Override - public void queueWriteOperation(WriteOperation writeOperation) { - - } - @Override public void flushChannel() throws IOException { if (randomBoolean()) { @@ -155,11 +300,6 @@ public void flushChannel() throws IOException { } } - @Override - public boolean hasQueuedWriteOps() { - return false; - } - @Override public boolean selectorShouldClose() { return false; @@ -167,7 +307,15 @@ public boolean selectorShouldClose() { @Override public void closeChannel() { + isClosing.set(true); + } + } + private static byte[] createMessage(int length) { + byte[] bytes = new byte[length]; + for (int i = 0; i < length; ++i) { + bytes[i] = randomByte(); } + return bytes; } } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java index 4f476c1ff6b22..a80563f7d74db 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java @@ -26,6 +26,7 @@ import java.nio.channels.CancelledKeyException; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; +import java.util.Collections; import java.util.function.Consumer; import static org.mockito.Mockito.mock; @@ -37,6 +38,7 @@ public class SocketEventHandlerTests extends ESTestCase { private Consumer exceptionHandler; + private ReadWriteHandler readWriteHandler; private SocketEventHandler handler; private NioSocketChannel channel; private SocketChannel rawChannel; @@ -46,13 +48,14 @@ public class SocketEventHandlerTests extends ESTestCase { @SuppressWarnings("unchecked") public void setUpHandler() throws IOException { exceptionHandler = mock(Consumer.class); + readWriteHandler = mock(ReadWriteHandler.class); SocketSelector selector = mock(SocketSelector.class); handler = new SocketEventHandler(logger); rawChannel = mock(SocketChannel.class); channel = new NioSocketChannel(rawChannel); when(rawChannel.finishConnect()).thenReturn(true); - context = new DoNotRegisterContext(channel, selector, exceptionHandler, new TestSelectionKey(0)); + context = new DoNotRegisterContext(channel, selector, exceptionHandler, new TestSelectionKey(0), readWriteHandler); channel.setContext(context); handler.handleRegistration(context); @@ -83,7 +86,9 @@ public void testRegisterAddsAttachment() throws IOException { } public void testRegisterWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException { - channel.getContext().queueWriteOperation(mock(BytesWriteOperation.class)); + FlushReadyWrite flushReadyWrite = mock(FlushReadyWrite.class); + when(readWriteHandler.writeToBytes(flushReadyWrite)).thenReturn(Collections.singletonList(flushReadyWrite)); + channel.getContext().queueWriteOperation(flushReadyWrite); handler.handleRegistration(context); assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT | SelectionKey.OP_WRITE, context.getSelectionKey().interestOps()); } @@ -162,7 +167,7 @@ public void testPostHandlingWillAddWriteIfNecessary() throws IOException { TestSelectionKey selectionKey = new TestSelectionKey(SelectionKey.OP_READ); SocketChannelContext context = mock(SocketChannelContext.class); when(context.getSelectionKey()).thenReturn(selectionKey); - when(context.hasQueuedWriteOps()).thenReturn(true); + when(context.readyForFlush()).thenReturn(true); NioSocketChannel channel = mock(NioSocketChannel.class); when(channel.getContext()).thenReturn(context); @@ -176,7 +181,7 @@ public void testPostHandlingWillRemoveWriteIfNecessary() throws IOException { TestSelectionKey key = new TestSelectionKey(SelectionKey.OP_READ | SelectionKey.OP_WRITE); SocketChannelContext context = mock(SocketChannelContext.class); when(context.getSelectionKey()).thenReturn(key); - when(context.hasQueuedWriteOps()).thenReturn(false); + when(context.readyForFlush()).thenReturn(false); NioSocketChannel channel = mock(NioSocketChannel.class); when(channel.getContext()).thenReturn(context); @@ -192,8 +197,8 @@ private class DoNotRegisterContext extends BytesChannelContext { private final TestSelectionKey selectionKey; DoNotRegisterContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler, - TestSelectionKey selectionKey) { - super(channel, selector, exceptionHandler, mock(ReadConsumer.class), InboundChannelBuffer.allocatingInstance()); + TestSelectionKey selectionKey, ReadWriteHandler handler) { + super(channel, selector, exceptionHandler, handler, InboundChannelBuffer.allocatingInstance()); this.selectionKey = selectionKey; } diff --git a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java index 223f14455f96d..a68f5c05dad5a 100644 --- a/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java +++ b/libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java @@ -117,13 +117,13 @@ public void testSuccessfullyRegisterChannelWillAttemptConnect() throws Exception public void testQueueWriteWhenNotRunning() throws Exception { socketSelector.close(); - socketSelector.queueWrite(new BytesWriteOperation(channelContext, buffers, listener)); + socketSelector.queueWrite(new FlushReadyWrite(channelContext, buffers, listener)); verify(listener).accept(isNull(Void.class), any(ClosedSelectorException.class)); } public void testQueueWriteChannelIsClosed() throws Exception { - BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); + WriteOperation writeOperation = new FlushReadyWrite(channelContext, buffers, listener); socketSelector.queueWrite(writeOperation); when(channelContext.isOpen()).thenReturn(false); @@ -136,7 +136,7 @@ public void testQueueWriteChannelIsClosed() throws Exception { public void testQueueWriteSelectionKeyThrowsException() throws Exception { SelectionKey selectionKey = mock(SelectionKey.class); - BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); + WriteOperation writeOperation = new FlushReadyWrite(channelContext, buffers, listener); CancelledKeyException cancelledKeyException = new CancelledKeyException(); socketSelector.queueWrite(writeOperation); @@ -149,7 +149,7 @@ public void testQueueWriteSelectionKeyThrowsException() throws Exception { } public void testQueueWriteSuccessful() throws Exception { - BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); + WriteOperation writeOperation = new FlushReadyWrite(channelContext, buffers, listener); socketSelector.queueWrite(writeOperation); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); @@ -161,7 +161,7 @@ public void testQueueWriteSuccessful() throws Exception { } public void testQueueDirectlyInChannelBufferSuccessful() throws Exception { - BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); + WriteOperation writeOperation = new FlushReadyWrite(channelContext, buffers, listener); assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0); @@ -174,7 +174,7 @@ public void testQueueDirectlyInChannelBufferSuccessful() throws Exception { public void testQueueDirectlyInChannelBufferSelectionKeyThrowsException() throws Exception { SelectionKey selectionKey = mock(SelectionKey.class); - BytesWriteOperation writeOperation = new BytesWriteOperation(channelContext, buffers, listener); + WriteOperation writeOperation = new FlushReadyWrite(channelContext, buffers, listener); CancelledKeyException cancelledKeyException = new CancelledKeyException(); when(channelContext.getSelectionKey()).thenReturn(selectionKey); @@ -277,7 +277,7 @@ public void testCleanup() throws Exception { socketSelector.preSelect(); - socketSelector.queueWrite(new BytesWriteOperation(channelContext, buffers, listener)); + socketSelector.queueWrite(new FlushReadyWrite(channelContext, buffers, listener)); socketSelector.scheduleForRegistration(unregisteredChannel); TestSelectionKey testSelectionKey = new TestSelectionKey(0); diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java index 12db47908d1f3..6e39a7f50d2cd 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java @@ -40,6 +40,7 @@ import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; import org.elasticsearch.http.netty4.pipelining.HttpPipelinedRequest; import org.elasticsearch.rest.AbstractRestChannel; @@ -60,27 +61,29 @@ final class Netty4HttpChannel extends AbstractRestChannel { private final FullHttpRequest nettyRequest; private final HttpPipelinedRequest pipelinedRequest; private final ThreadContext threadContext; + private final HttpHandlingSettings handlingSettings; /** * @param transport The corresponding NettyHttpServerTransport where this channel belongs to. * @param request The request that is handled by this channel. * @param pipelinedRequest If HTTP pipelining is enabled provide the corresponding pipelined request. May be null if - * HTTP pipelining is disabled. - * @param detailedErrorsEnabled true iff error messages should include stack traces. + * HTTP pipelining is disabled. + * @param handlingSettings true iff error messages should include stack traces. * @param threadContext the thread context for the channel */ Netty4HttpChannel( final Netty4HttpServerTransport transport, final Netty4HttpRequest request, final HttpPipelinedRequest pipelinedRequest, - final boolean detailedErrorsEnabled, + final HttpHandlingSettings handlingSettings, final ThreadContext threadContext) { - super(request, detailedErrorsEnabled); + super(request, handlingSettings.getDetailedErrorsEnabled()); this.transport = transport; this.channel = request.getChannel(); this.nettyRequest = request.request(); this.pipelinedRequest = pipelinedRequest; this.threadContext = threadContext; + this.handlingSettings = handlingSettings; } @Override @@ -170,7 +173,7 @@ private void setHeaderField(HttpResponse resp, String headerField, String value, } private void addCookies(HttpResponse resp) { - if (transport.resetCookies) { + if (handlingSettings.isResetCookies()) { String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE); if (cookieString != null) { Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); @@ -222,8 +225,6 @@ private FullHttpResponse newResponse(ByteBuf buffer) { return response; } - private static final HttpResponseStatus TOO_MANY_REQUESTS = new HttpResponseStatus(429, "Too Many Requests"); - private static Map MAP; static { @@ -266,7 +267,7 @@ private FullHttpResponse newResponse(ByteBuf buffer) { map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.TOO_MANY_REQUESTS, TOO_MANY_REQUESTS); + map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); @@ -279,5 +280,4 @@ private FullHttpResponse newResponse(ByteBuf buffer) { private static HttpResponseStatus getStatus(RestStatus status) { return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); } - } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java index 1fd18b2a016d7..74429c8dda9b7 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java @@ -29,6 +29,7 @@ import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaders; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.netty4.pipelining.HttpPipelinedRequest; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.transport.netty4.Netty4Utils; @@ -39,14 +40,15 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler { private final Netty4HttpServerTransport serverTransport; + private final HttpHandlingSettings handlingSettings; private final boolean httpPipeliningEnabled; - private final boolean detailedErrorsEnabled; private final ThreadContext threadContext; - Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport, boolean detailedErrorsEnabled, ThreadContext threadContext) { + Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport, HttpHandlingSettings handlingSettings, + ThreadContext threadContext) { this.serverTransport = serverTransport; this.httpPipeliningEnabled = serverTransport.pipelining; - this.detailedErrorsEnabled = detailedErrorsEnabled; + this.handlingSettings = handlingSettings; this.threadContext = threadContext; } @@ -109,7 +111,7 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Except Netty4HttpChannel innerChannel; try { innerChannel = - new Netty4HttpChannel(serverTransport, httpRequest, pipelinedRequest, detailedErrorsEnabled, threadContext); + new Netty4HttpChannel(serverTransport, httpRequest, pipelinedRequest, handlingSettings, threadContext); } catch (final IllegalArgumentException e) { if (badRequestCause == null) { badRequestCause = e; @@ -124,7 +126,7 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Except copy, ctx.channel()); innerChannel = - new Netty4HttpChannel(serverTransport, innerRequest, pipelinedRequest, detailedErrorsEnabled, threadContext); + new Netty4HttpChannel(serverTransport, innerRequest, pipelinedRequest, handlingSettings, threadContext); } channel = innerChannel; } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index c8c2c4829d2cf..8e5bace46aa7e 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -19,8 +19,6 @@ package org.elasticsearch.http.netty4; -import com.carrotsearch.hppc.IntHashSet; -import com.carrotsearch.hppc.IntSet; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; @@ -44,15 +42,12 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.NetworkExceptionHelper; -import org.elasticsearch.common.transport.PortsRange; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; @@ -62,18 +57,14 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.http.BindHttpException; -import org.elasticsearch.http.HttpInfo; -import org.elasticsearch.http.HttpServerTransport; +import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpStats; import org.elasticsearch.http.netty4.cors.Netty4CorsConfig; import org.elasticsearch.http.netty4.cors.Netty4CorsConfigBuilder; import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; import org.elasticsearch.http.netty4.pipelining.HttpPipeliningHandler; -import org.elasticsearch.rest.RestChannel; -import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestUtils; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.BindTransportException; import org.elasticsearch.transport.netty4.Netty4OpenChannelsHandler; import org.elasticsearch.transport.netty4.Netty4Utils; @@ -94,7 +85,6 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_MAX_AGE; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_BIND_HOST; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION_LEVEL; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_DETAILED_ERRORS_ENABLED; @@ -102,9 +92,6 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PORT; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_HOST; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_PORT; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_READ_TIMEOUT; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_RESET_COOKIES; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_TCP_KEEP_ALIVE; @@ -116,7 +103,7 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_PIPELINING_MAX_EVENTS; import static org.elasticsearch.http.netty4.cors.Netty4CorsHandler.ANY_ORIGIN; -public class Netty4HttpServerTransport extends AbstractLifecycleComponent implements HttpServerTransport { +public class Netty4HttpServerTransport extends AbstractHttpServerTransport { static { Netty4Utils.setup(); @@ -167,11 +154,8 @@ public class Netty4HttpServerTransport extends AbstractLifecycleComponent implem public static final Setting SETTING_HTTP_NETTY_RECEIVE_PREDICTOR_SIZE = Setting.byteSizeSetting("http.netty.receive_predictor_size", new ByteSizeValue(64, ByteSizeUnit.KB), Property.NodeScope); - - protected final NetworkService networkService; protected final BigArrays bigArrays; - protected final ByteSizeValue maxContentLength; protected final ByteSizeValue maxInitialLineLength; protected final ByteSizeValue maxHeaderSize; protected final ByteSizeValue maxChunkSize; @@ -182,20 +166,6 @@ public class Netty4HttpServerTransport extends AbstractLifecycleComponent implem protected final int pipeliningMaxEvents; - protected final boolean compression; - - protected final int compressionLevel; - - protected final boolean resetCookies; - - protected final PortsRange port; - - protected final String bindHosts[]; - - protected final String publishHosts[]; - - protected final boolean detailedErrorsEnabled; - protected final ThreadPool threadPool; /** * The registry used to construct parsers so they support {@link XContentParser#namedObject(Class, String, Object)}. */ @@ -211,14 +181,13 @@ public class Netty4HttpServerTransport extends AbstractLifecycleComponent implem private final int readTimeoutMillis; protected final int maxCompositeBufferComponents; - private final Dispatcher dispatcher; protected volatile ServerBootstrap serverBootstrap; - protected volatile BoundTransportAddress boundAddress; - protected final List serverChannels = new ArrayList<>(); + protected final HttpHandlingSettings httpHandlingSettings; + // package private for testing Netty4OpenChannelsHandler serverOpenChannels; @@ -227,49 +196,40 @@ public class Netty4HttpServerTransport extends AbstractLifecycleComponent implem public Netty4HttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) { - super(settings); + super(settings, networkService, threadPool, dispatcher); Netty4Utils.setAvailableProcessors(EsExecutors.PROCESSORS_SETTING.get(settings)); - this.networkService = networkService; this.bigArrays = bigArrays; - this.threadPool = threadPool; this.xContentRegistry = xContentRegistry; - this.dispatcher = dispatcher; - ByteSizeValue maxContentLength = SETTING_HTTP_MAX_CONTENT_LENGTH.get(settings); this.maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings); this.maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); this.maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings); - this.resetCookies = SETTING_HTTP_RESET_COOKIES.get(settings); + this.httpHandlingSettings = new HttpHandlingSettings(Math.toIntExact(maxContentLength.getBytes()), + Math.toIntExact(maxChunkSize.getBytes()), + Math.toIntExact(maxHeaderSize.getBytes()), + Math.toIntExact(maxInitialLineLength.getBytes()), + SETTING_HTTP_RESET_COOKIES.get(settings), + SETTING_HTTP_COMPRESSION.get(settings), + SETTING_HTTP_COMPRESSION_LEVEL.get(settings), + SETTING_HTTP_DETAILED_ERRORS_ENABLED.get(settings)); + this.maxCompositeBufferComponents = SETTING_HTTP_NETTY_MAX_COMPOSITE_BUFFER_COMPONENTS.get(settings); this.workerCount = SETTING_HTTP_WORKER_COUNT.get(settings); - this.port = SETTING_HTTP_PORT.get(settings); - // we can't make the network.bind_host a fallback since we already fall back to http.host hence the extra conditional here - List httpBindHost = SETTING_HTTP_BIND_HOST.get(settings); - this.bindHosts = (httpBindHost.isEmpty() ? NetworkService.GLOBAL_NETWORK_BINDHOST_SETTING.get(settings) : httpBindHost) - .toArray(Strings.EMPTY_ARRAY); - // we can't make the network.publish_host a fallback since we already fall back to http.host hence the extra conditional here - List httpPublishHost = SETTING_HTTP_PUBLISH_HOST.get(settings); - this.publishHosts = (httpPublishHost.isEmpty() ? NetworkService.GLOBAL_NETWORK_PUBLISHHOST_SETTING.get(settings) : httpPublishHost) - .toArray(Strings.EMPTY_ARRAY); + this.tcpNoDelay = SETTING_HTTP_TCP_NO_DELAY.get(settings); this.tcpKeepAlive = SETTING_HTTP_TCP_KEEP_ALIVE.get(settings); this.reuseAddress = SETTING_HTTP_TCP_REUSE_ADDRESS.get(settings); this.tcpSendBufferSize = SETTING_HTTP_TCP_SEND_BUFFER_SIZE.get(settings); this.tcpReceiveBufferSize = SETTING_HTTP_TCP_RECEIVE_BUFFER_SIZE.get(settings); - this.detailedErrorsEnabled = SETTING_HTTP_DETAILED_ERRORS_ENABLED.get(settings); this.readTimeoutMillis = Math.toIntExact(SETTING_HTTP_READ_TIMEOUT.get(settings).getMillis()); ByteSizeValue receivePredictor = SETTING_HTTP_NETTY_RECEIVE_PREDICTOR_SIZE.get(settings); recvByteBufAllocator = new FixedRecvByteBufAllocator(receivePredictor.bytesAsInt()); - this.compression = SETTING_HTTP_COMPRESSION.get(settings); - this.compressionLevel = SETTING_HTTP_COMPRESSION_LEVEL.get(settings); this.pipelining = SETTING_PIPELINING.get(settings); this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings); this.corsConfig = buildCorsConfig(settings); - this.maxContentLength = maxContentLength; - logger.debug("using max_chunk_size[{}], max_header_size[{}], max_initial_line_length[{}], max_content_length[{}], " + "receive_predictor[{}], max_composite_buffer_components[{}], pipelining[{}], pipelining_max_events[{}]", maxChunkSize, maxHeaderSize, maxInitialLineLength, this.maxContentLength, receivePredictor, maxCompositeBufferComponents, @@ -326,65 +286,6 @@ protected void doStart() { } } - private BoundTransportAddress createBoundHttpAddress() { - // Bind and start to accept incoming connections. - InetAddress hostAddresses[]; - try { - hostAddresses = networkService.resolveBindHostAddresses(bindHosts); - } catch (IOException e) { - throw new BindHttpException("Failed to resolve host [" + Arrays.toString(bindHosts) + "]", e); - } - - List boundAddresses = new ArrayList<>(hostAddresses.length); - for (InetAddress address : hostAddresses) { - boundAddresses.add(bindAddress(address)); - } - - final InetAddress publishInetAddress; - try { - publishInetAddress = networkService.resolvePublishHostAddresses(publishHosts); - } catch (Exception e) { - throw new BindTransportException("Failed to resolve publish address", e); - } - - final int publishPort = resolvePublishPort(settings, boundAddresses, publishInetAddress); - final InetSocketAddress publishAddress = new InetSocketAddress(publishInetAddress, publishPort); - return new BoundTransportAddress(boundAddresses.toArray(new TransportAddress[0]), new TransportAddress(publishAddress)); - } - - // package private for tests - static int resolvePublishPort(Settings settings, List boundAddresses, InetAddress publishInetAddress) { - int publishPort = SETTING_HTTP_PUBLISH_PORT.get(settings); - - if (publishPort < 0) { - for (TransportAddress boundAddress : boundAddresses) { - InetAddress boundInetAddress = boundAddress.address().getAddress(); - if (boundInetAddress.isAnyLocalAddress() || boundInetAddress.equals(publishInetAddress)) { - publishPort = boundAddress.getPort(); - break; - } - } - } - - // if no matching boundAddress found, check if there is a unique port for all bound addresses - if (publishPort < 0) { - final IntSet ports = new IntHashSet(); - for (TransportAddress boundAddress : boundAddresses) { - ports.add(boundAddress.getPort()); - } - if (ports.size() == 1) { - publishPort = ports.iterator().next().value; - } - } - - if (publishPort < 0) { - throw new BindHttpException("Failed to auto-resolve http publish port, multiple bound addresses " + boundAddresses + - " with distinct ports and none of them matched the publish address (" + publishInetAddress + "). " + - "Please specify a unique port by setting " + SETTING_HTTP_PORT.getKey() + " or " + SETTING_HTTP_PUBLISH_PORT.getKey()); - } - return publishPort; - } - // package private for testing static Netty4CorsConfig buildCorsConfig(Settings settings) { if (SETTING_CORS_ENABLED.get(settings) == false) { @@ -419,7 +320,8 @@ static Netty4CorsConfig buildCorsConfig(Settings settings) { .build(); } - private TransportAddress bindAddress(final InetAddress hostAddress) { + @Override + protected TransportAddress bindAddress(final InetAddress hostAddress) { final AtomicReference lastException = new AtomicReference<>(); final AtomicReference boundSocket = new AtomicReference<>(); boolean success = port.iterate(portNumber -> { @@ -473,20 +375,6 @@ protected void doStop() { protected void doClose() { } - @Override - public BoundTransportAddress boundAddress() { - return this.boundAddress; - } - - @Override - public HttpInfo info() { - BoundTransportAddress boundTransportAddress = boundAddress(); - if (boundTransportAddress == null) { - return null; - } - return new HttpInfo(boundTransportAddress, maxContentLength.getBytes()); - } - @Override public HttpStats stats() { Netty4OpenChannelsHandler channels = serverOpenChannels; @@ -497,20 +385,6 @@ public Netty4CorsConfig getCorsConfig() { return corsConfig; } - void dispatchRequest(final RestRequest request, final RestChannel channel) { - final ThreadContext threadContext = threadPool.getThreadContext(); - try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - dispatcher.dispatchRequest(request, channel, threadContext); - } - } - - void dispatchBadRequest(final RestRequest request, final RestChannel channel, final Throwable cause) { - final ThreadContext threadContext = threadPool.getThreadContext(); - try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - dispatcher.dispatchBadRequest(request, channel, threadContext, cause); - } - } - protected void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { if (cause instanceof ReadTimeoutException) { if (logger.isTraceEnabled()) { @@ -539,20 +413,22 @@ protected void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throw } public ChannelHandler configureServerChannelHandler() { - return new HttpChannelHandler(this, detailedErrorsEnabled, threadPool.getThreadContext()); + return new HttpChannelHandler(this, httpHandlingSettings, threadPool.getThreadContext()); } protected static class HttpChannelHandler extends ChannelInitializer { private final Netty4HttpServerTransport transport; private final Netty4HttpRequestHandler requestHandler; + private final HttpHandlingSettings handlingSettings; protected HttpChannelHandler( final Netty4HttpServerTransport transport, - final boolean detailedErrorsEnabled, + final HttpHandlingSettings handlingSettings, final ThreadContext threadContext) { this.transport = transport; - this.requestHandler = new Netty4HttpRequestHandler(transport, detailedErrorsEnabled, threadContext); + this.handlingSettings = handlingSettings; + this.requestHandler = new Netty4HttpRequestHandler(transport, handlingSettings, threadContext); } @Override @@ -560,18 +436,18 @@ protected void initChannel(Channel ch) throws Exception { ch.pipeline().addLast("openChannels", transport.serverOpenChannels); ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS)); final HttpRequestDecoder decoder = new HttpRequestDecoder( - Math.toIntExact(transport.maxInitialLineLength.getBytes()), - Math.toIntExact(transport.maxHeaderSize.getBytes()), - Math.toIntExact(transport.maxChunkSize.getBytes())); + handlingSettings.getMaxInitialLineLength(), + handlingSettings.getMaxHeaderSize(), + handlingSettings.getMaxChunkSize()); decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR); ch.pipeline().addLast("decoder", decoder); ch.pipeline().addLast("decoder_compress", new HttpContentDecompressor()); ch.pipeline().addLast("encoder", new HttpResponseEncoder()); - final HttpObjectAggregator aggregator = new HttpObjectAggregator(Math.toIntExact(transport.maxContentLength.getBytes())); + final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength()); aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents); ch.pipeline().addLast("aggregator", aggregator); - if (transport.compression) { - ch.pipeline().addLast("encoder_compress", new HttpContentCompressor(transport.compressionLevel)); + if (handlingSettings.isCompression()) { + ch.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); } if (SETTING_CORS_ENABLED.get(transport.settings())) { ch.pipeline().addLast("cors", new Netty4CorsHandler(transport.getCorsConfig())); @@ -587,7 +463,6 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E Netty4Utils.maybeDie(cause); super.exceptionCaught(ctx, cause); } - } } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java index 918e98fd2e7c0..0ef1ea585b11c 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java @@ -56,6 +56,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpTransportSettings; import org.elasticsearch.http.NullDispatcher; import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; @@ -212,10 +213,11 @@ public void testHeadersSet() { httpRequest.headers().add(HttpHeaderNames.ORIGIN, "remote"); final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel(); Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel); + HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; // send a response Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, null, randomBoolean(), threadPool.getThreadContext()); + new Netty4HttpChannel(httpServerTransport, request, null, handlingSettings, threadPool.getThreadContext()); TestResponse resp = new TestResponse(); final String customHeader = "custom-header"; final String customHeaderValue = "xyz"; @@ -242,8 +244,9 @@ public void testReleaseOnSendToClosedChannel() { final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel); final HttpPipelinedRequest pipelinedRequest = randomBoolean() ? new HttpPipelinedRequest(request.request(), 1) : null; + HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, pipelinedRequest, randomBoolean(), threadPool.getThreadContext()); + new Netty4HttpChannel(httpServerTransport, request, pipelinedRequest, handlingSettings, threadPool.getThreadContext()); final TestResponse response = new TestResponse(bigArrays); assertThat(response.content(), instanceOf(Releasable.class)); embeddedChannel.close(); @@ -261,8 +264,9 @@ public void testReleaseOnSendToChannelAfterException() throws IOException { final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel); final HttpPipelinedRequest pipelinedRequest = randomBoolean() ? new HttpPipelinedRequest(request.request(), 1) : null; + HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, pipelinedRequest, randomBoolean(), threadPool.getThreadContext()); + new Netty4HttpChannel(httpServerTransport, request, pipelinedRequest, handlingSettings, threadPool.getThreadContext()); final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, JsonXContent.contentBuilder().startObject().endObject()); assertThat(response.content(), not(instanceOf(Releasable.class))); @@ -306,8 +310,9 @@ public void testConnectionClose() throws Exception { // send a response, the channel close status should match assertTrue(embeddedChannel.isOpen()); + HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, null, randomBoolean(), threadPool.getThreadContext()); + new Netty4HttpChannel(httpServerTransport, request, null, handlingSettings, threadPool.getThreadContext()); final TestResponse resp = new TestResponse(); channel.sendResponse(resp); assertThat(embeddedChannel.isOpen(), equalTo(!close)); @@ -332,9 +337,10 @@ private FullHttpResponse executeRequest(final Settings settings, final String or final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel(); final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel); + HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, null, randomBoolean(), threadPool.getThreadContext()); + new Netty4HttpChannel(httpServerTransport, request, null, handlingSettings, threadPool.getThreadContext()); channel.sendResponse(new TestResponse()); // get the response diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java index 91a5465f6a764..0eb14a8a76e9b 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java @@ -184,7 +184,7 @@ private class CustomHttpChannelHandler extends Netty4HttpServerTransport.HttpCha private final ExecutorService executorService; CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService, ThreadContext threadContext) { - super(transport, randomBoolean(), threadContext); + super(transport, transport.httpHandlingSettings, threadContext); this.executorService = executorService; } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/rest/Netty4BadRequestIT.java b/modules/transport-netty4/src/test/java/org/elasticsearch/rest/Netty4BadRequestIT.java index 028770ed22469..bc89558d3c6dc 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/rest/Netty4BadRequestIT.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/rest/Netty4BadRequestIT.java @@ -99,5 +99,4 @@ public void testInvalidHeaderValue() throws IOException { assertThat(map.get("type"), equalTo("content_type_header_exception")); assertThat(map.get("reason"), equalTo("java.lang.IllegalArgumentException: invalid Content-Type header []")); } - } diff --git a/plugins/transport-nio/build.gradle b/plugins/transport-nio/build.gradle index 60fef4b34241d..e278ebf47983e 100644 --- a/plugins/transport-nio/build.gradle +++ b/plugins/transport-nio/build.gradle @@ -29,4 +29,115 @@ compileTestJava.options.compilerArgs << "-Xlint:-rawtypes,-unchecked" dependencies { compile "org.elasticsearch:elasticsearch-nio:${version}" -} \ No newline at end of file + + // network stack + compile "io.netty:netty-buffer:4.1.16.Final" + compile "io.netty:netty-codec:4.1.16.Final" + compile "io.netty:netty-codec-http:4.1.16.Final" + compile "io.netty:netty-common:4.1.16.Final" + compile "io.netty:netty-handler:4.1.16.Final" + compile "io.netty:netty-resolver:4.1.16.Final" + compile "io.netty:netty-transport:4.1.16.Final" +} + +thirdPartyAudit.excludes = [ + // classes are missing + + // from io.netty.handler.codec.protobuf.ProtobufDecoder (netty) + 'com.google.protobuf.ExtensionRegistry', + 'com.google.protobuf.MessageLite$Builder', + 'com.google.protobuf.MessageLite', + 'com.google.protobuf.Parser', + + // from io.netty.logging.CommonsLoggerFactory (netty) + 'org.apache.commons.logging.Log', + 'org.apache.commons.logging.LogFactory', + + // from io.netty.handler.ssl.OpenSslEngine (netty) + 'io.netty.internal.tcnative.Buffer', + 'io.netty.internal.tcnative.Library', + 'io.netty.internal.tcnative.SSL', + 'io.netty.internal.tcnative.SSLContext', + + // from io.netty.handler.ssl.util.BouncyCastleSelfSignedCertGenerator (netty) + 'org.bouncycastle.asn1.x500.X500Name', + 'org.bouncycastle.cert.X509v3CertificateBuilder', + 'org.bouncycastle.cert.jcajce.JcaX509CertificateConverter', + 'org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder', + 'org.bouncycastle.jce.provider.BouncyCastleProvider', + 'org.bouncycastle.operator.jcajce.JcaContentSignerBuilder', + + // from io.netty.handler.ssl.JettyNpnSslEngine (netty) + 'org.eclipse.jetty.npn.NextProtoNego$ClientProvider', + 'org.eclipse.jetty.npn.NextProtoNego$ServerProvider', + 'org.eclipse.jetty.npn.NextProtoNego', + + // from io.netty.handler.codec.marshalling.ChannelBufferByteInput (netty) + 'org.jboss.marshalling.ByteInput', + + // from io.netty.handler.codec.marshalling.ChannelBufferByteOutput (netty) + 'org.jboss.marshalling.ByteOutput', + + // from io.netty.handler.codec.marshalling.CompatibleMarshallingEncoder (netty) + 'org.jboss.marshalling.Marshaller', + + // from io.netty.handler.codec.marshalling.ContextBoundUnmarshallerProvider (netty) + 'org.jboss.marshalling.MarshallerFactory', + 'org.jboss.marshalling.MarshallingConfiguration', + 'org.jboss.marshalling.Unmarshaller', + + // from io.netty.util.internal.logging.InternalLoggerFactory (netty) - it's optional + 'org.slf4j.Logger', + 'org.slf4j.LoggerFactory', + + 'com.google.protobuf.ExtensionRegistryLite', + 'com.google.protobuf.MessageLiteOrBuilder', + 'com.google.protobuf.nano.CodedOutputByteBufferNano', + 'com.google.protobuf.nano.MessageNano', + 'com.jcraft.jzlib.Deflater', + 'com.jcraft.jzlib.Inflater', + 'com.jcraft.jzlib.JZlib$WrapperType', + 'com.jcraft.jzlib.JZlib', + 'com.ning.compress.BufferRecycler', + 'com.ning.compress.lzf.ChunkDecoder', + 'com.ning.compress.lzf.ChunkEncoder', + 'com.ning.compress.lzf.LZFEncoder', + 'com.ning.compress.lzf.util.ChunkDecoderFactory', + 'com.ning.compress.lzf.util.ChunkEncoderFactory', + 'lzma.sdk.lzma.Encoder', + 'net.jpountz.lz4.LZ4Compressor', + 'net.jpountz.lz4.LZ4Factory', + 'net.jpountz.lz4.LZ4FastDecompressor', + 'net.jpountz.xxhash.StreamingXXHash32', + 'net.jpountz.xxhash.XXHashFactory', + 'io.netty.internal.tcnative.CertificateRequestedCallback', + 'io.netty.internal.tcnative.CertificateRequestedCallback$KeyMaterial', + 'io.netty.internal.tcnative.CertificateVerifier', + 'io.netty.internal.tcnative.SessionTicketKey', + 'io.netty.internal.tcnative.SniHostNameMatcher', + 'org.eclipse.jetty.alpn.ALPN$ClientProvider', + 'org.eclipse.jetty.alpn.ALPN$ServerProvider', + 'org.eclipse.jetty.alpn.ALPN', + + 'io.netty.handler.ssl.util.OpenJdkSelfSignedCertGenerator', + 'io.netty.util.internal.PlatformDependent0', + 'io.netty.util.internal.PlatformDependent0$1', + 'io.netty.util.internal.PlatformDependent0$2', + 'io.netty.util.internal.PlatformDependent0$3', + 'io.netty.util.internal.shaded.org.jctools.queues.BaseLinkedQueueConsumerNodeRef', + 'io.netty.util.internal.shaded.org.jctools.queues.BaseLinkedQueueProducerNodeRef', + 'io.netty.util.internal.shaded.org.jctools.queues.BaseMpscLinkedArrayQueueColdProducerFields', + 'io.netty.util.internal.shaded.org.jctools.queues.BaseMpscLinkedArrayQueueConsumerFields', + 'io.netty.util.internal.shaded.org.jctools.queues.BaseMpscLinkedArrayQueueProducerFields', + 'io.netty.util.internal.shaded.org.jctools.queues.LinkedQueueNode', + 'io.netty.util.internal.shaded.org.jctools.queues.MpscArrayQueueConsumerIndexField', + 'io.netty.util.internal.shaded.org.jctools.queues.MpscArrayQueueProducerIndexField', + 'io.netty.util.internal.shaded.org.jctools.queues.MpscArrayQueueProducerLimitField', + 'io.netty.util.internal.shaded.org.jctools.util.UnsafeAccess', + 'io.netty.util.internal.shaded.org.jctools.util.UnsafeRefArrayAccess', + + 'org.conscrypt.AllocatedBuffer', + 'org.conscrypt.BufferAllocator', + 'org.conscrypt.Conscrypt$Engines', + 'org.conscrypt.HandshakeListener' +] \ No newline at end of file diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/ByteBufUtils.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/ByteBufUtils.java new file mode 100644 index 0000000000000..b4108b3e6c7d0 --- /dev/null +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/ByteBufUtils.java @@ -0,0 +1,252 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.http.nio; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefIterator; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.StreamInput; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +class ByteBufUtils { + + /** + * Turns the given BytesReference into a ByteBuf. Note: the returned ByteBuf will reference the internal + * pages of the BytesReference. Don't free the bytes of reference before the ByteBuf goes out of scope. + */ + static ByteBuf toByteBuf(final BytesReference reference) { + if (reference.length() == 0) { + return Unpooled.EMPTY_BUFFER; + } + if (reference instanceof ByteBufBytesReference) { + return ((ByteBufBytesReference) reference).toByteBuf(); + } else { + final BytesRefIterator iterator = reference.iterator(); + // usually we have one, two, or three components from the header, the message, and a buffer + final List buffers = new ArrayList<>(3); + try { + BytesRef slice; + while ((slice = iterator.next()) != null) { + buffers.add(Unpooled.wrappedBuffer(slice.bytes, slice.offset, slice.length)); + } + final CompositeByteBuf composite = Unpooled.compositeBuffer(buffers.size()); + composite.addComponents(true, buffers); + return composite; + } catch (IOException ex) { + throw new AssertionError("no IO happens here", ex); + } + } + } + + static BytesReference toBytesReference(final ByteBuf buffer) { + return new ByteBufBytesReference(buffer, buffer.readableBytes()); + } + + private static class ByteBufBytesReference extends BytesReference { + + private final ByteBuf buffer; + private final int length; + private final int offset; + + ByteBufBytesReference(ByteBuf buffer, int length) { + this.buffer = buffer; + this.length = length; + this.offset = buffer.readerIndex(); + assert length <= buffer.readableBytes() : "length[" + length +"] > " + buffer.readableBytes(); + } + + @Override + public byte get(int index) { + return buffer.getByte(offset + index); + } + + @Override + public int length() { + return length; + } + + @Override + public BytesReference slice(int from, int length) { + return new ByteBufBytesReference(buffer.slice(offset + from, length), length); + } + + @Override + public StreamInput streamInput() { + return new ByteBufStreamInput(buffer.duplicate(), length); + } + + @Override + public void writeTo(OutputStream os) throws IOException { + buffer.getBytes(offset, os, length); + } + + ByteBuf toByteBuf() { + return buffer.duplicate(); + } + + @Override + public String utf8ToString() { + return buffer.toString(offset, length, StandardCharsets.UTF_8); + } + + @Override + public BytesRef toBytesRef() { + if (buffer.hasArray()) { + return new BytesRef(buffer.array(), buffer.arrayOffset() + offset, length); + } + final byte[] copy = new byte[length]; + buffer.getBytes(offset, copy); + return new BytesRef(copy); + } + + @Override + public long ramBytesUsed() { + return buffer.capacity(); + } + + } + + private static class ByteBufStreamInput extends StreamInput { + + private final ByteBuf buffer; + private final int endIndex; + + ByteBufStreamInput(ByteBuf buffer, int length) { + if (length > buffer.readableBytes()) { + throw new IndexOutOfBoundsException(); + } + this.buffer = buffer; + int startIndex = buffer.readerIndex(); + endIndex = startIndex + length; + buffer.markReaderIndex(); + } + + @Override + public BytesReference readBytesReference(int length) throws IOException { + // NOTE: It is unsafe to share a reference of the internal structure, so we + // use the default implementation which will copy the bytes. It is unsafe because + // a netty ByteBuf might be pooled which requires a manual release to prevent + // memory leaks. + return super.readBytesReference(length); + } + + @Override + public BytesRef readBytesRef(int length) throws IOException { + // NOTE: It is unsafe to share a reference of the internal structure, so we + // use the default implementation which will copy the bytes. It is unsafe because + // a netty ByteBuf might be pooled which requires a manual release to prevent + // memory leaks. + return super.readBytesRef(length); + } + + @Override + public int available() throws IOException { + return endIndex - buffer.readerIndex(); + } + + @Override + protected void ensureCanReadBytes(int length) throws EOFException { + int bytesAvailable = endIndex - buffer.readerIndex(); + if (bytesAvailable < length) { + throw new EOFException("tried to read: " + length + " bytes but only " + bytesAvailable + " remaining"); + } + } + + @Override + public void mark(int readlimit) { + buffer.markReaderIndex(); + } + + @Override + public boolean markSupported() { + return true; + } + + @Override + public int read() throws IOException { + if (available() == 0) { + return -1; + } + return buffer.readByte() & 0xff; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return 0; + } + int available = available(); + if (available == 0) { + return -1; + } + + len = Math.min(available, len); + buffer.readBytes(b, off, len); + return len; + } + + @Override + public void reset() throws IOException { + buffer.resetReaderIndex(); + } + + @Override + public long skip(long n) throws IOException { + if (n > Integer.MAX_VALUE) { + return skipBytes(Integer.MAX_VALUE); + } else { + return skipBytes((int) n); + } + } + + public int skipBytes(int n) throws IOException { + int nBytes = Math.min(available(), n); + buffer.skipBytes(nBytes); + return nBytes; + } + + + @Override + public byte readByte() throws IOException { + return buffer.readByte(); + } + + @Override + public void readBytes(byte[] b, int offset, int len) throws IOException { + int read = read(b, offset, len); + if (read < len) { + throw new IndexOutOfBoundsException(); + } + } + + @Override + public void close() throws IOException { + // nothing to do here + } + } +} diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java new file mode 100644 index 0000000000000..f1d18ddacbd13 --- /dev/null +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java @@ -0,0 +1,225 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.nio; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpContentCompressor; +import io.netty.handler.codec.http.HttpContentDecompressor; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.HttpResponseEncoder; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpHandlingSettings; +import org.elasticsearch.nio.FlushOperation; +import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.ReadWriteHandler; +import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.SocketChannelContext; +import org.elasticsearch.nio.WriteOperation; +import org.elasticsearch.rest.RestRequest; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.BiConsumer; + +public class HttpReadWriteHandler implements ReadWriteHandler { + + private final NettyAdaptor adaptor; + private final NioSocketChannel nioChannel; + private final NioHttpServerTransport transport; + private final HttpHandlingSettings settings; + private final NamedXContentRegistry xContentRegistry; + private final ThreadContext threadContext; + + HttpReadWriteHandler(NioSocketChannel nioChannel, NioHttpServerTransport transport, HttpHandlingSettings settings, + NamedXContentRegistry xContentRegistry, ThreadContext threadContext) { + this.nioChannel = nioChannel; + this.transport = transport; + this.settings = settings; + this.xContentRegistry = xContentRegistry; + this.threadContext = threadContext; + + List handlers = new ArrayList<>(5); + HttpRequestDecoder decoder = new HttpRequestDecoder(settings.getMaxInitialLineLength(), settings.getMaxHeaderSize(), + settings.getMaxChunkSize()); + decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR); + handlers.add(decoder); + handlers.add(new HttpContentDecompressor()); + handlers.add(new HttpResponseEncoder()); + handlers.add(new HttpObjectAggregator(settings.getMaxContentLength())); + if (settings.isCompression()) { + handlers.add(new HttpContentCompressor(settings.getCompressionLevel())); + } + + adaptor = new NettyAdaptor(handlers.toArray(new ChannelHandler[0])); + adaptor.addCloseListener((v, e) -> nioChannel.close()); + } + + @Override + public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException { + int bytesConsumed = adaptor.read(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())); + Object message; + while ((message = adaptor.pollInboundMessage()) != null) { + handleRequest(message); + } + + return bytesConsumed; + } + + @Override + public WriteOperation createWriteOperation(SocketChannelContext context, Object message, BiConsumer listener) { + assert message instanceof FullHttpResponse : "This channel only supports messages that are of type: " + FullHttpResponse.class + + ". Found type: " + message.getClass() + "."; + return new HttpWriteOperation(context, (FullHttpResponse) message, listener); + } + + @Override + public List writeToBytes(WriteOperation writeOperation) { + adaptor.write(writeOperation); + return pollFlushOperations(); + } + + @Override + public List pollFlushOperations() { + ArrayList copiedOperations = new ArrayList<>(adaptor.getOutboundCount()); + FlushOperation flushOperation; + while ((flushOperation = adaptor.pollOutboundOperation()) != null) { + copiedOperations.add(flushOperation); + } + return copiedOperations; + } + + @Override + public void close() throws IOException { + try { + adaptor.close(); + } catch (Exception e) { + throw new IOException(e); + } + } + + private void handleRequest(Object msg) { + final FullHttpRequest request = (FullHttpRequest) msg; + + final FullHttpRequest copiedRequest = + new DefaultFullHttpRequest( + request.protocolVersion(), + request.method(), + request.uri(), + Unpooled.copiedBuffer(request.content()), + request.headers(), + request.trailingHeaders()); + + Exception badRequestCause = null; + + /* + * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there + * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we + * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, + * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the + * underlying exception that caused us to treat the request as bad. + */ + final NioHttpRequest httpRequest; + { + NioHttpRequest innerHttpRequest; + try { + innerHttpRequest = new NioHttpRequest(xContentRegistry, copiedRequest); + } catch (final RestRequest.ContentTypeHeaderException e) { + badRequestCause = e; + innerHttpRequest = requestWithoutContentTypeHeader(copiedRequest, badRequestCause); + } catch (final RestRequest.BadParameterException e) { + badRequestCause = e; + innerHttpRequest = requestWithoutParameters(copiedRequest); + } + httpRequest = innerHttpRequest; + } + + /* + * We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid + * parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an + * IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these + * parameter values. + */ + final NioHttpChannel channel; + { + NioHttpChannel innerChannel; + try { + innerChannel = new NioHttpChannel(nioChannel, transport.getBigArrays(), httpRequest, settings, threadContext); + } catch (final IllegalArgumentException e) { + if (badRequestCause == null) { + badRequestCause = e; + } else { + badRequestCause.addSuppressed(e); + } + final NioHttpRequest innerRequest = + new NioHttpRequest( + xContentRegistry, + Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters + copiedRequest.uri(), + copiedRequest); + innerChannel = new NioHttpChannel(nioChannel, transport.getBigArrays(), innerRequest, settings, threadContext); + } + channel = innerChannel; + } + + if (request.decoderResult().isFailure()) { + transport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause()); + } else if (badRequestCause != null) { + transport.dispatchBadRequest(httpRequest, channel, badRequestCause); + } else { + transport.dispatchRequest(httpRequest, channel); + } + } + + private NioHttpRequest requestWithoutContentTypeHeader(final FullHttpRequest request, final Exception badRequestCause) { + final HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders(); + headersWithoutContentTypeHeader.add(request.headers()); + headersWithoutContentTypeHeader.remove("Content-Type"); + final FullHttpRequest requestWithoutContentTypeHeader = + new DefaultFullHttpRequest( + request.protocolVersion(), + request.method(), + request.uri(), + request.content(), + headersWithoutContentTypeHeader, // remove the Content-Type header so as to not parse it again + request.trailingHeaders()); // Content-Type can not be a trailing header + try { + return new NioHttpRequest(xContentRegistry, requestWithoutContentTypeHeader); + } catch (final RestRequest.BadParameterException e) { + badRequestCause.addSuppressed(e); + return requestWithoutParameters(requestWithoutContentTypeHeader); + } + } + + private NioHttpRequest requestWithoutParameters(final FullHttpRequest request) { + // remove all parameters as at least one is incorrectly encoded + return new NioHttpRequest(xContentRegistry, Collections.emptyMap(), request.uri(), request); + } +} diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpWriteOperation.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpWriteOperation.java new file mode 100644 index 0000000000000..c838ae85e9d40 --- /dev/null +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpWriteOperation.java @@ -0,0 +1,54 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.nio; + +import io.netty.handler.codec.http.FullHttpResponse; +import org.elasticsearch.nio.SocketChannelContext; +import org.elasticsearch.nio.WriteOperation; + +import java.util.function.BiConsumer; + +public class HttpWriteOperation implements WriteOperation { + + private final SocketChannelContext channelContext; + private final FullHttpResponse response; + private final BiConsumer listener; + + HttpWriteOperation(SocketChannelContext channelContext, FullHttpResponse response, BiConsumer listener) { + this.channelContext = channelContext; + this.response = response; + this.listener = listener; + } + + @Override + public BiConsumer getListener() { + return listener; + } + + @Override + public SocketChannelContext getChannel() { + return channelContext; + } + + @Override + public FullHttpResponse getObject() { + return response; + } +} diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyAdaptor.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyAdaptor.java new file mode 100644 index 0000000000000..3344a31264121 --- /dev/null +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyAdaptor.java @@ -0,0 +1,131 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.nio; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.nio.FlushOperation; +import org.elasticsearch.nio.WriteOperation; + +import java.nio.ByteBuffer; +import java.util.LinkedList; +import java.util.function.BiConsumer; + +public class NettyAdaptor implements AutoCloseable { + + private final EmbeddedChannel nettyChannel; + private final LinkedList flushOperations = new LinkedList<>(); + + NettyAdaptor(ChannelHandler... handlers) { + nettyChannel = new EmbeddedChannel(); + nettyChannel.pipeline().addLast("write_captor", new ChannelOutboundHandlerAdapter() { + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + // This is a little tricky. The embedded channel will complete the promise once it writes the message + // to its outbound buffer. We do not want to complete the promise until the message is sent. So we + // intercept the promise and pass a different promise back to the rest of the pipeline. + + try { + ByteBuf message = (ByteBuf) msg; + promise.addListener((f) -> message.release()); + NettyListener listener; + if (promise instanceof NettyListener) { + listener = (NettyListener) promise; + } else { + listener = new NettyListener(promise); + } + flushOperations.add(new FlushOperation(message.nioBuffers(), listener)); + } catch (Exception e) { + promise.setFailure(e); + } + } + }); + nettyChannel.pipeline().addLast(handlers); + } + + @Override + public void close() throws Exception { + assert flushOperations.isEmpty() : "Should close outbound operations before calling close"; + + ChannelFuture closeFuture = nettyChannel.close(); + // This should be safe as we are not a real network channel + closeFuture.await(); + if (closeFuture.isSuccess() == false) { + Throwable cause = closeFuture.cause(); + ExceptionsHelper.dieOnError(cause); + throw (Exception) cause; + } + } + + public void addCloseListener(BiConsumer listener) { + nettyChannel.closeFuture().addListener(f -> { + if (f.isSuccess()) { + listener.accept(null, null); + } else { + final Throwable cause = f.cause(); + ExceptionsHelper.dieOnError(cause); + assert cause instanceof Exception; + listener.accept(null, (Exception) cause); + } + }); + } + + public int read(ByteBuffer[] buffers) { + ByteBuf byteBuf = Unpooled.wrappedBuffer(buffers); + int initialReaderIndex = byteBuf.readerIndex(); + nettyChannel.writeInbound(byteBuf); + return byteBuf.readerIndex() - initialReaderIndex; + } + + public Object pollInboundMessage() { + return nettyChannel.readInbound(); + } + + public void write(WriteOperation writeOperation) { + ChannelPromise channelPromise = nettyChannel.newPromise(); + channelPromise.addListener(f -> { + BiConsumer consumer = writeOperation.getListener(); + if (f.cause() == null) { + consumer.accept(null, null); + } else { + ExceptionsHelper.dieOnError(f.cause()); + consumer.accept(null, f.cause()); + } + }); + + nettyChannel.writeAndFlush(writeOperation.getObject(), new NettyListener(channelPromise)); + } + + public FlushOperation pollOutboundOperation() { + return flushOperations.pollFirst(); + } + + public int getOutboundCount() { + return flushOperations.size(); + } +} diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyListener.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyListener.java new file mode 100644 index 0000000000000..e806b0d23ce3a --- /dev/null +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyListener.java @@ -0,0 +1,214 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.nio; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelPromise; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.concurrent.FutureUtils; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; + +/** + * This is an {@link BiConsumer} that interfaces with netty code. It wraps a netty promise and will + * complete that promise when accept is called. It delegates the normal promise methods to the underlying + * promise. + */ +public class NettyListener implements BiConsumer, ChannelPromise { + + private final ChannelPromise promise; + + NettyListener(ChannelPromise promise) { + this.promise = promise; + } + + @Override + public void accept(Void v, Throwable throwable) { + if (throwable == null) { + promise.setSuccess(); + } else { + promise.setFailure(throwable); + } + } + + @Override + public Channel channel() { + return promise.channel(); + } + + @Override + public ChannelPromise setSuccess(Void result) { + return promise.setSuccess(result); + } + + @Override + public boolean trySuccess(Void result) { + return promise.trySuccess(result); + } + + @Override + public ChannelPromise setSuccess() { + return promise.setSuccess(); + } + + @Override + public boolean trySuccess() { + return promise.trySuccess(); + } + + @Override + public ChannelPromise setFailure(Throwable cause) { + return promise.setFailure(cause); + } + + @Override + public boolean tryFailure(Throwable cause) { + return promise.tryFailure(cause); + } + + @Override + public boolean setUncancellable() { + return promise.setUncancellable(); + } + + @Override + public boolean isSuccess() { + return promise.isSuccess(); + } + + @Override + public boolean isCancellable() { + return promise.isCancellable(); + } + + @Override + public Throwable cause() { + return promise.cause(); + } + + @Override + public ChannelPromise addListener(GenericFutureListener> listener) { + return promise.addListener(listener); + } + + @Override + @SafeVarargs + @SuppressWarnings("varargs") + public final ChannelPromise addListeners(GenericFutureListener>... listeners) { + return promise.addListeners(listeners); + } + + @Override + public ChannelPromise removeListener(GenericFutureListener> listener) { + return promise.removeListener(listener); + } + + @Override + @SafeVarargs + @SuppressWarnings("varargs") + public final ChannelPromise removeListeners(GenericFutureListener>... listeners) { + return promise.removeListeners(listeners); + } + + @Override + public ChannelPromise sync() throws InterruptedException { + return promise.sync(); + } + + @Override + public ChannelPromise syncUninterruptibly() { + return promise.syncUninterruptibly(); + } + + @Override + public ChannelPromise await() throws InterruptedException { + return promise.await(); + } + + @Override + public ChannelPromise awaitUninterruptibly() { + return promise.awaitUninterruptibly(); + } + + @Override + public boolean await(long timeout, TimeUnit unit) throws InterruptedException { + return promise.await(timeout, unit); + } + + @Override + public boolean await(long timeoutMillis) throws InterruptedException { + return promise.await(timeoutMillis); + } + + @Override + public boolean awaitUninterruptibly(long timeout, TimeUnit unit) { + return promise.awaitUninterruptibly(timeout, unit); + } + + @Override + public boolean awaitUninterruptibly(long timeoutMillis) { + return promise.awaitUninterruptibly(timeoutMillis); + } + + @Override + public Void getNow() { + return promise.getNow(); + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return FutureUtils.cancel(promise); + } + + @Override + public boolean isCancelled() { + return promise.isCancelled(); + } + + @Override + public boolean isDone() { + return promise.isDone(); + } + + @Override + public Void get() throws InterruptedException, ExecutionException { + return promise.get(); + } + + @Override + public Void get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + return promise.get(timeout, unit); + } + + @Override + public boolean isVoid() { + return promise.isVoid(); + } + + @Override + public ChannelPromise unvoid() { + return promise.unvoid(); + } +} diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java new file mode 100644 index 0000000000000..672c6d5abad0e --- /dev/null +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java @@ -0,0 +1,254 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.nio; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.cookie.Cookie; +import io.netty.handler.codec.http.cookie.ServerCookieDecoder; +import io.netty.handler.codec.http.cookie.ServerCookieEncoder; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.http.HttpHandlingSettings; +import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.rest.AbstractRestChannel; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.rest.RestStatus; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class NioHttpChannel extends AbstractRestChannel { + + private final BigArrays bigArrays; + private final ThreadContext threadContext; + private final FullHttpRequest nettyRequest; + private final NioSocketChannel nioChannel; + private final boolean resetCookies; + + NioHttpChannel(NioSocketChannel nioChannel, BigArrays bigArrays, NioHttpRequest request, + HttpHandlingSettings settings, ThreadContext threadContext) { + super(request, settings.getDetailedErrorsEnabled()); + this.nioChannel = nioChannel; + this.bigArrays = bigArrays; + this.threadContext = threadContext; + this.nettyRequest = request.getRequest(); + this.resetCookies = settings.isResetCookies(); + } + + @Override + public void sendResponse(RestResponse response) { + // if the response object was created upstream, then use it; + // otherwise, create a new one + ByteBuf buffer = ByteBufUtils.toByteBuf(response.content()); + final FullHttpResponse resp; + if (HttpMethod.HEAD.equals(nettyRequest.method())) { + resp = newResponse(Unpooled.EMPTY_BUFFER); + } else { + resp = newResponse(buffer); + } + resp.setStatus(getStatus(response.status())); + + String opaque = nettyRequest.headers().get("X-Opaque-Id"); + if (opaque != null) { + setHeaderField(resp, "X-Opaque-Id", opaque); + } + + // Add all custom headers + addCustomHeaders(resp, response.getHeaders()); + addCustomHeaders(resp, threadContext.getResponseHeaders()); + + ArrayList toClose = new ArrayList<>(3); + + boolean success = false; + try { + // If our response doesn't specify a content-type header, set one + setHeaderField(resp, HttpHeaderNames.CONTENT_TYPE.toString(), response.contentType(), false); + // If our response has no content-length, calculate and set one + setHeaderField(resp, HttpHeaderNames.CONTENT_LENGTH.toString(), String.valueOf(buffer.readableBytes()), false); + + addCookies(resp); + + BytesReference content = response.content(); + if (content instanceof Releasable) { + toClose.add((Releasable) content); + } + BytesStreamOutput bytesStreamOutput = bytesOutputOrNull(); + if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) { + toClose.add((Releasable) bytesStreamOutput); + } + + if (isCloseConnection()) { + toClose.add(nioChannel::close); + } + + nioChannel.getContext().sendMessage(resp, (aVoid, throwable) -> { + Releasables.close(toClose); + }); + success = true; + } finally { + if (success == false) { + Releasables.close(toClose); + } + } + } + + @Override + protected BytesStreamOutput newBytesOutput() { + return new ReleasableBytesStreamOutput(bigArrays); + } + + private void setHeaderField(HttpResponse resp, String headerField, String value) { + setHeaderField(resp, headerField, value, true); + } + + private void setHeaderField(HttpResponse resp, String headerField, String value, boolean override) { + if (override || !resp.headers().contains(headerField)) { + resp.headers().add(headerField, value); + } + } + + private void addCookies(HttpResponse resp) { + if (resetCookies) { + String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE); + if (cookieString != null) { + Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); + if (!cookies.isEmpty()) { + // Reset the cookies if necessary. + resp.headers().set(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookies)); + } + } + } + } + + private void addCustomHeaders(HttpResponse response, Map> customHeaders) { + if (customHeaders != null) { + for (Map.Entry> headerEntry : customHeaders.entrySet()) { + for (String headerValue : headerEntry.getValue()) { + setHeaderField(response, headerEntry.getKey(), headerValue); + } + } + } + } + + // Create a new {@link HttpResponse} to transmit the response for the netty request. + private FullHttpResponse newResponse(ByteBuf buffer) { + final boolean http10 = isHttp10(); + final boolean close = isCloseConnection(); + // Build the response object. + final HttpResponseStatus status = HttpResponseStatus.OK; // default to initialize + final FullHttpResponse response; + if (http10) { + response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, status, buffer); + if (!close) { + response.headers().add(HttpHeaderNames.CONNECTION, "Keep-Alive"); + } + } else { + response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buffer); + } + return response; + } + + // Determine if the request protocol version is HTTP 1.0 + private boolean isHttp10() { + return nettyRequest.protocolVersion().equals(HttpVersion.HTTP_1_0); + } + + // Determine if the request connection should be closed on completion. + private boolean isCloseConnection() { + final boolean http10 = isHttp10(); + return HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)) || + (http10 && !HttpHeaderValues.KEEP_ALIVE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION))); + } + + private static Map MAP; + + static { + EnumMap map = new EnumMap<>(RestStatus.class); + map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); + map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); + map.put(RestStatus.OK, HttpResponseStatus.OK); + map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); + map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); + map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); + map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); + map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); + map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); + map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? + map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); + map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); + map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); + map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); + map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); + map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); + map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); + map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); + map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); + map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); + map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); + map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); + map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); + map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); + map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); + map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); + map.put(RestStatus.GONE, HttpResponseStatus.GONE); + map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); + map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); + map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); + map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); + map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); + map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); + map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); + map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); + map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); + map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); + map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); + map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); + map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); + map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); + MAP = Collections.unmodifiableMap(map); + } + + private static HttpResponseStatus getStatus(RestStatus status) { + return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); + } +} diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java new file mode 100644 index 0000000000000..b5bfcc6b0cca2 --- /dev/null +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java @@ -0,0 +1,186 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.nio; + +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.rest.RestRequest; + +import java.util.AbstractMap; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class NioHttpRequest extends RestRequest { + + private final FullHttpRequest request; + private final BytesReference content; + + NioHttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request) { + super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers())); + this.request = request; + if (request.content().isReadable()) { + this.content = ByteBufUtils.toBytesReference(request.content()); + } else { + this.content = BytesArray.EMPTY; + } + + } + + NioHttpRequest(NamedXContentRegistry xContentRegistry, Map params, String uri, FullHttpRequest request) { + super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers())); + this.request = request; + if (request.content().isReadable()) { + this.content = ByteBufUtils.toBytesReference(request.content()); + } else { + this.content = BytesArray.EMPTY; + } + } + + @Override + public Method method() { + HttpMethod httpMethod = request.method(); + if (httpMethod == HttpMethod.GET) + return Method.GET; + + if (httpMethod == HttpMethod.POST) + return Method.POST; + + if (httpMethod == HttpMethod.PUT) + return Method.PUT; + + if (httpMethod == HttpMethod.DELETE) + return Method.DELETE; + + if (httpMethod == HttpMethod.HEAD) { + return Method.HEAD; + } + + if (httpMethod == HttpMethod.OPTIONS) { + return Method.OPTIONS; + } + + return Method.GET; + } + + @Override + public String uri() { + return request.uri(); + } + + @Override + public boolean hasContent() { + return content.length() > 0; + } + + @Override + public BytesReference content() { + return content; + } + + public FullHttpRequest getRequest() { + return request; + } + + /** + * A wrapper of {@link HttpHeaders} that implements a map to prevent copying unnecessarily. This class does not support modifications + * and due to the underlying implementation, it performs case insensitive lookups of key to values. + * + * It is important to note that this implementation does have some downsides in that each invocation of the + * {@link #values()} and {@link #entrySet()} methods will perform a copy of the values in the HttpHeaders rather than returning a + * view of the underlying values. + */ + private static class HttpHeadersMap implements Map> { + + private final HttpHeaders httpHeaders; + + private HttpHeadersMap(HttpHeaders httpHeaders) { + this.httpHeaders = httpHeaders; + } + + @Override + public int size() { + return httpHeaders.size(); + } + + @Override + public boolean isEmpty() { + return httpHeaders.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return key instanceof String && httpHeaders.contains((String) key); + } + + @Override + public boolean containsValue(Object value) { + return value instanceof List && httpHeaders.names().stream().map(httpHeaders::getAll).anyMatch(value::equals); + } + + @Override + public List get(Object key) { + return key instanceof String ? httpHeaders.getAll((String) key) : null; + } + + @Override + public List put(String key, List value) { + throw new UnsupportedOperationException("modifications are not supported"); + } + + @Override + public List remove(Object key) { + throw new UnsupportedOperationException("modifications are not supported"); + } + + @Override + public void putAll(Map> m) { + throw new UnsupportedOperationException("modifications are not supported"); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("modifications are not supported"); + } + + @Override + public Set keySet() { + return httpHeaders.names(); + } + + @Override + public Collection> values() { + return httpHeaders.names().stream().map(k -> Collections.unmodifiableList(httpHeaders.getAll(k))).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return httpHeaders.names().stream().map(k -> new AbstractMap.SimpleImmutableEntry<>(k, httpHeaders.getAll(k))) + .collect(Collectors.toSet()); + } + } +} diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java new file mode 100644 index 0000000000000..bdbee715bd0cf --- /dev/null +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java @@ -0,0 +1,322 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.nio; + +import io.netty.handler.timeout.ReadTimeoutException; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.logging.log4j.util.Supplier; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.network.NetworkAddress; +import org.elasticsearch.common.network.NetworkService; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.NetworkExceptionHelper; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.BindHttpException; +import org.elasticsearch.http.HttpHandlingSettings; +import org.elasticsearch.http.HttpServerTransport; +import org.elasticsearch.http.HttpStats; +import org.elasticsearch.http.netty4.AbstractHttpServerTransport; +import org.elasticsearch.nio.AcceptingSelector; +import org.elasticsearch.nio.AcceptorEventHandler; +import org.elasticsearch.nio.BytesChannelContext; +import org.elasticsearch.nio.ChannelFactory; +import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.NioChannel; +import org.elasticsearch.nio.NioGroup; +import org.elasticsearch.nio.NioServerSocketChannel; +import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.ServerChannelContext; +import org.elasticsearch.nio.SocketChannelContext; +import org.elasticsearch.nio.SocketEventHandler; +import org.elasticsearch.nio.SocketSelector; +import org.elasticsearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.elasticsearch.common.settings.Setting.intSetting; +import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION_LEVEL; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_DETAILED_ERRORS_ENABLED; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CHUNK_SIZE; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_RESET_COOKIES; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_TCP_KEEP_ALIVE; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_TCP_NO_DELAY; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_TCP_RECEIVE_BUFFER_SIZE; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_TCP_REUSE_ADDRESS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_TCP_SEND_BUFFER_SIZE; + +public class NioHttpServerTransport extends AbstractHttpServerTransport { + + public static final Setting NIO_HTTP_ACCEPTOR_COUNT = + intSetting("http.nio.acceptor_count", 1, 1, Setting.Property.NodeScope); + public static final Setting NIO_HTTP_WORKER_COUNT = + new Setting<>("http.nio.worker_count", + (s) -> Integer.toString(EsExecutors.numberOfProcessors(s) * 2), + (s) -> Setting.parseInt(s, 1, "http.nio.worker_count"), Setting.Property.NodeScope); + + private static final String TRANSPORT_WORKER_THREAD_NAME_PREFIX = "http_nio_transport_worker"; + private static final String TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX = "http_nio_transport_acceptor"; + + private final BigArrays bigArrays; + private final ThreadPool threadPool; + private final NamedXContentRegistry xContentRegistry; + + private final HttpHandlingSettings httpHandlingSettings; + + private final boolean tcpNoDelay; + private final boolean tcpKeepAlive; + private final boolean reuseAddress; + private final int tcpSendBufferSize; + private final int tcpReceiveBufferSize; + + private final Set serverChannels = Collections.newSetFromMap(new ConcurrentHashMap<>()); + private final Set socketChannels = Collections.newSetFromMap(new ConcurrentHashMap<>()); + private NioGroup nioGroup; + private HttpChannelFactory channelFactory; + + public NioHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, + NamedXContentRegistry xContentRegistry, HttpServerTransport.Dispatcher dispatcher) { + super(settings, networkService, threadPool, dispatcher); + this.bigArrays = bigArrays; + this.threadPool = threadPool; + this.xContentRegistry = xContentRegistry; + + ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings); + ByteSizeValue maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); + ByteSizeValue maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings); + this.httpHandlingSettings = new HttpHandlingSettings(Math.toIntExact(maxContentLength.getBytes()), + Math.toIntExact(maxChunkSize.getBytes()), + Math.toIntExact(maxHeaderSize.getBytes()), + Math.toIntExact(maxInitialLineLength.getBytes()), + SETTING_HTTP_RESET_COOKIES.get(settings), + SETTING_HTTP_COMPRESSION.get(settings), + SETTING_HTTP_COMPRESSION_LEVEL.get(settings), + SETTING_HTTP_DETAILED_ERRORS_ENABLED.get(settings)); + + this.tcpNoDelay = SETTING_HTTP_TCP_NO_DELAY.get(settings); + this.tcpKeepAlive = SETTING_HTTP_TCP_KEEP_ALIVE.get(settings); + this.reuseAddress = SETTING_HTTP_TCP_REUSE_ADDRESS.get(settings); + this.tcpSendBufferSize = Math.toIntExact(SETTING_HTTP_TCP_SEND_BUFFER_SIZE.get(settings).getBytes()); + this.tcpReceiveBufferSize = Math.toIntExact(SETTING_HTTP_TCP_RECEIVE_BUFFER_SIZE.get(settings).getBytes()); + + + logger.debug("using max_chunk_size[{}], max_header_size[{}], max_initial_line_length[{}], max_content_length[{}]", + maxChunkSize, maxHeaderSize, maxInitialLineLength, maxContentLength); + } + + BigArrays getBigArrays() { + return bigArrays; + } + + @Override + protected void doStart() { + boolean success = false; + try { + int acceptorCount = NIO_HTTP_ACCEPTOR_COUNT.get(settings); + int workerCount = NIO_HTTP_WORKER_COUNT.get(settings); + nioGroup = new NioGroup(logger, daemonThreadFactory(this.settings, TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX), acceptorCount, + AcceptorEventHandler::new, daemonThreadFactory(this.settings, TRANSPORT_WORKER_THREAD_NAME_PREFIX), + workerCount, SocketEventHandler::new); + channelFactory = new HttpChannelFactory(); + this.boundAddress = createBoundHttpAddress(); + + if (logger.isInfoEnabled()) { + logger.info("{}", boundAddress); + } + + success = true; + } catch (IOException e) { + throw new ElasticsearchException(e); + } finally { + if (success == false) { + doStop(); // otherwise we leak threads since we never moved to started + } + } + } + + @Override + protected void doStop() { + synchronized (serverChannels) { + if (serverChannels.isEmpty() == false) { + try { + closeChannels(new ArrayList<>(serverChannels)); + } catch (Exception e) { + logger.error("unexpected exception while closing http server channels", e); + } + serverChannels.clear(); + } + } + + try { + closeChannels(new ArrayList<>(socketChannels)); + } catch (Exception e) { + logger.warn("unexpected exception while closing http channels", e); + } + socketChannels.clear(); + + try { + nioGroup.close(); + } catch (Exception e) { + logger.warn("unexpected exception while stopping nio group", e); + } + } + + @Override + protected void doClose() throws IOException { + } + + @Override + protected TransportAddress bindAddress(InetAddress hostAddress) { + final AtomicReference lastException = new AtomicReference<>(); + final AtomicReference boundSocket = new AtomicReference<>(); + boolean success = port.iterate(portNumber -> { + try { + synchronized (serverChannels) { + InetSocketAddress address = new InetSocketAddress(hostAddress, portNumber); + NioServerSocketChannel channel = nioGroup.bindServerChannel(address, channelFactory); + serverChannels.add(channel); + boundSocket.set(channel.getLocalAddress()); + } + } catch (Exception e) { + lastException.set(e); + return false; + } + return true; + }); + if (success == false) { + throw new BindHttpException("Failed to bind to [" + port.getPortRangeString() + "]", lastException.get()); + } + + if (logger.isDebugEnabled()) { + logger.debug("Bound http to address {{}}", NetworkAddress.format(boundSocket.get())); + } + return new TransportAddress(boundSocket.get()); + } + + @Override + public HttpStats stats() { + return new HttpStats(serverChannels.size(), socketChannels.size()); + } + + protected void exceptionCaught(NioSocketChannel channel, Exception cause) { + if (cause instanceof ReadTimeoutException) { + if (logger.isTraceEnabled()) { + logger.trace("Read timeout [{}]", channel.getRemoteAddress()); + } + channel.close(); + } else { + if (lifecycle.started() == false) { + // ignore + return; + } + if (NetworkExceptionHelper.isCloseConnectionException(cause) == false) { + logger.warn( + (Supplier) () -> new ParameterizedMessage( + "caught exception while handling client http traffic, closing connection {}", channel), + cause); + channel.close(); + } else { + logger.debug( + (Supplier) () -> new ParameterizedMessage( + "caught exception while handling client http traffic, closing connection {}", channel), + cause); + channel.close(); + } + } + } + + private void closeChannels(List channels) { + List> futures = new ArrayList<>(channels.size()); + + for (NioChannel channel : channels) { + PlainActionFuture future = PlainActionFuture.newFuture(); + channel.addCloseListener(ActionListener.toBiConsumer(future)); + futures.add(future); + channel.close(); + } + + List closeExceptions = new ArrayList<>(); + for (ActionFuture f : futures) { + try { + f.actionGet(); + } catch (RuntimeException e) { + closeExceptions.add(e); + } + } + + ExceptionsHelper.rethrowAndSuppress(closeExceptions); + } + + private void acceptChannel(NioSocketChannel socketChannel) { + socketChannels.add(socketChannel); + } + + private class HttpChannelFactory extends ChannelFactory { + + private HttpChannelFactory() { + super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize)); + } + + @Override + public NioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException { + NioSocketChannel nioChannel = new NioSocketChannel(channel); + HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(nioChannel,NioHttpServerTransport.this, + httpHandlingSettings, xContentRegistry, threadPool.getThreadContext()); + Consumer exceptionHandler = (e) -> exceptionCaught(nioChannel, e); + SocketChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, httpReadWritePipeline, + InboundChannelBuffer.allocatingInstance()); + nioChannel.setContext(context); + return nioChannel; + } + + @Override + public NioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { + NioServerSocketChannel nioChannel = new NioServerSocketChannel(channel); + ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, NioHttpServerTransport.this::acceptChannel, + (e) -> {}); + nioChannel.setContext(context); + return nioChannel; + } + + } +} diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index eb3d7f3d710dc..9d794f951c8d2 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -21,7 +21,6 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.recycler.Recycler; @@ -39,7 +38,6 @@ import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; -import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketEventHandler; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.threadpool.ThreadPool; @@ -184,10 +182,9 @@ public TcpNioSocketChannel createChannel(SocketSelector selector, SocketChannel Recycler.V bytes = pageCacheRecycler.bytePage(false); return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; - SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> - consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); + TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this); Consumer exceptionHandler = (e) -> exceptionCaught(nioChannel, e); - BytesChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, nioReadConsumer, + BytesChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, new InboundChannelBuffer(pageSupplier)); nioChannel.setContext(context); return nioChannel; diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransportPlugin.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransportPlugin.java index 029507a5ba49d..422e3e9b83330 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransportPlugin.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransportPlugin.java @@ -19,14 +19,15 @@ package org.elasticsearch.transport.nio; -import org.elasticsearch.bootstrap.BootstrapCheck; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.network.NetworkModule; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpServerTransport; +import org.elasticsearch.http.nio.NioHttpServerTransport; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.plugins.NetworkPlugin; import org.elasticsearch.plugins.Plugin; @@ -42,10 +43,13 @@ public class NioTransportPlugin extends Plugin implements NetworkPlugin { public static final String NIO_TRANSPORT_NAME = "nio-transport"; + public static final String NIO_HTTP_TRANSPORT_NAME = "nio-http-transport"; @Override public List> getSettings() { return Arrays.asList( + NioHttpServerTransport.NIO_HTTP_ACCEPTOR_COUNT, + NioHttpServerTransport.NIO_HTTP_WORKER_COUNT, NioTransport.NIO_WORKER_COUNT, NioTransport.NIO_ACCEPTOR_COUNT ); @@ -61,4 +65,15 @@ public Map> getTransports(Settings settings, ThreadP () -> new NioTransport(settings, threadPool, networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService)); } + + @Override + public Map> getHttpTransports(Settings settings, ThreadPool threadPool, BigArrays bigArrays, + CircuitBreakerService circuitBreakerService, + NamedWriteableRegistry namedWriteableRegistry, + NamedXContentRegistry xContentRegistry, + NetworkService networkService, + HttpServerTransport.Dispatcher dispatcher) { + return Collections.singletonMap(NIO_HTTP_TRANSPORT_NAME, + () -> new NioHttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher)); + } } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpReadWriteHandler.java new file mode 100644 index 0000000000000..f2d07b180855c --- /dev/null +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpReadWriteHandler.java @@ -0,0 +1,44 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.nio.BytesWriteHandler; +import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.transport.TcpTransport; + +import java.io.IOException; + +public class TcpReadWriteHandler extends BytesWriteHandler { + + private final TcpNioSocketChannel channel; + private final TcpTransport transport; + + public TcpReadWriteHandler(TcpNioSocketChannel channel, TcpTransport transport) { + this.channel = channel; + this.transport = transport; + } + + @Override + public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException { + BytesReference bytesReference = BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())); + return transport.consumeNetworkReads(channel, bytesReference); + } +} diff --git a/plugins/transport-nio/src/main/plugin-metadata/plugin-security.policy b/plugins/transport-nio/src/main/plugin-metadata/plugin-security.policy index 2dbe07bd8a5c6..8c8fe7c327412 100644 --- a/plugins/transport-nio/src/main/plugin-metadata/plugin-security.policy +++ b/plugins/transport-nio/src/main/plugin-metadata/plugin-security.policy @@ -21,3 +21,9 @@ grant codeBase "${codebase.elasticsearch-nio}" { // elasticsearch-nio makes and accepts socket connections permission java.net.SocketPermission "*", "accept,connect"; }; + +grant codeBase "${codebase.netty-common}" { + // This should only currently be required as we use the netty http client for tests + // netty makes and accepts socket connections + permission java.net.SocketPermission "*", "accept,connect"; +}; diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java new file mode 100644 index 0000000000000..dce8319d2fc82 --- /dev/null +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java @@ -0,0 +1,241 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.nio; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpRequestEncoder; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpHandlingSettings; +import org.elasticsearch.nio.FlushOperation; +import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.SocketChannelContext; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.function.BiConsumer; + +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION_LEVEL; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_DETAILED_ERRORS_ENABLED; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CHUNK_SIZE; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_RESET_COOKIES; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; + +public class HttpReadWriteHandlerTests extends ESTestCase { + + private HttpReadWriteHandler handler; + private NioSocketChannel nioSocketChannel; + private NioHttpServerTransport transport; + + private final RequestEncoder requestEncoder = new RequestEncoder(); + private final ResponseDecoder responseDecoder = new ResponseDecoder(); + + @Before + @SuppressWarnings("unchecked") + public void setMocks() { + transport = mock(NioHttpServerTransport.class); + Settings settings = Settings.EMPTY; + ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.getDefault(settings); + ByteSizeValue maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.getDefault(settings); + ByteSizeValue maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.getDefault(settings); + HttpHandlingSettings httpHandlingSettings = new HttpHandlingSettings(1024, + Math.toIntExact(maxChunkSize.getBytes()), + Math.toIntExact(maxHeaderSize.getBytes()), + Math.toIntExact(maxInitialLineLength.getBytes()), + SETTING_HTTP_RESET_COOKIES.getDefault(settings), + SETTING_HTTP_COMPRESSION.getDefault(settings), + SETTING_HTTP_COMPRESSION_LEVEL.getDefault(settings), + SETTING_HTTP_DETAILED_ERRORS_ENABLED.getDefault(settings)); + ThreadContext threadContext = new ThreadContext(settings); + nioSocketChannel = mock(NioSocketChannel.class); + handler = new HttpReadWriteHandler(nioSocketChannel, transport, httpHandlingSettings, NamedXContentRegistry.EMPTY, threadContext); + } + + public void testSuccessfulDecodeHttpRequest() throws IOException { + String uri = "localhost:9090/" + randomAlphaOfLength(8); + HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + + ByteBuf buf = requestEncoder.encode(httpRequest); + int slicePoint = randomInt(buf.writerIndex() - 1); + + ByteBuf slicedBuf = buf.retainedSlice(0, slicePoint); + ByteBuf slicedBuf2 = buf.retainedSlice(slicePoint, buf.writerIndex()); + handler.consumeReads(toChannelBuffer(slicedBuf)); + + verify(transport, times(0)).dispatchRequest(any(RestRequest.class), any(RestChannel.class)); + + handler.consumeReads(toChannelBuffer(slicedBuf2)); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(RestRequest.class); + verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class)); + + NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue(); + FullHttpRequest nettyHttpRequest = nioHttpRequest.getRequest(); + assertEquals(httpRequest.protocolVersion(), nettyHttpRequest.protocolVersion()); + assertEquals(httpRequest.method(), nettyHttpRequest.method()); + } + + public void testDecodeHttpRequestError() throws IOException { + String uri = "localhost:9090/" + randomAlphaOfLength(8); + HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + + ByteBuf buf = requestEncoder.encode(httpRequest); + buf.setByte(0, ' '); + buf.setByte(1, ' '); + buf.setByte(2, ' '); + + handler.consumeReads(toChannelBuffer(buf)); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Throwable.class); + verify(transport).dispatchBadRequest(any(RestRequest.class), any(RestChannel.class), exceptionCaptor.capture()); + + assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException); + } + + public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() throws IOException { + String uri = "localhost:9090/" + randomAlphaOfLength(8); + HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, false); + HttpUtil.setContentLength(httpRequest, 1025); + HttpUtil.setKeepAlive(httpRequest, false); + + ByteBuf buf = requestEncoder.encode(httpRequest); + + handler.consumeReads(toChannelBuffer(buf)); + + verifyZeroInteractions(transport); + + List flushOperations = handler.pollFlushOperations(); + assertFalse(flushOperations.isEmpty()); + + FlushOperation flushOperation = flushOperations.get(0); + HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); + assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + + flushOperation.getListener().accept(null, null); + // Since we have keep-alive set to false, we should close the channel after the response has been + // flushed + verify(nioSocketChannel).close(); + } + + @SuppressWarnings("unchecked") + public void testEncodeHttpResponse() throws IOException { + prepareHandlerForResponse(handler); + + FullHttpResponse fullHttpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + + SocketChannelContext context = mock(SocketChannelContext.class); + HttpWriteOperation writeOperation = new HttpWriteOperation(context, fullHttpResponse, mock(BiConsumer.class)); + List flushOperations = handler.writeToBytes(writeOperation); + + HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite())); + + assertEquals(HttpResponseStatus.OK, response.status()); + assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); + } + + private FullHttpRequest prepareHandlerForResponse(HttpReadWriteHandler adaptor) throws IOException { + HttpMethod method = HttpMethod.GET; + HttpVersion version = HttpVersion.HTTP_1_1; + String uri = "http://localhost:9090/" + randomAlphaOfLength(8); + + HttpRequest request = new DefaultFullHttpRequest(version, method, uri); + ByteBuf buf = requestEncoder.encode(request); + + handler.consumeReads(toChannelBuffer(buf)); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(RestRequest.class); + verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class)); + + NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue(); + FullHttpRequest requestParsed = nioHttpRequest.getRequest(); + assertNotNull(requestParsed); + assertEquals(requestParsed.method(), method); + assertEquals(requestParsed.protocolVersion(), version); + assertEquals(requestParsed.uri(), uri); + return requestParsed; + } + + private InboundChannelBuffer toChannelBuffer(ByteBuf buf) { + InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance(); + int readableBytes = buf.readableBytes(); + buffer.ensureCapacity(readableBytes); + int bytesWritten = 0; + ByteBuffer[] byteBuffers = buffer.sliceBuffersTo(readableBytes); + int i = 0; + while (bytesWritten != readableBytes) { + ByteBuffer byteBuffer = byteBuffers[i++]; + int initialRemaining = byteBuffer.remaining(); + buf.readBytes(byteBuffer); + bytesWritten += initialRemaining - byteBuffer.remaining(); + } + buffer.incrementIndex(bytesWritten); + return buffer; + } + + private static class RequestEncoder { + + private final EmbeddedChannel requestEncoder = new EmbeddedChannel(new HttpRequestEncoder()); + + private ByteBuf encode(HttpRequest httpRequest) { + requestEncoder.writeOutbound(httpRequest); + return requestEncoder.readOutbound(); + } + } + + private static class ResponseDecoder { + + private final EmbeddedChannel responseDecoder = new EmbeddedChannel(new HttpResponseDecoder()); + + private HttpResponse decode(ByteBuf response) { + responseDecoder.writeInbound(response); + return responseDecoder.readInbound(); + } + } +} diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/Netty4HttpClient.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/Netty4HttpClient.java new file mode 100644 index 0000000000000..32f294f47ce9c --- /dev/null +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/Netty4HttpClient.java @@ -0,0 +1,200 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.nio; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpRequestEncoder; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; + +import java.io.Closeable; +import java.net.SocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static io.netty.handler.codec.http.HttpHeaderNames.HOST; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; + +/** + * Tiny helper to send http requests over netty. + */ +class Netty4HttpClient implements Closeable { + + static Collection returnHttpResponseBodies(Collection responses) { + List list = new ArrayList<>(responses.size()); + for (FullHttpResponse response : responses) { + list.add(response.content().toString(StandardCharsets.UTF_8)); + } + return list; + } + + static Collection returnOpaqueIds(Collection responses) { + List list = new ArrayList<>(responses.size()); + for (HttpResponse response : responses) { + list.add(response.headers().get("X-Opaque-Id")); + } + return list; + } + + private final Bootstrap clientBootstrap; + + Netty4HttpClient() { + clientBootstrap = new Bootstrap().channel(NioSocketChannel.class).group(new NioEventLoopGroup()); + } + + public Collection get(SocketAddress remoteAddress, String... uris) throws InterruptedException { + Collection requests = new ArrayList<>(uris.length); + for (int i = 0; i < uris.length; i++) { + final HttpRequest httpRequest = new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, uris[i]); + httpRequest.headers().add(HOST, "localhost"); + httpRequest.headers().add("X-Opaque-ID", String.valueOf(i)); + requests.add(httpRequest); + } + return sendRequests(remoteAddress, requests); + } + + @SafeVarargs // Safe not because it doesn't do anything with the type parameters but because it won't leak them into other methods. + public final Collection post(SocketAddress remoteAddress, Tuple... urisAndBodies) + throws InterruptedException { + return processRequestsWithBody(HttpMethod.POST, remoteAddress, urisAndBodies); + } + + public final FullHttpResponse post(SocketAddress remoteAddress, FullHttpRequest httpRequest) throws InterruptedException { + Collection responses = sendRequests(remoteAddress, Collections.singleton(httpRequest)); + assert responses.size() == 1 : "expected 1 and only 1 http response"; + return responses.iterator().next(); + } + + @SafeVarargs // Safe not because it doesn't do anything with the type parameters but because it won't leak them into other methods. + public final Collection put(SocketAddress remoteAddress, Tuple... urisAndBodies) + throws InterruptedException { + return processRequestsWithBody(HttpMethod.PUT, remoteAddress, urisAndBodies); + } + + private Collection processRequestsWithBody(HttpMethod method, SocketAddress remoteAddress, Tuple... urisAndBodies) throws InterruptedException { + Collection requests = new ArrayList<>(urisAndBodies.length); + for (Tuple uriAndBody : urisAndBodies) { + ByteBuf content = Unpooled.copiedBuffer(uriAndBody.v2(), StandardCharsets.UTF_8); + HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, method, uriAndBody.v1(), content); + request.headers().add(HttpHeaderNames.HOST, "localhost"); + request.headers().add(HttpHeaderNames.CONTENT_LENGTH, content.readableBytes()); + request.headers().add(HttpHeaderNames.CONTENT_TYPE, "application/json"); + requests.add(request); + } + return sendRequests(remoteAddress, requests); + } + + private synchronized Collection sendRequests( + final SocketAddress remoteAddress, + final Collection requests) throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(requests.size()); + final Collection content = Collections.synchronizedList(new ArrayList<>(requests.size())); + + clientBootstrap.handler(new CountDownLatchHandler(latch, content)); + + ChannelFuture channelFuture = null; + try { + channelFuture = clientBootstrap.connect(remoteAddress); + channelFuture.sync(); + + for (HttpRequest request : requests) { + channelFuture.channel().writeAndFlush(request); + } + latch.await(30, TimeUnit.SECONDS); + + } finally { + if (channelFuture != null) { + channelFuture.channel().close().sync(); + } + } + + return content; + } + + @Override + public void close() { + clientBootstrap.config().group().shutdownGracefully().awaitUninterruptibly(); + } + + /** + * helper factory which adds returned data to a list and uses a count down latch to decide when done + */ + private static class CountDownLatchHandler extends ChannelInitializer { + + private final CountDownLatch latch; + private final Collection content; + + CountDownLatchHandler(final CountDownLatch latch, final Collection content) { + this.latch = latch; + this.content = content; + } + + @Override + protected void initChannel(SocketChannel ch) throws Exception { + final int maxContentLength = new ByteSizeValue(100, ByteSizeUnit.MB).bytesAsInt(); + ch.pipeline().addLast(new HttpResponseDecoder()); + ch.pipeline().addLast(new HttpRequestEncoder()); + ch.pipeline().addLast(new HttpObjectAggregator(maxContentLength)); + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, HttpObject msg) throws Exception { + final FullHttpResponse response = (FullHttpResponse) msg; + content.add(response.copy()); + latch.countDown(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + super.exceptionCaught(ctx, cause); + latch.countDown(); + } + }); + } + + } + +} diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NettyAdaptorTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NettyAdaptorTests.java new file mode 100644 index 0000000000000..d6944a5f510e2 --- /dev/null +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NettyAdaptorTests.java @@ -0,0 +1,177 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.nio; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.SimpleChannelInboundHandler; +import org.elasticsearch.nio.FlushOperation; +import org.elasticsearch.test.ESTestCase; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Locale; +import java.util.concurrent.atomic.AtomicBoolean; + +public class NettyAdaptorTests extends ESTestCase { + + public void testBasicRead() { + TenIntsToStringsHandler handler = new TenIntsToStringsHandler(); + NettyAdaptor nettyAdaptor = new NettyAdaptor(handler); + ByteBuffer message = ByteBuffer.allocate(40); + for (int i = 0; i < 10; ++i) { + message.putInt(i); + } + message.flip(); + ByteBuffer[] buffers = {message}; + assertEquals(40, nettyAdaptor.read(buffers)); + assertEquals("0123456789", handler.result); + } + + public void testBasicReadWithExcessData() { + TenIntsToStringsHandler handler = new TenIntsToStringsHandler(); + NettyAdaptor nettyAdaptor = new NettyAdaptor(handler); + ByteBuffer message = ByteBuffer.allocate(52); + for (int i = 0; i < 13; ++i) { + message.putInt(i); + } + message.flip(); + ByteBuffer[] buffers = {message}; + assertEquals(40, nettyAdaptor.read(buffers)); + assertEquals("0123456789", handler.result); + } + + public void testUncaughtReadExceptionsBubbleUp() { + NettyAdaptor nettyAdaptor = new NettyAdaptor(new TenIntsToStringsHandler()); + ByteBuffer message = ByteBuffer.allocate(40); + for (int i = 0; i < 9; ++i) { + message.putInt(i); + } + message.flip(); + ByteBuffer[] buffers = {message}; + expectThrows(IllegalStateException.class, () -> nettyAdaptor.read(buffers)); + } + + public void testWriteInsidePipelineIsCaptured() { + TenIntsToStringsHandler tenIntsToStringsHandler = new TenIntsToStringsHandler(); + PromiseCheckerHandler promiseCheckerHandler = new PromiseCheckerHandler(); + NettyAdaptor nettyAdaptor = new NettyAdaptor(new CapitalizeWriteHandler(), + promiseCheckerHandler, + new WriteInMiddleHandler(), + tenIntsToStringsHandler); + byte[] bytes = "SHOULD_WRITE".getBytes(StandardCharsets.UTF_8); + ByteBuffer message = ByteBuffer.wrap(bytes); + ByteBuffer[] buffers = {message}; + assertNull(nettyAdaptor.pollOutboundOperation()); + nettyAdaptor.read(buffers); + assertFalse(tenIntsToStringsHandler.wasCalled); + FlushOperation flushOperation = nettyAdaptor.pollOutboundOperation(); + assertNotNull(flushOperation); + assertEquals("FAILED", Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()).toString(StandardCharsets.UTF_8)); + assertFalse(promiseCheckerHandler.isCalled.get()); + flushOperation.getListener().accept(null, null); + assertTrue(promiseCheckerHandler.isCalled.get()); + } + + public void testCloseListener() { + AtomicBoolean listenerCalled = new AtomicBoolean(false); + CloseChannelHandler handler = new CloseChannelHandler(); + NettyAdaptor nettyAdaptor = new NettyAdaptor(handler); + byte[] bytes = "SHOULD_CLOSE".getBytes(StandardCharsets.UTF_8); + ByteBuffer[] buffers = {ByteBuffer.wrap(bytes)}; + nettyAdaptor.addCloseListener((v, e) -> listenerCalled.set(true)); + assertFalse(listenerCalled.get()); + nettyAdaptor.read(buffers); + assertTrue(listenerCalled.get()); + + } + + private class TenIntsToStringsHandler extends SimpleChannelInboundHandler { + + private String result; + boolean wasCalled = false; + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + wasCalled = true; + if (msg.readableBytes() < 10 * 4) { + throw new IllegalStateException("Must have ten ints"); + } + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < 10; ++i) { + builder.append(msg.readInt()); + } + result = builder.toString(); + } + } + + private class WriteInMiddleHandler extends ChannelInboundHandlerAdapter { + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ByteBuf buffer = (ByteBuf) msg; + String bufferString = buffer.toString(StandardCharsets.UTF_8); + if (bufferString.equals("SHOULD_WRITE")) { + ctx.writeAndFlush("Failed"); + } else { + throw new IllegalArgumentException("Only accept SHOULD_WRITE message"); + } + } + } + + private class CapitalizeWriteHandler extends ChannelOutboundHandlerAdapter { + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + String string = (String) msg; + assert string.equals("Failed") : "Should be the same was what we wrote."; + super.write(ctx, Unpooled.wrappedBuffer(string.toUpperCase(Locale.ROOT).getBytes(StandardCharsets.UTF_8)), promise); + } + } + + private class PromiseCheckerHandler extends ChannelOutboundHandlerAdapter { + + private AtomicBoolean isCalled = new AtomicBoolean(false); + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + promise.addListener((f) -> isCalled.set(true)); + super.write(ctx, msg, promise); + } + } + + private class CloseChannelHandler extends ChannelInboundHandlerAdapter { + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ByteBuf buffer = (ByteBuf) msg; + String bufferString = buffer.toString(StandardCharsets.UTF_8); + if (bufferString.equals("SHOULD_CLOSE")) { + ctx.close(); + } else { + throw new IllegalArgumentException("Only accept SHOULD_CLOSE message"); + } + } + } +} diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java new file mode 100644 index 0000000000000..4741bd69a527a --- /dev/null +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java @@ -0,0 +1,353 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.nio; + +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.network.NetworkService; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.http.BindHttpException; +import org.elasticsearch.http.HttpServerTransport; +import org.elasticsearch.http.HttpTransportSettings; +import org.elasticsearch.http.NullDispatcher; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.rest.BytesRestResponse; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.rest.RestStatus.BAD_REQUEST; +import static org.elasticsearch.rest.RestStatus.OK; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +/** + * Tests for the {@link NioHttpServerTransport} class. + */ +public class NioHttpServerTransportTests extends ESTestCase { + + private NetworkService networkService; + private ThreadPool threadPool; + private MockBigArrays bigArrays; + + @Before + public void setup() throws Exception { + networkService = new NetworkService(Collections.emptyList()); + threadPool = new TestThreadPool("test"); + bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + } + + @After + public void shutdown() throws Exception { + if (threadPool != null) { + threadPool.shutdownNow(); + } + threadPool = null; + networkService = null; + bigArrays = null; + } + +// public void testCorsConfig() { +// final Set methods = new HashSet<>(Arrays.asList("get", "options", "post")); +// final Set headers = new HashSet<>(Arrays.asList("Content-Type", "Content-Length")); +// final String prefix = randomBoolean() ? " " : ""; // sometimes have a leading whitespace between comma delimited elements +// final Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*") +// .put(SETTING_CORS_ALLOW_METHODS.getKey(), collectionToDelimitedString(methods, ",", prefix, "")) +// .put(SETTING_CORS_ALLOW_HEADERS.getKey(), collectionToDelimitedString(headers, ",", prefix, "")) +// .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) +// .build(); +// final Netty4CorsConfig corsConfig = Netty4HttpServerTransport.buildCorsConfig(settings); +// assertTrue(corsConfig.isAnyOriginSupported()); +// assertEquals(headers, corsConfig.allowedRequestHeaders()); +// assertEquals(methods, corsConfig.allowedRequestMethods().stream().map(HttpMethod::name).collect(Collectors.toSet())); +// } + +// public void testCorsConfigWithDefaults() { +// final Set methods = Strings.commaDelimitedListToSet(SETTING_CORS_ALLOW_METHODS.getDefault(Settings.EMPTY)); +// final Set headers = Strings.commaDelimitedListToSet(SETTING_CORS_ALLOW_HEADERS.getDefault(Settings.EMPTY)); +// final long maxAge = SETTING_CORS_MAX_AGE.getDefault(Settings.EMPTY); +// final Settings settings = Settings.builder().put(SETTING_CORS_ENABLED.getKey(), true).build(); +// final Netty4CorsConfig corsConfig = Netty4HttpServerTransport.buildCorsConfig(settings); +// assertFalse(corsConfig.isAnyOriginSupported()); +// assertEquals(Collections.emptySet(), corsConfig.origins().get()); +// assertEquals(headers, corsConfig.allowedRequestHeaders()); +// assertEquals(methods, corsConfig.allowedRequestMethods().stream().map(HttpMethod::name).collect(Collectors.toSet())); +// assertEquals(maxAge, corsConfig.maxAge()); +// assertFalse(corsConfig.isCredentialsAllowed()); +// } + + /** + * Test that {@link NioHttpServerTransport} supports the "Expect: 100-continue" HTTP header + * @throws InterruptedException if the client communication with the server is interrupted + */ + public void testExpectContinueHeader() throws InterruptedException { + final Settings settings = Settings.EMPTY; + final int contentLength = randomIntBetween(1, HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH.get(settings).bytesAsInt()); + runExpectHeaderTest(settings, HttpHeaderValues.CONTINUE.toString(), contentLength, HttpResponseStatus.CONTINUE); + } + + /** + * Test that {@link NioHttpServerTransport} responds to a + * 100-continue expectation with too large a content-length + * with a 413 status. + * @throws InterruptedException if the client communication with the server is interrupted + */ + public void testExpectContinueHeaderContentLengthTooLong() throws InterruptedException { + final String key = HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(); + final int maxContentLength = randomIntBetween(1, 104857600); + final Settings settings = Settings.builder().put(key, maxContentLength + "b").build(); + final int contentLength = randomIntBetween(maxContentLength + 1, Integer.MAX_VALUE); + runExpectHeaderTest( + settings, HttpHeaderValues.CONTINUE.toString(), contentLength, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); + } + + /** + * Test that {@link NioHttpServerTransport} responds to an unsupported expectation with a 417 status. + * @throws InterruptedException if the client communication with the server is interrupted + */ + public void testExpectUnsupportedExpectation() throws InterruptedException { + runExpectHeaderTest(Settings.EMPTY, "chocolate=yummy", 0, HttpResponseStatus.EXPECTATION_FAILED); + } + + private void runExpectHeaderTest( + final Settings settings, + final String expectation, + final int contentLength, + final HttpResponseStatus expectedStatus) throws InterruptedException { + final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { + @Override + public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { + channel.sendResponse(new BytesRestResponse(OK, BytesRestResponse.TEXT_CONTENT_TYPE, new BytesArray("done"))); + } + + @Override + public void dispatchBadRequest(RestRequest request, RestChannel channel, ThreadContext threadContext, Throwable cause) { + throw new AssertionError(); + } + }; + try (NioHttpServerTransport transport = new NioHttpServerTransport(settings, networkService, bigArrays, threadPool, + xContentRegistry(), dispatcher)) { + transport.start(); + final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses()); + try (Netty4HttpClient client = new Netty4HttpClient()) { + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + request.headers().set(HttpHeaderNames.EXPECT, expectation); + HttpUtil.setContentLength(request, contentLength); + + final FullHttpResponse response = client.post(remoteAddress.address(), request); + assertThat(response.status(), equalTo(expectedStatus)); + if (expectedStatus.equals(HttpResponseStatus.CONTINUE)) { + final FullHttpRequest continuationRequest = + new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", Unpooled.EMPTY_BUFFER); + final FullHttpResponse continuationResponse = client.post(remoteAddress.address(), continuationRequest); + + assertThat(continuationResponse.status(), is(HttpResponseStatus.OK)); + assertThat(new String(ByteBufUtil.getBytes(continuationResponse.content()), StandardCharsets.UTF_8), is("done")); + } + } + } + } + + public void testBindUnavailableAddress() { + try (NioHttpServerTransport transport = new NioHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, + xContentRegistry(), new NullDispatcher())) { + transport.start(); + TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses()); + Settings settings = Settings.builder().put("http.port", remoteAddress.getPort()).build(); + try (NioHttpServerTransport otherTransport = new NioHttpServerTransport(settings, networkService, bigArrays, threadPool, + xContentRegistry(), new NullDispatcher())) { + BindHttpException bindHttpException = expectThrows(BindHttpException.class, () -> otherTransport.start()); + assertEquals("Failed to bind to [" + remoteAddress.getPort() + "]", bindHttpException.getMessage()); + } + } + } + + public void testBadRequest() throws InterruptedException { + final AtomicReference causeReference = new AtomicReference<>(); + final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { + + @Override + public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { + throw new AssertionError(); + } + + @Override + public void dispatchBadRequest(final RestRequest request, + final RestChannel channel, + final ThreadContext threadContext, + final Throwable cause) { + causeReference.set(cause); + try { + final ElasticsearchException e = new ElasticsearchException("you sent a bad request and you should feel bad"); + channel.sendResponse(new BytesRestResponse(channel, BAD_REQUEST, e)); + } catch (final IOException e) { + throw new AssertionError(e); + } + } + + }; + + final Settings settings; + final int maxInitialLineLength; + final Setting httpMaxInitialLineLengthSetting = HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH; + if (randomBoolean()) { + maxInitialLineLength = httpMaxInitialLineLengthSetting.getDefault(Settings.EMPTY).bytesAsInt(); + settings = Settings.EMPTY; + } else { + maxInitialLineLength = randomIntBetween(1, 8192); + settings = Settings.builder().put(httpMaxInitialLineLengthSetting.getKey(), maxInitialLineLength + "b").build(); + } + + try (NioHttpServerTransport transport = + new NioHttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) { + transport.start(); + final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses()); + + try (Netty4HttpClient client = new Netty4HttpClient()) { + final String url = "/" + new String(new byte[maxInitialLineLength], Charset.forName("UTF-8")); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, url); + + final FullHttpResponse response = client.post(remoteAddress.address(), request); + assertThat(response.status(), equalTo(HttpResponseStatus.BAD_REQUEST)); + assertThat( + new String(response.content().array(), Charset.forName("UTF-8")), + containsString("you sent a bad request and you should feel bad")); + } + } + + assertNotNull(causeReference.get()); + assertThat(causeReference.get(), instanceOf(TooLongFrameException.class)); + } + + public void testDispatchDoesNotModifyThreadContext() throws InterruptedException { + final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { + + @Override + public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { + threadContext.putHeader("foo", "bar"); + threadContext.putTransient("bar", "baz"); + } + + @Override + public void dispatchBadRequest(final RestRequest request, + final RestChannel channel, + final ThreadContext threadContext, + final Throwable cause) { + threadContext.putHeader("foo_bad", "bar"); + threadContext.putTransient("bar_bad", "baz"); + } + + }; + + try (NioHttpServerTransport transport = + new NioHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) { + transport.start(); + + transport.dispatchRequest(null, null); + assertNull(threadPool.getThreadContext().getHeader("foo")); + assertNull(threadPool.getThreadContext().getTransient("bar")); + + transport.dispatchBadRequest(null, null, null); + assertNull(threadPool.getThreadContext().getHeader("foo_bad")); + assertNull(threadPool.getThreadContext().getTransient("bar_bad")); + } + } + +// public void testReadTimeout() throws Exception { +// final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { +// +// @Override +// public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { +// throw new AssertionError("Should not have received a dispatched request"); +// } +// +// @Override +// public void dispatchBadRequest(final RestRequest request, +// final RestChannel channel, +// final ThreadContext threadContext, +// final Throwable cause) { +// throw new AssertionError("Should not have received a dispatched request"); +// } +// +// }; +// +// Settings settings = Settings.builder() +// .put(HttpTransportSettings.SETTING_HTTP_READ_TIMEOUT.getKey(), new TimeValue(randomIntBetween(100, 300))) +// .build(); +// +// +// NioEventLoopGroup group = new NioEventLoopGroup(); +// try (NioHttpServerTransport transport = +// new NioHttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) { +// transport.start(); +// final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses()); +// +// AtomicBoolean channelClosed = new AtomicBoolean(false); +// +// Bootstrap clientBootstrap = new Bootstrap().channel(NioSocketChannel.class).handler(new ChannelInitializer() { +// +// @Override +// protected void initChannel(SocketChannel ch) { +// ch.pipeline().addLast(new ChannelHandlerAdapter() {}); +// +// } +// }).group(group); +// ChannelFuture connect = clientBootstrap.connect(remoteAddress.address()); +// connect.channel().closeFuture().addListener(future -> channelClosed.set(true)); +// +// assertBusy(() -> assertTrue("Channel should be closed due to read timeout", channelClosed.get()), 5, TimeUnit.SECONDS); +// +// } finally { +// group.shutdownGracefully().await(); +// } +// } +} diff --git a/server/src/main/java/org/elasticsearch/ExceptionsHelper.java b/server/src/main/java/org/elasticsearch/ExceptionsHelper.java index 0427685b8ef4f..dff14bc8b393b 100644 --- a/server/src/main/java/org/elasticsearch/ExceptionsHelper.java +++ b/server/src/main/java/org/elasticsearch/ExceptionsHelper.java @@ -242,6 +242,35 @@ public static boolean reThrowIfNotNull(@Nullable Throwable e) { return true; } + /** + * If the specified cause is an unrecoverable error, this method will rethrow the cause on a separate thread so that it can not be + * caught and bubbles up to the uncaught exception handler. + * + * @param throwable the throwable to test + */ + public static void dieOnError(Throwable throwable) { + final Optional maybeError = ExceptionsHelper.maybeError(throwable, logger); + if (maybeError.isPresent()) { + /* + * Here be dragons. We want to rethrow this so that it bubbles up to the uncaught exception handler. Yet, Netty wraps too many + * invocations of user-code in try/catch blocks that swallow all throwables. This means that a rethrow here will not bubble up + * to where we want it to. So, we fork a thread and throw the exception from there where Netty can not get to it. We do not wrap + * the exception so as to not lose the original cause during exit. + */ + try { + // try to log the current stack trace + final String formatted = ExceptionsHelper.formatStackTrace(Thread.currentThread().getStackTrace()); + logger.error("fatal error\n{}", formatted); + } finally { + new Thread( + () -> { + throw maybeError.get(); + }) + .start(); + } + } + } + /** * Deduplicate the failures by exception message and index. */ diff --git a/server/src/main/java/org/elasticsearch/http/HttpHandlingSettings.java b/server/src/main/java/org/elasticsearch/http/HttpHandlingSettings.java new file mode 100644 index 0000000000000..f86049292f3fd --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpHandlingSettings.java @@ -0,0 +1,76 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +public class HttpHandlingSettings { + + private final int maxContentLength; + private final int maxChunkSize; + private final int maxHeaderSize; + private final int maxInitialLineLength; + private final boolean resetCookies; + private final boolean compression; + private final int compressionLevel; + private final boolean detailedErrorsEnabled; + + public HttpHandlingSettings(int maxContentLength, int maxChunkSize, int maxHeaderSize, int maxInitialLineLength, + boolean resetCookies, boolean compression, int compressionLevel, boolean detailedErrorsEnabled) { + this.maxContentLength = maxContentLength; + this.maxChunkSize = maxChunkSize; + this.maxHeaderSize = maxHeaderSize; + this.maxInitialLineLength = maxInitialLineLength; + this.resetCookies = resetCookies; + this.compression = compression; + this.compressionLevel = compressionLevel; + this.detailedErrorsEnabled = detailedErrorsEnabled; + } + + public int getMaxContentLength() { + return maxContentLength; + } + + public int getMaxChunkSize() { + return maxChunkSize; + } + + public int getMaxHeaderSize() { + return maxHeaderSize; + } + + public int getMaxInitialLineLength() { + return maxInitialLineLength; + } + + public boolean isResetCookies() { + return resetCookies; + } + + public boolean isCompression() { + return compression; + } + + public int getCompressionLevel() { + return compressionLevel; + } + + public boolean getDetailedErrorsEnabled() { + return detailedErrorsEnabled; + } +} diff --git a/server/src/main/java/org/elasticsearch/http/HttpServerTransport.java b/server/src/main/java/org/elasticsearch/http/HttpServerTransport.java index b5a720e0160bc..de345a39fd6d9 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpServerTransport.java +++ b/server/src/main/java/org/elasticsearch/http/HttpServerTransport.java @@ -62,5 +62,4 @@ interface Dispatcher { void dispatchBadRequest(RestRequest request, RestChannel channel, ThreadContext threadContext, Throwable cause); } - } diff --git a/server/src/main/java/org/elasticsearch/http/netty4/AbstractHttpServerTransport.java b/server/src/main/java/org/elasticsearch/http/netty4/AbstractHttpServerTransport.java new file mode 100644 index 0000000000000..a0b3632310b24 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/netty4/AbstractHttpServerTransport.java @@ -0,0 +1,174 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.netty4; + +import com.carrotsearch.hppc.IntHashSet; +import com.carrotsearch.hppc.IntSet; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.component.AbstractLifecycleComponent; +import org.elasticsearch.common.network.NetworkService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.BoundTransportAddress; +import org.elasticsearch.common.transport.PortsRange; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.http.BindHttpException; +import org.elasticsearch.http.HttpInfo; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.BindTransportException; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_BIND_HOST; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PORT; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_HOST; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_PORT; + +public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements org.elasticsearch.http.HttpServerTransport { + + protected final NetworkService networkService; + protected final ThreadPool threadPool; + protected final Dispatcher dispatcher; + + protected final String[] bindHosts; + protected final String[] publishHosts; + protected final PortsRange port; + protected final ByteSizeValue maxContentLength; + + protected volatile BoundTransportAddress boundAddress; + + protected AbstractHttpServerTransport(Settings settings, NetworkService networkService, ThreadPool threadPool, Dispatcher dispatcher) { + super(settings); + this.networkService = networkService; + this.threadPool = threadPool; + this.dispatcher = dispatcher; + + // we can't make the network.bind_host a fallback since we already fall back to http.host hence the extra conditional here + List httpBindHost = SETTING_HTTP_BIND_HOST.get(settings); + this.bindHosts = (httpBindHost.isEmpty() ? NetworkService.GLOBAL_NETWORK_BINDHOST_SETTING.get(settings) : httpBindHost) + .toArray(Strings.EMPTY_ARRAY); + // we can't make the network.publish_host a fallback since we already fall back to http.host hence the extra conditional here + List httpPublishHost = SETTING_HTTP_PUBLISH_HOST.get(settings); + this.publishHosts = (httpPublishHost.isEmpty() ? NetworkService.GLOBAL_NETWORK_PUBLISHHOST_SETTING.get(settings) : httpPublishHost) + .toArray(Strings.EMPTY_ARRAY); + + this.port = SETTING_HTTP_PORT.get(settings); + + this.maxContentLength = SETTING_HTTP_MAX_CONTENT_LENGTH.get(settings); + } + + @Override + public BoundTransportAddress boundAddress() { + return this.boundAddress; + } + + @Override + public HttpInfo info() { + BoundTransportAddress boundTransportAddress = boundAddress(); + if (boundTransportAddress == null) { + return null; + } + return new HttpInfo(boundTransportAddress, maxContentLength.getBytes()); + } + + protected BoundTransportAddress createBoundHttpAddress() { + // Bind and start to accept incoming connections. + InetAddress hostAddresses[]; + try { + hostAddresses = networkService.resolveBindHostAddresses(bindHosts); + } catch (IOException e) { + throw new BindHttpException("Failed to resolve host [" + Arrays.toString(bindHosts) + "]", e); + } + + List boundAddresses = new ArrayList<>(hostAddresses.length); + for (InetAddress address : hostAddresses) { + boundAddresses.add(bindAddress(address)); + } + + final InetAddress publishInetAddress; + try { + publishInetAddress = networkService.resolvePublishHostAddresses(publishHosts); + } catch (Exception e) { + throw new BindTransportException("Failed to resolve publish address", e); + } + + final int publishPort = resolvePublishPort(settings, boundAddresses, publishInetAddress); + final InetSocketAddress publishAddress = new InetSocketAddress(publishInetAddress, publishPort); + return new BoundTransportAddress(boundAddresses.toArray(new TransportAddress[0]), new TransportAddress(publishAddress)); + } + + protected abstract TransportAddress bindAddress(InetAddress hostAddress); + + // package private for tests + static int resolvePublishPort(Settings settings, List boundAddresses, InetAddress publishInetAddress) { + int publishPort = SETTING_HTTP_PUBLISH_PORT.get(settings); + + if (publishPort < 0) { + for (TransportAddress boundAddress : boundAddresses) { + InetAddress boundInetAddress = boundAddress.address().getAddress(); + if (boundInetAddress.isAnyLocalAddress() || boundInetAddress.equals(publishInetAddress)) { + publishPort = boundAddress.getPort(); + break; + } + } + } + + // if no matching boundAddress found, check if there is a unique port for all bound addresses + if (publishPort < 0) { + final IntSet ports = new IntHashSet(); + for (TransportAddress boundAddress : boundAddresses) { + ports.add(boundAddress.getPort()); + } + if (ports.size() == 1) { + publishPort = ports.iterator().next().value; + } + } + + if (publishPort < 0) { + throw new BindHttpException("Failed to auto-resolve http publish port, multiple bound addresses " + boundAddresses + + " with distinct ports and none of them matched the publish address (" + publishInetAddress + "). " + + "Please specify a unique port by setting " + SETTING_HTTP_PORT.getKey() + " or " + SETTING_HTTP_PUBLISH_PORT.getKey()); + } + return publishPort; + } + + public void dispatchRequest(final RestRequest request, final RestChannel channel) { + final ThreadContext threadContext = threadPool.getThreadContext(); + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + dispatcher.dispatchRequest(request, channel, threadContext); + } + } + + public void dispatchBadRequest(final RestRequest request, final RestChannel channel, final Throwable cause) { + final ThreadContext threadContext = threadPool.getThreadContext(); + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + dispatcher.dispatchBadRequest(request, channel, threadContext, cause); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java index 5271ac6a14837..36e282f32959d 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java @@ -32,13 +32,13 @@ import org.elasticsearch.nio.AcceptingSelector; import org.elasticsearch.nio.AcceptorEventHandler; import org.elasticsearch.nio.BytesChannelContext; +import org.elasticsearch.nio.BytesWriteHandler; import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; -import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TcpChannel; @@ -162,10 +162,9 @@ public MockSocketChannel createChannel(SocketSelector selector, SocketChannel ch Recycler.V bytes = pageCacheRecycler.bytePage(false); return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; - SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> - consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); + MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this); BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e), - nioReadConsumer, new InboundChannelBuffer(pageSupplier)); + readWriteHandler, new InboundChannelBuffer(pageSupplier)); nioChannel.setContext(context); return nioChannel; } @@ -180,6 +179,23 @@ public MockServerChannel createServerChannel(AcceptingSelector selector, ServerS } } + private static class MockTcpReadWriteHandler extends BytesWriteHandler { + + private final MockSocketChannel channel; + private final TcpTransport transport; + + private MockTcpReadWriteHandler(MockSocketChannel channel, TcpTransport transport) { + this.channel = channel; + this.transport = transport; + } + + @Override + public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException { + BytesReference bytesReference = BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())); + return transport.consumeNetworkReads(channel, bytesReference); + } + } + private static class MockServerChannel extends NioServerSocketChannel implements TcpChannel { private final String profile; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java index 5b4543ccaf275..01916b9138031 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java @@ -104,7 +104,7 @@ public ChannelHandler configureServerChannelHandler() { private final class HttpSslChannelHandler extends HttpChannelHandler { HttpSslChannelHandler() { - super(SecurityNetty4HttpServerTransport.this, detailedErrorsEnabled, threadPool.getThreadContext()); + super(SecurityNetty4HttpServerTransport.this, httpHandlingSettings, threadPool.getThreadContext()); } @Override diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index a4e88ec70f203..075e68183933f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -5,43 +5,33 @@ */ package org.elasticsearch.xpack.security.transport.nio; -import org.elasticsearch.nio.BytesWriteOperation; +import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.ReadWriteHandler; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.nio.WriteOperation; -import org.elasticsearch.nio.utils.ExceptionsHelper; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; -import java.util.ArrayList; -import java.util.LinkedList; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.Consumer; /** * Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake * with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted - * before being passed to the {@link ReadConsumer}. Outbound data will - * be encrypted before being flushed to the channel. + * before being passed to the {@link ReadWriteHandler}. Outbound data will be encrypted before being flushed + * to the channel. */ public final class SSLChannelContext extends SocketChannelContext { - private final LinkedList queued = new LinkedList<>(); private final SSLDriver sslDriver; - private final ReadConsumer readConsumer; - private final InboundChannelBuffer buffer; - private final AtomicBoolean isClosing = new AtomicBoolean(false); SSLChannelContext(NioSocketChannel channel, SocketSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, - ReadConsumer readConsumer, InboundChannelBuffer buffer) { - super(channel, selector, exceptionHandler); + ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) { + super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); this.sslDriver = sslDriver; - this.readConsumer = readConsumer; - this.buffer = buffer; } @Override @@ -50,32 +40,13 @@ public void register() throws IOException { sslDriver.init(); } - @Override - public void sendMessage(ByteBuffer[] buffers, BiConsumer listener) { - if (isClosing.get()) { - listener.accept(null, new ClosedChannelException()); - return; - } - - BytesWriteOperation writeOperation = new BytesWriteOperation(this, buffers, listener); - SocketSelector selector = getSelector(); - if (selector.isOnCurrentThread() == false) { - // If this message is being sent from another thread, we queue the write to be handled by the - // network thread - selector.queueWrite(writeOperation); - return; - } - - selector.queueWriteInChannelBuffer(writeOperation); - } - @Override public void queueWriteOperation(WriteOperation writeOperation) { getSelector().assertOnSelectorThread(); if (writeOperation instanceof CloseNotifyOperation) { sslDriver.initiateClose(); } else { - queued.add((BytesWriteOperation) writeOperation); + super.queueWriteOperation(writeOperation); } } @@ -96,28 +67,25 @@ public void flushChannel() throws IOException { // If the driver is ready for application writes, we can attempt to proceed with any queued writes. if (sslDriver.readyForApplicationWrites()) { - BytesWriteOperation currentOperation = queued.peekFirst(); - while (sslDriver.hasFlushPending() == false && currentOperation != null) { + FlushOperation currentFlush; + while (sslDriver.hasFlushPending() == false && (currentFlush = getPendingFlush()) != null) { // If the current operation has been fully consumed (encrypted) we now know that it has been // sent (as we only get to this point if the write buffer has been fully flushed). - if (currentOperation.isFullyFlushed()) { - queued.removeFirst(); - getSelector().executeListener(currentOperation.getListener(), null); - currentOperation = queued.peekFirst(); + if (currentFlush.isFullyFlushed()) { + currentFlushOperationComplete(); } else { try { // Attempt to encrypt application write data. The encrypted data ends up in the // outbound write buffer. - int bytesEncrypted = sslDriver.applicationWrite(currentOperation.getBuffersToWrite()); + int bytesEncrypted = sslDriver.applicationWrite(currentFlush.getBuffersToWrite()); if (bytesEncrypted == 0) { break; } - currentOperation.incrementIndex(bytesEncrypted); + currentFlush.incrementIndex(bytesEncrypted); // Flush the write buffer to the channel flushToChannel(sslDriver.getNetworkWriteBuffer()); } catch (IOException e) { - queued.removeFirst(); - getSelector().executeFailedListener(currentOperation.getListener(), e); + currentFlushOperationFailed(e); throw e; } } @@ -136,10 +104,10 @@ public void flushChannel() throws IOException { } @Override - public boolean hasQueuedWriteOps() { + public boolean readyForFlush() { getSelector().assertOnSelectorThread(); if (sslDriver.readyForApplicationWrites()) { - return sslDriver.hasFlushPending() || queued.isEmpty() == false; + return sslDriver.hasFlushPending() || super.readyForFlush(); } else { return sslDriver.hasFlushPending() || sslDriver.needsNonApplicationWrite(); } @@ -156,13 +124,9 @@ public int read() throws IOException { return bytesRead; } - sslDriver.read(buffer); + sslDriver.read(channelBuffer); - int bytesConsumed = Integer.MAX_VALUE; - while (bytesConsumed > 0 && buffer.getIndex() > 0) { - bytesConsumed = readConsumer.consumeReads(buffer); - buffer.release(bytesConsumed); - } + handleReadBytes(); return bytesRead; } @@ -189,31 +153,14 @@ public void closeChannel() { public void closeFromSelector() throws IOException { getSelector().assertOnSelectorThread(); if (channel.isOpen()) { - // Set to true in order to reject new writes before queuing with selector - isClosing.set(true); - ArrayList closingExceptions = new ArrayList<>(2); - try { - super.closeFromSelector(); - } catch (IOException e) { - closingExceptions.add(e); - } - try { - buffer.close(); - for (BytesWriteOperation op : queued) { - getSelector().executeFailedListener(op.getListener(), new ClosedChannelException()); - } - queued.clear(); - sslDriver.close(); - } catch (IOException e) { - closingExceptions.add(e); - } - ExceptionsHelper.rethrowAndSuppress(closingExceptions); + IOUtils.close(super::closeFromSelector, sslDriver::close); } } private static class CloseNotifyOperation implements WriteOperation { private static final BiConsumer LISTENER = (v, t) -> {}; + private static final Object WRITE_OBJECT = new Object(); private final SocketChannelContext channelContext; private CloseNotifyOperation(SocketChannelContext channelContext) { @@ -229,5 +176,10 @@ public BiConsumer getListener() { public SocketChannelContext getChannel() { return channelContext; } + + @Override + public Object getObject() { + return WRITE_OBJECT; + } } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java index 7773404762eb1..0f511af6b57d8 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.security.transport.nio; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.recycler.Recycler; @@ -17,20 +16,19 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; -import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.nio.NioTransport; import org.elasticsearch.transport.nio.TcpNioServerSocketChannel; import org.elasticsearch.transport.nio.TcpNioSocketChannel; +import org.elasticsearch.transport.nio.TcpReadWriteHandler; import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.security.transport.netty4.SecurityNetty4Transport; import org.elasticsearch.xpack.core.ssl.SSLConfiguration; import org.elasticsearch.xpack.core.ssl.SSLService; import javax.net.ssl.SSLEngine; - import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ServerSocketChannel; @@ -128,12 +126,10 @@ public TcpNioSocketChannel createChannel(SocketSelector selector, SocketChannel return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; - SocketChannelContext.ReadConsumer nioReadConsumer = channelBuffer -> - consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex()))); + TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this); InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); Consumer exceptionHandler = (e) -> exceptionCaught(nioChannel, e); - SSLChannelContext context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, nioReadConsumer, - buffer); + SSLChannelContext context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, buffer); nioChannel.setContext(context); return nioChannel; } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java index 884b348721fc5..fc501c68922e5 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java @@ -5,11 +5,12 @@ */ package org.elasticsearch.xpack.security.transport.nio; +import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.nio.BytesWriteOperation; +import org.elasticsearch.nio.BytesWriteHandler; +import org.elasticsearch.nio.FlushReadyWrite; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketSelector; import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.test.ESTestCase; @@ -19,16 +20,12 @@ import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; import java.nio.channels.Selector; import java.nio.channels.SocketChannel; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Supplier; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.isNull; import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -38,7 +35,7 @@ public class SSLChannelContextTests extends ESTestCase { - private SocketChannelContext.ReadConsumer readConsumer; + private CheckedFunction readConsumer; private NioSocketChannel channel; private SocketChannel rawChannel; private SSLChannelContext context; @@ -54,7 +51,8 @@ public class SSLChannelContextTests extends ESTestCase { @Before @SuppressWarnings("unchecked") public void init() { - readConsumer = mock(SocketChannelContext.ReadConsumer.class); + readConsumer = mock(CheckedFunction.class); + TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); messageLength = randomInt(96) + 20; selector = mock(SocketSelector.class); @@ -65,7 +63,7 @@ public void init() { channelBuffer = InboundChannelBuffer.allocatingInstance(); when(channel.getRawChannel()).thenReturn(rawChannel); exceptionHandler = mock(Consumer.class); - context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer); + context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); when(selector.isOnCurrentThread()).thenReturn(true); when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer); @@ -78,13 +76,13 @@ public void testSuccessfulRead() throws IOException { when(rawChannel.read(same(readBuffer))).thenReturn(bytes.length); doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); - when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0); + when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, 0); assertEquals(messageLength, context.read()); assertEquals(0, channelBuffer.getIndex()); assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity()); - verify(readConsumer, times(1)).consumeReads(channelBuffer); + verify(readConsumer, times(1)).apply(channelBuffer); } public void testMultipleReadsConsumed() throws IOException { @@ -93,13 +91,13 @@ public void testMultipleReadsConsumed() throws IOException { when(rawChannel.read(same(readBuffer))).thenReturn(bytes.length); doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); - when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0); + when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, messageLength, 0); assertEquals(bytes.length, context.read()); assertEquals(0, channelBuffer.getIndex()); assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity()); - verify(readConsumer, times(2)).consumeReads(channelBuffer); + verify(readConsumer, times(2)).apply(channelBuffer); } public void testPartialRead() throws IOException { @@ -109,20 +107,20 @@ public void testPartialRead() throws IOException { doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); - when(readConsumer.consumeReads(channelBuffer)).thenReturn(0); + when(readConsumer.apply(channelBuffer)).thenReturn(0); assertEquals(messageLength, context.read()); assertEquals(bytes.length, channelBuffer.getIndex()); - verify(readConsumer, times(1)).consumeReads(channelBuffer); + verify(readConsumer, times(1)).apply(channelBuffer); - when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength * 2, 0); + when(readConsumer.apply(channelBuffer)).thenReturn(messageLength * 2, 0); assertEquals(messageLength, context.read()); assertEquals(0, channelBuffer.getIndex()); assertEquals(BigArrays.BYTE_PAGE_SIZE - (bytes.length * 2), channelBuffer.getCapacity()); - verify(readConsumer, times(2)).consumeReads(channelBuffer); + verify(readConsumer, times(2)).apply(channelBuffer); } public void testReadThrowsIOException() throws IOException { @@ -149,50 +147,12 @@ public void testReadLessThanZeroMeansReadyForClose() throws IOException { assertTrue(context.selectorShouldClose()); } - @SuppressWarnings("unchecked") - public void testCloseClosesChannelBuffer() throws IOException { - try (SocketChannel realChannel = SocketChannel.open()) { - when(channel.getRawChannel()).thenReturn(realChannel); - - AtomicInteger closeCount = new AtomicInteger(0); - Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), - closeCount::incrementAndGet); - InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); - buffer.ensureCapacity(1); - SSLChannelContext context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, buffer); - when(channel.isOpen()).thenReturn(true); - context.closeFromSelector(); - assertEquals(1, closeCount.get()); - } - } - - @SuppressWarnings("unchecked") - public void testWriteOpsClearedOnClose() throws IOException { - try (SocketChannel realChannel = SocketChannel.open()) { - when(channel.getRawChannel()).thenReturn(realChannel); - context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer); - assertFalse(context.hasQueuedWriteOps()); - - ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); - - when(sslDriver.readyForApplicationWrites()).thenReturn(true); - assertTrue(context.hasQueuedWriteOps()); - - when(channel.isOpen()).thenReturn(true); - context.closeFromSelector(); - - verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class)); - - assertFalse(context.hasQueuedWriteOps()); - } - } - @SuppressWarnings("unchecked") public void testSSLDriverClosedOnClose() throws IOException { try (SocketChannel realChannel = SocketChannel.open()) { when(channel.getRawChannel()).thenReturn(realChannel); - context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer); + TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); + context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); when(channel.isOpen()).thenReturn(true); context.closeFromSelector(); @@ -200,66 +160,14 @@ public void testSSLDriverClosedOnClose() throws IOException { } } - public void testWriteFailsIfClosing() { - context.closeChannel(); - - ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; - context.sendMessage(buffers, listener); - - verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class)); - } - - public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception { - ArgumentCaptor writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class); - - when(selector.isOnCurrentThread()).thenReturn(false); - - ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; - context.sendMessage(buffers, listener); - - verify(selector).queueWrite(writeOpCaptor.capture()); - BytesWriteOperation writeOp = writeOpCaptor.getValue(); - - assertSame(listener, writeOp.getListener()); - assertSame(context, writeOp.getChannel()); - assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); - } - - public void testSendMessageFromSameThreadIsQueuedInChannel() { - ArgumentCaptor writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class); - - ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))}; - context.sendMessage(buffers, listener); - - verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture()); - BytesWriteOperation writeOp = writeOpCaptor.getValue(); - - assertSame(listener, writeOp.getListener()); - assertSame(context, writeOp.getChannel()); - assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]); - } - - public void testWriteIsQueuedInChannel() { - when(sslDriver.readyForApplicationWrites()).thenReturn(true); - when(sslDriver.hasFlushPending()).thenReturn(false); - when(sslDriver.needsNonApplicationWrite()).thenReturn(false); - assertFalse(context.hasQueuedWriteOps()); - - ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); - - assertTrue(context.hasQueuedWriteOps()); - } - public void testQueuedWritesAreIgnoredWhenNotReadyForAppWrites() { when(sslDriver.readyForApplicationWrites()).thenReturn(false); when(sslDriver.hasFlushPending()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(false); - ByteBuffer[] buffer = {ByteBuffer.allocate(10)}; - context.queueWriteOperation(new BytesWriteOperation(context, buffer, listener)); + context.queueWriteOperation(mock(FlushReadyWrite.class)); - assertFalse(context.hasQueuedWriteOps()); + assertFalse(context.readyForFlush()); } public void testPendingFlushMeansWriteInterested() { @@ -267,7 +175,7 @@ public void testPendingFlushMeansWriteInterested() { when(sslDriver.hasFlushPending()).thenReturn(true); when(sslDriver.needsNonApplicationWrite()).thenReturn(false); - assertTrue(context.hasQueuedWriteOps()); + assertTrue(context.readyForFlush()); } public void testNeedsNonAppWritesMeansWriteInterested() { @@ -275,14 +183,14 @@ public void testNeedsNonAppWritesMeansWriteInterested() { when(sslDriver.hasFlushPending()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(true); - assertTrue(context.hasQueuedWriteOps()); + assertTrue(context.readyForFlush()); } public void testNotWritesInterestInAppMode() { when(sslDriver.readyForApplicationWrites()).thenReturn(true); when(sslDriver.hasFlushPending()).thenReturn(false); - assertFalse(context.hasQueuedWriteOps()); + assertFalse(context.readyForFlush()); verify(sslDriver, times(0)).needsNonApplicationWrite(); } @@ -320,40 +228,40 @@ public void testNonAppWritesStopIfBufferNotFullyFlushed() throws Exception { public void testQueuedWriteIsFlushedInFlushCall() throws Exception { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOperation = mock(BytesWriteOperation.class); - context.queueWriteOperation(writeOperation); + FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); + context.queueWriteOperation(flushOperation); - when(writeOperation.getBuffersToWrite()).thenReturn(buffers); - when(writeOperation.getListener()).thenReturn(listener); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.getListener()).thenReturn(listener); when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false); when(sslDriver.readyForApplicationWrites()).thenReturn(true); when(sslDriver.applicationWrite(buffers)).thenReturn(10); - when(writeOperation.isFullyFlushed()).thenReturn(false,true); + when(flushOperation.isFullyFlushed()).thenReturn(false,true); context.flushChannel(); - verify(writeOperation).incrementIndex(10); + verify(flushOperation).incrementIndex(10); verify(rawChannel, times(1)).write(sslDriver.getNetworkWriteBuffer()); verify(selector).executeListener(listener, null); - assertFalse(context.hasQueuedWriteOps()); + assertFalse(context.readyForFlush()); } public void testPartialFlush() throws IOException { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOperation = mock(BytesWriteOperation.class); - context.queueWriteOperation(writeOperation); + FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); + context.queueWriteOperation(flushOperation); - when(writeOperation.getBuffersToWrite()).thenReturn(buffers); - when(writeOperation.getListener()).thenReturn(listener); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.getListener()).thenReturn(listener); when(sslDriver.hasFlushPending()).thenReturn(false, false, true); when(sslDriver.readyForApplicationWrites()).thenReturn(true); when(sslDriver.applicationWrite(buffers)).thenReturn(5); - when(writeOperation.isFullyFlushed()).thenReturn(false, false); + when(flushOperation.isFullyFlushed()).thenReturn(false, false); context.flushChannel(); - verify(writeOperation).incrementIndex(5); + verify(flushOperation).incrementIndex(5); verify(rawChannel, times(1)).write(sslDriver.getNetworkWriteBuffer()); verify(selector, times(0)).executeListener(listener, null); - assertTrue(context.hasQueuedWriteOps()); + assertTrue(context.readyForFlush()); } @SuppressWarnings("unchecked") @@ -361,48 +269,48 @@ public void testMultipleWritesPartialFlushes() throws IOException { BiConsumer listener2 = mock(BiConsumer.class); ByteBuffer[] buffers1 = {ByteBuffer.allocate(10)}; ByteBuffer[] buffers2 = {ByteBuffer.allocate(5)}; - BytesWriteOperation writeOperation1 = mock(BytesWriteOperation.class); - BytesWriteOperation writeOperation2 = mock(BytesWriteOperation.class); - when(writeOperation1.getBuffersToWrite()).thenReturn(buffers1); - when(writeOperation2.getBuffersToWrite()).thenReturn(buffers2); - when(writeOperation1.getListener()).thenReturn(listener); - when(writeOperation2.getListener()).thenReturn(listener2); - context.queueWriteOperation(writeOperation1); - context.queueWriteOperation(writeOperation2); + FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class); + FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class); + when(flushOperation1.getBuffersToWrite()).thenReturn(buffers1); + when(flushOperation2.getBuffersToWrite()).thenReturn(buffers2); + when(flushOperation1.getListener()).thenReturn(listener); + when(flushOperation2.getListener()).thenReturn(listener2); + context.queueWriteOperation(flushOperation1); + context.queueWriteOperation(flushOperation2); when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false, false, true); when(sslDriver.readyForApplicationWrites()).thenReturn(true); when(sslDriver.applicationWrite(buffers1)).thenReturn(5, 5); when(sslDriver.applicationWrite(buffers2)).thenReturn(3); - when(writeOperation1.isFullyFlushed()).thenReturn(false, false, true); - when(writeOperation2.isFullyFlushed()).thenReturn(false); + when(flushOperation1.isFullyFlushed()).thenReturn(false, false, true); + when(flushOperation2.isFullyFlushed()).thenReturn(false); context.flushChannel(); - verify(writeOperation1, times(2)).incrementIndex(5); + verify(flushOperation1, times(2)).incrementIndex(5); verify(rawChannel, times(3)).write(sslDriver.getNetworkWriteBuffer()); verify(selector).executeListener(listener, null); verify(selector, times(0)).executeListener(listener2, null); - assertTrue(context.hasQueuedWriteOps()); + assertTrue(context.readyForFlush()); } public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - BytesWriteOperation writeOperation = mock(BytesWriteOperation.class); - context.queueWriteOperation(writeOperation); + FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); + context.queueWriteOperation(flushOperation); IOException exception = new IOException(); - when(writeOperation.getBuffersToWrite()).thenReturn(buffers); - when(writeOperation.getListener()).thenReturn(listener); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.getListener()).thenReturn(listener); when(sslDriver.hasFlushPending()).thenReturn(false, false); when(sslDriver.readyForApplicationWrites()).thenReturn(true); when(sslDriver.applicationWrite(buffers)).thenReturn(5); when(rawChannel.write(sslDriver.getNetworkWriteBuffer())).thenThrow(exception); - when(writeOperation.isFullyFlushed()).thenReturn(false); + when(flushOperation.isFullyFlushed()).thenReturn(false); expectThrows(IOException.class, () -> context.flushChannel()); - verify(writeOperation).incrementIndex(5); + verify(flushOperation).incrementIndex(5); verify(selector).executeFailedListener(listener, exception); - assertFalse(context.hasQueuedWriteOps()); + assertFalse(context.readyForFlush()); } public void testWriteIOExceptionMeansChannelReadyToClose() throws Exception { @@ -426,7 +334,7 @@ public void testInitiateCloseFromDifferentThreadSchedulesCloseNotify() { when(selector.isOnCurrentThread()).thenReturn(false, true); context.closeChannel(); - ArgumentCaptor captor = ArgumentCaptor.forClass(WriteOperation.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(FlushReadyWrite.class); verify(selector).queueWrite(captor.capture()); context.queueWriteOperation(captor.getValue()); @@ -450,7 +358,8 @@ public void testRegisterInitiatesDriver() throws IOException { realSocket.configureBlocking(false); when(selector.rawSelector()).thenReturn(realSelector); when(channel.getRawChannel()).thenReturn(realSocket); - context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readConsumer, channelBuffer); + TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); + context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); context.register(); verify(sslDriver).init(); } @@ -475,4 +384,18 @@ private static byte[] createMessage(int length) { } return bytes; } + + private static class TestReadWriteHandler extends BytesWriteHandler { + + private final CheckedFunction fn; + + private TestReadWriteHandler(CheckedFunction fn) { + this.fn = fn; + } + + @Override + public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException { + return fn.apply(channelBuffer); + } + } }