Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error handling in NIO server. #364

Merged
merged 11 commits into from
Feb 26, 2019
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ test-plugin:
test-plugin-nio:
swift build $(CFLAGS) --product protoc-gen-swiftgrpc
protoc Sources/Examples/Echo/echo.proto --proto_path=Sources/Examples/Echo --plugin=.build/debug/protoc-gen-swift --plugin=.build/debug/protoc-gen-swiftgrpc --swiftgrpc_out=/tmp --swiftgrpc_opt=Client=false,NIO=true
diff -u /tmp/echo.grpc.swift Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift
diff -u /tmp/echo.grpc.swift Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift

xcodebuild: project
xcodebuild -project SwiftGRPC.xcodeproj -configuration "Debug" -parallelizeTargets -target SwiftGRPC -target Echo -target Simple -target protoc-gen-swiftgrpc build
Expand Down
64 changes: 57 additions & 7 deletions Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,79 @@ import NIOHTTP1
/// Calls through to `processMessage` for individual messages it receives, which needs to be implemented by subclasses.
public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>: GRPCCallHandler {
public func makeGRPCServerCodec() -> ChannelHandler { return GRPCServerCodec<RequestMessage, ResponseMessage>() }

/// Called whenever a message has been received.
///
/// Overridden by subclasses.
public func processMessage(_ message: RequestMessage) {
public func processMessage(_ message: RequestMessage) throws {
fatalError("needs to be overridden")
}

/// Called when the client has half-closed the stream, indicating that they won't send any further data.
///
/// Overridden by subclasses if the "end-of-stream" event is relevant.
public func endOfStreamReceived() { }

/// Whether this handler can still write messages to the client.
private var serverCanWrite = true

/// Called for each error recieved in `errorCaught(ctx:error:)`.
private weak var errorDelegate: ServerErrorDelegate?

public init(errorDelegate: ServerErrorDelegate? = nil) {
glbrntt marked this conversation as resolved.
Show resolved Hide resolved
self.errorDelegate = errorDelegate
}
}

extension BaseCallHandler: ChannelInboundHandler {
public typealias InboundIn = GRPCServerRequestPart<RequestMessage>
public typealias OutboundOut = GRPCServerResponsePart<ResponseMessage>

/// Passes errors to the user-provided `errorHandler`. After an error has been received an
/// appropriate status is written. Errors which don't conform to `GRPCStatusTransformable`
/// return a status with code `.internalError`.
public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
errorDelegate?.observe(error)

let transformed = errorDelegate?.transform(error) ?? error
let status = (transformed as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError
self.write(ctx: ctx, data: NIOAny(GRPCServerResponsePart<ResponseMessage>.status(status)), promise: nil)
}

public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
switch self.unwrapInboundIn(data) {
case .head: preconditionFailure("should not have received headers")
case .message(let message): processMessage(message)
case .end: endOfStreamReceived()
case .head:
// Head should have been handled by `GRPCChannelHandler`.
self.errorCaught(ctx: ctx, error: GRPCStatus(code: .unknown, message: "unexpectedly received head"))

case .message(let message):
do {
try processMessage(message)
} catch {
self.errorCaught(ctx: ctx, error: error)
}

case .end:
endOfStreamReceived()
}
}
}

