diff --git a/ktor-server/ktor-server-test-suites/jvmAndNix/src/io/ktor/server/testing/suites/WebSocketEngineSuite.kt b/ktor-server/ktor-server-test-suites/jvmAndNix/src/io/ktor/server/testing/suites/WebSocketEngineSuite.kt index e59ab4e15e4..52096d8c5a6 100644 --- a/ktor-server/ktor-server-test-suites/jvmAndNix/src/io/ktor/server/testing/suites/WebSocketEngineSuite.kt +++ b/ktor-server/ktor-server-test-suites/jvmAndNix/src/io/ktor/server/testing/suites/WebSocketEngineSuite.kt @@ -14,6 +14,7 @@ import io.ktor.server.testing.* import io.ktor.server.websocket.* import io.ktor.util.* import io.ktor.utils.io.* +import io.ktor.utils.io.bits.* import io.ktor.utils.io.charsets.* import io.ktor.utils.io.core.* import io.ktor.websocket.* @@ -579,6 +580,49 @@ abstract class WebSocketEngineSuite() + val second = CompletableDeferred() + createAndStartServer { + webSocket("/") { + val frame = incoming.receive() + assertIs(frame) + first.complete(frame) + + val frame2 = incoming.receive() + assertIs(frame2) + second.complete(frame2) + } + } + + useSocket { + negotiateHttpWebSocket() + + output.apply { + repeat(2) { + writeFrameTest(Frame.Text(false, "Hello".toByteArray(), true, false, false), false) + writeFrameTest(Frame.Text(true, ", World".toByteArray(), false, false, false), false, opcode = 0) + } + writeFrameTest(Frame.Close(), false) + flush() + } + } + + fun checkFrame(frame: Frame) { + assertIs(frame) + assertTrue(frame.fin) + assertTrue(frame.rsv1) + assertFalse(frame.rsv2) + assertFalse(frame.rsv3) + + assertEquals("Hello, World", frame.readText()) + } + + checkFrame(first.await()) + checkFrame(second.await()) + } + private suspend fun Connection.negotiateHttpWebSocket() { // send upgrade request output.apply { @@ -621,6 +665,7 @@ abstract class WebSocketEngineSuite fail("Unexpected frame $frame: \n${hex(frame.data)}") } } @@ -684,3 +729,53 @@ abstract class WebSocketEngineSuite length + length <= 0xffff -> 126 + else -> 127 + } + + val maskAndLength = masking.flagAt(7) or formattedLength + + writeByte(maskAndLength.toByte()) + + when (formattedLength) { + 126 -> writeShort(length.toShort()) + 127 -> writeLong(length.toLong()) + } + + val data = ByteReadPacket(frame.data) + + val maskedData = when (masking) { + true -> { + val maskKey = Random.nextInt() + writeInt(maskKey) + data.mask(maskKey) + } + false -> data + } + writePacket(maskedData) +} + +internal fun Boolean.flagAt(at: Int) = if (this) 1 shl at else 0 + +private fun ByteReadPacket.mask(maskKey: Int): ByteReadPacket = withMemory(4) { maskMemory -> + maskMemory.storeIntAt(0, maskKey) + buildPacket { + repeat(remaining.toInt()) { i -> + writeByte((readByte().toInt() xor (maskMemory[i % 4].toInt())).toByte()) + } + } +} diff --git a/ktor-server/ktor-server-tomcat-jakarta/jvm/test/io/ktor/tests/server/tomcat/jakarta/TomcatWebSocketTest.kt b/ktor-server/ktor-server-tomcat-jakarta/jvm/test/io/ktor/tests/server/tomcat/jakarta/TomcatWebSocketTest.kt index a186baa3923..3e70505e829 100644 --- a/ktor-server/ktor-server-tomcat-jakarta/jvm/test/io/ktor/tests/server/tomcat/jakarta/TomcatWebSocketTest.kt +++ b/ktor-server/ktor-server-tomcat-jakarta/jvm/test/io/ktor/tests/server/tomcat/jakarta/TomcatWebSocketTest.kt @@ -14,4 +14,8 @@ class TomcatWebSocketTest : @Ignore override fun testClientClosingFirst() { } + + @Ignore + override fun testFragmentedFlagsFromTheFirstFrame() { + } } diff --git a/ktor-shared/ktor-websockets/common/src/io/ktor/websocket/DefaultWebSocketSession.kt b/ktor-shared/ktor-websockets/common/src/io/ktor/websocket/DefaultWebSocketSession.kt index 7e1384cc3ec..fd7edc36c97 100644 --- a/ktor-shared/ktor-websockets/common/src/io/ktor/websocket/DefaultWebSocketSession.kt +++ b/ktor-shared/ktor-websockets/common/src/io/ktor/websocket/DefaultWebSocketSession.kt @@ -159,7 +159,8 @@ internal class DefaultWebSocketSessionImpl( private fun runIncomingProcessor(ponger: SendChannel): Job = launch( IncomingProcessorCoroutineName + Dispatchers.Unconfined ) { - var last: BytePacketBuilder? = null + var firstFrame: Frame? = null + var frameBody: BytePacketBuilder? = null var closeFramePresented = false try { @OptIn(DelicateCoroutinesApi::class) @@ -177,31 +178,37 @@ internal class DefaultWebSocketSessionImpl( is Frame.Pong -> pinger.value?.send(frame) is Frame.Ping -> ponger.send(frame) else -> { - checkMaxFrameSize(last, frame) + checkMaxFrameSize(frameBody, frame) if (!frame.fin) { - if (last == null) { - last = BytePacketBuilder() + if (firstFrame == null) { + firstFrame = frame } + if (frameBody == null) { + frameBody = BytePacketBuilder() + } + + frameBody!!.writeFully(frame.data) + return@consumeEach + } - last!!.writeFully(frame.data) + if (firstFrame == null) { + filtered.send(processIncomingExtensions(frame)) return@consumeEach } - val frameToSend = last?.let { builder -> - builder.writeFully(frame.data) - Frame.byType( - fin = true, - frame.frameType, - builder.build().readBytes(), - frame.rsv1, - frame.rsv2, - frame.rsv3 - ) - } ?: frame - - last = null - filtered.send(processIncomingExtensions(frameToSend)) + frameBody!!.writeFully(frame.data) + val defragmented = Frame.byType( + fin = true, + firstFrame!!.frameType, + frameBody!!.build().readBytes(), + firstFrame!!.rsv1, + firstFrame!!.rsv2, + firstFrame!!.rsv3 + ) + + firstFrame = null + filtered.send(processIncomingExtensions(defragmented)) } } } @@ -211,7 +218,7 @@ internal class DefaultWebSocketSessionImpl( filtered.close(cause) } finally { ponger.close() - last?.release() + frameBody?.release() filtered.close() if (!closeFramePresented) {