From 10aff097901eb63d20b56c85cfcc2d196492c4e1 Mon Sep 17 00:00:00 2001 From: Michael Rebello Date: Tue, 26 Feb 2019 10:20:49 -0800 Subject: [PATCH] Refactor channel connectivity to avoid multiple spin loops (#380) The current Swift implementation of the gRPC channel's connectivity observer spins up a new ConnectivityObserver instance which then starts a new "spin loop" thread and continually runs, observing changes to the underlying gRPC Core channel's connectivity and piping those back through a callback closure. This means that there's a new spin loop spun up for each observer for each channel. We can avoid having to spin up multiple spin loops for each observer (keeping only 0 or 1 per channel) by allowing a single ConnectivityObserver instance to pipe changes back to multiple callbacks. --- Sources/SwiftGRPC/Core/Channel.swift | 120 +++++------------- .../Core/ChannelConnectivityObserver.swift | 96 ++++++++++++++ .../ChannelConnectivityTests.swift | 25 +++- 3 files changed, 148 insertions(+), 93 deletions(-) create mode 100644 Sources/SwiftGRPC/Core/ChannelConnectivityObserver.swift diff --git a/Sources/SwiftGRPC/Core/Channel.swift b/Sources/SwiftGRPC/Core/Channel.swift index 43f626c09..4fffdf704 100644 --- a/Sources/SwiftGRPC/Core/Channel.swift +++ b/Sources/SwiftGRPC/Core/Channel.swift @@ -14,18 +14,19 @@ * limitations under the License. */ #if SWIFT_PACKAGE - import CgRPC - import Dispatch +import CgRPC #endif import Foundation /// A gRPC Channel public class Channel { + private let mutex = Mutex() /// Pointer to underlying C representation private let underlyingChannel: UnsafeMutableRawPointer - /// Completion queue for channel call operations private let completionQueue: CompletionQueue + /// Observer for connectivity state changes. Created lazily if needed + private var connectivityObserver: ConnectivityObserver? /// Timeout for new calls public var timeout: TimeInterval = 600.0 @@ -33,9 +34,6 @@ public class Channel { /// Default host to use for new calls public var host: String - /// Connectivity state observers - private var connectivityObservers: [ConnectivityObserver] = [] - /// Initializes a gRPC channel /// /// - Parameter address: the address of the server to be called @@ -47,12 +45,12 @@ public class Channel { let argumentWrappers = arguments.map { $0.toCArg() } underlyingChannel = withExtendedLifetime(argumentWrappers) { - var argumentValues = argumentWrappers.map { $0.wrapped } - if secure { - return cgrpc_channel_create_secure(address, kRootCertificates, nil, nil, &argumentValues, Int32(arguments.count)) - } else { - return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count)) - } + var argumentValues = argumentWrappers.map { $0.wrapped } + if secure { + return cgrpc_channel_create_secure(address, kRootCertificates, nil, nil, &argumentValues, Int32(arguments.count)) + } else { + return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count)) + } } completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client") completionQueue.run() // start a loop that watches the channel's completion queue @@ -66,10 +64,10 @@ public class Channel { gRPC.initialize() host = googleAddress let argumentWrappers = arguments.map { $0.toCArg() } - + underlyingChannel = withExtendedLifetime(argumentWrappers) { - var argumentValues = argumentWrappers.map { $0.wrapped } - return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count)) + var argumentValues = argumentWrappers.map { $0.wrapped } + return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count)) } completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client") @@ -89,17 +87,19 @@ public class Channel { let argumentWrappers = arguments.map { $0.toCArg() } underlyingChannel = withExtendedLifetime(argumentWrappers) { - var argumentValues = argumentWrappers.map { $0.wrapped } - return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count)) + var argumentValues = argumentWrappers.map { $0.wrapped } + return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count)) } completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client") completionQueue.run() // start a loop that watches the channel's completion queue } deinit { - connectivityObservers.forEach { $0.shutdown() } - cgrpc_channel_destroy(underlyingChannel) - completionQueue.shutdown() + self.mutex.synchronize { + self.connectivityObserver?.shutdown() + } + cgrpc_channel_destroy(self.underlyingChannel) + self.completionQueue.shutdown() } /// Constructs a Call object to make a gRPC API call @@ -109,7 +109,7 @@ public class Channel { /// - Parameter timeout: a timeout value in seconds /// - Returns: a Call object that can be used to perform the request public func makeCall(_ method: String, host: String = "", timeout: TimeInterval? = nil) -> Call { - let host = (host == "") ? self.host : host + let host = host.isEmpty ? self.host : host let timeout = timeout ?? self.timeout let underlyingCall = cgrpc_channel_create_call(underlyingChannel, method, host, timeout)! return Call(underlyingCall: underlyingCall, owned: true, completionQueue: completionQueue) @@ -126,77 +126,17 @@ public class Channel { /// Subscribe to connectivity state changes /// /// - Parameter callback: block executed every time a new connectivity state is detected - public func subscribe(callback: @escaping (ConnectivityState) -> Void) { - connectivityObservers.append(ConnectivityObserver(underlyingChannel: underlyingChannel, currentState: connectivityState(), callback: callback)) - } -} - -private extension Channel { - final class ConnectivityObserver { - private let completionQueue: CompletionQueue - private let underlyingChannel: UnsafeMutableRawPointer - private let underlyingCompletionQueue: UnsafeMutableRawPointer - private let callback: (ConnectivityState) -> Void - private var lastState: ConnectivityState - private var hasBeenShutdown = false - private let stateMutex: Mutex = Mutex() - - init(underlyingChannel: UnsafeMutableRawPointer, currentState: ConnectivityState, callback: @escaping (ConnectivityState) -> ()) { - self.underlyingChannel = underlyingChannel - self.underlyingCompletionQueue = cgrpc_completion_queue_create_for_next() - self.completionQueue = CompletionQueue(underlyingCompletionQueue: self.underlyingCompletionQueue, name: "Connectivity State") - self.callback = callback - self.lastState = currentState - run() - } - - deinit { - shutdown() - } - - private func run() { - let spinloopThreadQueue = DispatchQueue(label: "SwiftGRPC.ConnectivityObserver.run.spinloopThread") - - spinloopThreadQueue.async { - while true { - guard (self.stateMutex.synchronize{ !self.hasBeenShutdown }) else { - return - } - - guard let underlyingState = self.lastState.underlyingState else { return } - - let deadline: TimeInterval = 0.2 - cgrpc_channel_watch_connectivity_state(self.underlyingChannel, self.underlyingCompletionQueue, underlyingState, deadline, nil) - let event = self.completionQueue.wait(timeout: deadline) - - guard (self.stateMutex.synchronize{ !self.hasBeenShutdown }) else { - return - } - - switch event.type { - case .complete: - let newState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0)) - - if newState != self.lastState { - self.callback(newState) - } - self.lastState = newState - - case .queueShutdown: - return - - default: - continue - } - } + public func addConnectivityObserver(callback: @escaping (ConnectivityState) -> Void) { + self.mutex.synchronize { + let observer: ConnectivityObserver + if let existingObserver = self.connectivityObserver { + observer = existingObserver + } else { + observer = ConnectivityObserver(underlyingChannel: self.underlyingChannel) + self.connectivityObserver = observer } - } - func shutdown() { - stateMutex.synchronize { - hasBeenShutdown = true - } - completionQueue.shutdown() + observer.addConnectivityObserver(callback: callback) } } } diff --git a/Sources/SwiftGRPC/Core/ChannelConnectivityObserver.swift b/Sources/SwiftGRPC/Core/ChannelConnectivityObserver.swift new file mode 100644 index 000000000..ecef95f56 --- /dev/null +++ b/Sources/SwiftGRPC/Core/ChannelConnectivityObserver.swift @@ -0,0 +1,96 @@ +/* + * Copyright 2016, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if SWIFT_PACKAGE +import CgRPC +import Dispatch +#endif +import Foundation + +extension Channel { + /// Provides an interface for observing the connectivity of a given channel. + final class ConnectivityObserver { + private let mutex = Mutex() + private let completionQueue: CompletionQueue + private let underlyingChannel: UnsafeMutableRawPointer + private let underlyingCompletionQueue: UnsafeMutableRawPointer + private var callbacks = [(ConnectivityState) -> Void]() + private var hasBeenShutdown = false + + init(underlyingChannel: UnsafeMutableRawPointer) { + self.underlyingChannel = underlyingChannel + self.underlyingCompletionQueue = cgrpc_completion_queue_create_for_next() + self.completionQueue = CompletionQueue(underlyingCompletionQueue: self.underlyingCompletionQueue, + name: "Connectivity State") + self.run() + } + + deinit { + self.shutdown() + } + + func addConnectivityObserver(callback: @escaping (ConnectivityState) -> Void) { + self.mutex.synchronize { + self.callbacks.append(callback) + } + } + + func shutdown() { + self.mutex.synchronize { + guard !self.hasBeenShutdown else { return } + + self.hasBeenShutdown = true + self.completionQueue.shutdown() + } + } + + // MARK: - Private + + private func run() { + let spinloopThreadQueue = DispatchQueue(label: "SwiftGRPC.ConnectivityObserver.run.spinloopThread") + var lastState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0)) + spinloopThreadQueue.async { + while (self.mutex.synchronize { !self.hasBeenShutdown }) { + guard let underlyingState = lastState.underlyingState else { return } + + let deadline: TimeInterval = 0.2 + cgrpc_channel_watch_connectivity_state(self.underlyingChannel, self.underlyingCompletionQueue, + underlyingState, deadline, nil) + + let event = self.completionQueue.wait(timeout: deadline) + guard (self.mutex.synchronize { !self.hasBeenShutdown }) else { + return + } + + switch event.type { + case .complete: + let newState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0)) + guard newState != lastState else { continue } + + let callbacks = self.mutex.synchronize { Array(self.callbacks) } + lastState = newState + callbacks.forEach { callback in callback(newState) } + + case .queueShutdown: + return + + default: + continue + } + } + } + } + } +} diff --git a/Tests/SwiftGRPCTests/ChannelConnectivityTests.swift b/Tests/SwiftGRPCTests/ChannelConnectivityTests.swift index 2271e15cf..e0e628473 100644 --- a/Tests/SwiftGRPCTests/ChannelConnectivityTests.swift +++ b/Tests/SwiftGRPCTests/ChannelConnectivityTests.swift @@ -21,7 +21,8 @@ final class ChannelConnectivityTests: BasicEchoTestCase { static var allTests: [(String, (ChannelConnectivityTests) -> () throws -> Void)] { return [ - ("testDanglingConnectivityObserversDontCrash", testDanglingConnectivityObserversDontCrash) + ("testDanglingConnectivityObserversDontCrash", testDanglingConnectivityObserversDontCrash), + ("testMultipleConnectivityObserversAreCalled", testMultipleConnectivityObserversAreCalled), ] } } @@ -30,12 +31,12 @@ extension ChannelConnectivityTests { func testDanglingConnectivityObserversDontCrash() { let completionHandlerExpectation = expectation(description: "completion handler called") - client?.channel.subscribe { connectivityState in + client.channel.addConnectivityObserver { connectivityState in print("ConnectivityState: \(connectivityState)") } let request = Echo_EchoRequest(text: "foo bar baz foo bar baz") - _ = try! client!.expand(request) { callResult in + _ = try! client.expand(request) { callResult in print("callResult.statusCode: \(callResult.statusCode)") completionHandlerExpectation.fulfill() } @@ -46,4 +47,22 @@ extension ChannelConnectivityTests { waitForExpectations(timeout: 0.5) } + + func testMultipleConnectivityObserversAreCalled() { + // Linux doesn't yet support `assertForOverFulfill` or `expectedFulfillmentCount`, and since these are + // called multiple times, we can't use expectations. https://bugs.swift.org/browse/SR-6249 + var firstObserverCalled = false + var secondObserverCalled = false + client.channel.addConnectivityObserver { _ in firstObserverCalled = true } + client.channel.addConnectivityObserver { _ in secondObserverCalled = true } + + let completionHandlerExpectation = expectation(description: "completion handler called") + _ = try! client.expand(Echo_EchoRequest(text: "foo bar baz foo bar baz")) { _ in + completionHandlerExpectation.fulfill() + } + + waitForExpectations(timeout: 0.5) + XCTAssertTrue(firstObserverCalled) + XCTAssertTrue(secondObserverCalled) + } }