Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync: make notify_waiters calls atomic #5458

Merged
merged 10 commits into from
Feb 19, 2023
190 changes: 148 additions & 42 deletions tokio/src/sync/notify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Mutex;
use crate::util::linked_list::{self, LinkedList};
use crate::util::linked_list::{self, GuardedLinkedList, LinkedList};
use crate::util::WakeList;

use std::cell::UnsafeCell;
Expand All @@ -20,6 +20,7 @@ use std::sync::atomic::Ordering::SeqCst;
use std::task::{Context, Poll, Waker};

type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
type GuardedWaitList = GuardedLinkedList<Waiter, <Waiter as linked_list::Link>::Target>;

/// Notifies a single task to wake up.
///
Expand Down Expand Up @@ -198,10 +199,16 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
/// [`Semaphore`]: crate::sync::Semaphore
#[derive(Debug)]
pub struct Notify {
// This uses 2 bits to store one of `EMPTY`,
// `state` uses 2 bits to store one of `EMPTY`,
// `WAITING` or `NOTIFIED`. The rest of the bits
// are used to store the number of times `notify_waiters`
// was called.
//
// Throughout the code there are two assumptions:
// - state can be transitioned *from* `WAITING` only if
// `waiters` lock is held
// - number of times `notify_waiters` was called can
// be modified only if `waiters` lock is held
state: AtomicUsize,
waiters: Mutex<WaitList>,
}
Expand Down Expand Up @@ -229,6 +236,17 @@ struct Waiter {
_p: PhantomPinned,
}

impl Waiter {
fn new() -> Waiter {
Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
notified: None,
_p: PhantomPinned,
}
}
}

generate_addr_of_methods! {
impl<> Waiter {
unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> {
Expand All @@ -237,6 +255,59 @@ generate_addr_of_methods! {
}
}

/// List used in `Notify::notify_waiters`. It wraps a guarded linked list
/// and gates the access to it on `notify.waiters` mutex. It also empties
/// the list on drop.
struct NotifyWaitersList<'a> {
list: GuardedWaitList,
is_empty: bool,
notify: &'a Notify,
}

impl<'a> NotifyWaitersList<'a> {
fn new(
unguarded_list: WaitList,
guard: Pin<&'a mut UnsafeCell<Waiter>>,
notify: &'a Notify,
) -> NotifyWaitersList<'a> {
// Safety: pointer to the guarding waiter is not null.
let guard_ptr = unsafe { NonNull::new_unchecked(guard.get()) };
let list = unguarded_list.into_guarded(guard_ptr);
NotifyWaitersList {
list,
is_empty: false,
notify,
}
}

/// Removes the last element from the guarded list. Modifying this list
/// requires an exclusive access to the main list in `Notify`.
fn pop_back_locked(&mut self, _waiters: &mut WaitList) -> Option<NonNull<Waiter>> {
let result = self.list.pop_back();
if result.is_none() {
// Save information about emptiness to avoid waiting for lock
// in the destructor.
self.is_empty = true;
}
result
}
}

impl Drop for NotifyWaitersList<'_> {
fn drop(&mut self) {
// If the list is not empty, we unlink all waiters from it.
// We do not wake the waiters to avoid double panics.
if !self.is_empty {
let _lock_guard = self.notify.waiters.lock();
while let Some(mut waiter) = self.list.pop_back() {
// Safety: we hold the lock.
let waiter = unsafe { waiter.as_mut() };
waiter.notified = Some(NotificationType::AllWaiters);
}
}
}
Comment on lines +297 to +308
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this destructor does not call waker.wake(), which means the remaining waiters will never be woken up unless manually polled. However, this is already the case in the current code because one panicking waker from a batch can result in the whole batch never being notified, see the discussion linked in #4069.

}

/// Future returned from [`Notify::notified()`].
///
/// This future is fused, so once it has completed, any future calls to poll
Expand All @@ -249,6 +320,9 @@ pub struct Notified<'a> {
/// The current state of the receiving process.
state: State,

/// Number of calls to `notify_waiters` at the time of creation.
notify_waiters_calls: usize,

