Skip to content

Commit

Permalink
sync: reduce contention in broadcast channel
Browse files Browse the repository at this point in the history
Implement atomic linked list that allows pushing
waiters concurrently, which reduces contention.

Fixes: tokio-rs#5465
  • Loading branch information
vnetserg committed Jan 14, 2024
1 parent 4f98a68 commit a6856c2
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 63 deletions.
2 changes: 1 addition & 1 deletion tokio/src/loom/std/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub(crate) mod sync {
#[cfg(all(feature = "parking_lot", not(miri)))]
#[allow(unused_imports)]
pub(crate) use crate::loom::std::parking_lot::{
Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult,
Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard, WaitTimeoutResult,
};

#[cfg(not(all(feature = "parking_lot", not(miri))))]
Expand Down
88 changes: 57 additions & 31 deletions tokio/src/sync/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@

use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
use crate::util::linked_list::{self, GuardedLinkedList, LinkedList};
use crate::loom::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use crate::util::linked_list::{self, AtomicLinkedList, GuardedLinkedList};
use crate::util::WakeList;

use std::fmt;
Expand Down Expand Up @@ -310,7 +310,7 @@ struct Shared<T> {
mask: usize,

/// Tail of the queue. Includes the rx wait list.
tail: Mutex<Tail>,
tail: RwLock<Tail>,

/// Number of outstanding Sender handles.
num_tx: AtomicUsize,
Expand All @@ -328,7 +328,7 @@ struct Tail {
closed: bool,

/// Receivers waiting for a value.
waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
waiters: AtomicLinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
}

/// Slot in the buffer.
Expand Down Expand Up @@ -521,11 +521,11 @@ impl<T> Sender<T> {
let shared = Arc::new(Shared {
buffer: buffer.into_boxed_slice(),
mask: capacity - 1,
tail: Mutex::new(Tail {
tail: RwLock::new(Tail {
pos: 0,
rx_cnt: receiver_count,
closed: false,
waiters: LinkedList::new(),
waiters: AtomicLinkedList::new(),
}),
num_tx: AtomicUsize::new(1),
});
Expand Down Expand Up @@ -585,7 +585,7 @@ impl<T> Sender<T> {
/// }
/// ```
pub fn send(&self, value: T) -> Result<usize, SendError<T>> {
let mut tail = self.shared.tail.lock();
let mut tail = self.shared.tail.write().unwrap();

if tail.rx_cnt == 0 {
return Err(SendError(value));
Expand Down Expand Up @@ -688,7 +688,7 @@ impl<T> Sender<T> {
/// }
/// ```
pub fn len(&self) -> usize {
let tail = self.shared.tail.lock();
let tail = self.shared.tail.read().unwrap();

let base_idx = (tail.pos & self.shared.mask as u64) as usize;
let mut low = 0;
Expand Down Expand Up @@ -735,7 +735,7 @@ impl<T> Sender<T> {
/// }
/// ```
pub fn is_empty(&self) -> bool {
let tail = self.shared.tail.lock();
let tail = self.shared.tail.read().unwrap();

let idx = (tail.pos.wrapping_sub(1) & self.shared.mask as u64) as usize;
self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0
Expand Down Expand Up @@ -778,7 +778,7 @@ impl<T> Sender<T> {
/// }
/// ```
pub fn receiver_count(&self) -> usize {
let tail = self.shared.tail.lock();
let tail = self.shared.tail.read().unwrap();
tail.rx_cnt
}

Expand Down Expand Up @@ -806,7 +806,7 @@ impl<T> Sender<T> {
}

fn close_channel(&self) {
let mut tail = self.shared.tail.lock();
let mut tail = self.shared.tail.write().unwrap();
tail.closed = true;

self.shared.notify_rx(tail);
Expand All @@ -815,7 +815,7 @@ impl<T> Sender<T> {

/// Create a new `Receiver` which reads starting from the tail.
fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> {
let mut tail = shared.tail.lock();
let mut tail = shared.tail.write().unwrap();

assert!(tail.rx_cnt != MAX_RECEIVERS, "max receivers");

Expand All @@ -842,20 +842,20 @@ impl<'a, T> Drop for WaitersList<'a, T> {
// If the list is not empty, we unlink all waiters from it.
// We do not wake the waiters to avoid double panics.
if !self.is_empty {
let _lock_guard = self.shared.tail.lock();
let _lock_guard = self.shared.tail.write().unwrap();
while self.list.pop_back().is_some() {}
}
}
}

impl<'a, T> WaitersList<'a, T> {
fn new(
unguarded_list: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
unguarded_list: AtomicLinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
guard: Pin<&'a Waiter>,
shared: &'a Shared<T>,
) -> Self {
let guard_ptr = NonNull::from(guard.get_ref());
let list = unguarded_list.into_guarded(guard_ptr);
let list = unguarded_list.into_list().into_guarded(guard_ptr);
WaitersList {
list,
is_empty: false,
Expand All @@ -877,7 +877,7 @@ impl<'a, T> WaitersList<'a, T> {
}

impl<T> Shared<T> {
fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: MutexGuard<'a, Tail>) {
fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: RwLockWriteGuard<'a, Tail>) {
// It is critical for `GuardedLinkedList` safety that the guard node is
// pinned in memory and is not dropped until the guarded list is dropped.
let guard = Waiter::new();
Expand Down Expand Up @@ -925,7 +925,7 @@ impl<T> Shared<T> {
wakers.wake_all();

// Acquire the lock again.
tail = self.tail.lock();
tail = self.tail.write().unwrap();
}

// Release the lock before waking.
Expand Down Expand Up @@ -987,7 +987,7 @@ impl<T> Receiver<T> {
/// }
/// ```
pub fn len(&self) -> usize {
let next_send_pos = self.shared.tail.lock().pos;
let next_send_pos = self.shared.tail.read().unwrap().pos;
(next_send_pos - self.next) as usize
}

Expand Down Expand Up @@ -1065,7 +1065,7 @@ impl<T> Receiver<T> {

let mut old_waker = None;

let mut tail = self.shared.tail.lock();
let tail = self.shared.tail.read().unwrap();

// Acquire slot lock again
slot = self.shared.buffer[idx].read().unwrap();
Expand All @@ -1086,7 +1086,16 @@ impl<T> Receiver<T> {

// Store the waker
if let Some((waiter, waker)) = waiter {
// Safety: called while locked.
// Safety: called while holding a read lock on tail.
// It suffices since we only update two waiter members:
// - `waiter.waker` - all other accesses of this member are
// write-lock protected,
// - `waiter.queued` - all other accesses of this member are
// either write-lock protected or read-lock protected with
// exclusive reference to the `Recv` that contains the waiter.
// Concurrent calls to `recv_ref` with the same waiter
// are impossible because it implies ownership of the `Recv`
// that contains it.
unsafe {
// Only queue if not already queued
waiter.with_mut(|ptr| {
Expand All @@ -1106,6 +1115,11 @@ impl<T> Receiver<T> {

if !(*ptr).queued {
(*ptr).queued = true;
// Safety:
// - `waiter` is not already queued,
// - calling `recv_ref` with a waiter implies ownership
// of it's `Recv`. As such, this waiter cannot be pushed
// concurrently by some other thread.
tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr));
}
});
Expand Down Expand Up @@ -1331,7 +1345,7 @@ impl<T: Clone> Receiver<T> {

impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let mut tail = self.shared.tail.lock();
let mut tail = self.shared.tail.write().unwrap();

tail.rx_cnt -= 1;
let until = tail.pos;
Expand Down Expand Up @@ -1402,22 +1416,34 @@ where

impl<'a, T> Drop for Recv<'a, T> {
fn drop(&mut self) {
// Acquire the tail lock. This is required for safety before accessing
// Acquire a read lock on tail. This is required for safety before accessing
// the waiter node.
let mut tail = self.receiver.shared.tail.lock();
let tail = self.receiver.shared.tail.read().unwrap();

// safety: tail lock is held
// Safety: we hold read lock on tail AND have exclusive reference to `Recv`.
let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued });

if queued {
// Remove the node
// Optimistic check failed. To remove the waiter, we need a write lock.
drop(tail);
let mut tail = self.receiver.shared.tail.write().unwrap();

// Double check that the waiter is still enqueued,
// in case it was removed before we reacquired the lock.
//
// safety: tail lock is held and the wait node is verified to be in
// the list.
unsafe {
self.waiter.with_mut(|ptr| {
tail.waiters.remove((&mut *ptr).into());
});
// Safety: tail write lock is held.
let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued });

if queued {
// Remove the node.
//
// Safety: tail write lock is held and the wait node is verified to be in
// the list.
unsafe {
self.waiter.with_mut(|ptr| {
tail.waiters.remove((&mut *ptr).into());
});
}
}
}
}
Expand Down
Loading

0 comments on commit a6856c2

Please sign in to comment.