diff --git a/core/src/main/java/io/undertow/UndertowOptions.java b/core/src/main/java/io/undertow/UndertowOptions.java index b06cf46a1c..edafbf036f 100644 --- a/core/src/main/java/io/undertow/UndertowOptions.java +++ b/core/src/main/java/io/undertow/UndertowOptions.java @@ -271,6 +271,8 @@ public class UndertowOptions { */ public static final Option HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS = Option.simple(UndertowOptions.class, "HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS", Integer.class); + public static final int DEFAULT_HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS = -1; + public static final Option HTTP2_SETTINGS_INITIAL_WINDOW_SIZE = Option.simple(UndertowOptions.class, "HTTP2_SETTINGS_INITIAL_WINDOW_SIZE", Integer.class); public static final Option HTTP2_SETTINGS_MAX_FRAME_SIZE = Option.simple(UndertowOptions.class, "HTTP2_SETTINGS_MAX_FRAME_SIZE", Integer.class); @@ -397,6 +399,29 @@ public class UndertowOptions { */ public static final Option TRACK_ACTIVE_REQUESTS = Option.simple(UndertowOptions.class, "TRACK_ACTIVE_REQUESTS", Boolean.class); + /** + * Default value of {@link #RST_FRAMES_TIME_WINDOW} option. + */ + public static final int DEFAULT_RST_FRAMES_TIME_WINDOW = 30000; + /** + * Default value of {@link #MAX_RST_FRAMES_PER_WINDOW} option. + */ + public static final int DEFAULT_MAX_RST_FRAMES_PER_WINDOW = 200; + + /** + * Window of time per which the number of HTTP2 RST received frames is measured, in milliseconds. + * If a number of RST frames bigger than {@link #MAX_RST_FRAMES_PER_WINDOW} is received during this time window, + * the server will send a GO_AWAY frame with error code 11 ({@code ENHANCE_YOUR_CALM}) and it will close the connection. + */ + public static final Option RST_FRAMES_TIME_WINDOW = Option.simple(UndertowOptions.class, "MAX_RST_STREAM_TIME_WINDOW", Integer.class); + + /** + * Maximum number of HTTP2 RST frames received allowed during a time window. + * If a number of RST frames bigger than this limit is received during {@link #RST_FRAMES_TIME_WINDOW} milliseconds, + * the server will send a GO_AWAY frame with error code 11 ({@code ENHANCE_YOUR_CALM}) and it will close the connection. + */ + public static final Option MAX_RST_FRAMES_PER_WINDOW = Option.simple(UndertowOptions.class, "MAX_RST_STREAMS_PER_TIME_WINDOW", Integer.class); + private UndertowOptions() { } diff --git a/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java b/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java index 6e999bddef..06658b9d02 100644 --- a/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java +++ b/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java @@ -43,6 +43,7 @@ import org.xnio.channels.StreamSinkChannel; import org.xnio.ssl.SslConnection; +import javax.net.ssl.SSLSession; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.Channel; @@ -52,13 +53,15 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.Random; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import javax.net.ssl.SSLSession; /** * HTTP2 channel. @@ -121,8 +124,6 @@ public class Http2Channel extends AbstractFramedChannel sendConcurrentStreamsAtomicUpdater = AtomicIntegerFieldUpdater.newUpdater( Http2Channel.class, "sendConcurrentStreams"); @@ -200,6 +214,8 @@ public class Http2Channel extends AbstractFramedChannel e : currentStreams.entrySet()) { + private void closeSubChannels(Map streams) { + for (Map.Entry e : streams.entrySet()) { StreamHolder holder = e.getValue(); AbstractHttp2StreamSourceChannel receiver = holder.sourceChannel; if(receiver != null) { @@ -765,7 +797,7 @@ public void handleWindowUpdate(int streamId, int deltaWindowSize) throws IOExcep StreamHolder holder = currentStreams.get(streamId); Http2StreamSinkChannel stream = holder != null ? holder.sinkChannel : null; if (stream == null) { - if(isIdle(streamId)) { + if (sentRstStreams.find(streamId) == null && isIdle(streamId)) { sendGoAway(ERROR_PROTOCOL_ERROR); } } else { @@ -1117,7 +1149,7 @@ public void sendRstStream(int streamId, int statusCode) { //no point sending if the channel is closed return; } - handleRstStream(streamId); + sentRstStreams.store(streamId, handleRstStream(streamId, false)); if(UndertowLogger.REQUEST_IO_LOGGER.isDebugEnabled()) { UndertowLogger.REQUEST_IO_LOGGER.debugf(new ClosedChannelException(), "Sending rststream on channel %s stream %s", this, streamId); } @@ -1125,8 +1157,8 @@ public void sendRstStream(int streamId, int statusCode) { flushChannelIgnoreFailure(channel); } - private void handleRstStream(int streamId) { - StreamHolder holder = currentStreams.remove(streamId); + private StreamHolder handleRstStream(int streamId, boolean receivedRst) { + final StreamHolder holder = currentStreams.remove(streamId); if(holder != null) { if(streamId % 2 == (isClient() ? 1 : 0)) { sendConcurrentStreamsAtomicUpdater.getAndDecrement(this); @@ -1139,7 +1171,23 @@ private void handleRstStream(int streamId) { if (holder.sourceChannel != null) { holder.sourceChannel.rstStream(); } + if (receivedRst) { + long currentTimeMillis = System.currentTimeMillis(); + // reset the window tracking + if (currentTimeMillis - lastRstFrameMillis >= rstFramesTimeWindow) { + lastRstFrameMillis = currentTimeMillis; + receivedRstFramesPerWindow = 1; + } else { + // + receivedRstFramesPerWindow ++; + if (receivedRstFramesPerWindow > maxRstFramesPerWindow) { + sendGoAway(Http2Channel.ERROR_ENHANCE_YOUR_CALM); + UndertowLogger.REQUEST_IO_LOGGER.debugf("Reached maximum number of rst frames %s during %s ms, sending GO_AWAY 11", maxRstFramesPerWindow, rstFramesTimeWindow); + } + } + } } + return holder; } /** @@ -1175,8 +1223,9 @@ public boolean isThisGoneAway() { Http2StreamSourceChannel removeStreamSource(int streamId) { StreamHolder existing = currentStreams.get(streamId); - if(existing == null){ - return null; + if (existing == null) { + existing = sentRstStreams.find(streamId); + return existing == null? null : existing.sourceChannel; } existing.sourceClosed = true; Http2StreamSourceChannel ret = existing.sourceChannel; @@ -1195,7 +1244,10 @@ Http2StreamSourceChannel removeStreamSource(int streamId) { Http2StreamSourceChannel getIncomingStream(int streamId) { StreamHolder existing = currentStreams.get(streamId); if(existing == null){ - return null; + existing = sentRstStreams.find(streamId); + if (existing == null) { + return null; + } } return existing.sourceChannel; } @@ -1248,4 +1300,58 @@ private static final class StreamHolder { this.sinkChannel = sinkChannel; } } + + // cache that keeps track of streams until they can be evicted @see Http2Channel#RST_STREAM_EVICATION_TIME + private static final class StreamCache { + private Map streamHolders = new ConcurrentHashMap<>(); + // entries are sorted per creation time + private Queue entries = new ConcurrentLinkedQueue<>(); + + private void store(int streamId, StreamHolder streamHolder) { + if (streamHolder == null) { + return; + } + streamHolders.put(streamId, streamHolder); + entries.add(new StreamCacheEntry(streamId)); + } + private StreamHolder find(int streamId) { + for (Iterator iterator = entries.iterator(); iterator.hasNext();) { + StreamCacheEntry entry = iterator.next(); + if (entry.shouldEvict()) { + iterator.remove(); + StreamHolder holder = streamHolders.remove(entry.streamId); + AbstractHttp2StreamSourceChannel receiver = holder.sourceChannel; + if(receiver != null) { + IoUtils.safeClose(receiver); + } + Http2StreamSinkChannel sink = holder.sinkChannel; + if(sink != null) { + if (sink.isWritesShutdown()) { + ChannelListeners.invokeChannelListener(sink.getIoThread(), sink, ((ChannelListener.SimpleSetter) sink.getWriteSetter()).get()); + } + IoUtils.safeClose(sink); + } + } else break; + } + return streamHolders.get(streamId); + } + + private Map getStreamHolders() { + return streamHolders; + } + } + + private static class StreamCacheEntry { + int streamId; + long time; + + StreamCacheEntry(int streamId) { + this.streamId = streamId; + this.time = System.currentTimeMillis(); + } + + public boolean shouldEvict() { + return System.currentTimeMillis() - time > STREAM_CACHE_EVICTION_TIME_MS; + } + } } diff --git a/core/src/test/java/io/undertow/client/http2/DoSHttp2ClientConnection.java b/core/src/test/java/io/undertow/client/http2/DoSHttp2ClientConnection.java new file mode 100644 index 0000000000..43a171062f --- /dev/null +++ b/core/src/test/java/io/undertow/client/http2/DoSHttp2ClientConnection.java @@ -0,0 +1,523 @@ +package io.undertow.client.http2; + +import io.undertow.UndertowLogger; +import io.undertow.UndertowMessages; +import io.undertow.client.ClientCallback; +import io.undertow.client.ClientConnection; +import io.undertow.client.ClientExchange; +import io.undertow.client.ClientRequest; +import io.undertow.client.ClientStatistics; +import io.undertow.connector.ByteBufferPool; +import io.undertow.protocols.http2.AbstractHttp2StreamSourceChannel; +import io.undertow.protocols.http2.Http2Channel; +import io.undertow.protocols.http2.Http2GoAwayStreamSourceChannel; +import io.undertow.protocols.http2.Http2HeadersStreamSinkChannel; +import io.undertow.protocols.http2.Http2PingStreamSourceChannel; +import io.undertow.protocols.http2.Http2PushPromiseStreamSourceChannel; +import io.undertow.protocols.http2.Http2RstStreamStreamSourceChannel; +import io.undertow.protocols.http2.Http2StreamSourceChannel; +import io.undertow.server.protocol.http.HttpAttachments; +import io.undertow.util.HeaderMap; +import io.undertow.util.HeaderValues; +import io.undertow.util.Headers; +import io.undertow.util.HttpString; +import io.undertow.util.Methods; +import io.undertow.util.Protocols; +import org.xnio.ChannelExceptionHandler; +import org.xnio.ChannelListener; +import org.xnio.ChannelListeners; +import org.xnio.IoUtils; +import org.xnio.Option; +import org.xnio.StreamConnection; +import org.xnio.XnioIoThread; +import org.xnio.XnioWorker; +import org.xnio.channels.Channels; +import org.xnio.channels.StreamSinkChannel; + +import java.io.IOException; +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +import static io.undertow.protocols.http2.Http2Channel.AUTHORITY; +import static io.undertow.protocols.http2.Http2Channel.METHOD; +import static io.undertow.protocols.http2.Http2Channel.PATH; +import static io.undertow.protocols.http2.Http2Channel.SCHEME; +import static io.undertow.protocols.http2.Http2Channel.STATUS; +import static io.undertow.util.Headers.CONTENT_LENGTH; +import static io.undertow.util.Headers.TRANSFER_ENCODING; + +/** + * ClientConnection implementation that mimics the DDoS rapid reset attack. See UNDERTOW-2323. + * + * @author Stuart Douglas + * @author Flavia Rainone + */ +public class DoSHttp2ClientConnection extends Http2ClientConnection implements ClientConnection { + + private final Http2Channel http2Channel; + private final ChannelListener.SimpleSetter closeSetter = new ChannelListener.SimpleSetter<>(); + + private final Map currentExchanges = new ConcurrentHashMap<>(); + + private static final AtomicLong PING_COUNTER = new AtomicLong(); + + private Http2GoAwayStreamSourceChannel goAwayStreamSourceChannel = null; + + + private boolean initialUpgradeRequest; + private final String defaultHost; + private final ClientStatistics clientStatistics; + private final List> closeListeners = new CopyOnWriteArrayList<>(); + private final boolean secure; + + private final Map outstandingPings = new HashMap<>(); + + public DoSHttp2ClientConnection(Http2Channel http2Channel, boolean initialUpgradeRequest, String defaultHost, ClientStatistics clientStatistics, boolean secure) { + super (http2Channel, initialUpgradeRequest, defaultHost, clientStatistics, secure); + this.http2Channel = http2Channel; + this.defaultHost = defaultHost; + this.clientStatistics = clientStatistics; + this.secure = secure; + http2Channel.getReceiveSetter().set(new Http2ReceiveListener()); + http2Channel.resumeReceives(); + ChannelListener closeTask = channel -> { + ChannelListeners.invokeChannelListener(DoSHttp2ClientConnection.this, closeSetter.get()); + for (ChannelListener listener : closeListeners) { + listener.handleEvent(DoSHttp2ClientConnection.this); + } + for (Map.Entry entry : currentExchanges.entrySet()) { + entry.getValue().failed(new ClosedChannelException()); + } + currentExchanges.clear(); + }; + http2Channel.addCloseTask(closeTask); + this.initialUpgradeRequest = initialUpgradeRequest; + } + + @Override + public void sendRequest(ClientRequest request, ClientCallback clientCallback) { + if(!http2Channel.isOpen()) { + clientCallback.failed(new ClosedChannelException()); + return; + } + request.getRequestHeaders().put(METHOD, request.getMethod().toString()); + boolean connectRequest = request.getMethod().equals(Methods.CONNECT); + if(!connectRequest) { + request.getRequestHeaders().put(PATH, request.getPath()); + request.getRequestHeaders().put(SCHEME, secure ? "https" : "http"); + } + final String host = request.getRequestHeaders().getFirst(Headers.HOST); + if(host != null) { + request.getRequestHeaders().put(AUTHORITY, host); + } else { + request.getRequestHeaders().put(AUTHORITY, defaultHost); + } + request.getRequestHeaders().remove(Headers.HOST); + + + boolean hasContent = true; + + String fixedLengthString = request.getRequestHeaders().getFirst(CONTENT_LENGTH); + String transferEncodingString = request.getRequestHeaders().getLast(TRANSFER_ENCODING); + if (fixedLengthString != null) { + try { + long length = Long.parseLong(fixedLengthString); + hasContent = length != 0; + } catch (NumberFormatException e) { + handleError(new IOException(e)); + return; + } + } else if (transferEncodingString == null && !connectRequest) { + hasContent = false; + } + + request.getRequestHeaders().remove(Headers.CONNECTION); + request.getRequestHeaders().remove(Headers.KEEP_ALIVE); + request.getRequestHeaders().remove(Headers.TRANSFER_ENCODING); + + Http2HeadersStreamSinkChannel sinkChannel; + try { + sinkChannel = http2Channel.createStream(request.getRequestHeaders()); + // + } catch (Throwable t) { + IOException e = t instanceof IOException ? (IOException) t : new IOException(t); + clientCallback.failed(e); + return; + } + Http2ClientExchange exchange = new Http2ClientExchange(this, sinkChannel, request); + currentExchanges.put(sinkChannel.getStreamId(), exchange); + + sinkChannel.setTrailersProducer(() -> { + HeaderMap attachment = exchange.getAttachment(HttpAttachments.RESPONSE_TRAILERS); + Supplier supplier = exchange.getAttachment(HttpAttachments.RESPONSE_TRAILER_SUPPLIER); + if(attachment != null && supplier == null) { + return attachment; + } else if(attachment == null && supplier != null) { + return supplier.get(); + } else if(attachment != null) { + HeaderMap supplied = supplier.get(); + for(HeaderValues k : supplied) { + attachment.putAll(k.getHeaderName(), k); + } + return attachment; + } else { + return null; + } + }); + try { + sinkChannel.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } + //UNDERTOW-2323 send a rst stream right away + http2Channel.sendRstStream(sinkChannel.getStreamId(), Http2Channel.ERROR_CANCEL); + if(clientCallback != null) { + clientCallback.completed(exchange); + } + if (!hasContent) { + //if there is no content we flush the response channel. + //otherwise it is up to the user + try { + sinkChannel.shutdownWrites(); + if (!sinkChannel.flush()) { + sinkChannel.getWriteSetter().set(ChannelListeners.flushingChannelListener(null, + (ChannelExceptionHandler) (channel, exception) -> handleError(exception))); + sinkChannel.resumeWrites(); + } + } catch (Throwable e) { + handleError(e); + } + } + } + + private void handleError(Throwable t) { + IOException e = t instanceof IOException ? (IOException) t : new IOException(t); + UndertowLogger.REQUEST_IO_LOGGER.ioException(e); + IoUtils.safeClose(DoSHttp2ClientConnection.this); + for (Map.Entry entry : currentExchanges.entrySet()) { + try { + entry.getValue().failed(e); + } catch (Exception ex) { + UndertowLogger.REQUEST_IO_LOGGER.ioException(new IOException(ex)); + } + } + } + + @Override + public StreamConnection performUpgrade() { + throw UndertowMessages.MESSAGES.upgradeNotSupported(); + } + + @Override + public ByteBufferPool getBufferPool() { + return http2Channel.getBufferPool(); + } + + @Override + public SocketAddress getPeerAddress() { + return http2Channel.getPeerAddress(); + } + + @Override + public A getPeerAddress(Class type) { + return http2Channel.getPeerAddress(type); + } + + @Override + public ChannelListener.Setter getCloseSetter() { + return closeSetter; + } + + @Override + public SocketAddress getLocalAddress() { + return http2Channel.getLocalAddress(); + } + + @Override + public A getLocalAddress(Class type) { + return http2Channel.getLocalAddress(type); + } + + @Override + public XnioWorker getWorker() { + return http2Channel.getWorker(); + } + + @Override + public XnioIoThread getIoThread() { + return http2Channel.getIoThread(); + } + + @Override + public boolean isOpen() { + return http2Channel.isOpen() && !http2Channel.isPeerGoneAway() && !http2Channel.isThisGoneAway(); + } + + @Override + public void close() throws IOException { + try { + http2Channel.sendGoAway(0); + } finally { + for(Map.Entry entry : currentExchanges.entrySet()) { + entry.getValue().failed(new ClosedChannelException()); + } + currentExchanges.clear(); + } + } + + @Override + public T getOption(Option option) { + return null; + } + + @Override + public T setOption(Option option, T value) throws IllegalArgumentException { + return null; + } + + @Override + public ClientStatistics getStatistics() { + return clientStatistics; + } + + @Override + public void addCloseListener(ChannelListener listener) { + closeListeners.add(listener); + } + + @Override + public void sendPing(PingListener listener, long timeout, TimeUnit timeUnit) { + long count = PING_COUNTER.incrementAndGet(); + byte[] data = new byte[8]; + data[0] = (byte) count; + data[1] = (byte)(count << 8); + data[2] = (byte)(count << 16); + data[3] = (byte)(count << 24); + data[4] = (byte)(count << 32); + data[5] = (byte)(count << 40); + data[6] = (byte)(count << 48); + data[7] = (byte)(count << 54); + final PingKey key = new PingKey(data); + outstandingPings.put(key, listener); + if(timeout > 0) { + http2Channel.getIoThread().executeAfter(() -> { + PingListener listener1 = outstandingPings.remove(key); + if(listener1 != null) { + listener1.failed(UndertowMessages.MESSAGES.pingTimeout()); + } + }, timeout, timeUnit); + } + http2Channel.sendPing(data, (channel, exception) -> listener.failed(exception)); + } + + private class Http2ReceiveListener implements ChannelListener { + + // listener that handles events for channels after receiving a continue response + private class ContinueReceiveListener implements ChannelListener { + private final Http2Channel http2Channel; + + ContinueReceiveListener(Http2Channel http2Channel) { + this.http2Channel = http2Channel; + } + + @Override + public void handleEvent(AbstractHttp2StreamSourceChannel sourceChannel) { + // listener is added only to instances of Http2StreamSourceChannel + assert sourceChannel instanceof Http2StreamSourceChannel; + try { + // channel is already created, no need to invoke receive + final Http2StreamSourceChannel channel = (Http2StreamSourceChannel) sourceChannel; + if (channel.getHeaders().getFirst(STATUS) == null) { + // instead, process pending frames, so we can see if we have a final status + Channels.drain(channel, Long.MAX_VALUE); + if (channel.getHeaders().getFirst(STATUS) == null) { + // no status yet, return and wait for next event + return; + } + } + // finally, a new status + int statusCode = Integer.parseInt(Objects.requireNonNull(channel.getHeaders().getFirst(STATUS))); + Http2ClientExchange request = currentExchanges.get(channel.getStreamId()); + if (statusCode < 200) { + //this is an informational response 1xx response + if (statusCode == 100) { + //we got a continue response again, just set the continue response and wait for next event + request.setContinueResponse(request.createResponse(channel)); + } + Channels.drain(channel, Long.MAX_VALUE); + return; + } + // we got the final response, handle it + handleFinalResponse(http2Channel, request, channel); + } catch (Throwable t) { + handleThrowable(t); + } + } + } + + @Override + public void handleEvent(Http2Channel channel) { + try { + AbstractHttp2StreamSourceChannel result = channel.receive(); + if (result instanceof Http2StreamSourceChannel) { + final Http2StreamSourceChannel streamSourceChannel = (Http2StreamSourceChannel) result; + + int statusCode = Integer.parseInt(Objects.requireNonNull(streamSourceChannel.getHeaders().getFirst(STATUS))); + Http2ClientExchange request = currentExchanges.get(streamSourceChannel.getStreamId()); + if(statusCode < 200) { + //this is an informational response 1xx response + if(statusCode == 100) { + //a continue response + request.setContinueResponse(request.createResponse(streamSourceChannel)); + // switch to continue receive listener, because next frame we will already have the Http2StreamSourceChannel + // previously created, we just need to read the new pending frames as they arrive + streamSourceChannel.getReadSetter().set(new Http2ReceiveListener.ContinueReceiveListener(http2Channel)); + streamSourceChannel.resumeReads(); + } + Channels.drain(result, Long.MAX_VALUE); + return; + } + handleFinalResponse(channel, request, streamSourceChannel); + } else if (result instanceof Http2PingStreamSourceChannel) { + handlePing((Http2PingStreamSourceChannel) result); + } else if (result instanceof Http2RstStreamStreamSourceChannel) { + Http2RstStreamStreamSourceChannel rstStream = (Http2RstStreamStreamSourceChannel) result; + int stream = rstStream.getStreamId(); + UndertowLogger.REQUEST_LOGGER.debugf("Client received RST_STREAM for stream %s", stream); + Http2ClientExchange exchange = currentExchanges.remove(stream); + + if(exchange != null) { + //if we have not yet received a response we treat this as an error + exchange.failed(UndertowMessages.MESSAGES.http2StreamWasReset()); + } + Channels.drain(result, Long.MAX_VALUE); + } else if (result instanceof Http2PushPromiseStreamSourceChannel) { + Http2PushPromiseStreamSourceChannel stream = (Http2PushPromiseStreamSourceChannel) result; + Http2ClientExchange request = currentExchanges.get(stream.getAssociatedStreamId()); + if(request == null) { + channel.sendGoAway(Http2Channel.ERROR_PROTOCOL_ERROR); //according to the spec this is a connection error + } else if(request.getPushCallback() == null) { + channel.sendRstStream(stream.getPushedStreamId(), Http2Channel.ERROR_REFUSED_STREAM); + } else { + ClientRequest cr = new ClientRequest(); + cr.setMethod(new HttpString(stream.getHeaders().getFirst(METHOD))); + cr.setPath(stream.getHeaders().getFirst(PATH)); + cr.setProtocol(Protocols.HTTP_1_1); + for (HeaderValues header : stream.getHeaders()) { + cr.getRequestHeaders().putAll(header.getHeaderName(), header); + } + + Http2ClientExchange newExchange = new Http2ClientExchange( + DoSHttp2ClientConnection.this, null, cr); + + if(!request.getPushCallback().handlePush(request, newExchange)) { + // if no push handler just reset the stream + channel.sendRstStream(stream.getPushedStreamId(), Http2Channel.ERROR_REFUSED_STREAM); + IoUtils.safeClose(stream); + } else if (!http2Channel.addPushPromiseStream(stream.getPushedStreamId())) { + // if invalid stream id send connection error of type PROTOCOL_ERROR as spec + channel.sendGoAway(Http2Channel.ERROR_PROTOCOL_ERROR); + } else { + // add the pushed stream to current exchanges + currentExchanges.put(stream.getPushedStreamId(), newExchange); + } + } + Channels.drain(result, Long.MAX_VALUE); + + } else if (result instanceof Http2GoAwayStreamSourceChannel) { + goAwayStreamSourceChannel = (Http2GoAwayStreamSourceChannel) result; + close(); + } else if(result != null) { + Channels.drain(result, Long.MAX_VALUE); + } + + } catch (Throwable t) { + handleThrowable(t); + } + } + + private void handleFinalResponse(Http2Channel channel, Http2ClientExchange request, Http2StreamSourceChannel response) throws IOException { + response.setTrailersHandler(headerMap -> request.putAttachment(io.undertow.server.protocol.http.HttpAttachments.REQUEST_TRAILERS, headerMap)); + response.addCloseTask(channel1 -> currentExchanges.remove(response.getStreamId())); + response.setCompletionListener(channel12 -> currentExchanges.remove(response.getStreamId())); + if (request == null && initialUpgradeRequest) { + Channels.drain(response, Long.MAX_VALUE); + initialUpgradeRequest = false; + return; + } else if (request == null) { + channel.sendGoAway(io.undertow.protocols.http2.Http2Channel.ERROR_PROTOCOL_ERROR); + IoUtils.safeClose(DoSHttp2ClientConnection.this); + return; + } + request.responseReady(response); + } + + private void handlePing(Http2PingStreamSourceChannel frame) { + byte[] id = frame.getData(); + if (!frame.isAck()) { + //server side ping, return it + frame.getHttp2Channel().sendPing(id); + } else { + PingListener listener = outstandingPings.remove(new PingKey(id)); + if(listener != null) { + listener.acknowledged(); + } + } + } + + private void handleThrowable(Throwable t) { + final IOException e = t instanceof IOException ? (IOException) t : new IOException(t); + UndertowLogger.REQUEST_IO_LOGGER.ioException(e); + IoUtils.safeClose(DoSHttp2ClientConnection.this); + for (Map.Entry entry : currentExchanges.entrySet()) { + try { + entry.getValue().failed(e); + } catch (Throwable ex) { + UndertowLogger.REQUEST_IO_LOGGER.ioException(new IOException(ex)); + } + } + } + } + + private static final class PingKey{ + private final byte[] data; + + private PingKey(byte[] data) { + this.data = data; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + PingKey pingKey = (PingKey) o; + + return Arrays.equals(data, pingKey.data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } + } + + /** + * Returns the last Http2GoAwayStreamSourceChannel received, if any. + * @return the last go away frame received + */ + public Http2GoAwayStreamSourceChannel getGoAwayStreamSourceChannel() { + return goAwayStreamSourceChannel; + } +} + diff --git a/core/src/test/java/io/undertow/protocols/http2/RapidResetDDoSUnitTestCase.java b/core/src/test/java/io/undertow/protocols/http2/RapidResetDDoSUnitTestCase.java new file mode 100644 index 0000000000..7b8f7b6674 --- /dev/null +++ b/core/src/test/java/io/undertow/protocols/http2/RapidResetDDoSUnitTestCase.java @@ -0,0 +1,330 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2023 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.undertow.protocols.http2; + +import io.undertow.Undertow; +import io.undertow.UndertowLogger; +import io.undertow.UndertowOptions; +import io.undertow.client.ALPNClientSelector; +import io.undertow.client.ClientCallback; +import io.undertow.client.ClientConnection; +import io.undertow.client.ClientExchange; +import io.undertow.client.ClientProvider; +import io.undertow.client.ClientRequest; +import io.undertow.client.ClientResponse; +import io.undertow.client.http.HttpClientProvider; +import io.undertow.client.http2.DoSHttp2ClientConnection; +import io.undertow.client.http2.Http2ClientConnection; +import io.undertow.connector.ByteBufferPool; +import io.undertow.io.Sender; +import io.undertow.protocols.ssl.UndertowXnioSsl; +import io.undertow.server.HttpServerExchange; +import io.undertow.server.handlers.PathHandler; +import io.undertow.testutils.DefaultServer; +import io.undertow.testutils.category.UnitTest; +import io.undertow.util.AttachmentKey; +import io.undertow.util.Headers; +import io.undertow.util.Methods; +import io.undertow.util.StatusCodes; +import io.undertow.util.StringReadChannelListener; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.xnio.ChannelListener; +import org.xnio.ChannelListeners; +import org.xnio.FutureResult; +import org.xnio.IoFuture; +import org.xnio.IoUtils; +import org.xnio.OptionMap; +import org.xnio.Options; +import org.xnio.StreamConnection; +import org.xnio.Xnio; +import org.xnio.XnioWorker; +import org.xnio.channels.StreamSinkChannel; +import org.xnio.ssl.SslConnection; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.channels.ClosedChannelException; +import java.security.PrivilegedAction; +import java.util.List; +import java.util.ServiceLoader; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static io.undertow.server.protocol.http2.Http2OpenListener.HTTP2; +import static io.undertow.testutils.StopServerWithExternalWorkerUtils.stopWorker; +import static java.security.AccessController.doPrivileged; + +/** + * Test that mimics the rapid reset DDoS attack. See UNDERTOW-2323. + * + * @author Flavia Rainone + */ +@Category(UnitTest.class) +public class RapidResetDDoSUnitTestCase { + + private static final String message = "Hello World!"; + public static final String MESSAGE = "/message"; + public static final String POST = "/post"; + private static XnioWorker worker; + private static Undertow defaultConfigServer; + private static Undertow overwrittenConfigServer; + private static final OptionMap DEFAULT_OPTIONS; + private static URI defaultConfigServerAddress; + private static URI overwrittenConfigServerAddress; + + private static final AttachmentKey RESPONSE_BODY = AttachmentKey.create(String.class); + + private static volatile DoSHttp2ClientConnection clientConnection; + private IOException exception; + + static { + final OptionMap.Builder builder = OptionMap.builder() + .set(Options.WORKER_IO_THREADS, 8) + .set(Options.TCP_NODELAY, true) + .set(Options.KEEP_ALIVE, true) + .set(Options.WORKER_NAME, "Client"); + + DEFAULT_OPTIONS = builder.getMap(); + } + + static void sendMessage(final HttpServerExchange exchange) { + exchange.setStatusCode(StatusCodes.OK); + final Sender sender = exchange.getResponseSender(); + sender.send(message); + } + + @BeforeClass + public static void beforeClass() throws Exception { + + int port = DefaultServer.getHostPort("default"); + + final PathHandler path = new PathHandler() + .addExactPath(MESSAGE, RapidResetDDoSUnitTestCase::sendMessage) + .addExactPath(POST, exchange -> exchange.getRequestReceiver().receiveFullString( + (exchange1, message) -> exchange1.getResponseSender().send(message))); + + defaultConfigServer = Undertow.builder() + .addHttpsListener(port + 1, DefaultServer.getHostAddress("default"), DefaultServer.getServerSslContext()) + .setServerOption(UndertowOptions.ENABLE_HTTP2, true) + .setSocketOption(Options.REUSE_ADDRESSES, true) + .setHandler(path::handleRequest) + .build(); + defaultConfigServer.start(); + + overwrittenConfigServer = Undertow.builder() + .addHttpsListener(port + 2, DefaultServer.getHostAddress("default"), DefaultServer.getServerSslContext()) + .setServerOption(UndertowOptions.ENABLE_HTTP2, true) + .setServerOption(UndertowOptions.RST_FRAMES_TIME_WINDOW, 5000) + .setServerOption(UndertowOptions.MAX_RST_FRAMES_PER_WINDOW, 50) + .setSocketOption(Options.REUSE_ADDRESSES, true) + .setHandler(path::handleRequest) + .build(); + overwrittenConfigServer.start(); + + defaultConfigServerAddress = new URI("https://" + DefaultServer.getHostAddress() + ":" + (port + 1)); + overwrittenConfigServerAddress = new URI("https://" + DefaultServer.getHostAddress() + ":" + (port + 2)); + + // Create xnio worker + worker = Xnio.getInstance().createWorker(null, DEFAULT_OPTIONS); + } + + @AfterClass + public static void afterClass() { + if (defaultConfigServer != null) + defaultConfigServer.stop(); + if (overwrittenConfigServer != null) + overwrittenConfigServer.stop(); + if (worker != null) + stopWorker(worker); + } + + @Test + public void testGoAwayWithDefaultConfig() throws Exception { + System.out.println("go away with default config"); + assertDoSRstFramesHandled(300, 200, true, defaultConfigServerAddress); + } + + @Test + public void testNoErrorWithDefaultConfig() throws Exception { + System.out.println("no error with default config"); + assertDoSRstFramesHandled(150, 200, false, defaultConfigServerAddress); + } + + @Test + public void testGoAwayWithOverwrittenConfig() throws Exception { + System.out.println("go away with overwritten config"); + assertDoSRstFramesHandled(100, 50, true, overwrittenConfigServerAddress); + } + + @Test + public void testNoErrorWithOverwrittenConfig() throws Exception { + System.out.println("no error with overwritten config"); + assertDoSRstFramesHandled(50, 50, false, overwrittenConfigServerAddress); + } + + public void assertDoSRstFramesHandled(int totalNumberOfRequests, int rstStreamLimit, boolean errorExpected, URI serverAddress) throws Exception { + final List responses = new CopyOnWriteArrayList<>(); + final CountDownLatch latch = new CountDownLatch(totalNumberOfRequests); + + ServiceLoader providers = doPrivileged((PrivilegedAction>) + () -> ServiceLoader.load(ClientProvider.class, this.getClass().getClassLoader())); + ClientProvider clientProvider = null; + for (ClientProvider provider : providers) { + for (String scheme : provider.handlesSchemes()) { + if (scheme.equals(serverAddress.getScheme())) { + clientProvider = provider; + break; + } + } + } + Assert.assertNotNull(clientProvider); + final FutureResult result = new FutureResult<>(); + ClientCallback listener = new ClientCallback<>() { + @Override public void completed(ClientConnection r) { + result.setResult(r); + } + + @Override public void failed(IOException e) { + result.setException(e); + } + }; + UndertowXnioSsl ssl = new UndertowXnioSsl(worker.getXnio(), OptionMap.EMPTY, DefaultServer.getClientSSLContext()); + OptionMap tlsOptions = OptionMap.builder() + .set(UndertowOptions.ENDPOINT_IDENTIFICATION_ALGORITHM, HttpClientProvider.DISABLE_HTTPS_ENDPOINT_IDENTIFICATION? "" : "HTTPS") + .set(Options.SSL_STARTTLS, true) + .getMap(); + ChannelListener openListener = connection -> ALPNClientSelector.runAlpn((SslConnection) connection, + connection1 -> { + UndertowLogger.ROOT_LOGGER.alpnConnectionFailed(connection1); + IoUtils.safeClose(connection1); + }, listener, alpnProtocol(listener, serverAddress.getHost(), DefaultServer.getBufferPool(), tlsOptions)); + + ssl.openSslConnection(worker, new InetSocketAddress(serverAddress.getHost(), serverAddress.getPort()), openListener, tlsOptions).addNotifier( + (IoFuture.Notifier) (ioFuture, o) -> { + if (ioFuture.getStatus() == IoFuture.Status.FAILED) { + listener.failed(ioFuture.getException()); + } + }, null); + + + final ClientConnection connection = result.getIoFuture().get(); + try { + connection.getIoThread().execute(() -> { + for (int i = 0; i < totalNumberOfRequests; i++) { + final ClientRequest request = new ClientRequest().setMethod(Methods.GET).setPath(MESSAGE); + request.getRequestHeaders().put(Headers.HOST, DefaultServer.getHostAddress()); + connection.sendRequest(request, createClientCallback(responses, latch)); + } + }); + + latch.await(200, TimeUnit.SECONDS); + + // server sent go away before processing and responding client frames, sometimes this happens, depends on the order of threads + // being executed + if (responses.isEmpty()) { + Assert.assertTrue(errorExpected); + Assert.assertNotNull(exception); + Assert.assertTrue(exception instanceof ClosedChannelException); + return; + } + Assert.assertEquals(errorExpected ? rstStreamLimit + 1 : totalNumberOfRequests, responses.size()); + for (final ClientResponse response : responses) { + final String responseBody = response.getAttachment(RESPONSE_BODY); + Assert.assertTrue("Unexpected response body: " + responseBody, responseBody.isEmpty() || responseBody.equals(message)); + } + if (errorExpected) { + Assert.assertNotNull(exception); + Assert.assertTrue(exception instanceof ClosedChannelException); + Http2GoAwayStreamSourceChannel http2GoAwayStreamSourceChannel = clientConnection.getGoAwayStreamSourceChannel(); + Assert.assertNotNull(http2GoAwayStreamSourceChannel); + Assert.assertEquals(11, http2GoAwayStreamSourceChannel.getStatus()); + } else { + Assert.assertNull(exception); + Assert.assertNull(clientConnection.getGoAwayStreamSourceChannel()); + } + } finally { + IoUtils.safeClose(connection); + } + } + + public static ALPNClientSelector.ALPNProtocol alpnProtocol(final ClientCallback listener, String defaultHost, ByteBufferPool bufferPool, OptionMap options) { + return new ALPNClientSelector.ALPNProtocol( + connection -> listener.completed(createHttp2Channel(connection, bufferPool, options, defaultHost)), HTTP2); + } + + private static Http2ClientConnection createHttp2Channel(StreamConnection connection, ByteBufferPool bufferPool, OptionMap options, String defaultHost) { + //first we set up statistics, if required + Http2Channel http2Channel = new Http2Channel(connection, null, bufferPool, null, true, false, options); + return clientConnection = new DoSHttp2ClientConnection(http2Channel, false, defaultHost, null, true); + } + + private ClientCallback createClientCallback(final List responses, final CountDownLatch latch) { + return new ClientCallback<>() { + @Override public void completed(ClientExchange result) { + result.setResponseListener(new ClientCallback<>() { + @Override public void completed(final ClientExchange result) { + responses.add(result.getResponse()); + new StringReadChannelListener(result.getConnection().getBufferPool()) { + + @Override protected void stringDone(String string) { + result.getResponse().putAttachment(RESPONSE_BODY, string); + latch.countDown(); + } + + @Override protected void error(IOException e) { + e.printStackTrace(); + exception = e; + latch.countDown(); + } + }.setup(result.getResponseChannel()); + } + + @Override public void failed(IOException e) { + e.printStackTrace(); + exception = e; + latch.countDown(); + } + }); + try { + result.getRequestChannel().shutdownWrites(); + if (!result.getRequestChannel().flush()) { + result.getRequestChannel().getWriteSetter().set(ChannelListeners.flushingChannelListener(null, null)); + result.getRequestChannel().resumeWrites(); + } + } catch (IOException e) { + e.printStackTrace(); + exception = e; + latch.countDown(); + } + } + + @Override public void failed(IOException e) { + e.printStackTrace(); + exception = e; + latch.countDown(); + } + }; + } +} \ No newline at end of file