diff --git a/server/src/main/java/org/opensearch/transport/Header.java b/server/src/main/java/org/opensearch/transport/Header.java index ac30df8dda02c..fcfeb9c632075 100644 --- a/server/src/main/java/org/opensearch/transport/Header.java +++ b/server/src/main/java/org/opensearch/transport/Header.java @@ -55,6 +55,7 @@ public class Header { private static final String RESPONSE_NAME = "NO_ACTION_NAME_FOR_RESPONSES"; + private final TransportProtocol protocol; private final int networkMessageSize; private final Version version; private final long requestId; @@ -64,13 +65,18 @@ public class Header { Tuple, Map>> headers; Set features; - Header(int networkMessageSize, long requestId, byte status, Version version) { + Header(TransportProtocol protocol, int networkMessageSize, long requestId, byte status, Version version) { + this.protocol = protocol; this.networkMessageSize = networkMessageSize; this.version = version; this.requestId = requestId; this.status = status; } + TransportProtocol getTransportProtocol() { + return protocol; + } + public int getNetworkMessageSize() { return networkMessageSize; } @@ -142,6 +148,8 @@ void finishParsingHeader(StreamInput input) throws IOException { @Override public String toString() { return "Header{" + + protocol + + "}{" + networkMessageSize + "}{" + version diff --git a/server/src/main/java/org/opensearch/transport/InboundAggregator.java b/server/src/main/java/org/opensearch/transport/InboundAggregator.java index f52875d880b4f..e894331f3b64e 100644 --- a/server/src/main/java/org/opensearch/transport/InboundAggregator.java +++ b/server/src/main/java/org/opensearch/transport/InboundAggregator.java @@ -40,7 +40,6 @@ import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.bytes.CompositeBytesReference; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; import java.util.ArrayList; @@ -114,7 +113,7 @@ public void aggregate(ReleasableBytesReference content) { } } - public NativeInboundMessage finishAggregation() throws IOException { + public InboundMessage finishAggregation() throws IOException { ensureOpen(); final ReleasableBytesReference releasableContent; if (isFirstContent()) { @@ -128,7 +127,7 @@ public NativeInboundMessage finishAggregation() throws IOException { } final BreakerControl breakerControl = new BreakerControl(circuitBreaker); - final NativeInboundMessage aggregated = new NativeInboundMessage(currentHeader, releasableContent, breakerControl); + final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent, breakerControl); boolean success = false; try { if (aggregated.getHeader().needsToReadVariableHeader()) { @@ -143,7 +142,7 @@ public NativeInboundMessage finishAggregation() throws IOException { if (isShortCircuited()) { aggregated.close(); success = true; - return new NativeInboundMessage(aggregated.getHeader(), aggregationException); + return new InboundMessage(aggregated.getHeader(), aggregationException); } else { success = true; return aggregated; diff --git a/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java index 276891212e43f..ad839bc990018 100644 --- a/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java @@ -9,24 +9,139 @@ package org.opensearch.transport; import org.opensearch.common.bytes.ReleasableBytesReference; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.lease.Releasables; +import org.opensearch.core.common.bytes.CompositeBytesReference; -import java.io.Closeable; import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.function.BiConsumer; /** - * Interface for handling inbound bytes. Can be implemented by different transport protocols. + * Handler for inbound bytes, using {@link InboundDecoder} to decode headers + * and {@link InboundAggregator} to assemble complete messages to forward to + * the given message handler to parse the message payload. */ -public interface InboundBytesHandler extends Closeable { +class InboundBytesHandler { - public void doHandleBytes( - TcpChannel channel, - ReleasableBytesReference reference, - BiConsumer messageHandler - ) throws IOException; + private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); - public boolean canHandleBytes(ReleasableBytesReference reference); + private final ArrayDeque pending; + private final InboundDecoder decoder; + private final InboundAggregator aggregator; + private final StatsTracker statsTracker; + private boolean isClosed = false; + + InboundBytesHandler( + ArrayDeque pending, + InboundDecoder decoder, + InboundAggregator aggregator, + StatsTracker statsTracker + ) { + this.pending = pending; + this.decoder = decoder; + this.aggregator = aggregator; + this.statsTracker = statsTracker; + } + + public void close() { + isClosed = true; + } + + public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference, BiConsumer messageHandler) + throws IOException { + final ArrayList fragments = fragmentList.get(); + boolean continueHandling = true; + + while (continueHandling && isClosed == false) { + boolean continueDecoding = true; + while (continueDecoding && pending.isEmpty() == false) { + try (ReleasableBytesReference toDecode = getPendingBytes()) { + final int bytesDecoded = decoder.decode(toDecode, fragments::add); + if (bytesDecoded != 0) { + releasePendingBytes(bytesDecoded); + if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { + continueDecoding = false; + } + } else { + continueDecoding = false; + } + } + } + + if (fragments.isEmpty()) { + continueHandling = false; + } else { + try { + forwardFragments(channel, fragments, messageHandler); + } finally { + for (Object fragment : fragments) { + if (fragment instanceof ReleasableBytesReference) { + ((ReleasableBytesReference) fragment).close(); + } + } + fragments.clear(); + } + } + } + } + + private ReleasableBytesReference getPendingBytes() { + if (pending.size() == 1) { + return pending.peekFirst().retain(); + } else { + final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; + int index = 0; + for (ReleasableBytesReference pendingReference : pending) { + bytesReferences[index] = pendingReference.retain(); + ++index; + } + final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences); + return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable); + } + } + + private void releasePendingBytes(int bytesConsumed) { + int bytesToRelease = bytesConsumed; + while (bytesToRelease != 0) { + try (ReleasableBytesReference reference = pending.pollFirst()) { + assert reference != null; + if (bytesToRelease < reference.length()) { + pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease)); + bytesToRelease -= bytesToRelease; + } else { + bytesToRelease -= reference.length(); + } + } + } + } + + private boolean endOfMessage(Object fragment) { + return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception; + } + + private void forwardFragments(TcpChannel channel, ArrayList fragments, BiConsumer messageHandler) + throws IOException { + for (Object fragment : fragments) { + if (fragment instanceof Header) { + assert aggregator.isAggregating() == false; + aggregator.headerReceived((Header) fragment); + } else if (fragment == InboundDecoder.PING) { + assert aggregator.isAggregating() == false; + messageHandler.accept(channel, InboundMessage.PING); + } else if (fragment == InboundDecoder.END_CONTENT) { + assert aggregator.isAggregating(); + try (InboundMessage aggregated = aggregator.finishAggregation()) { + statsTracker.markMessageReceived(); + messageHandler.accept(channel, aggregated); + } + } else { + assert aggregator.isAggregating(); + assert fragment instanceof ReleasableBytesReference; + aggregator.aggregate((ReleasableBytesReference) fragment); + } + } + } - @Override - void close(); } diff --git a/server/src/main/java/org/opensearch/transport/InboundDecoder.java b/server/src/main/java/org/opensearch/transport/InboundDecoder.java index d6b7a98e876b3..3e735d4be2420 100644 --- a/server/src/main/java/org/opensearch/transport/InboundDecoder.java +++ b/server/src/main/java/org/opensearch/transport/InboundDecoder.java @@ -187,11 +187,12 @@ private int headerBytesToRead(BytesReference reference) { // exposed for use in tests static Header readHeader(Version version, int networkMessageSize, BytesReference bytesReference) throws IOException { try (StreamInput streamInput = bytesReference.streamInput()) { - streamInput.skip(TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE); + TransportProtocol protocol = TransportProtocol.fromBytes(streamInput.readByte(), streamInput.readByte()); + streamInput.skip(TcpHeader.MESSAGE_LENGTH_SIZE); long requestId = streamInput.readLong(); byte status = streamInput.readByte(); Version remoteVersion = Version.fromId(streamInput.readInt()); - Header header = new Header(networkMessageSize, requestId, status, remoteVersion); + Header header = new Header(protocol, networkMessageSize, requestId, status, remoteVersion); final IllegalStateException invalidVersion = ensureVersionCompatibility(remoteVersion, version, header.isHandshake()); if (invalidVersion != null) { throw invalidVersion; diff --git a/server/src/main/java/org/opensearch/transport/InboundHandler.java b/server/src/main/java/org/opensearch/transport/InboundHandler.java index f77c44ea362cf..76a44832b08dc 100644 --- a/server/src/main/java/org/opensearch/transport/InboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundHandler.java @@ -38,7 +38,6 @@ import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; import java.util.Map; @@ -56,7 +55,7 @@ public class InboundHandler { private volatile long slowLogThresholdMs = Long.MAX_VALUE; - private final Map protocolMessageHandlers; + private final Map protocolMessageHandlers; InboundHandler( String nodeName, @@ -75,7 +74,7 @@ public class InboundHandler { ) { this.threadPool = threadPool; this.protocolMessageHandlers = Map.of( - NativeInboundMessage.NATIVE_PROTOCOL, + TransportProtocol.NATIVE, new NativeMessageHandler( nodeName, version, @@ -107,16 +106,16 @@ void setSlowLogThreshold(TimeValue slowLogThreshold) { this.slowLogThresholdMs = slowLogThreshold.getMillis(); } - void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) throws Exception { + void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception { final long startTime = threadPool.relativeTimeInMillis(); channel.getChannelStats().markAccessed(startTime); messageReceivedFromPipeline(channel, message, startTime); } - private void messageReceivedFromPipeline(TcpChannel channel, ProtocolInboundMessage message, long startTime) throws IOException { - ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getProtocol()); + private void messageReceivedFromPipeline(TcpChannel channel, InboundMessage message, long startTime) throws IOException { + ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getTransportProtocol()); if (protocolMessageHandler == null) { - throw new IllegalStateException("No protocol message handler found for protocol: " + message.getProtocol()); + throw new IllegalStateException("No protocol message handler found for protocol: " + message.getTransportProtocol()); } protocolMessageHandler.messageReceived(channel, message, startTime, slowLogThresholdMs, messageListener); } diff --git a/server/src/main/java/org/opensearch/transport/InboundMessage.java b/server/src/main/java/org/opensearch/transport/InboundMessage.java new file mode 100644 index 0000000000000..576ab73ce9c98 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/InboundMessage.java @@ -0,0 +1,149 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.transport; + +import org.opensearch.common.annotation.PublicApi; +import org.opensearch.common.bytes.ReleasableBytesReference; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.lease.Releasables; +import org.opensearch.common.util.io.IOUtils; +import org.opensearch.core.common.io.stream.StreamInput; + +import java.io.IOException; + +/** + * Inbound data as a message + */ +@PublicApi(since = "1.0.0") +public class InboundMessage implements Releasable, ProtocolInboundMessage { + + static final InboundMessage PING = new InboundMessage(null, null, null, true, null); + + protected final Header header; + protected final ReleasableBytesReference content; + protected final Exception exception; + protected final boolean isPing; + private Releasable breakerRelease; + private StreamInput streamInput; + + public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { + this(header, content, null, false, breakerRelease); + } + + public InboundMessage(Header header, Exception exception) { + this(header, null, exception, false, null); + } + + public InboundMessage(Header header, boolean isPing) { + this(header, null, null, isPing, null); + } + + private InboundMessage( + Header header, + ReleasableBytesReference content, + Exception exception, + boolean isPing, + Releasable breakerRelease + ) { + this.header = header; + this.content = content; + this.exception = exception; + this.isPing = isPing; + this.breakerRelease = breakerRelease; + } + + TransportProtocol getTransportProtocol() { + if (isPing) { + return TransportProtocol.NATIVE; + } + return header.getTransportProtocol(); + } + + public String getProtocol() { + return header.getTransportProtocol().toString(); + } + + public Header getHeader() { + return header; + } + + public int getContentLength() { + if (content == null) { + return 0; + } else { + return content.length(); + } + } + + public Exception getException() { + return exception; + } + + public boolean isPing() { + return isPing; + } + + public boolean isShortCircuit() { + return exception != null; + } + + public Releasable takeBreakerReleaseControl() { + final Releasable toReturn = breakerRelease; + breakerRelease = null; + if (toReturn != null) { + return toReturn; + } else { + return () -> {}; + } + } + + public StreamInput openOrGetStreamInput() throws IOException { + assert isPing == false && content != null; + if (streamInput == null) { + streamInput = content.streamInput(); + streamInput.setVersion(header.getVersion()); + } + return streamInput; + } + + @Override + public void close() { + IOUtils.closeWhileHandlingException(streamInput); + Releasables.closeWhileHandlingException(content, breakerRelease); + } + + @Override + public String toString() { + return "InboundMessage{" + header + "}"; + } +} diff --git a/server/src/main/java/org/opensearch/transport/InboundPipeline.java b/server/src/main/java/org/opensearch/transport/InboundPipeline.java index 5cee3bb975223..3acb43f58b443 100644 --- a/server/src/main/java/org/opensearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/opensearch/transport/InboundPipeline.java @@ -38,11 +38,9 @@ import org.opensearch.common.lease.Releasables; import org.opensearch.common.util.PageCacheRecycler; import org.opensearch.core.common.breaker.CircuitBreaker; -import org.opensearch.transport.nativeprotocol.NativeInboundBytesHandler; import java.io.IOException; import java.util.ArrayDeque; -import java.util.List; import java.util.function.BiConsumer; import java.util.function.Function; import java.util.function.LongSupplier; @@ -62,9 +60,8 @@ public class InboundPipeline implements Releasable { private Exception uncaughtException; private final ArrayDeque pending = new ArrayDeque<>(2); private boolean isClosed = false; - private final BiConsumer messageHandler; - private final List protocolBytesHandlers; - private InboundBytesHandler currentHandler; + private final BiConsumer messageHandler; + private final InboundBytesHandler bytesHandler; public InboundPipeline( Version version, @@ -73,7 +70,7 @@ public InboundPipeline( LongSupplier relativeTimeInMillis, Supplier circuitBreaker, Function> registryFunction, - BiConsumer messageHandler + BiConsumer messageHandler ) { this( statsTracker, @@ -89,23 +86,20 @@ public InboundPipeline( LongSupplier relativeTimeInMillis, InboundDecoder decoder, InboundAggregator aggregator, - BiConsumer messageHandler + BiConsumer messageHandler ) { this.relativeTimeInMillis = relativeTimeInMillis; this.statsTracker = statsTracker; this.decoder = decoder; this.aggregator = aggregator; - this.protocolBytesHandlers = List.of(new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker)); + this.bytesHandler = new InboundBytesHandler(pending, decoder, aggregator, statsTracker); this.messageHandler = messageHandler; } @Override public void close() { isClosed = true; - if (currentHandler != null) { - currentHandler.close(); - currentHandler = null; - } + bytesHandler.close(); Releasables.closeWhileHandlingException(decoder, aggregator); Releasables.closeWhileHandlingException(pending); pending.clear(); @@ -127,22 +121,6 @@ public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong()); statsTracker.markBytesRead(reference.length()); pending.add(reference.retain()); - - // If we don't have a current handler, we should try to find one based on the protocol of the incoming bytes. - if (currentHandler == null) { - for (InboundBytesHandler handler : protocolBytesHandlers) { - if (handler.canHandleBytes(reference)) { - currentHandler = handler; - break; - } - } - } - - // If we have a current handler determined based on protocol, we should continue to use it for the fragmented bytes. - if (currentHandler != null) { - currentHandler.doHandleBytes(channel, reference, messageHandler); - } else { - throw new IllegalStateException("No bytes handler found for the incoming transport protocol"); - } + bytesHandler.doHandleBytes(channel, reference, messageHandler); } } diff --git a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java index 4c972fdc14fa5..58adc2d3d68a5 100644 --- a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java +++ b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java @@ -52,7 +52,6 @@ import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.opensearch.transport.nativeprotocol.NativeOutboundHandler; import java.io.EOFException; @@ -119,7 +118,7 @@ public void messageReceived( long slowLogThresholdMs, TransportMessageListener messageListener ) throws IOException { - NativeInboundMessage inboundMessage = (NativeInboundMessage) message; + InboundMessage inboundMessage = (InboundMessage) message; TransportLogger.logInboundMessage(channel, inboundMessage); if (inboundMessage.isPing()) { keepAlive.receiveKeepAlive(channel); @@ -130,7 +129,7 @@ public void messageReceived( private void handleMessage( TcpChannel channel, - NativeInboundMessage message, + InboundMessage message, long startTime, long slowLogThresholdMs, TransportMessageListener messageListener @@ -202,7 +201,7 @@ private Map> extractHeaders(Map heade private void handleRequest( TcpChannel channel, Header header, - NativeInboundMessage message, + InboundMessage message, TransportMessageListener messageListener ) throws IOException { final String action = header.getActionName(); diff --git a/server/src/main/java/org/opensearch/transport/TcpTransport.java b/server/src/main/java/org/opensearch/transport/TcpTransport.java index ffa3168da0b3e..f56cd146ce953 100644 --- a/server/src/main/java/org/opensearch/transport/TcpTransport.java +++ b/server/src/main/java/org/opensearch/transport/TcpTransport.java @@ -777,13 +777,21 @@ protected void serverAcceptedChannel(TcpChannel channel) { */ protected abstract void stopInternal(); + /** + * @deprecated Use {{@link #inboundMessage(TcpChannel, InboundMessage)}} instead + */ + @Deprecated + public void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) { + inboundMessage(channel, (InboundMessage) message); + } + /** * Handles inbound message that has been decoded. * * @param channel the channel the message is from * @param message the message */ - public void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) { + public void inboundMessage(TcpChannel channel, InboundMessage message) { try { inboundHandler.inboundMessage(channel, message); } catch (Exception e) { diff --git a/server/src/main/java/org/opensearch/transport/TransportLogger.java b/server/src/main/java/org/opensearch/transport/TransportLogger.java index e780f643aafd7..997b3bb5ba18e 100644 --- a/server/src/main/java/org/opensearch/transport/TransportLogger.java +++ b/server/src/main/java/org/opensearch/transport/TransportLogger.java @@ -40,7 +40,6 @@ import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.compress.CompressorRegistry; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; @@ -65,7 +64,7 @@ static void logInboundMessage(TcpChannel channel, BytesReference message) { } } - static void logInboundMessage(TcpChannel channel, NativeInboundMessage message) { + static void logInboundMessage(TcpChannel channel, InboundMessage message) { if (logger.isTraceEnabled()) { try { String logMessage = format(channel, message, "READ"); @@ -137,7 +136,7 @@ private static String format(TcpChannel channel, BytesReference message, String return sb.toString(); } - private static String format(TcpChannel channel, NativeInboundMessage message, String event) throws IOException { + private static String format(TcpChannel channel, InboundMessage message, String event) throws IOException { final StringBuilder sb = new StringBuilder(); sb.append(channel); diff --git a/server/src/main/java/org/opensearch/transport/TransportProtocol.java b/server/src/main/java/org/opensearch/transport/TransportProtocol.java new file mode 100644 index 0000000000000..4a11520d38d56 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/TransportProtocol.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport; + +/** + * Enumeration of transport protocols. + */ +enum TransportProtocol { + /** + * The original, hand-rolled binary protocol used for node-to-node + * communication. Message schemas are defined implicitly in code using the + * StreamInput and StreamOutput classes to parse and generate binary data. + */ + NATIVE; + + public static TransportProtocol fromBytes(byte b1, byte b2) { + if (b1 == 'E' && b2 == 'S') { + return NATIVE; + } + + throw new IllegalArgumentException("Unknown transport protocol: [" + b1 + ", " + b2 + "]"); + } +} diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java deleted file mode 100644 index 97981aeb6736e..0000000000000 --- a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java +++ /dev/null @@ -1,166 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.transport.nativeprotocol; - -import org.opensearch.common.bytes.ReleasableBytesReference; -import org.opensearch.common.lease.Releasable; -import org.opensearch.common.lease.Releasables; -import org.opensearch.core.common.bytes.CompositeBytesReference; -import org.opensearch.transport.Header; -import org.opensearch.transport.InboundAggregator; -import org.opensearch.transport.InboundBytesHandler; -import org.opensearch.transport.InboundDecoder; -import org.opensearch.transport.ProtocolInboundMessage; -import org.opensearch.transport.StatsTracker; -import org.opensearch.transport.TcpChannel; - -import java.io.IOException; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.function.BiConsumer; - -/** - * Handler for inbound bytes for the native protocol. - */ -public class NativeInboundBytesHandler implements InboundBytesHandler { - - private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); - private static final NativeInboundMessage PING_MESSAGE = new NativeInboundMessage(null, true); - - private final ArrayDeque pending; - private final InboundDecoder decoder; - private final InboundAggregator aggregator; - private final StatsTracker statsTracker; - private boolean isClosed = false; - - public NativeInboundBytesHandler( - ArrayDeque pending, - InboundDecoder decoder, - InboundAggregator aggregator, - StatsTracker statsTracker - ) { - this.pending = pending; - this.decoder = decoder; - this.aggregator = aggregator; - this.statsTracker = statsTracker; - } - - @Override - public void close() { - isClosed = true; - } - - @Override - public boolean canHandleBytes(ReleasableBytesReference reference) { - return true; - } - - @Override - public void doHandleBytes( - TcpChannel channel, - ReleasableBytesReference reference, - BiConsumer messageHandler - ) throws IOException { - final ArrayList fragments = fragmentList.get(); - boolean continueHandling = true; - - while (continueHandling && isClosed == false) { - boolean continueDecoding = true; - while (continueDecoding && pending.isEmpty() == false) { - try (ReleasableBytesReference toDecode = getPendingBytes()) { - final int bytesDecoded = decoder.decode(toDecode, fragments::add); - if (bytesDecoded != 0) { - releasePendingBytes(bytesDecoded); - if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { - continueDecoding = false; - } - } else { - continueDecoding = false; - } - } - } - - if (fragments.isEmpty()) { - continueHandling = false; - } else { - try { - forwardFragments(channel, fragments, messageHandler); - } finally { - for (Object fragment : fragments) { - if (fragment instanceof ReleasableBytesReference) { - ((ReleasableBytesReference) fragment).close(); - } - } - fragments.clear(); - } - } - } - } - - private ReleasableBytesReference getPendingBytes() { - if (pending.size() == 1) { - return pending.peekFirst().retain(); - } else { - final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; - int index = 0; - for (ReleasableBytesReference pendingReference : pending) { - bytesReferences[index] = pendingReference.retain(); - ++index; - } - final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences); - return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable); - } - } - - private void releasePendingBytes(int bytesConsumed) { - int bytesToRelease = bytesConsumed; - while (bytesToRelease != 0) { - try (ReleasableBytesReference reference = pending.pollFirst()) { - assert reference != null; - if (bytesToRelease < reference.length()) { - pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease)); - bytesToRelease -= bytesToRelease; - } else { - bytesToRelease -= reference.length(); - } - } - } - } - - private boolean endOfMessage(Object fragment) { - return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception; - } - - private void forwardFragments( - TcpChannel channel, - ArrayList fragments, - BiConsumer messageHandler - ) throws IOException { - for (Object fragment : fragments) { - if (fragment instanceof Header) { - assert aggregator.isAggregating() == false; - aggregator.headerReceived((Header) fragment); - } else if (fragment == InboundDecoder.PING) { - assert aggregator.isAggregating() == false; - messageHandler.accept(channel, PING_MESSAGE); - } else if (fragment == InboundDecoder.END_CONTENT) { - assert aggregator.isAggregating(); - try (NativeInboundMessage aggregated = aggregator.finishAggregation()) { - statsTracker.markMessageReceived(); - messageHandler.accept(channel, aggregated); - } - } else { - assert aggregator.isAggregating(); - assert fragment instanceof ReleasableBytesReference; - aggregator.aggregate((ReleasableBytesReference) fragment); - } - } - } - -} diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java index 1143f129b6319..47dcb87e5a386 100644 --- a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java @@ -32,118 +32,34 @@ package org.opensearch.transport.nativeprotocol; -import org.opensearch.common.annotation.PublicApi; +import org.opensearch.common.annotation.DeprecatedApi; import org.opensearch.common.bytes.ReleasableBytesReference; import org.opensearch.common.lease.Releasable; -import org.opensearch.common.lease.Releasables; -import org.opensearch.common.util.io.IOUtils; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.transport.Header; -import org.opensearch.transport.ProtocolInboundMessage; - -import java.io.IOException; +import org.opensearch.transport.InboundMessage; /** * Inbound data as a message * - * @opensearch.api + * This class is deprecated in favor of {@link InboundMessage}. */ -@PublicApi(since = "2.14.0") -public class NativeInboundMessage implements Releasable, ProtocolInboundMessage { +@DeprecatedApi(since = "2.17.0") +public class NativeInboundMessage extends InboundMessage { /** * The protocol used to encode this message */ public static String NATIVE_PROTOCOL = "native"; - private final Header header; - private final ReleasableBytesReference content; - private final Exception exception; - private final boolean isPing; - private Releasable breakerRelease; - private StreamInput streamInput; - public NativeInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { - this.header = header; - this.content = content; - this.breakerRelease = breakerRelease; - this.exception = null; - this.isPing = false; + super(header, content, breakerRelease); } public NativeInboundMessage(Header header, Exception exception) { - this.header = header; - this.content = null; - this.breakerRelease = null; - this.exception = exception; - this.isPing = false; + super(header, exception); } public NativeInboundMessage(Header header, boolean isPing) { - this.header = header; - this.content = null; - this.breakerRelease = null; - this.exception = null; - this.isPing = isPing; - } - - @Override - public String getProtocol() { - return NATIVE_PROTOCOL; - } - - public Header getHeader() { - return header; - } - - public int getContentLength() { - if (content == null) { - return 0; - } else { - return content.length(); - } - } - - public Exception getException() { - return exception; - } - - public boolean isPing() { - return isPing; + super(header, isPing); } - - public boolean isShortCircuit() { - return exception != null; - } - - public Releasable takeBreakerReleaseControl() { - final Releasable toReturn = breakerRelease; - breakerRelease = null; - if (toReturn != null) { - return toReturn; - } else { - return () -> {}; - } - } - - public StreamInput openOrGetStreamInput() throws IOException { - assert isPing == false && content != null; - if (streamInput == null) { - streamInput = content.streamInput(); - streamInput.setVersion(header.getVersion()); - } - return streamInput; - } - - @Override - public void close() { - IOUtils.closeWhileHandlingException(streamInput); - Releasables.closeWhileHandlingException(content, breakerRelease); - } - - @Override - public String toString() { - return "InboundMessage{" + header + "}"; - } - } diff --git a/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java b/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java index 4ac78366360d7..6168fd1c6a307 100644 --- a/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java @@ -42,7 +42,6 @@ import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.junit.Before; import java.io.IOException; @@ -79,7 +78,7 @@ public void setUp() throws Exception { public void testInboundAggregation() throws IOException { long requestId = randomNonNegativeLong(); - Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header header = new Header(TransportProtocol.NATIVE, randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); header.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); header.actionName = "action_name"; // Initiate Message @@ -108,7 +107,7 @@ public void testInboundAggregation() throws IOException { } // Signal EOS - NativeInboundMessage aggregated = aggregator.finishAggregation(); + InboundMessage aggregated = aggregator.finishAggregation(); assertThat(aggregated, notNullValue()); assertFalse(aggregated.isPing()); @@ -126,7 +125,7 @@ public void testInboundAggregation() throws IOException { public void testInboundUnknownAction() throws IOException { long requestId = randomNonNegativeLong(); - Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header header = new Header(TransportProtocol.NATIVE, randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); header.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); header.actionName = unknownAction; // Initiate Message @@ -139,7 +138,7 @@ public void testInboundUnknownAction() throws IOException { assertEquals(0, content.refCount()); // Signal EOS - NativeInboundMessage aggregated = aggregator.finishAggregation(); + InboundMessage aggregated = aggregator.finishAggregation(); assertThat(aggregated, notNullValue()); assertTrue(aggregated.isShortCircuit()); @@ -150,7 +149,13 @@ public void testInboundUnknownAction() throws IOException { public void testCircuitBreak() throws IOException { circuitBreaker.startBreaking(); // Actions are breakable - Header breakableHeader = new Header(randomInt(), randomNonNegativeLong(), TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header breakableHeader = new Header( + TransportProtocol.NATIVE, + randomInt(), + randomNonNegativeLong(), + TransportStatus.setRequest((byte) 0), + Version.CURRENT + ); breakableHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); breakableHeader.actionName = "action_name"; // Initiate Message @@ -162,7 +167,7 @@ public void testCircuitBreak() throws IOException { content1.close(); // Signal EOS - NativeInboundMessage aggregated1 = aggregator.finishAggregation(); + InboundMessage aggregated1 = aggregator.finishAggregation(); assertEquals(0, content1.refCount()); assertThat(aggregated1, notNullValue()); @@ -170,7 +175,13 @@ public void testCircuitBreak() throws IOException { assertThat(aggregated1.getException(), instanceOf(CircuitBreakingException.class)); // Actions marked as unbreakable are not broken - Header unbreakableHeader = new Header(randomInt(), randomNonNegativeLong(), TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header unbreakableHeader = new Header( + TransportProtocol.NATIVE, + randomInt(), + randomNonNegativeLong(), + TransportStatus.setRequest((byte) 0), + Version.CURRENT + ); unbreakableHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); unbreakableHeader.actionName = unBreakableAction; // Initiate Message @@ -181,7 +192,7 @@ public void testCircuitBreak() throws IOException { content2.close(); // Signal EOS - NativeInboundMessage aggregated2 = aggregator.finishAggregation(); + InboundMessage aggregated2 = aggregator.finishAggregation(); assertEquals(1, content2.refCount()); assertThat(aggregated2, notNullValue()); @@ -189,7 +200,13 @@ public void testCircuitBreak() throws IOException { // Handshakes are not broken final byte handshakeStatus = TransportStatus.setHandshake(TransportStatus.setRequest((byte) 0)); - Header handshakeHeader = new Header(randomInt(), randomNonNegativeLong(), handshakeStatus, Version.CURRENT); + Header handshakeHeader = new Header( + TransportProtocol.NATIVE, + randomInt(), + randomNonNegativeLong(), + handshakeStatus, + Version.CURRENT + ); handshakeHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); handshakeHeader.actionName = "handshake"; // Initiate Message @@ -200,7 +217,7 @@ public void testCircuitBreak() throws IOException { content3.close(); // Signal EOS - NativeInboundMessage aggregated3 = aggregator.finishAggregation(); + InboundMessage aggregated3 = aggregator.finishAggregation(); assertEquals(1, content3.refCount()); assertThat(aggregated3, notNullValue()); @@ -209,7 +226,7 @@ public void testCircuitBreak() throws IOException { public void testCloseWillCloseContent() { long requestId = randomNonNegativeLong(); - Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header header = new Header(TransportProtocol.NATIVE, randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); header.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); header.actionName = "action_name"; // Initiate Message @@ -249,7 +266,7 @@ public void testFinishAggregationWillFinishHeader() throws IOException { } else { actionName = "action_name"; } - Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); + Header header = new Header(TransportProtocol.NATIVE, randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); // Initiate Message aggregator.headerReceived(header); @@ -264,7 +281,7 @@ public void testFinishAggregationWillFinishHeader() throws IOException { content.close(); // Signal EOS - NativeInboundMessage aggregated = aggregator.finishAggregation(); + InboundMessage aggregated = aggregator.finishAggregation(); assertThat(aggregated, notNullValue()); assertFalse(header.needsToReadVariableHeader()); diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index 2553e7740990b..7779db9dacc3c 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -57,7 +57,6 @@ import org.opensearch.test.VersionUtils; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.junit.After; import org.junit.Before; @@ -152,7 +151,7 @@ public void testPing() throws Exception { ); requestHandlers.registerHandler(registry); - handler.inboundMessage(channel, new NativeInboundMessage(null, true)); + handler.inboundMessage(channel, InboundMessage.PING); if (channel.isServerChannel()) { BytesReference ping = channel.getMessageCaptor().get(); assertEquals('E', ping.get(0)); @@ -215,12 +214,14 @@ public TestResponse read(StreamInput in) throws IOException { false ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); - Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(requestContent), - () -> {} + Header requestHeader = new Header( + TransportProtocol.NATIVE, + fullRequestBytes.length() - 6, + requestId, + TransportStatus.setRequest((byte) 0), + version ); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -240,12 +241,8 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); - Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); - NativeInboundMessage responseMessage = new NativeInboundMessage( - responseHeader, - ReleasableBytesReference.wrap(responseContent), - () -> {} - ); + Header responseHeader = new Header(TransportProtocol.NATIVE, fullResponseBytes.length() - 6, requestId, responseStatus, version); + InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -267,12 +264,13 @@ public void testSendsErrorResponseToHandshakeFromCompatibleVersion() throws Exce final Version remoteVersion = VersionUtils.randomCompatibleVersion(random(), version); final long requestId = randomNonNegativeLong(); final Header requestHeader = new Header( + TransportProtocol.NATIVE, between(0, 100), requestId, TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final NativeInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); + final InboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; requestHeader.headers = Tuple.tuple(Map.of(), Map.of()); requestHeader.features = Set.of(); @@ -307,12 +305,13 @@ public void testClosesChannelOnErrorInHandshakeWithIncompatibleVersion() throws final Version remoteVersion = Version.fromId(randomIntBetween(0, version.minimumCompatibilityVersion().id - 1)); final long requestId = randomNonNegativeLong(); final Header requestHeader = new Header( + TransportProtocol.NATIVE, between(0, 100), requestId, TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final NativeInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); + final InboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; requestHeader.headers = Tuple.tuple(Map.of(), Map.of()); requestHeader.features = Set.of(); @@ -338,22 +337,19 @@ public void testLogsSlowInboundProcessing() throws Exception { final Version remoteVersion = Version.CURRENT; final long requestId = randomNonNegativeLong(); final Header requestHeader = new Header( + TransportProtocol.NATIVE, between(0, 100), requestId, TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(BytesArray.EMPTY), - () -> { - try { - TimeUnit.SECONDS.sleep(1L); - } catch (InterruptedException e) { - throw new AssertionError(e); - } + final InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> { + try { + TimeUnit.SECONDS.sleep(1L); + } catch (InterruptedException e) { + throw new AssertionError(e); } - ); + }); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; requestHeader.headers = Tuple.tuple(Collections.emptyMap(), Collections.emptyMap()); requestHeader.features = Set.of(); @@ -424,12 +420,14 @@ public void onResponseSent(long requestId, String action, Exception error) { BytesReference fullRequestBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); - Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(requestContent), - () -> {} + Header requestHeader = new Header( + TransportProtocol.NATIVE, + fullRequestBytes.length() - 6, + requestId, + TransportStatus.setRequest((byte) 0), + version ); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -493,12 +491,14 @@ public void onResponseSent(long requestId, String action, Exception error) { ); // Create the request payload by intentionally stripping 1 byte away BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize - 1); - Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(requestContent), - () -> {} + Header requestHeader = new Header( + TransportProtocol.NATIVE, + fullRequestBytes.length() - 6, + requestId, + TransportStatus.setRequest((byte) 0), + version ); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -561,12 +561,14 @@ public TestResponse read(StreamInput in) throws IOException { false ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); - Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(requestContent), - () -> {} + Header requestHeader = new Header( + TransportProtocol.NATIVE, + fullRequestBytes.length() - 6, + requestId, + TransportStatus.setRequest((byte) 0), + version ); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -587,12 +589,8 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); - Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); - NativeInboundMessage responseMessage = new NativeInboundMessage( - responseHeader, - ReleasableBytesReference.wrap(responseContent), - () -> {} - ); + Header responseHeader = new Header(TransportProtocol.NATIVE, fullResponseBytes.length() - 6, requestId, responseStatus, version); + InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -655,12 +653,14 @@ public TestResponse read(StreamInput in) throws IOException { false ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); - Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(requestContent), - () -> {} + Header requestHeader = new Header( + TransportProtocol.NATIVE, + fullRequestBytes.length() - 6, + requestId, + TransportStatus.setRequest((byte) 0), + version ); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -676,12 +676,8 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); // Create the response payload by intentionally stripping 1 byte away BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize - 1); - Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); - NativeInboundMessage responseMessage = new NativeInboundMessage( - responseHeader, - ReleasableBytesReference.wrap(responseContent), - () -> {} - ); + Header responseHeader = new Header(TransportProtocol.NATIVE, fullResponseBytes.length() - 6, requestId, responseStatus, version); + InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -690,8 +686,8 @@ public TestResponse read(StreamInput in) throws IOException { assertThat(exceptionCaptor.get().getMessage(), containsString("Failed to deserialize response from handler")); } - private static NativeInboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) { - return new NativeInboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) { + private static InboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) { + return new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) { @Override public StreamInput openOrGetStreamInput() { final StreamInput streamInput = new InputStreamStreamInput(new InputStream() { diff --git a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java index 5a89bf1e0ead3..cd6c4cf260176 100644 --- a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java @@ -49,7 +49,6 @@ import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; import java.util.ArrayList; @@ -82,9 +81,8 @@ public void testPipelineHandlingForNativeProtocol() throws IOException { final List> expected = new ArrayList<>(); final List> actual = new ArrayList<>(); final List toRelease = new ArrayList<>(); - final BiConsumer messageHandler = (c, m) -> { + final BiConsumer messageHandler = (c, message) -> { try { - NativeInboundMessage message = (NativeInboundMessage) m; final Header header = message.getHeader(); final MessageData actualData; final Version version = header.getVersion(); @@ -199,7 +197,7 @@ public void testPipelineHandlingForNativeProtocol() throws IOException { } public void testDecodeExceptionIsPropagated() throws IOException { - BiConsumer messageHandler = (c, m) -> {}; + BiConsumer messageHandler = (c, m) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); @@ -229,7 +227,7 @@ public void testDecodeExceptionIsPropagated() throws IOException { } public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { - BiConsumer messageHandler = (c, m) -> {}; + BiConsumer messageHandler = (c, m) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); diff --git a/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java index 01f19bea7a37f..11ca683c306bf 100644 --- a/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java @@ -52,7 +52,6 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.opensearch.transport.nativeprotocol.NativeOutboundHandler; import org.junit.After; import org.junit.Before; @@ -106,9 +105,8 @@ public void setUp() throws Exception { final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { - NativeInboundMessage m1 = (NativeInboundMessage) m; - Streams.copy(m1.openOrGetStreamInput(), streamOutput); - message.set(new Tuple<>(m1.getHeader(), streamOutput.bytes())); + Streams.copy(m.openOrGetStreamInput(), streamOutput); + message.set(new Tuple<>(m.getHeader(), streamOutput.bytes())); } catch (IOException e) { throw new AssertionError(e); } diff --git a/server/src/test/java/org/opensearch/transport/TransportProtocolTests.java b/server/src/test/java/org/opensearch/transport/TransportProtocolTests.java new file mode 100644 index 0000000000000..024d3281fb76e --- /dev/null +++ b/server/src/test/java/org/opensearch/transport/TransportProtocolTests.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport; + +import org.opensearch.test.OpenSearchTestCase; + +public class TransportProtocolTests extends OpenSearchTestCase { + + public void testNativeProtocol() { + assertEquals(TransportProtocol.NATIVE, TransportProtocol.fromBytes((byte) 'E', (byte) 'S')); + } + + public void testInvalidProtocol() { + assertThrows(IllegalArgumentException.class, () -> TransportProtocol.fromBytes((byte) 'e', (byte) 'S')); + } +}