extension BaseCallHandler: ChannelOutboundHandler {
public typealias OutboundIn = GRPCServerResponsePart<ResponseMessage>
public typealias OutboundOut = GRPCServerResponsePart<ResponseMessage>

public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
glbrntt marked this conversation as resolved.
Show resolved Hide resolved
guard serverCanWrite else {
promise?.fail(error: GRPCStatus.processingError)
MrMage marked this conversation as resolved.
Show resolved Hide resolved
return
}

// We can only write one status; make sure we don't write again.
if case .status = unwrapOutboundIn(data) {
serverCanWrite = false
ctx.writeAndFlush(data, promise: promise)
} else {
ctx.write(data, promise: promise)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public class BidirectionalStreamingCallHandler<RequestMessage: Message, Response

// We ask for a future of type `EventObserver` to allow the framework user to e.g. asynchronously authenticate a call.
// If authentication fails, they can simply fail the observer future, which causes the call to be terminated.
public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init(errorDelegate: errorDelegate)
let context = StreamingResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
let eventObserver = eventObserverFactory(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage

// We ask for a future of type `EventObserver` to allow the framework user to e.g. asynchronously authenticate a call.
// If authentication fails, they can simply fail the observer future, which causes the call to be terminated.
public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init(errorDelegate: errorDelegate)
let context = UnaryResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
let eventObserver = eventObserverFactory(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage

private var context: StreamingResponseCallContext<ResponseMessage>?

public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init(errorDelegate: errorDelegate)
let context = StreamingResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
self.eventObserver = eventObserverFactory(context)
Expand All @@ -26,12 +26,10 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage
}


public override func processMessage(_ message: RequestMessage) {
public override func processMessage(_ message: RequestMessage) throws {
guard let eventObserver = self.eventObserver,
let context = self.context else {
//! FIXME: Better handle this error?
print("multiple messages received on unary call")
return
throw GRPCStatus(code: .unimplemented, message: "multiple messages received on unary call")
MrMage marked this conversation as resolved.
Show resolved Hide resolved
}

let resultFuture = eventObserver(message)
Expand Down
10 changes: 4 additions & 6 deletions Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>

private var context: UnaryResponseCallContext<ResponseMessage>?

public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init(errorDelegate: errorDelegate)
let context = UnaryResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
self.eventObserver = eventObserverFactory(context)
Expand All @@ -26,12 +26,10 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>
}
}

public override func processMessage(_ message: RequestMessage) {
public override func processMessage(_ message: RequestMessage) throws {
guard let eventObserver = self.eventObserver,
let context = self.context else {
//! FIXME: Better handle this error?
print("multiple messages received on unary call")
return
throw GRPCStatus(code: .unimplemented, message: "multiple messages received on unary call")
}

let resultFuture = eventObserver(message)
Expand Down
48 changes: 34 additions & 14 deletions Sources/SwiftGRPCNIO/GRPCChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public protocol CallHandlerProvider: class {

/// Determines, calls and returns the appropriate request handler (`GRPCCallHandler`), depending on the request's
/// method. Returns nil for methods not handled by this service.
func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel) -> GRPCCallHandler?
func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate?) -> GRPCCallHandler?
}

/// Listens on a newly-opened HTTP2 subchannel and yields to the sub-handler matching a call, if available.
Expand All @@ -28,30 +28,32 @@ public protocol CallHandlerProvider: class {
/// for an `GRPCCallHandler` object. That object is then forwarded the individual gRPC messages.
public final class GRPCChannelHandler {
private let servicesByName: [String: CallHandlerProvider]
private weak var errorDelegate: ServerErrorDelegate?

public init(servicesByName: [String: CallHandlerProvider]) {
public init(servicesByName: [String: CallHandlerProvider], errorDelegate: ServerErrorDelegate? = nil) {
glbrntt marked this conversation as resolved.
Show resolved Hide resolved
self.servicesByName = servicesByName
self.errorDelegate = errorDelegate
}
}

extension GRPCChannelHandler: ChannelInboundHandler {
public typealias InboundIn = RawGRPCServerRequestPart
public typealias OutboundOut = RawGRPCServerResponsePart


public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
MrMage marked this conversation as resolved.
Show resolved Hide resolved
errorDelegate?.observe(error)

let transformedError = (errorDelegate?.transform(error) ?? error)
glbrntt marked this conversation as resolved.
Show resolved Hide resolved
let status = (transformedError as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError
ctx.writeAndFlush(wrapOutboundOut(.status(status)), promise: nil)
}

public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
let requestPart = self.unwrapInboundIn(data)
switch requestPart {
case .head(let requestHead):
// URI format: "/package.Servicename/MethodName", resulting in the following components separated by a slash:
// - uriComponents[0]: empty
// - uriComponents[1]: service name (including the package name);
// `CallHandlerProvider`s should provide the service name including the package name.
// - uriComponents[2]: method name.
let uriComponents = requestHead.uri.components(separatedBy: "/")
guard uriComponents.count >= 3 && uriComponents[0].isEmpty,
let providerForServiceName = servicesByName[uriComponents[1]],
let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: ctx.channel) else {
ctx.writeAndFlush(self.wrapOutboundOut(.status(.unimplemented(method: requestHead.uri))), promise: nil)
guard let callHandler = getCallHandler(channel: ctx.channel, requestHead: requestHead) else {
errorCaught(ctx: ctx, error: GRPCStatus.unimplemented(method: requestHead.uri))
return
}

Expand All @@ -71,7 +73,25 @@ extension GRPCChannelHandler: ChannelInboundHandler {
.whenComplete { ctx.pipeline.remove(handler: self, promise: handlerRemoved) }

case .message, .end:
preconditionFailure("received \(requestPart), should have been removed as a handler at this point")
// We can reach this point if we're receiving messages for a method that isn't implemented.
// A status resposne will have been fired which should also close the stream; there's not a
// lot we can do at this point.
break
}
}

private func getCallHandler(channel: Channel, requestHead: HTTPRequestHead) -> GRPCCallHandler? {
// URI format: "/package.Servicename/MethodName", resulting in the following components separated by a slash:
// - uriComponents[0]: empty
// - uriComponents[1]: service name (including the package name);
// `CallHandlerProvider`s should provide the service name including the package name.
// - uriComponents[2]: method name.
let uriComponents = requestHead.uri.components(separatedBy: "/")
guard uriComponents.count >= 3 && uriComponents[0].isEmpty,
let providerForServiceName = servicesByName[uriComponents[1]],
let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: channel, errorDelegate: errorDelegate) else {
return nil
}
return callHandler
}
}
20 changes: 16 additions & 4 deletions Sources/SwiftGRPCNIO/GRPCServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ public final class GRPCServer {
hostname: String,
port: Int,
eventLoopGroup: EventLoopGroup,
serviceProviders: [CallHandlerProvider]) -> EventLoopFuture<GRPCServer> {
serviceProviders: [CallHandlerProvider],
errorDelegate: ServerErrorDelegate? = nil
MrMage marked this conversation as resolved.
Show resolved Hide resolved
) -> EventLoopFuture<GRPCServer> {
let servicesByName = Dictionary(uniqueKeysWithValues: serviceProviders.map { ($0.serviceName, $0) })
let bootstrap = ServerBootstrap(group: eventLoopGroup)
// Specify a backlog to avoid overloading the server.
Expand All @@ -27,7 +29,7 @@ public final class GRPCServer {
let multiplexer = HTTP2StreamMultiplexer { (channel, streamID) -> EventLoopFuture<Void> in
return channel.pipeline.add(handler: HTTP2ToHTTP1ServerCodec(streamID: streamID))
.then { channel.pipeline.add(handler: HTTP1ToRawGRPCServerCodec()) }
.then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName)) }
.then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName, errorDelegate: errorDelegate)) }
}

