Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error handling in NIO server. #364

Merged
merged 11 commits into from
Feb 26, 2019
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ test-plugin:
test-plugin-nio:
swift build $(CFLAGS) --product protoc-gen-swiftgrpc
protoc Sources/Examples/Echo/echo.proto --proto_path=Sources/Examples/Echo --plugin=.build/debug/protoc-gen-swift --plugin=.build/debug/protoc-gen-swiftgrpc --swiftgrpc_out=/tmp --swiftgrpc_opt=Client=false,NIO=true
diff -u /tmp/echo.grpc.swift Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift
diff -u /tmp/echo.grpc.swift Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift

xcodebuild: project
xcodebuild -project SwiftGRPC.xcodeproj -configuration "Debug" -parallelizeTargets -target SwiftGRPC -target Echo -target Simple -target protoc-gen-swiftgrpc build
Expand Down
64 changes: 57 additions & 7 deletions Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,79 @@ import NIOHTTP1
/// Calls through to `processMessage` for individual messages it receives, which needs to be implemented by subclasses.
public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>: GRPCCallHandler {
public func makeGRPCServerCodec() -> ChannelHandler { return GRPCServerCodec<RequestMessage, ResponseMessage>() }

/// Called whenever a message has been received.
///
/// Overridden by subclasses.
public func processMessage(_ message: RequestMessage) {
public func processMessage(_ message: RequestMessage) throws {
fatalError("needs to be overridden")
}

/// Called when the client has half-closed the stream, indicating that they won't send any further data.
///
/// Overridden by subclasses if the "end-of-stream" event is relevant.
public func endOfStreamReceived() { }

/// Whether this handler can still write messages to the client.
private var serverCanWrite = true

/// Called for each error recieved in `errorCaught(ctx:error:)`.
private weak var errorDelegate: ServerErrorDelegate?

public init(errorDelegate: ServerErrorDelegate?) {
self.errorDelegate = errorDelegate
}
}

extension BaseCallHandler: ChannelInboundHandler {
public typealias InboundIn = GRPCServerRequestPart<RequestMessage>
public typealias OutboundOut = GRPCServerResponsePart<ResponseMessage>

/// Passes errors to the user-provided `errorHandler`. After an error has been received an
/// appropriate status is written. Errors which don't conform to `GRPCStatusTransformable`
/// return a status with code `.internalError`.
public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
errorDelegate?.observe(error)

let transformed = errorDelegate?.transform(error) ?? error
let status = (transformed as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError
self.write(ctx: ctx, data: NIOAny(GRPCServerResponsePart<ResponseMessage>.status(status)), promise: nil)
}

public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
switch self.unwrapInboundIn(data) {
case .head: preconditionFailure("should not have received headers")
case .message(let message): processMessage(message)
case .end: endOfStreamReceived()
case .head(let requestHead):
// Head should have been handled by `GRPCChannelHandler`.
self.errorCaught(ctx: ctx, error: GRPCError.invalidState("unexpected request head received \(requestHead)"))

case .message(let message):
do {
try processMessage(message)
} catch {
self.errorCaught(ctx: ctx, error: error)
}

case .end:
endOfStreamReceived()
}
}
}

