Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rt: use optional non-zero value for task owner_id #5876

Merged
merged 5 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions tokio/src/runtime/task/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::runtime::task::state::State;
use crate::runtime::task::{Id, Schedule};
use crate::util::linked_list;

use std::num::NonZeroU64;
use std::pin::Pin;
use std::ptr::NonNull;
use std::task::{Context, Poll, Waker};
Expand Down Expand Up @@ -162,7 +163,7 @@ pub(crate) struct Header {
/// Table of function pointers for executing actions on the task.
pub(super) vtable: &'static Vtable,

/// This integer contains the id of the OwnedTasks or LocalOwnedTasks that
/// This non-zero integer contains the id of the OwnedTasks or LocalOwnedTasks that
/// this task is stored in. If the task is not in any list, should be the
/// id of the list that it was previously in, or zero if it has never been
/// in any list.
hds marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -173,7 +174,7 @@ pub(crate) struct Header {
/// The id is not unset when removed from a list because we want to be able
/// to read the id without synchronization, even if it is concurrently being
/// removed from the list.
pub(super) owner_id: UnsafeCell<u64>,
pub(super) owner_id: UnsafeCell<Option<NonZeroU64>>,

/// The tracing ID for this instrumented task.
#[cfg(all(tokio_unstable, feature = "tracing"))]
Expand Down Expand Up @@ -221,7 +222,7 @@ impl<T: Future, S: Schedule> Cell<T, S> {
state,
queue_next: UnsafeCell::new(None),
vtable,
owner_id: UnsafeCell::new(0),
owner_id: UnsafeCell::new(None),
#[cfg(all(tokio_unstable, feature = "tracing"))]
tracing_id,
}
Expand Down Expand Up @@ -394,13 +395,13 @@ impl Header {
}

// safety: The caller must guarantee exclusive access to this field, and
// must ensure that the id is either 0 or the id of the OwnedTasks
// must ensure that the id is either `None` or the id of the OwnedTasks
// containing this task.
pub(super) unsafe fn set_owner_id(&self, owner: u64) {
pub(super) unsafe fn set_owner_id(&self, owner: Option<NonZeroU64>) {
self.owner_id.with_mut(|ptr| *ptr = owner);
}
hds marked this conversation as resolved.
Show resolved Hide resolved

pub(super) fn get_owner_id(&self) -> u64 {
pub(super) fn get_owner_id(&self) -> Option<NonZeroU64> {
// safety: If there are concurrent writes, then that write has violated
// the safety requirements on `set_owner_id`.
unsafe { self.owner_id.with(|ptr| *ptr) }
Expand Down
37 changes: 14 additions & 23 deletions tokio/src/runtime/task/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task};
use crate::util::linked_list::{CountedLinkedList, Link, LinkedList};

use std::marker::PhantomData;
use std::num::NonZeroU64;

// The id from the module below is used to verify whether a given task is stored
// in this OwnedTasks, or some other task. The counter starts at one so we can
Expand All @@ -28,10 +29,10 @@ cfg_has_atomic_u64! {

static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);

fn get_next_id() -> u64 {
fn get_next_id() -> NonZeroU64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if id != 0 {
if let Some(id) = NonZeroU64::new(id) {
return id;
}
}
Expand All @@ -43,27 +44,27 @@ cfg_not_has_atomic_u64! {

static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);

fn get_next_id() -> u64 {
fn get_next_id() -> NonZeroU64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if id != 0 {
return u64::from(id);
if let Some(id) = NonZeroU64::new(id) {
hds marked this conversation as resolved.
Show resolved Hide resolved
return id;
}
}
}
}

pub(crate) struct OwnedTasks<S: 'static> {
inner: Mutex<CountedOwnedTasksInner<S>>,
id: u64,
id: NonZeroU64,
}
struct CountedOwnedTasksInner<S: 'static> {
list: CountedLinkedList<Task<S>, <Task<S> as Link>::Target>,
closed: bool,
}
pub(crate) struct LocalOwnedTasks<S: 'static> {
inner: UnsafeCell<OwnedTasksInner<S>>,
id: u64,
id: NonZeroU64,
_not_send_or_sync: PhantomData<*const ()>,
}
struct OwnedTasksInner<S: 'static> {
Expand Down Expand Up @@ -108,7 +109,7 @@ impl<S: 'static> OwnedTasks<S> {
unsafe {
// safety: We just created the task, so we have exclusive access
// to the field.
task.header().set_owner_id(self.id);
task.header().set_owner_id(Some(self.id));
}

let mut lock = self.inner.lock();
Expand All @@ -127,7 +128,7 @@ impl<S: 'static> OwnedTasks<S> {
/// a LocalNotified, giving the thread permission to poll this task.
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
assert_eq!(task.header().get_owner_id(), self.id);
assert_eq!(task.header().get_owner_id(), Some(self.id));

// safety: All tasks bound to this OwnedTasks are Send, so it is safe
// to poll it on this thread no matter what thread we are on.
Expand Down Expand Up @@ -170,11 +171,7 @@ impl<S: 'static> OwnedTasks<S> {
}

pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id();
if task_id == 0 {
// The task is unowned.
return None;
}
let task_id = task.header().get_owner_id()?;
hds marked this conversation as resolved.
Show resolved Hide resolved

assert_eq!(task_id, self.id);

Expand Down Expand Up @@ -228,7 +225,7 @@ impl<S: 'static> LocalOwnedTasks<S> {
unsafe {
// safety: We just created the task, so we have exclusive access
// to the field.
task.header().set_owner_id(self.id);
task.header().set_owner_id(Some(self.id));
}

if self.is_closed() {
Expand Down Expand Up @@ -257,11 +254,7 @@ impl<S: 'static> LocalOwnedTasks<S> {
}

pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id();
if task_id == 0 {
// The task is unowned.
return None;
}
let task_id = task.header().get_owner_id()?;
hds marked this conversation as resolved.
Show resolved Hide resolved

assert_eq!(task_id, self.id);

Expand All @@ -275,7 +268,7 @@ impl<S: 'static> LocalOwnedTasks<S> {
/// it to a LocalNotified, giving the thread permission to poll this task.
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
assert_eq!(task.header().get_owner_id(), self.id);
assert_eq!(task.header().get_owner_id(), Some(self.id));

// safety: The task was bound to this LocalOwnedTasks, and the
// LocalOwnedTasks is not Send or Sync, so we are on the right thread
Expand Down Expand Up @@ -315,11 +308,9 @@ mod tests {
#[test]
fn test_id_not_broken() {
let mut last_id = get_next_id();
assert_ne!(last_id, 0);

for _ in 0..1000 {
let next_id = get_next_id();
assert_ne!(next_id, 0);
assert!(last_id < next_id);
last_id = next_id;
}
Expand Down
Loading