Skip to content

Commit

Permalink
Improve error handling in NIO server. (grpc#364)
Browse files Browse the repository at this point in the history
* Improve error handling in NIO server.

- Adds a user-configurable error handler to the server
- Updates NIO server codegen to provide an optional error handler
- Errors are handled by GRPCChannelHandler or BaseCallHandler,
  depending on the pipeline state
- Adds some error handling tests
- Tidies some logic in HTTP1ToRawGRPCServerCodec
- Extends message handling logic in HTTP1ToRawGRPCServerCodec
  to handle messages split across multiple ByteBuffers (i.e. when a
  message exceeds a the size of a frame)

* Update error delegate

* Strongly hold errorDelegate in the server until shutdown

* More errors to a dedicated enum, fix typos, etc.

* Renaming, typo fixes

* Split out GRPCChannelHandlerTests and HTTPToRawGRPCServerCodecTests

* Update LinuxMain

* Add missing commas to LinuxMain

* Fix grpc-web testUnaryLotsOfRequests on Linux

* Disable broken Linux test
  • Loading branch information
glbrntt authored and MrMage committed Feb 26, 2019
1 parent 587218a commit 158c4ef
Show file tree
Hide file tree
Showing 23 changed files with 826 additions and 166 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ test-plugin:
test-plugin-nio:
swift build $(CFLAGS) --product protoc-gen-swiftgrpc
protoc Sources/Examples/Echo/echo.proto --proto_path=Sources/Examples/Echo --plugin=.build/debug/protoc-gen-swift --plugin=.build/debug/protoc-gen-swiftgrpc --swiftgrpc_out=/tmp --swiftgrpc_opt=Client=false,NIO=true
diff -u /tmp/echo.grpc.swift Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift
diff -u /tmp/echo.grpc.swift Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift

xcodebuild: project
xcodebuild -project SwiftGRPC.xcodeproj -configuration "Debug" -parallelizeTargets -target SwiftGRPC -target Echo -target Simple -target protoc-gen-swiftgrpc build
Expand Down
64 changes: 57 additions & 7 deletions Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,79 @@ import NIOHTTP1
/// Calls through to `processMessage` for individual messages it receives, which needs to be implemented by subclasses.
public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>: GRPCCallHandler {
public func makeGRPCServerCodec() -> ChannelHandler { return GRPCServerCodec<RequestMessage, ResponseMessage>() }

/// Called whenever a message has been received.
///
/// Overridden by subclasses.
public func processMessage(_ message: RequestMessage) {
public func processMessage(_ message: RequestMessage) throws {
fatalError("needs to be overridden")
}

/// Called when the client has half-closed the stream, indicating that they won't send any further data.
///
/// Overridden by subclasses if the "end-of-stream" event is relevant.
public func endOfStreamReceived() { }

/// Whether this handler can still write messages to the client.
private var serverCanWrite = true

/// Called for each error recieved in `errorCaught(ctx:error:)`.
private weak var errorDelegate: ServerErrorDelegate?

public init(errorDelegate: ServerErrorDelegate?) {
self.errorDelegate = errorDelegate
}
}

extension BaseCallHandler: ChannelInboundHandler {
public typealias InboundIn = GRPCServerRequestPart<RequestMessage>
public typealias OutboundOut = GRPCServerResponsePart<ResponseMessage>

/// Passes errors to the user-provided `errorHandler`. After an error has been received an
/// appropriate status is written. Errors which don't conform to `GRPCStatusTransformable`
/// return a status with code `.internalError`.
public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
errorDelegate?.observe(error)

let transformed = errorDelegate?.transform(error) ?? error
let status = (transformed as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError
self.write(ctx: ctx, data: NIOAny(GRPCServerResponsePart<ResponseMessage>.status(status)), promise: nil)
}

public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
switch self.unwrapInboundIn(data) {
case .head: preconditionFailure("should not have received headers")
case .message(let message): processMessage(message)
case .end: endOfStreamReceived()
case .head(let requestHead):
// Head should have been handled by `GRPCChannelHandler`.
self.errorCaught(ctx: ctx, error: GRPCServerError.invalidState("unexpected request head received \(requestHead)"))

case .message(let message):
do {
try processMessage(message)
} catch {
self.errorCaught(ctx: ctx, error: error)
}

case .end:
endOfStreamReceived()
}
}
}

extension BaseCallHandler: ChannelOutboundHandler {
public typealias OutboundIn = GRPCServerResponsePart<ResponseMessage>
public typealias OutboundOut = GRPCServerResponsePart<ResponseMessage>

public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
guard serverCanWrite else {
promise?.fail(error: GRPCServerError.serverNotWritable)
return
}

// We can only write one status; make sure we don't write again.
if case .status = unwrapOutboundIn(data) {
serverCanWrite = false
ctx.writeAndFlush(data, promise: promise)
} else {
ctx.write(data, promise: promise)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public class BidirectionalStreamingCallHandler<RequestMessage: Message, Response

// We ask for a future of type `EventObserver` to allow the framework user to e.g. asynchronously authenticate a call.
// If authentication fails, they can simply fail the observer future, which causes the call to be terminated.
public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init(errorDelegate: errorDelegate)
let context = StreamingResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
let eventObserver = eventObserverFactory(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage

// We ask for a future of type `EventObserver` to allow the framework user to e.g. asynchronously authenticate a call.
// If authentication fails, they can simply fail the observer future, which causes the call to be terminated.
public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventLoopFuture<EventObserver>) {
super.init(errorDelegate: errorDelegate)
let context = UnaryResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
let eventObserver = eventObserverFactory(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage

private var context: StreamingResponseCallContext<ResponseMessage>?

public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (StreamingResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init(errorDelegate: errorDelegate)
let context = StreamingResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
self.eventObserver = eventObserverFactory(context)
Expand All @@ -26,12 +26,10 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage
}


public override func processMessage(_ message: RequestMessage) {
public override func processMessage(_ message: RequestMessage) throws {
guard let eventObserver = self.eventObserver,
let context = self.context else {
//! FIXME: Better handle this error?
print("multiple messages received on unary call")
return
throw GRPCServerError.requestCardinalityViolation
}

let resultFuture = eventObserver(message)
Expand Down
10 changes: 4 additions & 6 deletions Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>

private var context: UnaryResponseCallContext<ResponseMessage>?

public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init()
public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (UnaryResponseCallContext<ResponseMessage>) -> EventObserver) {
super.init(errorDelegate: errorDelegate)
let context = UnaryResponseCallContextImpl<ResponseMessage>(channel: channel, request: request)
self.context = context
self.eventObserver = eventObserverFactory(context)
Expand All @@ -26,12 +26,10 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>
}
}

public override func processMessage(_ message: RequestMessage) {
public override func processMessage(_ message: RequestMessage) throws {
guard let eventObserver = self.eventObserver,
let context = self.context else {
//! FIXME: Better handle this error?
print("multiple messages received on unary call")
return
throw GRPCServerError.requestCardinalityViolation
}

let resultFuture = eventObserver(message)
Expand Down
46 changes: 33 additions & 13 deletions Sources/SwiftGRPCNIO/GRPCChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public protocol CallHandlerProvider: class {

/// Determines, calls and returns the appropriate request handler (`GRPCCallHandler`), depending on the request's
/// method. Returns nil for methods not handled by this service.
func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel) -> GRPCCallHandler?
func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate?) -> GRPCCallHandler?
}

/// Listens on a newly-opened HTTP2 subchannel and yields to the sub-handler matching a call, if available.
Expand All @@ -28,30 +28,32 @@ public protocol CallHandlerProvider: class {
/// for an `GRPCCallHandler` object. That object is then forwarded the individual gRPC messages.
public final class GRPCChannelHandler {
private let servicesByName: [String: CallHandlerProvider]
private weak var errorDelegate: ServerErrorDelegate?

public init(servicesByName: [String: CallHandlerProvider]) {
public init(servicesByName: [String: CallHandlerProvider], errorDelegate: ServerErrorDelegate?) {
self.servicesByName = servicesByName
self.errorDelegate = errorDelegate
}
}

extension GRPCChannelHandler: ChannelInboundHandler {
public typealias InboundIn = RawGRPCServerRequestPart
public typealias OutboundOut = RawGRPCServerResponsePart

public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
errorDelegate?.observe(error)

let transformedError = errorDelegate?.transform(error) ?? error
let status = (transformedError as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError
ctx.writeAndFlush(wrapOutboundOut(.status(status)), promise: nil)
}

public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
let requestPart = self.unwrapInboundIn(data)
switch requestPart {
case .head(let requestHead):
// URI format: "/package.Servicename/MethodName", resulting in the following components separated by a slash:
// - uriComponents[0]: empty
// - uriComponents[1]: service name (including the package name);
// `CallHandlerProvider`s should provide the service name including the package name.
// - uriComponents[2]: method name.
let uriComponents = requestHead.uri.components(separatedBy: "/")
guard uriComponents.count >= 3 && uriComponents[0].isEmpty,
let providerForServiceName = servicesByName[uriComponents[1]],
let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: ctx.channel) else {
ctx.writeAndFlush(self.wrapOutboundOut(.status(.unimplemented(method: requestHead.uri))), promise: nil)
guard let callHandler = getCallHandler(channel: ctx.channel, requestHead: requestHead) else {
errorCaught(ctx: ctx, error: GRPCServerError.unimplementedMethod(requestHead.uri))
return
}

Expand All @@ -75,7 +77,25 @@ extension GRPCChannelHandler: ChannelInboundHandler {
.whenComplete { ctx.pipeline.remove(handler: self, promise: handlerRemoved) }

case .message, .end:
preconditionFailure("received \(requestPart), should have been removed as a handler at this point")
// We can reach this point if we're receiving messages for a method that isn't implemented.
// A status resposne will have been fired which should also close the stream; there's not a
// lot we can do at this point.
break
}
}

private func getCallHandler(channel: Channel, requestHead: HTTPRequestHead) -> GRPCCallHandler? {
// URI format: "/package.Servicename/MethodName", resulting in the following components separated by a slash:
// - uriComponents[0]: empty
// - uriComponents[1]: service name (including the package name);
// `CallHandlerProvider`s should provide the service name including the package name.
// - uriComponents[2]: method name.
let uriComponents = requestHead.uri.components(separatedBy: "/")
guard uriComponents.count >= 3 && uriComponents[0].isEmpty,
let providerForServiceName = servicesByName[uriComponents[1]],
let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: channel, errorDelegate: errorDelegate) else {
return nil
}
return callHandler
}
}
20 changes: 16 additions & 4 deletions Sources/SwiftGRPCNIO/GRPCServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ public final class GRPCServer {
hostname: String,
port: Int,
eventLoopGroup: EventLoopGroup,
serviceProviders: [CallHandlerProvider]) -> EventLoopFuture<GRPCServer> {
serviceProviders: [CallHandlerProvider],
errorDelegate: ServerErrorDelegate? = LoggingServerErrorDelegate()
) -> EventLoopFuture<GRPCServer> {
let servicesByName = Dictionary(uniqueKeysWithValues: serviceProviders.map { ($0.serviceName, $0) })
let bootstrap = ServerBootstrap(group: eventLoopGroup)
// Specify a backlog to avoid overloading the server.
Expand All @@ -25,7 +27,7 @@ public final class GRPCServer {
return channel.pipeline.add(handler: HTTPProtocolSwitcher {
channel -> EventLoopFuture<Void> in
return channel.pipeline.add(handler: HTTP1ToRawGRPCServerCodec())
.then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName)) }
.then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName, errorDelegate: errorDelegate)) }
})
}

Expand All @@ -34,20 +36,30 @@ public final class GRPCServer {
.childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)

return bootstrap.bind(host: hostname, port: port)
.map { GRPCServer(channel: $0) }
.map { GRPCServer(channel: $0, errorDelegate: errorDelegate) }
}

