diff --git a/Makefile b/Makefile index 9a9f13565..771ef91eb 100644 --- a/Makefile +++ b/Makefile @@ -59,7 +59,7 @@ test-plugin: test-plugin-nio: swift build $(CFLAGS) --product protoc-gen-swiftgrpc protoc Sources/Examples/Echo/echo.proto --proto_path=Sources/Examples/Echo --plugin=.build/debug/protoc-gen-swift --plugin=.build/debug/protoc-gen-swiftgrpc --swiftgrpc_out=/tmp --swiftgrpc_opt=Client=false,NIO=true - diff -u /tmp/echo.grpc.swift Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift + diff -u /tmp/echo.grpc.swift Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift xcodebuild: project xcodebuild -project SwiftGRPC.xcodeproj -configuration "Debug" -parallelizeTargets -target SwiftGRPC -target Echo -target Simple -target protoc-gen-swiftgrpc build diff --git a/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift index 33d18d5e6..ae3986fd5 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift @@ -8,29 +8,79 @@ import NIOHTTP1 /// Calls through to `processMessage` for individual messages it receives, which needs to be implemented by subclasses. public class BaseCallHandler: GRPCCallHandler { public func makeGRPCServerCodec() -> ChannelHandler { return GRPCServerCodec() } - + /// Called whenever a message has been received. /// /// Overridden by subclasses. - public func processMessage(_ message: RequestMessage) { + public func processMessage(_ message: RequestMessage) throws { fatalError("needs to be overridden") } - + /// Called when the client has half-closed the stream, indicating that they won't send any further data. /// /// Overridden by subclasses if the "end-of-stream" event is relevant. public func endOfStreamReceived() { } + + /// Whether this handler can still write messages to the client. + private var serverCanWrite = true + + /// Called for each error recieved in `errorCaught(ctx:error:)`. + private weak var errorDelegate: ServerErrorDelegate? + + public init(errorDelegate: ServerErrorDelegate?) { + self.errorDelegate = errorDelegate + } } extension BaseCallHandler: ChannelInboundHandler { public typealias InboundIn = GRPCServerRequestPart - public typealias OutboundOut = GRPCServerResponsePart + + /// Passes errors to the user-provided `errorHandler`. After an error has been received an + /// appropriate status is written. Errors which don't conform to `GRPCStatusTransformable` + /// return a status with code `.internalError`. + public func errorCaught(ctx: ChannelHandlerContext, error: Error) { + errorDelegate?.observe(error) + + let transformed = errorDelegate?.transform(error) ?? error + let status = (transformed as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError + self.write(ctx: ctx, data: NIOAny(GRPCServerResponsePart.status(status)), promise: nil) + } public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { switch self.unwrapInboundIn(data) { - case .head: preconditionFailure("should not have received headers") - case .message(let message): processMessage(message) - case .end: endOfStreamReceived() + case .head(let requestHead): + // Head should have been handled by `GRPCChannelHandler`. + self.errorCaught(ctx: ctx, error: GRPCServerError.invalidState("unexpected request head received \(requestHead)")) + + case .message(let message): + do { + try processMessage(message) + } catch { + self.errorCaught(ctx: ctx, error: error) + } + + case .end: + endOfStreamReceived() + } + } +} + +extension BaseCallHandler: ChannelOutboundHandler { + public typealias OutboundIn = GRPCServerResponsePart + public typealias OutboundOut = GRPCServerResponsePart + + public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + guard serverCanWrite else { + promise?.fail(error: GRPCServerError.serverNotWritable) + return + } + + // We can only write one status; make sure we don't write again. + if case .status = unwrapOutboundIn(data) { + serverCanWrite = false + ctx.writeAndFlush(data, promise: promise) + } else { + ctx.write(data, promise: promise) } } } diff --git a/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift index 46f4b7622..2d5a4294c 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift @@ -15,8 +15,8 @@ public class BidirectionalStreamingCallHandler) -> EventLoopFuture) { - super.init() + public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (StreamingResponseCallContext) -> EventLoopFuture) { + super.init(errorDelegate: errorDelegate) let context = StreamingResponseCallContextImpl(channel: channel, request: request) self.context = context let eventObserver = eventObserverFactory(context) diff --git a/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift index bd03ae744..a6213a497 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift @@ -14,8 +14,8 @@ public class ClientStreamingCallHandler) -> EventLoopFuture) { - super.init() + public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (UnaryResponseCallContext) -> EventLoopFuture) { + super.init(errorDelegate: errorDelegate) let context = UnaryResponseCallContextImpl(channel: channel, request: request) self.context = context let eventObserver = eventObserverFactory(context) diff --git a/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift index 893745c69..6374cea58 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift @@ -13,8 +13,8 @@ public class ServerStreamingCallHandler? - public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (StreamingResponseCallContext) -> EventObserver) { - super.init() + public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (StreamingResponseCallContext) -> EventObserver) { + super.init(errorDelegate: errorDelegate) let context = StreamingResponseCallContextImpl(channel: channel, request: request) self.context = context self.eventObserver = eventObserverFactory(context) @@ -26,12 +26,10 @@ public class ServerStreamingCallHandler private var context: UnaryResponseCallContext? - public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (UnaryResponseCallContext) -> EventObserver) { - super.init() + public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (UnaryResponseCallContext) -> EventObserver) { + super.init(errorDelegate: errorDelegate) let context = UnaryResponseCallContextImpl(channel: channel, request: request) self.context = context self.eventObserver = eventObserverFactory(context) @@ -26,12 +26,10 @@ public class UnaryCallHandler } } - public override func processMessage(_ message: RequestMessage) { + public override func processMessage(_ message: RequestMessage) throws { guard let eventObserver = self.eventObserver, let context = self.context else { - //! FIXME: Better handle this error? - print("multiple messages received on unary call") - return + throw GRPCServerError.requestCardinalityViolation } let resultFuture = eventObserver(message) diff --git a/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift b/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift index 8f1e0d6e9..d18b7a4dc 100644 --- a/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift +++ b/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift @@ -19,7 +19,7 @@ public protocol CallHandlerProvider: class { /// Determines, calls and returns the appropriate request handler (`GRPCCallHandler`), depending on the request's /// method. Returns nil for methods not handled by this service. - func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel) -> GRPCCallHandler? + func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate?) -> GRPCCallHandler? } /// Listens on a newly-opened HTTP2 subchannel and yields to the sub-handler matching a call, if available. @@ -28,9 +28,11 @@ public protocol CallHandlerProvider: class { /// for an `GRPCCallHandler` object. That object is then forwarded the individual gRPC messages. public final class GRPCChannelHandler { private let servicesByName: [String: CallHandlerProvider] + private weak var errorDelegate: ServerErrorDelegate? - public init(servicesByName: [String: CallHandlerProvider]) { + public init(servicesByName: [String: CallHandlerProvider], errorDelegate: ServerErrorDelegate?) { self.servicesByName = servicesByName + self.errorDelegate = errorDelegate } } @@ -38,20 +40,20 @@ extension GRPCChannelHandler: ChannelInboundHandler { public typealias InboundIn = RawGRPCServerRequestPart public typealias OutboundOut = RawGRPCServerResponsePart + public func errorCaught(ctx: ChannelHandlerContext, error: Error) { + errorDelegate?.observe(error) + + let transformedError = errorDelegate?.transform(error) ?? error + let status = (transformedError as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError + ctx.writeAndFlush(wrapOutboundOut(.status(status)), promise: nil) + } + public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { let requestPart = self.unwrapInboundIn(data) switch requestPart { case .head(let requestHead): - // URI format: "/package.Servicename/MethodName", resulting in the following components separated by a slash: - // - uriComponents[0]: empty - // - uriComponents[1]: service name (including the package name); - // `CallHandlerProvider`s should provide the service name including the package name. - // - uriComponents[2]: method name. - let uriComponents = requestHead.uri.components(separatedBy: "/") - guard uriComponents.count >= 3 && uriComponents[0].isEmpty, - let providerForServiceName = servicesByName[uriComponents[1]], - let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: ctx.channel) else { - ctx.writeAndFlush(self.wrapOutboundOut(.status(.unimplemented(method: requestHead.uri))), promise: nil) + guard let callHandler = getCallHandler(channel: ctx.channel, requestHead: requestHead) else { + errorCaught(ctx: ctx, error: GRPCServerError.unimplementedMethod(requestHead.uri)) return } @@ -75,7 +77,25 @@ extension GRPCChannelHandler: ChannelInboundHandler { .whenComplete { ctx.pipeline.remove(handler: self, promise: handlerRemoved) } case .message, .end: - preconditionFailure("received \(requestPart), should have been removed as a handler at this point") + // We can reach this point if we're receiving messages for a method that isn't implemented. + // A status resposne will have been fired which should also close the stream; there's not a + // lot we can do at this point. + break + } + } + + private func getCallHandler(channel: Channel, requestHead: HTTPRequestHead) -> GRPCCallHandler? { + // URI format: "/package.Servicename/MethodName", resulting in the following components separated by a slash: + // - uriComponents[0]: empty + // - uriComponents[1]: service name (including the package name); + // `CallHandlerProvider`s should provide the service name including the package name. + // - uriComponents[2]: method name. + let uriComponents = requestHead.uri.components(separatedBy: "/") + guard uriComponents.count >= 3 && uriComponents[0].isEmpty, + let providerForServiceName = servicesByName[uriComponents[1]], + let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: channel, errorDelegate: errorDelegate) else { + return nil } + return callHandler } } diff --git a/Sources/SwiftGRPCNIO/GRPCServer.swift b/Sources/SwiftGRPCNIO/GRPCServer.swift index ef9a7b7e1..b5fa24d81 100644 --- a/Sources/SwiftGRPCNIO/GRPCServer.swift +++ b/Sources/SwiftGRPCNIO/GRPCServer.swift @@ -12,7 +12,9 @@ public final class GRPCServer { hostname: String, port: Int, eventLoopGroup: EventLoopGroup, - serviceProviders: [CallHandlerProvider]) -> EventLoopFuture { + serviceProviders: [CallHandlerProvider], + errorDelegate: ServerErrorDelegate? = LoggingServerErrorDelegate() + ) -> EventLoopFuture { let servicesByName = Dictionary(uniqueKeysWithValues: serviceProviders.map { ($0.serviceName, $0) }) let bootstrap = ServerBootstrap(group: eventLoopGroup) // Specify a backlog to avoid overloading the server. @@ -25,7 +27,7 @@ public final class GRPCServer { return channel.pipeline.add(handler: HTTPProtocolSwitcher { channel -> EventLoopFuture in return channel.pipeline.add(handler: HTTP1ToRawGRPCServerCodec()) - .then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName)) } + .then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName, errorDelegate: errorDelegate)) } }) } @@ -34,13 +36,22 @@ public final class GRPCServer { .childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) return bootstrap.bind(host: hostname, port: port) - .map { GRPCServer(channel: $0) } + .map { GRPCServer(channel: $0, errorDelegate: errorDelegate) } } private let channel: Channel + private var errorDelegate: ServerErrorDelegate? - private init(channel: Channel) { + private init(channel: Channel, errorDelegate: ServerErrorDelegate?) { self.channel = channel + + // Maintain a strong reference to ensure it lives as long as the server. + self.errorDelegate = errorDelegate + + // nil out errorDelegate to avoid retain cycles. + onClose.whenComplete { + self.errorDelegate = nil + } } /// Fired when the server shuts down. @@ -48,6 +59,7 @@ public final class GRPCServer { return channel.closeFuture } + /// Shut down the server; this should be called to avoid leaking resources. public func close() -> EventLoopFuture { return channel.close(mode: .all) } diff --git a/Sources/SwiftGRPCNIO/GRPCServerCodec.swift b/Sources/SwiftGRPCNIO/GRPCServerCodec.swift index 21652e8bb..9193c9ea9 100644 --- a/Sources/SwiftGRPCNIO/GRPCServerCodec.swift +++ b/Sources/SwiftGRPCNIO/GRPCServerCodec.swift @@ -19,7 +19,7 @@ public enum GRPCServerResponsePart { } /// A simple channel handler that translates raw gRPC packets into decoded protobuf messages, and vice versa. -public final class GRPCServerCodec { } +public final class GRPCServerCodec {} extension GRPCServerCodec: ChannelInboundHandler { public typealias InboundIn = RawGRPCServerRequestPart @@ -35,8 +35,7 @@ extension GRPCServerCodec: ChannelInboundHandler { do { ctx.fireChannelRead(self.wrapInboundOut(.message(try RequestMessage(serializedData: messageAsData)))) } catch { - //! FIXME: Ensure that the last handler in the pipeline returns `.dataLoss` here? - ctx.fireErrorCaught(error) + ctx.fireErrorCaught(GRPCServerError.requestProtoParseFailure) } case .end: @@ -54,6 +53,7 @@ extension GRPCServerCodec: ChannelOutboundHandler { switch responsePart { case .headers(let headers): ctx.write(self.wrapOutboundOut(.headers(headers)), promise: promise) + case .message(let message): do { let messageData = try message.serializedData() @@ -61,9 +61,11 @@ extension GRPCServerCodec: ChannelOutboundHandler { responseBuffer.write(bytes: messageData) ctx.write(self.wrapOutboundOut(.message(responseBuffer)), promise: promise) } catch { + let error = GRPCServerError.responseProtoSerializationFailure promise?.fail(error: error) ctx.fireErrorCaught(error) } + case .status(let status): ctx.write(self.wrapOutboundOut(.status(status)), promise: promise) } diff --git a/Sources/SwiftGRPCNIO/GRPCServerError.swift b/Sources/SwiftGRPCNIO/GRPCServerError.swift new file mode 100644 index 000000000..309fe3ad7 --- /dev/null +++ b/Sources/SwiftGRPCNIO/GRPCServerError.swift @@ -0,0 +1,76 @@ +/* + * Copyright 2019, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Foundation + +public enum GRPCServerError: Error, Equatable { + /// The RPC method is not implemented on the server. + case unimplementedMethod(String) + + /// It was not possible to decode a base64 message (gRPC-Web only). + case base64DecodeError + + /// It was not possible to parse the request protobuf. + case requestProtoParseFailure + + /// It was not possible to serialize the response protobuf. + case responseProtoSerializationFailure + + /// The given compression mechanism is not supported. + case unsupportedCompressionMechanism(String) + + /// Compression was indicated in the gRPC message, but not for the call. + case unexpectedCompression + + /// More than one request was sent for a unary-request call. + case requestCardinalityViolation + + /// The server received a message when it was not in a writable state. + case serverNotWritable + + /// An invalid state has been reached; something has gone very wrong. + case invalidState(String) +} + +extension GRPCServerError: GRPCStatusTransformable { + public func asGRPCStatus() -> GRPCStatus { + // These status codes are informed by: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md + switch self { + case .unimplementedMethod(let method): + return GRPCStatus(code: .unimplemented, message: "unknown method \(method)") + + case .base64DecodeError: + return GRPCStatus(code: .internalError, message: "could not decode base64 message") + + case .requestProtoParseFailure: + return GRPCStatus(code: .internalError, message: "could not parse request proto") + + case .responseProtoSerializationFailure: + return GRPCStatus(code: .internalError, message: "could not serialize response proto") + + case .unsupportedCompressionMechanism(let mechanism): + return GRPCStatus(code: .unimplemented, message: "unsupported compression mechanism \(mechanism)") + + case .unexpectedCompression: + return GRPCStatus(code: .unimplemented, message: "compression was enabled for this gRPC message but not for this call") + + case .requestCardinalityViolation: + return GRPCStatus(code: .unimplemented, message: "request cardinality violation; method requires exactly one request but client sent more") + + case .serverNotWritable, .invalidState: + return GRPCStatus.processingError + } + } +} diff --git a/Sources/SwiftGRPCNIO/GRPCStatus.swift b/Sources/SwiftGRPCNIO/GRPCStatus.swift index 9e4109b82..7a023242c 100644 --- a/Sources/SwiftGRPCNIO/GRPCStatus.swift +++ b/Sources/SwiftGRPCNIO/GRPCStatus.swift @@ -2,7 +2,7 @@ import Foundation import NIOHTTP1 /// Encapsulates the result of a gRPC call. -public struct GRPCStatus: Error { +public struct GRPCStatus: Error, Equatable { /// The code to return in the `grpc-status` header. public let code: StatusCode /// The message to return in the `grpc-message` header. @@ -22,9 +22,14 @@ public struct GRPCStatus: Error { public static let ok = GRPCStatus(code: .ok, message: "OK") /// "Internal server error" status. public static let processingError = GRPCStatus(code: .internalError, message: "unknown error processing request") +} + +public protocol GRPCStatusTransformable: Error { + func asGRPCStatus() -> GRPCStatus +} - /// Status indicating that the given method is not implemented. - public static func unimplemented(method: String) -> GRPCStatus { - return GRPCStatus(code: .unimplemented, message: "unknown method " + method) +extension GRPCStatus: GRPCStatusTransformable { + public func asGRPCStatus() -> GRPCStatus { + return self } } diff --git a/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift b/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift index 440afd511..2b502398a 100644 --- a/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift +++ b/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift @@ -28,31 +28,8 @@ public enum RawGRPCServerResponsePart { /// /// The translation from HTTP2 to HTTP1 is done by `HTTP2ToHTTP1ServerCodec`. public final class HTTP1ToRawGRPCServerCodec { - /// Expected content types for incoming requests. - private enum ContentType: String { - /// Binary encoded gRPC request. - case binary = "application/grpc" - /// Base64 encoded gRPC-Web request. - case text = "application/grpc-web-text" - /// Binary encoded gRPC-Web request. - case web = "application/grpc-web" - } - - private enum State { - case expectingHeaders - case expectingCompressedFlag - case expectingMessageLength - case receivedMessageLength(UInt32) - - var expectingBody: Bool { - switch self { - case .expectingHeaders: return false - case .expectingCompressedFlag, .expectingMessageLength, .receivedMessageLength: return true - } - } - } - - private var state = State.expectingHeaders + // 1-byte for compression flag, 4-bytes for message length. + private let protobufMetadataSize = 5 private var contentType: ContentType? @@ -62,7 +39,8 @@ public final class HTTP1ToRawGRPCServerCodec { // would then have to be re-assigned into the class variable for the changes to take effect. // By force unwrapping, we avoid those reassignments, and the code is a bit cleaner. - // Buffer to store binary encoded protos as they're being received. + // Buffer to store binary encoded protos as they're being received if the proto is split across + // multiple buffers. private var binaryRequestBuffer: NIO.ByteBuffer! // Buffers to store text encoded protos. Only used when content-type is application/grpc-web-text. @@ -70,6 +48,40 @@ public final class HTTP1ToRawGRPCServerCodec { // the HTTP1.1 pipeline, as it's starting to get in the way of readability. private var requestTextBuffer: NIO.ByteBuffer! private var responseTextBuffer: NIO.ByteBuffer! + + var inboundState = InboundState.expectingHeaders + var outboundState = OutboundState.expectingHeaders +} + +extension HTTP1ToRawGRPCServerCodec { + /// Expected content types for incoming requests. + private enum ContentType: String { + /// Binary encoded gRPC request. + case binary = "application/grpc" + /// Base64 encoded gRPC-Web request. + case text = "application/grpc-web-text" + /// Binary encoded gRPC-Web request. + case web = "application/grpc-web" + } + + enum InboundState { + case expectingHeaders + case expectingBody(Body) + // ignore any additional messages; e.g. we've seen .end or we've sent an error and are waiting for the stream to close. + case ignore + + enum Body { + case expectingCompressedFlag + case expectingMessageLength + case expectingMoreMessageBytes(UInt32) + } + } + + enum OutboundState { + case expectingHeaders + case expectingBodyOrStatus + case ignore + } } extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { @@ -77,89 +89,133 @@ extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { public typealias InboundOut = RawGRPCServerRequestPart public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { - switch self.unwrapInboundIn(data) { - case .head(let requestHead): - guard case .expectingHeaders = state - else { preconditionFailure("received headers while in state \(state)") } - - state = .expectingCompressedFlag - binaryRequestBuffer = ctx.channel.allocator.buffer(capacity: 5) - if let contentTypeHeader = requestHead.headers["content-type"].first { - contentType = ContentType(rawValue: contentTypeHeader) - } else { - // If the Content-Type is not present, assume the request is binary encoded gRPC. - contentType = .binary + if case .ignore = inboundState { return } + + do { + switch self.unwrapInboundIn(data) { + case .head(let requestHead): + inboundState = try processHead(ctx: ctx, requestHead: requestHead) + + case .body(var body): + inboundState = try processBody(ctx: ctx, body: &body) + + case .end(let trailers): + inboundState = try processEnd(ctx: ctx, trailers: trailers) } - if contentType == .text { - requestTextBuffer = ctx.channel.allocator.buffer(capacity: 0) + } catch { + ctx.fireErrorCaught(error) + inboundState = .ignore + } + } + + func processHead(ctx: ChannelHandlerContext, requestHead: HTTPRequestHead) throws -> InboundState { + guard case .expectingHeaders = inboundState else { + throw GRPCServerError.invalidState("expecteded state .expectingHeaders, got \(inboundState)") + } + + if let contentTypeHeader = requestHead.headers["content-type"].first { + contentType = ContentType(rawValue: contentTypeHeader) + } else { + // If the Content-Type is not present, assume the request is binary encoded gRPC. + contentType = .binary + } + + if contentType == .text { + requestTextBuffer = ctx.channel.allocator.buffer(capacity: 0) + } + + ctx.fireChannelRead(self.wrapInboundOut(.head(requestHead))) + return .expectingBody(.expectingCompressedFlag) + } + + func processBody(ctx: ChannelHandlerContext, body: inout ByteBuffer) throws -> InboundState { + guard case .expectingBody(let bodyState) = inboundState else { + throw GRPCServerError.invalidState("expecteded state .expectingBody(_), got \(inboundState)") + } + + // If the contentType is text, then decode the incoming bytes as base64 encoded, and append + // it to the binary buffer. If the request is chunked, this section will process the text + // in the biggest chunk that is multiple of 4, leaving the unread bytes in the textBuffer + // where it will expect a new incoming chunk. + if contentType == .text { + precondition(requestTextBuffer != nil) + requestTextBuffer.write(buffer: &body) + + // Read in chunks of 4 bytes as base64 encoded strings will always be multiples of 4. + let readyBytes = requestTextBuffer.readableBytes - (requestTextBuffer.readableBytes % 4) + guard let base64Encoded = requestTextBuffer.readString(length: readyBytes), + let decodedData = Data(base64Encoded: base64Encoded) else { + throw GRPCServerError.base64DecodeError } - ctx.fireChannelRead(self.wrapInboundOut(.head(requestHead))) - case .body(var body): - precondition(binaryRequestBuffer != nil, "buffer not initialized") - assert(state.expectingBody, "received body while in state \(state)") + body.write(bytes: decodedData) + } + + return .expectingBody(try processBodyState(ctx: ctx, initialState: bodyState, messageBuffer: &body)) + } - // If the contentType is text, then decode the incoming bytes as base64 encoded, and append - // it to the binary buffer. If the request is chunked, this section will process the text - // in the biggest chunk that is multiple of 4, leaving the unread bytes in the textBuffer - // where it will expect a new incoming chunk. - if contentType == .text { - precondition(requestTextBuffer != nil) - requestTextBuffer.write(buffer: &body) - // Read in chunks of 4 bytes as base64 encoded strings will always be multiples of 4. - let readyBytes = requestTextBuffer.readableBytes - (requestTextBuffer.readableBytes % 4) - guard let base64Encoded = requestTextBuffer.readString(length:readyBytes), - let decodedData = Data(base64Encoded: base64Encoded) else { - //! FIXME: Improve error handling when the message couldn't be decoded as base64. - ctx.close(mode: .all, promise: nil) - return + func processBodyState(ctx: ChannelHandlerContext, initialState: InboundState.Body, messageBuffer: inout ByteBuffer) throws -> InboundState.Body { + var bodyState = initialState + + // Iterate over all available incoming data, trying to read length-delimited messages. + // Each message has the following format: + // - 1 byte "compressed" flag (currently always zero, as we do not support for compression) + // - 4 byte signed-integer payload length (N) + // - N bytes payload (normally a valid wire-format protocol buffer) + while true { + switch bodyState { + case .expectingCompressedFlag: + guard let compressedFlag: Int8 = messageBuffer.readInteger() else { return .expectingCompressedFlag } + + // TODO: Add support for compression. + guard compressedFlag == 0 else { throw GRPCServerError.unexpectedCompression } + + bodyState = .expectingMessageLength + + case .expectingMessageLength: + guard let messageLength: UInt32 = messageBuffer.readInteger() else { return .expectingMessageLength } + bodyState = .expectingMoreMessageBytes(messageLength) + + case .expectingMoreMessageBytes(let bytesOutstanding): + // We need to account for messages being spread across multiple `ByteBuffer`s so buffer them + // into `buffer`. Note: when messages are contained within a single `ByteBuffer` we're just + // taking a slice so don't incur any extra writes. + guard messageBuffer.readableBytes >= bytesOutstanding else { + let remainingBytes = bytesOutstanding - numericCast(messageBuffer.readableBytes) + + if self.binaryRequestBuffer != nil { + self.binaryRequestBuffer.write(buffer: &messageBuffer) + } else { + messageBuffer.reserveCapacity(numericCast(bytesOutstanding)) + self.binaryRequestBuffer = messageBuffer + } + return .expectingMoreMessageBytes(remainingBytes) } - binaryRequestBuffer.write(bytes: decodedData) - } else { - binaryRequestBuffer.write(buffer: &body) - } - // Iterate over all available incoming data, trying to read length-delimited messages. - // Each message has the following format: - // - 1 byte "compressed" flag (currently always zero, as we do not support for compression) - // - 4 byte signed-integer payload length (N) - // - N bytes payload (normally a valid wire-format protocol buffer) - requestProcessing: while true { - switch state { - case .expectingHeaders: preconditionFailure("unexpected state \(state)") - case .expectingCompressedFlag: - guard let compressionFlag: Int8 = binaryRequestBuffer.readInteger() else { break requestProcessing } - //! FIXME: Avoid crashing here and instead drop the connection. - precondition(compressionFlag == 0, "unexpected compression flag \(compressionFlag); compression is not supported and we did not indicate support for it") - state = .expectingMessageLength - - case .expectingMessageLength: - guard let messageLength: UInt32 = binaryRequestBuffer.readInteger() else { break requestProcessing } - state = .receivedMessageLength(messageLength) - - case .receivedMessageLength(let messageLength): - guard let messageBytes = binaryRequestBuffer.readBytes(length: numericCast(messageLength)) else { break } - - //! FIXME: Use a slice of this buffer instead of copying to a new buffer. - var messageBuffer = ctx.channel.allocator.buffer(capacity: messageBytes.count) - messageBuffer.write(bytes: messageBytes) - ctx.fireChannelRead(self.wrapInboundOut(.message(messageBuffer))) - //! FIXME: Call buffer.discardReadBytes() here? - //! ALTERNATIVE: Check if the buffer has no further data right now, then clear it. - - state = .expectingCompressedFlag + // We know buffer.readableBytes >= messageLength, so it's okay to force unwrap here. + var slice = messageBuffer.readSlice(length: numericCast(bytesOutstanding))! + + if self.binaryRequestBuffer != nil { + self.binaryRequestBuffer.write(buffer: &slice) + ctx.fireChannelRead(self.wrapInboundOut(.message(self.binaryRequestBuffer))) + } else { + ctx.fireChannelRead(self.wrapInboundOut(.message(slice))) } - } - case .end(let trailers): - if let trailers = trailers { - //! FIXME: Better handle this error. - print("unexpected trailers received: \(trailers)") - return + self.binaryRequestBuffer = nil + bodyState = .expectingCompressedFlag } - ctx.fireChannelRead(self.wrapInboundOut(.end)) } } + + private func processEnd(ctx: ChannelHandlerContext, trailers: HTTPHeaders?) throws -> InboundState { + if let trailers = trailers { + throw GRPCServerError.invalidState("unexpected trailers received \(trailers)") + } + + ctx.fireChannelRead(self.wrapInboundOut(.end)) + return .ignore + } } extension HTTP1ToRawGRPCServerCodec: ChannelOutboundHandler { @@ -167,10 +223,12 @@ extension HTTP1ToRawGRPCServerCodec: ChannelOutboundHandler { public typealias OutboundOut = HTTPServerResponsePart public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - let responsePart = self.unwrapOutboundIn(data) - switch responsePart { - case .headers: - var headers = HTTPHeaders() + if case .ignore = outboundState { return } + + switch self.unwrapOutboundIn(data) { + case .headers(var headers): + guard case .expectingHeaders = outboundState else { return } + var version = HTTPVersion(major: 2, minor: 0) if let contentType = contentType { headers.add(name: "content-type", value: contentType.rawValue) @@ -184,9 +242,13 @@ extension HTTP1ToRawGRPCServerCodec: ChannelOutboundHandler { } ctx.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: version, status: .ok, headers: headers))), promise: promise) + outboundState = .expectingBodyOrStatus + case .message(var messageBytes): - // Write out a length-delimited message payload. See `channelRead` fpor the corresponding format. - var responseBuffer = ctx.channel.allocator.buffer(capacity: messageBytes.readableBytes + 5) + guard case .expectingBodyOrStatus = outboundState else { return } + + // Write out a length-delimited message payload. See `processBodyState` for the corresponding format. + var responseBuffer = ctx.channel.allocator.buffer(capacity: messageBytes.readableBytes + protobufMetadataSize) responseBuffer.write(integer: Int8(0)) // Compression flag: no compression responseBuffer.write(integer: UInt32(messageBytes.readableBytes)) responseBuffer.write(buffer: &messageBytes) @@ -203,8 +265,16 @@ extension HTTP1ToRawGRPCServerCodec: ChannelOutboundHandler { } else { ctx.write(self.wrapOutboundOut(.body(.byteBuffer(responseBuffer))), promise: promise) } + outboundState = .expectingBodyOrStatus case .status(let status): + // If we error before sending the initial headers (e.g. unimplemented method) then we won't have sent the request head. + // NIOHTTP2 doesn't support sending a single frame as a "Trailers-Only" response so we still need to loop back and + // send the request head first. + if case .expectingHeaders = outboundState { + self.write(ctx: ctx, data: NIOAny(RawGRPCServerResponsePart.headers(HTTPHeaders())), promise: nil) + } + var trailers = status.trailingMetadata trailers.add(name: "grpc-status", value: String(describing: status.code.rawValue)) trailers.add(name: "grpc-message", value: status.message) @@ -236,6 +306,9 @@ extension HTTP1ToRawGRPCServerCodec: ChannelOutboundHandler { } else { ctx.write(self.wrapOutboundOut(.end(trailers)), promise: promise) } + + outboundState = .ignore + inboundState = .ignore } } } diff --git a/Sources/SwiftGRPCNIO/LoggingServerErrorDelegate.swift b/Sources/SwiftGRPCNIO/LoggingServerErrorDelegate.swift new file mode 100644 index 000000000..b0a30c178 --- /dev/null +++ b/Sources/SwiftGRPCNIO/LoggingServerErrorDelegate.swift @@ -0,0 +1,24 @@ +/* + * Copyright 2019, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Foundation + +public class LoggingServerErrorDelegate: ServerErrorDelegate { + public init() {} + + public func observe(_ error: Error) { + print("[grpc-server][\(Date())] \(error)") + } +} diff --git a/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift b/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift new file mode 100644 index 000000000..83521a6e3 --- /dev/null +++ b/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift @@ -0,0 +1,40 @@ +/* + * Copyright 2019, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Foundation +import NIO + +public protocol ServerErrorDelegate: class { + //! FIXME: Provide more context about where the error was thrown. + /// Called when an error is thrown in the channel pipeline. + func observe(_ error: Error) + + /// Transforms the given error into a new error. + /// + /// This allows framework users to transform errors which may be out of their control + /// due to third-party libraries, for example, into more meaningful errors or + /// `GRPCStatus` errors. Errors returned from this protocol are not passed to + /// `observe`. + /// + /// - note: + /// This defaults to returning the provided error. + func transform(_ error: Error) -> Error +} + +public extension ServerErrorDelegate { + func transform(_ error: Error) -> Error { + return error + } +} diff --git a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift index 45de736ff..a7de29b47 100644 --- a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift +++ b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift @@ -85,7 +85,7 @@ extension Generator { if options.generateNIOImplementation { println("/// Determines, calls and returns the appropriate request handler, depending on the request's method.") println("/// Returns nil for methods not handled by this service.") - println("\(access) func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel) -> GRPCCallHandler? {") + println("\(access) func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate?) -> GRPCCallHandler? {") indent() println("switch methodName {") for method in service.methods { @@ -99,7 +99,7 @@ extension Generator { case .clientStreaming: callHandlerType = "ClientStreamingCallHandler" case .bidirectionalStreaming: callHandlerType = "BidirectionalStreamingCallHandler" } - println("return \(callHandlerType)(channel: channel, request: request) { context in") + println("return \(callHandlerType)(channel: channel, request: request, errorDelegate: errorDelegate) { context in") indent() switch streamingType(method) { case .unary, .serverStreaming: diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index b0abb3e49..91d0a9807 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -37,5 +37,8 @@ XCTMain([ testCase(ServerTimeoutTests.allTests), // SwiftGRPCNIO - testCase(NIOServerTests.allTests) + testCase(NIOServerTests.allTests), + testCase(NIOServerWebTests.allTests), + testCase(GRPCChannelHandlerTests.allTests), + testCase(HTTP1ToRawGRPCServerCodecTests.allTests) ]) diff --git a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift new file mode 100644 index 000000000..0999801fd --- /dev/null +++ b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift @@ -0,0 +1,59 @@ +import Foundation +import NIO +import NIOHTTP1 +@testable import SwiftGRPCNIO +import XCTest + +class CollectingChannelHandler: ChannelOutboundHandler { + var responses: [OutboundIn] = [] + + func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + responses.append(unwrapOutboundIn(data)) + } +} + +class CollectingServerErrorDelegate: ServerErrorDelegate { + var errors: [Error] = [] + + func observe(_ error: Error) { + self.errors.append(error) + } +} + +class GRPCChannelHandlerResponseCapturingTestCase: XCTestCase { + static let echoProvider: [String: CallHandlerProvider] = ["echo.Echo": EchoProvider_NIO()] + class var defaultServiceProvider: [String: CallHandlerProvider] { + return echoProvider + } + + func configureChannel(withHandlers handlers: [ChannelHandler]) -> EventLoopFuture { + let channel = EmbeddedChannel() + return channel.pipeline.addHandlers(handlers, first: true) + .map { _ in channel } + } + + var errorCollector: CollectingServerErrorDelegate = CollectingServerErrorDelegate() + + /// Waits for `count` responses to be collected and then returns them. The test fails if the number + /// of collected responses does not match the expected. + /// + /// - Parameters: + /// - count: expected number of responses. + /// - servicesByName: service providers keyed by their service name. + /// - callback: a callback called after the channel has been setup, intended to "fill" the channel + /// with messages. The callback is called before this function returns. + /// - Returns: The responses collected from the pipeline. + func waitForGRPCChannelHandlerResponses( + count: Int, + servicesByName: [String: CallHandlerProvider] = defaultServiceProvider, + callback: @escaping (EmbeddedChannel) throws -> Void + ) throws -> [RawGRPCServerResponsePart] { + let collector = CollectingChannelHandler() + try configureChannel(withHandlers: [collector, GRPCChannelHandler(servicesByName: servicesByName, errorDelegate: errorCollector)]) + .thenThrowing(callback) + .wait() + + XCTAssertEqual(count, collector.responses.count) + return collector.responses + } +} diff --git a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift new file mode 100644 index 000000000..b97c3f49f --- /dev/null +++ b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift @@ -0,0 +1,67 @@ +import Foundation +import XCTest +import NIO +import NIOHTTP1 +@testable import SwiftGRPCNIO + +class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { + static var allTests: [(String, (GRPCChannelHandlerTests) -> () throws -> Void)] { + return [ + ("testUnimplementedMethodReturnsUnimplementedStatus", testUnimplementedMethodReturnsUnimplementedStatus), + ("testImplementedMethodReturnsHeadersMessageAndStatus", testImplementedMethodReturnsHeadersMessageAndStatus), + ("testImplementedMethodReturnsStatusForBadlyFormedProto", testImplementedMethodReturnsStatusForBadlyFormedProto), + ] + } + + func testUnimplementedMethodReturnsUnimplementedStatus() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 1) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "unimplementedMethodName") + try channel.writeInbound(RawGRPCServerRequestPart.head(requestHead)) + } + + let expectedError = GRPCServerError.unimplementedMethod("unimplementedMethodName") + XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) + + XCTAssertNoThrow(try extractStatus(responses[0])) { status in + XCTAssertEqual(status, expectedError.asGRPCStatus()) + } + } + + func testImplementedMethodReturnsHeadersMessageAndStatus() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 3) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(RawGRPCServerRequestPart.head(requestHead)) + + let request = Echo_EchoRequest.with { $0.text = "echo!" } + let requestData = try request.serializedData() + var buffer = channel.allocator.buffer(capacity: requestData.count) + buffer.write(bytes: requestData) + try channel.writeInbound(RawGRPCServerRequestPart.message(buffer)) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractMessage(responses[1])) + XCTAssertNoThrow(try extractStatus(responses[2])) { status in + XCTAssertEqual(status, .ok) + } + } + + func testImplementedMethodReturnsStatusForBadlyFormedProto() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(RawGRPCServerRequestPart.head(requestHead)) + + var buffer = channel.allocator.buffer(capacity: 3) + buffer.write(bytes: [1, 2, 3]) + try channel.writeInbound(RawGRPCServerRequestPart.message(buffer)) + } + + let expectedError = GRPCServerError.requestProtoParseFailure + XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractStatus(responses[1])) { status in + XCTAssertEqual(status, expectedError.asGRPCStatus()) + } + } +} diff --git a/Tests/SwiftGRPCNIOTests/HTTP1ToRawGRPCServerCodecTests.swift b/Tests/SwiftGRPCNIOTests/HTTP1ToRawGRPCServerCodecTests.swift new file mode 100644 index 000000000..bb17fef8e --- /dev/null +++ b/Tests/SwiftGRPCNIOTests/HTTP1ToRawGRPCServerCodecTests.swift @@ -0,0 +1,153 @@ +import Foundation +import XCTest +import NIO +import NIOHTTP1 +@testable import SwiftGRPCNIO + +func gRPCMessage(channel: EmbeddedChannel, compression: Bool = false, message: Data? = nil) -> ByteBuffer { + let messageLength = message?.count ?? 0 + var buffer = channel.allocator.buffer(capacity: 5 + messageLength) + buffer.write(integer: Int8(compression ? 1 : 0)) + buffer.write(integer: UInt32(messageLength)) + if let bytes = message { + buffer.write(bytes: bytes) + } + return buffer +} + +class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCase { + static var allTests: [(String, (HTTP1ToRawGRPCServerCodecTests) -> () throws -> Void)] { + return [ + ("testInternalErrorStatusReturnedWhenCompressionFlagIsSet", testInternalErrorStatusReturnedWhenCompressionFlagIsSet), + ("testMessageCanBeSentAcrossMultipleByteBuffers", testMessageCanBeSentAcrossMultipleByteBuffers), + ("testInternalErrorStatusIsReturnedIfMessageCannotBeDeserialized", testInternalErrorStatusIsReturnedIfMessageCannotBeDeserialized), + ("testInternalErrorStatusIsReturnedWhenSendingTrailersInRequest", testInternalErrorStatusIsReturnedWhenSendingTrailersInRequest), + ("testOnlyOneStatusIsReturned", testOnlyOneStatusIsReturned), + ] + } + + func testInternalErrorStatusReturnedWhenCompressionFlagIsSet() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel, compression: true))) + } + + let expectedError = GRPCServerError.unexpectedCompression + XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractStatus(responses[1])) { status in + XCTAssertEqual(status, expectedError.asGRPCStatus()) + } + } + + func testMessageCanBeSentAcrossMultipleByteBuffers() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 3) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + // Sending the header allocates a buffer. + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + + let request = Echo_EchoRequest.with { $0.text = "echo!" } + let requestAsData = try request.serializedData() + + var buffer = channel.allocator.buffer(capacity: 1) + buffer.write(integer: Int8(0)) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + + buffer = channel.allocator.buffer(capacity: 4) + buffer.write(integer: Int32(requestAsData.count)) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + + buffer = channel.allocator.buffer(capacity: requestAsData.count) + buffer.write(bytes: requestAsData) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractMessage(responses[1])) + XCTAssertNoThrow(try extractStatus(responses[2])) { status in + XCTAssertEqual(status, .ok) + } + } + + func testInternalErrorStatusIsReturnedIfMessageCannotBeDeserialized() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + + let buffer = gRPCMessage(channel: channel, message: Data(bytes: [42])) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + } + + let expectedError = GRPCServerError.requestProtoParseFailure + XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractStatus(responses[1])) { status in + XCTAssertEqual(status, expectedError.asGRPCStatus()) + } + } + + func testInternalErrorStatusIsReturnedWhenSendingTrailersInRequest() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in + // We have to use "Collect" (client streaming) as the tests rely on `EmbeddedChannel` which runs in this thread. + // In the current server implementation, responses from unary calls send a status immediately after sending the response. + // As such, a unary "Get" would return an "ok" status before the trailers would be sent. + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Collect") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel))) + + var trailers = HTTPHeaders() + trailers.add(name: "foo", value: "bar") + try channel.writeInbound(HTTPServerRequestPart.end(trailers)) + } + + XCTAssertEqual(errorCollector.errors.count, 1) + + if case .invalidState(let message)? = errorCollector.errors.first as? GRPCServerError { + XCTAssert(message.contains("trailers")) + } else { + XCTFail("\(String(describing: errorCollector.errors.first)) was not GRPCError.invalidState") + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractStatus(responses[1])) { status in + XCTAssertEqual(status, .processingError) + } + } + + func testOnlyOneStatusIsReturned() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 3) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel))) + + // Sending trailers with `.end` should trigger an error. However, writing a message to a unary call + // will trigger a response and status to be sent back. Since we're using `EmbeddedChannel` this will + // be done before the trailers are sent. If a 4th resposne were to be sent (for the error status) then + // the test would fail. + + var trailers = HTTPHeaders() + trailers.add(name: "foo", value: "bar") + try channel.writeInbound(HTTPServerRequestPart.end(trailers)) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractMessage(responses[1])) + XCTAssertNoThrow(try extractStatus(responses[2])) { status in + XCTAssertEqual(status, .ok) + } + } + + override func waitForGRPCChannelHandlerResponses( + count: Int, + servicesByName: [String: CallHandlerProvider] = GRPCChannelHandlerResponseCapturingTestCase.echoProvider, + callback: @escaping (EmbeddedChannel) throws -> Void + ) throws -> [RawGRPCServerResponsePart] { + return try super.waitForGRPCChannelHandlerResponses(count: count, servicesByName: servicesByName) { channel in + _ = channel.pipeline.addHandlers(HTTP1ToRawGRPCServerCodec(), first: true) + .thenThrowing { _ in try callback(channel) } + } + } +} diff --git a/Tests/SwiftGRPCNIOTests/NIOServerTests.swift b/Tests/SwiftGRPCNIOTests/NIOServerTests.swift index 5940923b9..475a2d63b 100644 --- a/Tests/SwiftGRPCNIOTests/NIOServerTests.swift +++ b/Tests/SwiftGRPCNIOTests/NIOServerTests.swift @@ -122,6 +122,14 @@ extension NIOServerTests { XCTAssertEqual("Swift echo get: foo", try! client.get(Echo_EchoRequest(text: "foo")).text) } + func testUnaryWithLargeData() throws { + // Default max frame size is: 16,384. We'll exceed this as we also have to send the size and compression flag. + let longMessage = String(repeating: "e", count: 16_384) + XCTAssertNoThrow(try client.get(Echo_EchoRequest(text: longMessage))) { response in + XCTAssertEqual("Swift echo get: \(longMessage)", response.text) + } + } + func testUnaryLotsOfRequests() { // Sending that many requests at once can sometimes trip things up, it seems. client.timeout = 5.0 @@ -135,6 +143,10 @@ extension NIOServerTests { } print("total time for \(numberOfRequests) requests: \(Double(clock() - clockStart) / Double(CLOCKS_PER_SEC))") } + + func testUnaryEmptyRequest() throws { + XCTAssertNoThrow(try client.get(Echo_EchoRequest())) + } } extension NIOServerTests { diff --git a/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift b/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift index 342975796..ce87064b5 100644 --- a/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift +++ b/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift @@ -25,7 +25,8 @@ class NIOServerWebTests: NIOServerTestCase { static var allTests: [(String, (NIOServerWebTests) -> () throws -> Void)] { return [ ("testUnary", testUnary), - ("testUnaryLotsOfRequests", testUnaryLotsOfRequests), + //! FIXME: Broken on Linux: https://github.com/grpc/grpc-swift/issues/382 + // ("testUnaryLotsOfRequests", testUnaryLotsOfRequests), ("testServerStreaming", testServerStreaming), ] } @@ -114,13 +115,15 @@ extension NIOServerWebTests { // Sending that many requests at once can sometimes trip things up, it seems. let clockStart = clock() let numberOfRequests = 2_000 + let completionHandlerExpectation = expectation(description: "completion handler called") -#if os(macOS) - // Linux version of Swift doesn't have this API yet. + // Linux version of Swift doesn't have the `expectedFulfillmentCount` API yet. // Implemented in https://github.com/apple/swift-corelibs-xctest/pull/228 but not yet // released. - completionHandlerExpectation.expectedFulfillmentCount = numberOfRequests -#endif + // + // Wait for the expected number of responses (i.e. `numberOfRequests`) instead. + var responses = 0 + for i in 0..( + _ expression: @autoclosure () throws -> T, + _ message: String = "", + file: StaticString = #file, + line: UInt = #line, + validate: (T) -> Void +) { + var value: T? = nil + XCTAssertNoThrow(try value = expression(), message, file: file, line: line) + value.map { validate($0) } +} + +struct CaseExtractError: Error { + let message: String +} + +@discardableResult +func extractHeaders(_ response: RawGRPCServerResponsePart) throws -> HTTPHeaders { + guard case .headers(let headers) = response else { + throw CaseExtractError(message: "\(response) did not match .headers") + } + return headers +} + +@discardableResult +func extractMessage(_ response: RawGRPCServerResponsePart) throws -> ByteBuffer { + guard case .message(let message) = response else { + throw CaseExtractError(message: "\(response) did not match .message") + } + return message +} + +@discardableResult +func extractStatus(_ response: RawGRPCServerResponsePart) throws -> GRPCStatus { + guard case .status(let status) = response else { + throw CaseExtractError(message: "\(response) did not match .status") + } + return status +} diff --git a/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift b/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift index e3cfbfb44..ecf86a285 100644 --- a/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift +++ b/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift @@ -40,29 +40,29 @@ extension Echo_EchoProvider_NIO { /// Determines, calls and returns the appropriate request handler, depending on the request's method. /// Returns nil for methods not handled by this service. - internal func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel) -> GRPCCallHandler? { + internal func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate?) -> GRPCCallHandler? { switch methodName { case "Get": - return UnaryCallHandler(channel: channel, request: request) { context in + return UnaryCallHandler(channel: channel, request: request, errorDelegate: errorDelegate) { context in return { request in self.get(request: request, context: context) } } case "Expand": - return ServerStreamingCallHandler(channel: channel, request: request) { context in + return ServerStreamingCallHandler(channel: channel, request: request, errorDelegate: errorDelegate) { context in return { request in self.expand(request: request, context: context) } } case "Collect": - return ClientStreamingCallHandler(channel: channel, request: request) { context in + return ClientStreamingCallHandler(channel: channel, request: request, errorDelegate: errorDelegate) { context in return self.collect(context: context) } case "Update": - return BidirectionalStreamingCallHandler(channel: channel, request: request) { context in + return BidirectionalStreamingCallHandler(channel: channel, request: request, errorDelegate: errorDelegate) { context in return self.update(context: context) }