return channel.pipeline.add(handler: multiplexer)
Expand All @@ -39,20 +41,30 @@ public final class GRPCServer {
.childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)

return bootstrap.bind(host: hostname, port: port)
.map { GRPCServer(channel: $0) }
.map { GRPCServer(channel: $0, errorDelegate: errorDelegate) }
}

private let channel: Channel
private var errorDelegate: ServerErrorDelegate?

private init(channel: Channel) {
private init(channel: Channel, errorDelegate: ServerErrorDelegate?) {
self.channel = channel

// Maintain a strong reference to ensure it lives as long as the server.
self.errorDelegate = errorDelegate

// `BaseCallHandler` holds a weak reference to the delegate; nil out this reference to avoid retain cycles.
glbrntt marked this conversation as resolved.
Show resolved Hide resolved
onClose.whenComplete {
self.errorDelegate = nil
}
}

/// Fired when the server shuts down.
public var onClose: EventLoopFuture<Void> {
return channel.closeFuture
}

/// Shut down the server; this should be called to avoid leaking resources.
public func close() -> EventLoopFuture<Void> {
return channel.close(mode: .all)
}
Expand Down
12 changes: 7 additions & 5 deletions Sources/SwiftGRPCNIO/GRPCServerCodec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public enum GRPCServerResponsePart<MessageType: Message> {
}

/// A simple channel handler that translates raw gRPC packets into decoded protobuf messages, and vice versa.
public final class GRPCServerCodec<RequestMessage: Message, ResponseMessage: Message> { }
public final class GRPCServerCodec<RequestMessage: Message, ResponseMessage: Message> {}

extension GRPCServerCodec: ChannelInboundHandler {
public typealias InboundIn = RawGRPCServerRequestPart
Expand All @@ -35,8 +35,7 @@ extension GRPCServerCodec: ChannelInboundHandler {
do {
ctx.fireChannelRead(self.wrapInboundOut(.message(try RequestMessage(serializedData: messageAsData))))
} catch {
//! FIXME: Ensure that the last handler in the pipeline returns `.dataLoss` here?
ctx.fireErrorCaught(error)
ctx.fireErrorCaught(GRPCStatus.requestProtoParseError)
}

case .end:
Expand All @@ -54,16 +53,19 @@ extension GRPCServerCodec: ChannelOutboundHandler {
switch responsePart {
case .headers(let headers):
ctx.write(self.wrapOutboundOut(.headers(headers)), promise: promise)

case .message(let message):
do {
let messageData = try message.serializedData()
var responseBuffer = ctx.channel.allocator.buffer(capacity: messageData.count)
responseBuffer.write(bytes: messageData)
ctx.write(self.wrapOutboundOut(.message(responseBuffer)), promise: promise)
} catch {
promise?.fail(error: error)
ctx.fireErrorCaught(error)
let status = GRPCStatus.responseProtoSerializationError
promise?.fail(error: status)
ctx.fireErrorCaught(status)
}

case .status(let status):
ctx.write(self.wrapOutboundOut(.status(status)), promise: promise)
}
Expand Down
15 changes: 15 additions & 0 deletions Sources/SwiftGRPCNIO/GRPCStatus.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,19 @@ public struct GRPCStatus: Error {
public static func unimplemented(method: String) -> GRPCStatus {
return GRPCStatus(code: .unimplemented, message: "unknown method " + method)
}

// These status codes are informed by: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
static internal let requestProtoParseError = GRPCStatus(code: .internalError, message: "could not parse request proto")
glbrntt marked this conversation as resolved.
Show resolved Hide resolved
static internal let responseProtoSerializationError = GRPCStatus(code: .internalError, message: "could not serialize response proto")
static internal let unsupportedCompression = GRPCStatus(code: .unimplemented, message: "compression is not supported on the server")
}

protocol GRPCStatusTransformable: Error {
glbrntt marked this conversation as resolved.
Show resolved Hide resolved
func asGRPCStatus() -> GRPCStatus
}

extension GRPCStatus: GRPCStatusTransformable {
func asGRPCStatus() -> GRPCStatus {
return self
}
}
Loading