private let channel: Channel
private var errorDelegate: ServerErrorDelegate?

private init(channel: Channel) {
private init(channel: Channel, errorDelegate: ServerErrorDelegate?) {
self.channel = channel

// Maintain a strong reference to ensure it lives as long as the server.
self.errorDelegate = errorDelegate

// nil out errorDelegate to avoid retain cycles.
onClose.whenComplete {
self.errorDelegate = nil
}
}

/// Fired when the server shuts down.
public var onClose: EventLoopFuture<Void> {
return channel.closeFuture
}

/// Shut down the server; this should be called to avoid leaking resources.
public func close() -> EventLoopFuture<Void> {
return channel.close(mode: .all)
}
Expand Down
8 changes: 5 additions & 3 deletions Sources/SwiftGRPCNIO/GRPCServerCodec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public enum GRPCServerResponsePart<MessageType: Message> {
}

/// A simple channel handler that translates raw gRPC packets into decoded protobuf messages, and vice versa.
public final class GRPCServerCodec<RequestMessage: Message, ResponseMessage: Message> { }
public final class GRPCServerCodec<RequestMessage: Message, ResponseMessage: Message> {}

extension GRPCServerCodec: ChannelInboundHandler {
public typealias InboundIn = RawGRPCServerRequestPart
Expand All @@ -35,8 +35,7 @@ extension GRPCServerCodec: ChannelInboundHandler {
do {
ctx.fireChannelRead(self.wrapInboundOut(.message(try RequestMessage(serializedData: messageAsData))))
} catch {
//! FIXME: Ensure that the last handler in the pipeline returns `.dataLoss` here?
ctx.fireErrorCaught(error)
ctx.fireErrorCaught(GRPCServerError.requestProtoParseFailure)
}

case .end:
Expand All @@ -54,16 +53,19 @@ extension GRPCServerCodec: ChannelOutboundHandler {
switch responsePart {
case .headers(let headers):
ctx.write(self.wrapOutboundOut(.headers(headers)), promise: promise)

case .message(let message):
do {
let messageData = try message.serializedData()
var responseBuffer = ctx.channel.allocator.buffer(capacity: messageData.count)
responseBuffer.write(bytes: messageData)
ctx.write(self.wrapOutboundOut(.message(responseBuffer)), promise: promise)
} catch {
let error = GRPCServerError.responseProtoSerializationFailure
promise?.fail(error: error)
ctx.fireErrorCaught(error)
}

case .status(let status):
ctx.write(self.wrapOutboundOut(.status(status)), promise: promise)
}
Expand Down
Loading

0 comments on commit 158c4ef

Please sign in to comment.