Skip to content

Commit

Permalink
Refactor channel connectivity to avoid multiple spin loops
Browse files Browse the repository at this point in the history
The current Swift implementation of the gRPC channel's connectivity observer spins up a new ConnectivityObserver instance which then starts a new "spin loop" thread and continually runs, observing changes to the underlying gRPC Core channel's connectivity and piping those back through a callback closure. This means that there's a new spin loop spun up for each observer for each channel.

We can avoid having to spin up multiple spin loops for each observer (keeping only 0 or 1 per channel) by allowing a single ConnectivityObserver instance to pipe changes back to multiple callbacks.
  • Loading branch information
rebello95 committed Feb 25, 2019
1 parent 5f24269 commit 0cd0451
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 41 deletions.
80 changes: 42 additions & 38 deletions Sources/SwiftGRPC/Core/Channel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,26 @@
* limitations under the License.
*/
#if SWIFT_PACKAGE
import CgRPC
import Dispatch
import CgRPC
import Dispatch
#endif
import Foundation

/// A gRPC Channel
public class Channel {
/// Pointer to underlying C representation
private let underlyingChannel: UnsafeMutableRawPointer

/// Completion queue for channel call operations
private let completionQueue: CompletionQueue
/// Observer for connectivity state changes.
private lazy var connectivityObserver = ConnectivityObserver(underlyingChannel: self.underlyingChannel)

/// Timeout for new calls
public var timeout: TimeInterval = 600.0

/// Default host to use for new calls
public var host: String

/// Connectivity state observers
private var connectivityObservers: [ConnectivityObserver] = []

/// Initializes a gRPC channel
///
/// - Parameter address: the address of the server to be called
Expand All @@ -47,12 +45,12 @@ public class Channel {
let argumentWrappers = arguments.map { $0.toCArg() }

underlyingChannel = withExtendedLifetime(argumentWrappers) {
var argumentValues = argumentWrappers.map { $0.wrapped }
if secure {
return cgrpc_channel_create_secure(address, kRootCertificates, nil, nil, &argumentValues, Int32(arguments.count))
} else {
return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count))
}
var argumentValues = argumentWrappers.map { $0.wrapped }
if secure {
return cgrpc_channel_create_secure(address, kRootCertificates, nil, nil, &argumentValues, Int32(arguments.count))
} else {
return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count))
}
}
completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
completionQueue.run() // start a loop that watches the channel's completion queue
Expand All @@ -66,10 +64,10 @@ public class Channel {
gRPC.initialize()
host = googleAddress
let argumentWrappers = arguments.map { $0.toCArg() }

underlyingChannel = withExtendedLifetime(argumentWrappers) {
var argumentValues = argumentWrappers.map { $0.wrapped }
return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count))
var argumentValues = argumentWrappers.map { $0.wrapped }
return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count))
}

completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
Expand All @@ -89,15 +87,15 @@ public class Channel {
let argumentWrappers = arguments.map { $0.toCArg() }

underlyingChannel = withExtendedLifetime(argumentWrappers) {
var argumentValues = argumentWrappers.map { $0.wrapped }
return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count))
var argumentValues = argumentWrappers.map { $0.wrapped }
return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count))
}
completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
completionQueue.run() // start a loop that watches the channel's completion queue
}