extension BaseCallHandler: ChannelOutboundHandler {
public typealias OutboundIn = GRPCServerResponsePart<ResponseMessage>
public typealias OutboundOut = GRPCServerResponsePart<ResponseMessage>

public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
glbrntt marked this conversation as resolved.
Show resolved Hide resolved
guard serverCanWrite else {
promise?.fail(error: GRPCError.serverNotWritable)
return
}

// We can only write one status; make sure we don't write again.
if case .status = unwrapOutboundIn(data) {
serverCanWrite = false
ctx.writeAndFlush(data, promise: promise)
} else {
ctx.write(data, promise: promise)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public class BidirectionalStreamingCallHandler<RequestMessage: Message, Response

// We ask for a future of type `EventObserver` to allow the framework user to e.g. asynchronously authenticate a call.
// If authentication fails, they can simply fail the observer future, which causes the call to be terminated.
public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init(errorDelegate: errorDelegate)
let context = StreamingResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
let eventObserver = eventObserverFactory(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage

// We ask for a future of type `EventObserver` to allow the framework user to e.g. asynchronously authenticate a call.
// If authentication fails, they can simply fail the observer future, which causes the call to be terminated.
public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init(errorDelegate: errorDelegate)
let context = UnaryResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
let eventObserver = eventObserverFactory(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage

private var context: StreamingResponseCallContext<ResponseMessage>?

public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init(errorDelegate: errorDelegate)
let context = StreamingResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
self.eventObserver = eventObserverFactory(context)
Expand All @@ -26,12 +26,10 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage
}


public override func processMessage(_ message: RequestMessage) {
public override func processMessage(_ message: RequestMessage) throws {
guard let eventObserver = self.eventObserver,
let context = self.context else {
//! FIXME: Better handle this error?
print("multiple messages received on unary call")
return
throw GRPCError.requestCardinalityViolation
}

let resultFuture = eventObserver(message)
Expand Down
10 changes: 4 additions & 6 deletions Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>

private var context: UnaryResponseCallContext<ResponseMessage>?

public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init(errorDelegate: errorDelegate)
let context = UnaryResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
self.eventObserver = eventObserverFactory(context)
Expand All @@ -26,12 +26,10 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>
}
}

public override func processMessage(_ message: RequestMessage) {
public override func processMessage(_ message: RequestMessage) throws {
guard let eventObserver = self.eventObserver,
let context = self.context else {
//! FIXME: Better handle this error?
print("multiple messages received on unary call")
return
throw GRPCError.requestCardinalityViolation
}

let resultFuture = eventObserver(message)
Expand Down
48 changes: 34 additions & 14 deletions Sources/SwiftGRPCNIO/GRPCChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public protocol CallHandlerProvider: class {

/// Determines, calls and returns the appropriate request handler (`GRPCCallHandler`), depending on the request's
/// method. Returns nil for methods not handled by this service.
func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel) -> GRPCCallHandler?
func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate?) -> GRPCCallHandler?
}

/// Listens on a newly-opened HTTP2 subchannel and yields to the sub-handler matching a call, if available.
Expand All @@ -28,30 +28,32 @@ public protocol CallHandlerProvider: class {
/// for an `GRPCCallHandler` object. That object is then forwarded the individual gRPC messages.
public final class GRPCChannelHandler {
private let servicesByName: [String: CallHandlerProvider]
private weak var errorDelegate: ServerErrorDelegate?

public init(servicesByName: [String: CallHandlerProvider]) {
public init(servicesByName: [String: CallHandlerProvider], errorDelegate: ServerErrorDelegate?) {
self.servicesByName = servicesByName
self.errorDelegate = errorDelegate
}
}

extension GRPCChannelHandler: ChannelInboundHandler {
public typealias InboundIn = RawGRPCServerRequestPart
public typealias OutboundOut = RawGRPCServerResponsePart


public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
MrMage marked this conversation as resolved.
Show resolved Hide resolved
errorDelegate?.observe(error)

let transformedError = errorDelegate?.transform(error) ?? error
let status = (transformedError as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError
ctx.writeAndFlush(wrapOutboundOut(.status(status)), promise: nil)
}

public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
let requestPart = self.unwrapInboundIn(data)
switch requestPart {
case .head(let requestHead):
// URI format: "/package.Servicename/MethodName", resulting in the following components separated by a slash:
// - uriComponents[0]: empty
// - uriComponents[1]: service name (including the package name);
// `CallHandlerProvider`s should provide the service name including the package name.
// - uriComponents[2]: method name.
let uriComponents = requestHead.uri.components(separatedBy: "/")
guard uriComponents.count >= 3 && uriComponents[0].isEmpty,
let providerForServiceName = servicesByName[uriComponents[1]],
let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: ctx.channel) else {
ctx.writeAndFlush(self.wrapOutboundOut(.status(.unimplemented(method: requestHead.uri))), promise: nil)
guard let callHandler = getCallHandler(channel: ctx.channel, requestHead: requestHead) else {
errorCaught(ctx: ctx, error: GRPCError.unimplementedMethod(requestHead.uri))
return
}

Expand All @@ -71,7 +73,25 @@ extension GRPCChannelHandler: ChannelInboundHandler {
.whenComplete { ctx.pipeline.remove(handler: self, promise: handlerRemoved) }

case .message, .end:
preconditionFailure("received \(requestPart), should have been removed as a handler at this point")
// We can reach this point if we're receiving messages for a method that isn't implemented.
// A status resposne will have been fired which should also close the stream; there's not a
// lot we can do at this point.
break
}
}

private func getCallHandler(channel: Channel, requestHead: HTTPRequestHead) -> GRPCCallHandler? {
// URI format: "/package.Servicename/MethodName", resulting in the following components separated by a slash:
// - uriComponents[0]: empty
// - uriComponents[1]: service name (including the package name);
// `CallHandlerProvider`s should provide the service name including the package name.
// - uriComponents[2]: method name.
let uriComponents = requestHead.uri.components(separatedBy: "/")
guard uriComponents.count >= 3 && uriComponents[0].isEmpty,
let providerForServiceName = servicesByName[uriComponents[1]],
let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: channel, errorDelegate: errorDelegate) else {
return nil
}
return callHandler
}
}
70 changes: 70 additions & 0 deletions Sources/SwiftGRPCNIO/GRPCError.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright 2019, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import Foundation

public enum GRPCError: Error, Equatable {
MrMage marked this conversation as resolved.
Show resolved Hide resolved
/// The RPC method is not implemented on the server.
case unimplementedMethod(String)

/// It was not possible to parse the request protobuf.
case requestProtoParseFailure

/// It was not possible to serialize the response protobuf.
case responseProtoSerializationFailure

/// The given compression mechanism is not supported.
case unsupportedCompressionMechanism(String)

/// Compression was indicated in the gRPC message, but not for the call.
case unexpectedCompression

/// More than one request was sent for a unary-request call.
case requestCardinalityViolation

/// The server received a message when it was not in a writable state.
case serverNotWritable

/// An invalid state has been reached; something has gone very wrong.
case invalidState(String)
}

extension GRPCError: GRPCStatusTransformable {
public func asGRPCStatus() -> GRPCStatus {
// These status codes are informed by: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
switch self {
case .unimplementedMethod(let method):
return GRPCStatus(code: .unimplemented, message: "unknown method \(method)")

case .requestProtoParseFailure:
return GRPCStatus(code: .internalError, message: "could not parse request proto")

case .responseProtoSerializationFailure:
return GRPCStatus(code: .internalError, message: "could not serialize response proto")

case .unsupportedCompressionMechanism(let mechanism):
return GRPCStatus(code: .unimplemented, message: "unsupported compression mechanism \(mechanism)")

case .unexpectedCompression:
return GRPCStatus(code: .unimplemented, message: "compression was enabled for this gRPC message but not for this call")

case .requestCardinalityViolation:
return GRPCStatus(code: .unimplemented, message: "request cardinality violation; method requires exactly one request but client sent more")

case .serverNotWritable, .invalidState:
return GRPCStatus.processingError
}
}
}
20 changes: 16 additions & 4 deletions Sources/SwiftGRPCNIO/GRPCServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ public final class GRPCServer {
hostname: String,
port: Int,
eventLoopGroup: EventLoopGroup,
serviceProviders: [CallHandlerProvider]) -> EventLoopFuture<GRPCServer> {
serviceProviders: [CallHandlerProvider],
errorDelegate: ServerErrorDelegate? = LoggingServerErrorDelegate()
) -> EventLoopFuture<GRPCServer> {
let servicesByName = Dictionary(uniqueKeysWithValues: serviceProviders.map { ($0.serviceName, $0) })
let bootstrap = ServerBootstrap(group: eventLoopGroup)
// Specify a backlog to avoid overloading the server.
Expand All @@ -27,7 +29,7 @@ public final class GRPCServer {
let multiplexer = HTTP2StreamMultiplexer { (channel, streamID) -> EventLoopFuture<Void> in
return channel.pipeline.add(handler: HTTP2ToHTTP1ServerCodec(streamID: streamID))
.then { channel.pipeline.add(handler: HTTP1ToRawGRPCServerCodec()) }
.then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName)) }
.then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName, errorDelegate: errorDelegate)) }
}

return channel.pipeline.add(handler: multiplexer)
Expand All @@ -39,20 +41,30 @@ public final class GRPCServer {
.childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)

return bootstrap.bind(host: hostname, port: port)
.map { GRPCServer(channel: $0) }
.map { GRPCServer(channel: $0, errorDelegate: errorDelegate) }
}

private let channel: Channel
private var errorDelegate: ServerErrorDelegate?

private init(channel: Channel) {
private init(channel: Channel, errorDelegate: ServerErrorDelegate?) {
self.channel = channel

// Maintain a strong reference to ensure it lives as long as the server.
self.errorDelegate = errorDelegate

// nil out errorDelegate to avoid retain cycles.
onClose.whenComplete {
self.errorDelegate = nil
}
}

/// Fired when the server shuts down.
public var onClose: EventLoopFuture<Void> {
return channel.closeFuture
}

/// Shut down the server; this should be called to avoid leaking resources.
public func close() -> EventLoopFuture<Void> {
return channel.close(mode: .all)
}
Expand Down
Loading