/// Entry in the waiter `LinkedList`.
waiter: UnsafeCell<Waiter>,
}
Expand All @@ -258,7 +332,7 @@ unsafe impl<'a> Sync for Notified<'a> {}

#[derive(Debug)]
enum State {
Init(usize),
Init,
Waiting,
Done,
}
Expand Down Expand Up @@ -383,17 +457,13 @@ impl Notify {
/// ```
pub fn notified(&self) -> Notified<'_> {
// we load the number of times notify_waiters
// was called and store that in our initial state
// was called and store that in the future.
let state = self.state.load(SeqCst);
Notified {
notify: self,
state: State::Init(state >> NOTIFY_WAITERS_SHIFT),
waiter: UnsafeCell::new(Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
notified: None,
_p: PhantomPinned,
}),
state: State::Init,
notify_waiters_calls: get_num_notify_waiters_calls(state),
waiter: UnsafeCell::new(Waiter::new()),
}
}

Expand Down Expand Up @@ -500,12 +570,9 @@ impl Notify {
/// }
/// ```
pub fn notify_waiters(&self) {
let mut wakers = WakeList::new();

// There are waiters, the lock must be acquired to notify.
let mut waiters = self.waiters.lock();

// The state must be reloaded while the lock is held. The state may only
// The state must be loaded while the lock is held. The state may only
// transition out of WAITING while the lock is held.
let curr = self.state.load(SeqCst);

Expand All @@ -516,12 +583,30 @@ impl Notify {
return;
}

// At this point, it is guaranteed that the state will not
// concurrently change, as holding the lock is required to
// transition **out** of `WAITING`.
// Increment the number of times this method was called
// and transition to empty.
let new_state = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this moved up? I'm not sure it matters, but it is not apparent it does not matter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is correct as long as the store happens while the mutex is held. Moving the store up increases the odds of a poll_notified succeeding early in some rare concurrent conditions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was moved up for correctness. If we allowed to poll a pending future between chunks and observe the old counter value, then it would be possible to observe the inconsistency from the description (number 2.). This is because such future would return Pending, even though other waiters from the decoupled list could be already notified. notify_waiters_poll_consistency_many checks such scenarios.

self.state.store(new_state, SeqCst);

// It is critical for `GuardedLinkedList` safety that the guard node is
// pinned in memory and is not dropped until the guarded list is dropped.
let guard = UnsafeCell::new(Waiter::new());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to pin guard to ensure it doesn't accidentally move (the pin! macro might work here). Also, could you add a big comment saying it is critical for safety that guard does not move and is not dropped until the guarded list is dropped?

pin!(guard);

// We move all waiters to a secondary list. It uses a `GuardedLinkedList`
// underneath to allow every waiter to safely remove itself from it.
//
// * This list will be still guarded by the `waiters` lock.
// `NotifyWaitersList` wrapper makes sure we hold the lock to modify it.
// * This wrapper will empty the list on drop. It is critical for safety
// that we will not leave any list entry with a pointer to the local
// guard node after this function returns / panics.
let mut list = NotifyWaitersList::new(std::mem::take(&mut *waiters), guard, self);

let mut wakers = WakeList::new();
'outer: loop {
while wakers.can_push() {
match waiters.pop_back() {
match list.pop_back_locked(&mut waiters) {
Some(mut waiter) => {
// Safety: `waiters` lock is still held.
let waiter = unsafe { waiter.as_mut() };
Expand All @@ -540,20 +625,17 @@ impl Notify {
}
}

// Release the lock before notifying.
drop(waiters);

// One of the wakers may panic, but the remaining waiters will still
// be unlinked from the list in `NotifyWaitersList` destructor.
wakers.wake_all();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We must clean up the linked list even if this call panics.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch. I moved the list to a new struct, which makes sure the list is cleaned up on drop.


// Acquire the lock again.
waiters = self.waiters.lock();
}

// All waiters will be notified, the state must be transitioned to
// `EMPTY`. As transitioning **from** `WAITING` requires the lock to be
// held, a `store` is sufficient.
let new = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
self.state.store(new, SeqCst);

// Release the lock before notifying
drop(waiters);

Expand Down Expand Up @@ -730,26 +812,32 @@ impl Notified<'_> {

/// A custom `project` implementation is used in place of `pin-project-lite`
/// as a custom drop implementation is needed.
fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &UnsafeCell<Waiter>) {
fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &usize, &UnsafeCell<Waiter>) {
unsafe {
// Safety: both `notify` and `state` are `Unpin`.
// Safety: `notify`, `state` and `notify_waiters_calls` are `Unpin`.

is_unpin::<&Notify>();
is_unpin::<AtomicUsize>();
is_unpin::<usize>();

let me = self.get_unchecked_mut();
(me.notify, &mut me.state, &me.waiter)
(
me.notify,
&mut me.state,
&me.notify_waiters_calls,
&me.waiter,
)
}
}

fn poll_notified(self: Pin<&mut Self>, waker: Option<&Waker>) -> Poll<()> {
use State::*;

let (notify, state, waiter) = self.project();
let (notify, state, notify_waiters_calls, waiter) = self.project();

loop {
match *state {
Init(initial_notify_waiters_calls) => {
Init => {
let curr = notify.state.load(SeqCst);

// Optimistically try acquiring a pending notification
Expand Down Expand Up @@ -779,7 +867,7 @@ impl Notified<'_> {

// if notify_waiters has been called after the future
// was created, then we are done
if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls {
if get_num_notify_waiters_calls(curr) != *notify_waiters_calls {
*state = Done;
return Poll::Ready(());
}
Expand Down Expand Up @@ -846,21 +934,37 @@ impl Notified<'_> {
return Poll::Pending;
}
Waiting => {
// Currently in the "Waiting" state, implying the caller has
// a waiter stored in the waiter list (guarded by
// `notify.waiters`). In order to access the waker fields,
// we must hold the lock.
// Currently in the "Waiting" state, implying the caller has a waiter stored in
// a waiter list (guarded by `notify.waiters`). In order to access the waker
// fields, we must acquire the lock.

let waiters = notify.waiters.lock();
let mut waiters = notify.waiters.lock();

// Load the state with the lock held.
let curr = notify.state.load(SeqCst);

// Safety: called while locked
let w = unsafe { &mut *waiter.get() };

if w.notified.is_some() {
// Our waker has been notified. Reset the fields and
// remove it from the list.
w.waker = None;
// Our waker has been notified and our waiter is already removed from
// the list. Reset the notification and convert to `Done`.
w.notified = None;
w.waker = None;
*state = Done;
} else if get_num_notify_waiters_calls(curr) != *notify_waiters_calls {
// Before we add a waiter to the list we check if these numbers are
// different while holding the lock. If these numbers are different now,
// it means that there is a call to `notify_waiters` in progress and this
// waiter must be contained by a guarded list used in `notify_waiters`.
// We can treat the waiter as notified and remove it from the list, as
// it would have been notified in the `notify_waiters` call anyways.

w.waker = None;

// Safety: we hold the lock, so we have an exclusive access to the list.
// The list is used in `notify_waiters`, so it must be guarded.
unsafe { waiters.remove(NonNull::new_unchecked(w)) };

*state = Done;
} else {
Expand Down Expand Up @@ -906,7 +1010,7 @@ impl Drop for Notified<'_> {
use State::*;

// Safety: The type only transitions to a "Waiting" state when pinned.
let (notify, state, waiter) = unsafe { Pin::new_unchecked(self).project() };
let (notify, state, _, waiter) = unsafe { Pin::new_unchecked(self).project() };

// This is where we ensure safety. The `Notified` value is being
// dropped, which means we must ensure that the waiter entry is no
Expand All @@ -917,8 +1021,10 @@ impl Drop for Notified<'_> {

// remove the entry from the list (if not already removed)
//
// safety: the waiter is only added to `waiters` by virtue of it
// being the only `LinkedList` available to the type.
// Safety: we hold the lock, so we have an exclusive access to every list the
// waiter may be contained in. If the node is not contained in the `waiters`
// list, then it is contained by a guarded list used by `notify_waiters` and
// in such case it must be a middle node.
unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) };

if waiters.is_empty() && get_state(notify_state) == WAITING {
Expand Down
Loading