deinit {
connectivityObservers.forEach { $0.shutdown() }
connectivityObserver.shutdown()
cgrpc_channel_destroy(underlyingChannel)
completionQueue.shutdown()
}
Expand All @@ -109,7 +107,7 @@ public class Channel {
/// - Parameter timeout: a timeout value in seconds
/// - Returns: a Call object that can be used to perform the request
public func makeCall(_ method: String, host: String = "", timeout: TimeInterval? = nil) -> Call {
let host = (host == "") ? self.host : host
let host = host.isEmpty ? self.host : host
let timeout = timeout ?? self.timeout
let underlyingCall = cgrpc_channel_create_call(underlyingChannel, method, host, timeout)!
return Call(underlyingCall: underlyingCall, owned: true, completionQueue: completionQueue)
Expand All @@ -126,8 +124,8 @@ public class Channel {
/// Subscribe to connectivity state changes
///
/// - Parameter callback: block executed every time a new connectivity state is detected
public func subscribe(callback: @escaping (ConnectivityState) -> Void) {
connectivityObservers.append(ConnectivityObserver(underlyingChannel: underlyingChannel, currentState: connectivityState(), callback: callback))
public func addConnectivityObserver(callback: @escaping (ConnectivityState) -> Void) {
connectivityObserver.addConnectivityObserver(callback: callback)
}
}

Expand All @@ -136,18 +134,16 @@ private extension Channel {
private let completionQueue: CompletionQueue
private let underlyingChannel: UnsafeMutableRawPointer
private let underlyingCompletionQueue: UnsafeMutableRawPointer
private let callback: (ConnectivityState) -> Void
private var lastState: ConnectivityState
private var callbacks = [(ConnectivityState) -> Void]()
private var hasBeenShutdown = false
private let stateMutex: Mutex = Mutex()
private let stateMutex = Mutex()

init(underlyingChannel: UnsafeMutableRawPointer, currentState: ConnectivityState, callback: @escaping (ConnectivityState) -> ()) {
init(underlyingChannel: UnsafeMutableRawPointer) {
self.underlyingChannel = underlyingChannel
self.underlyingCompletionQueue = cgrpc_completion_queue_create_for_next()
self.completionQueue = CompletionQueue(underlyingCompletionQueue: self.underlyingCompletionQueue, name: "Connectivity State")
self.callback = callback
self.lastState = currentState
run()
self.completionQueue = CompletionQueue(underlyingCompletionQueue: self.underlyingCompletionQueue,
name: "Connectivity State")
self.run()
}

deinit {
Expand All @@ -156,31 +152,33 @@ private extension Channel {

private func run() {
let spinloopThreadQueue = DispatchQueue(label: "SwiftGRPC.ConnectivityObserver.run.spinloopThread")

var lastState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0))
spinloopThreadQueue.async {
while true {
guard (self.stateMutex.synchronize{ !self.hasBeenShutdown }) else {
guard (self.stateMutex.synchronize { !self.hasBeenShutdown }) else {
return
}
guard let underlyingState = self.lastState.underlyingState else { return }

guard let underlyingState = lastState.underlyingState else { return }

let deadline: TimeInterval = 0.2
cgrpc_channel_watch_connectivity_state(self.underlyingChannel, self.underlyingCompletionQueue, underlyingState, deadline, nil)
cgrpc_channel_watch_connectivity_state(self.underlyingChannel, self.underlyingCompletionQueue,
underlyingState, deadline, nil)

let event = self.completionQueue.wait(timeout: deadline)

guard (self.stateMutex.synchronize{ !self.hasBeenShutdown }) else {
return
}

switch event.type {
case .complete:
let newState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0))
guard newState != lastState else { continue }

if newState != self.lastState {
self.callback(newState)
lastState = newState
self.stateMutex.synchronize {
self.callbacks.forEach { callback in callback(newState) }
}
self.lastState = newState

case .queueShutdown:
return
Expand All @@ -192,6 +190,12 @@ private extension Channel {
}
}

func addConnectivityObserver(callback: @escaping (ConnectivityState) -> Void) {
self.stateMutex.synchronize {
self.callbacks.append(callback)
}
}

func shutdown() {
stateMutex.synchronize {
hasBeenShutdown = true
Expand Down
24 changes: 21 additions & 3 deletions Tests/SwiftGRPCTests/ChannelConnectivityTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ final class ChannelConnectivityTests: BasicEchoTestCase {

static var allTests: [(String, (ChannelConnectivityTests) -> () throws -> Void)] {
return [
("testDanglingConnectivityObserversDontCrash", testDanglingConnectivityObserversDontCrash)
("testDanglingConnectivityObserversDontCrash", testDanglingConnectivityObserversDontCrash),
("testMultipleConnectivityObserversAreCalled", testMultipleConnectivityObserversAreCalled),
]
}
}
Expand All @@ -30,12 +31,12 @@ extension ChannelConnectivityTests {
func testDanglingConnectivityObserversDontCrash() {
let completionHandlerExpectation = expectation(description: "completion handler called")

client?.channel.subscribe { connectivityState in
client.channel.addConnectivityObserver { connectivityState in
print("ConnectivityState: \(connectivityState)")
}

let request = Echo_EchoRequest(text: "foo bar baz foo bar baz")
_ = try! client!.expand(request) { callResult in
_ = try! client.expand(request) { callResult in
print("callResult.statusCode: \(callResult.statusCode)")
completionHandlerExpectation.fulfill()
}
Expand All @@ -46,4 +47,21 @@ extension ChannelConnectivityTests {

waitForExpectations(timeout: 0.5)
}

func testMultipleConnectivityObserversAreCalled() {
let completionHandlerExpectation = expectation(description: "completion handler called")
var firstObserverCalled = false
var secondObserverCalled = false

client.channel.addConnectivityObserver { _ in firstObserverCalled = true }
client.channel.addConnectivityObserver { _ in secondObserverCalled = true }

_ = try! client.expand(Echo_EchoRequest(text: "foo bar baz foo bar baz")) { _ in
completionHandlerExpectation.fulfill()
}

waitForExpectations(timeout: 0.5)
XCTAssertTrue(firstObserverCalled)
XCTAssertTrue(secondObserverCalled)
}
}

0 comments on commit 0cd0451

Please sign in to comment.