Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

task/future: support spawning locally #24

Merged
merged 7 commits into from
Jan 2, 2020
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 123 additions & 42 deletions src/task/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
use crate::pool::{Local, Remote};
use crate::queue::Extras;

use std::borrow::Cow;
use std::cell::{Cell, UnsafeCell};
use std::future::Future;
use std::mem::ManuallyDrop;
Expand All @@ -18,12 +19,16 @@ use std::{fmt, mem};
/// details.
const DEFAULT_REPOLL_LIMIT: usize = 5;

struct TaskExtras {
extras: Extras,
remote: Option<Remote<TaskCell>>,
}

/// A [`Future`] task.
pub struct Task {
status: AtomicU8,
extras: UnsafeCell<TaskExtras>,
future: UnsafeCell<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
remote: Remote<TaskCell>,
extras: UnsafeCell<Extras>,
}

/// A [`Future`] task cell.
Expand Down Expand Up @@ -54,23 +59,21 @@ const COMPLETED: u8 = 4;

impl TaskCell {
/// Creates a [`Future`] task cell that is ready to be polled.
pub fn new<F: Future<Output = ()> + Send + 'static>(
future: F,
remote: Remote<TaskCell>,
extras: Extras,
) -> Self {
pub fn new<F: Future<Output = ()> + Send + 'static>(future: F, extras: Extras) -> Self {
TaskCell(Arc::new(Task {
status: AtomicU8::new(NOTIFIED),
future: UnsafeCell::new(Box::pin(future)),
remote,
extras: UnsafeCell::new(extras),
extras: UnsafeCell::new(TaskExtras {
extras,
remote: None,
}),
}))
}
}

impl crate::queue::TaskCell for TaskCell {
fn mut_extras(&mut self) -> &mut Extras {
unsafe { &mut *self.0.extras.get() }
unsafe { &mut (*self.0.extras.get()).extras }
}
}

Expand All @@ -85,6 +88,15 @@ unsafe fn waker(task: *const Task) -> Waker {
#[inline]
unsafe fn clone_raw(this: *const ()) -> RawWaker {
let task_cell = clone_task(this as *const Task);
let extras = { &mut *task_cell.0.extras.get() };
Copy link
Contributor

@sticnarf sticnarf Dec 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use an immutable reference here and get the mutable reference only at L98 to make the program sound.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it unsound?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two wakers in different threads concurrently, then there can be two multable references to the extras at the same time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean clone_raw can be executed concurrently due to L100, then it's unsound either to move it to L98. Lock is required.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, you're right. Executing L98 and L91 concurrently is unsound.

if extras.remote.is_none() {
// `Future` guarantees that waker has to be cloned before getting out
// of the thread pool scope. And `Runner` guarantees `LOCAL` is
// initialized whenever the future is polled in the scope.
LOCAL.with(|l| {
extras.remote = Some((&*l.get()).remote());
})
}
RawWaker::new(
Arc::into_raw(task_cell.0) as *const (),
&RawWakerVTable::new(clone_raw, wake_raw, wake_ref_raw, drop_raw),
Expand All @@ -96,25 +108,26 @@ unsafe fn drop_raw(this: *const ()) {
drop(task_cell(this as *const Task))
}

unsafe fn wake_impl(task_cell: &TaskCell) {
let task = &task_cell.0;
let mut status = task.status.load(SeqCst);
unsafe fn wake_impl(task: Cow<'_, Arc<Task>>) {
let mut status = task.as_ref().status.load(SeqCst);
loop {
match status {
IDLE => {
match task
.as_ref()
.status
.compare_exchange_weak(IDLE, NOTIFIED, SeqCst, SeqCst)
{
Ok(_) => {
task.remote.spawn(clone_task(&**task));
wake_task(task, false);
break;
}
Err(cur) => status = cur,
}
}
POLLING => {
match task
.as_ref()
.status
.compare_exchange_weak(POLLING, NOTIFIED, SeqCst, SeqCst)
{
Expand All @@ -130,13 +143,13 @@ unsafe fn wake_impl(task_cell: &TaskCell) {
#[inline]
unsafe fn wake_raw(this: *const ()) {
let task_cell = task_cell(this as *const Task);
wake_impl(&task_cell);
wake_impl(Cow::Owned(task_cell.0));
}

#[inline]
unsafe fn wake_ref_raw(this: *const ()) {
let task_cell = ManuallyDrop::new(task_cell(this as *const Task));
wake_impl(&task_cell);
wake_impl(Cow::Borrowed(&task_cell.0));
}

#[inline]
Expand All @@ -151,6 +164,47 @@ unsafe fn clone_task(task: *const Task) -> TaskCell {
task_cell
}

thread_local! {
/// Local queue reference that is set before polling and unset after polled.
static LOCAL: Cell<*mut Local<TaskCell>> = Cell::new(std::ptr::null_mut());
}

unsafe fn wake_task(task: Cow<'_, Arc<Task>>, reschedule: bool) {
LOCAL.with(|ptr| {
if ptr.get().is_null() {
// It's out of polling process, has to be spawn to global queue.
// It needs to clone to make it safe as it's unclear whether `self`
// is still used inside method `spawn` after `TaskCell` is dropped.
Copy link
Contributor

@sticnarf sticnarf Dec 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking, there is still some risk because Waker is accidentally Sync. rust-lang/rust#66481
It means the user might be able to move the reference to other threads (for example, using crossbeam::scoped`) and break our assumption.

However, I think it can be ignored. Maybe we can give some documentation and use expect to give some messages?

(*task.as_ref().extras.get())
.remote
.as_ref()
.unwrap()
.spawn(TaskCell(task.clone().into_owned()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I still don't understand why we need a clone here...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TaskCell can be dropped inside method spawn, which can make self invalid.

} else if reschedule {
// It's requested explicitly to schedule to global queue.
(*ptr.get()).spawn_remote(TaskCell(task.into_owned()));
} else {
// Otherwise spawns to local queue for best locality.
(*ptr.get()).spawn(TaskCell(task.into_owned()));
}
})
}

struct Scope<'a>(&'a mut Local<TaskCell>);

impl<'a> Scope<'a> {
fn new(l: &'a mut Local<TaskCell>) -> Scope<'a> {
LOCAL.with(|c| c.set(l));
Scope(l)
}
}

impl<'a> Drop for Scope<'a> {
fn drop(&mut self) {
LOCAL.with(|c| c.set(std::ptr::null_mut()));
}
}

/// [`Future`] task runner.
#[derive(Clone)]
pub struct Runner {
Expand Down Expand Up @@ -182,7 +236,8 @@ thread_local! {
impl crate::pool::Runner for Runner {
type TaskCell = TaskCell;

fn handle(&mut self, _local: &mut Local<TaskCell>, task_cell: TaskCell) -> bool {
fn handle(&mut self, local: &mut Local<TaskCell>, task_cell: TaskCell) -> bool {
let _scope = Scope::new(local);
let task = task_cell.0;
unsafe {
let waker = ManuallyDrop::new(waker(&*task));
Expand All @@ -197,10 +252,9 @@ impl crate::pool::Runner for Runner {
match task.status.compare_exchange(POLLING, IDLE, SeqCst, SeqCst) {
Ok(_) => return false,
Err(NOTIFIED) => {
if repoll_times >= self.repoll_limit
|| NEED_RESCHEDULE.with(|r| r.replace(false))
{
task.remote.spawn(clone_task(&*task));
let need_reschedule = NEED_RESCHEDULE.with(|r| r.replace(false));
if repoll_times >= self.repoll_limit || need_reschedule {
wake_task(Cow::Owned(task), need_reschedule);
return false;
} else {
repoll_times += 1;
Expand All @@ -213,7 +267,7 @@ impl crate::pool::Runner for Runner {
}
}

/// Gives up a timeslice to the task scheduler.
/// Gives up a time slice to the task scheduler.
///
/// It is only guaranteed to work in yatp.
pub async fn reschedule() {
Expand Down Expand Up @@ -321,11 +375,9 @@ mod tests {
WakeLater::new(waker_tx.clone()).await;
res_tx.send(2).unwrap();
};
local.remote.spawn(TaskCell::new(
fut,
local.remote.clone(),
Extras::single_level(),
));
local
.remote
.spawn(TaskCell::new(fut, Extras::single_level()));

local.handle_once();
assert_eq!(res_rx.recv().unwrap(), 1);
Expand Down Expand Up @@ -386,11 +438,9 @@ mod tests {
PendingOnce::new().await;
res_tx.send(2).unwrap();
};
local.remote.spawn(TaskCell::new(
fut,
local.remote.clone(),
Extras::single_level(),
));
local
.remote
.spawn(TaskCell::new(fut, Extras::single_level()));

local.handle_once();
assert_eq!(res_rx.recv().unwrap(), 1);
Expand All @@ -411,11 +461,9 @@ mod tests {
PendingOnce::new().await;
res_tx.send(4).unwrap();
};
local.remote.spawn(TaskCell::new(
fut,
local.remote.clone(),
Extras::single_level(),
));
local
.remote
.spawn(TaskCell::new(fut, Extras::single_level()));

local.handle_once();
assert_eq!(res_rx.recv().unwrap(), 1);
Expand All @@ -427,29 +475,62 @@ mod tests {
assert_eq!(res_rx.recv().unwrap(), 4);
}

struct ForwardWaker {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's identical to WakeLater?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. Didn't notice it before.

first_poll: bool,
tx: mpsc::Sender<Waker>,
}

impl Future for ForwardWaker {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.first_poll {
self.first_poll = false;
self.tx.send(cx.waker().clone()).unwrap();
Poll::Pending
} else {
Poll::Ready(())
}
}
}

#[test]
fn test_reschedule() {
let mut local = MockLocal::default();
let (res_tx, res_rx) = mpsc::channel();
let (waker_tx, waker_rx) = mpsc::channel();

let fut = async move {
res_tx.send(1).unwrap();
reschedule().await;
res_tx.send(2).unwrap();
PendingOnce::new().await;
res_tx.send(3).unwrap();
ForwardWaker {
first_poll: true,
tx: waker_tx,
}
.await;
res_tx.send(4).unwrap();
};
local.remote.spawn(TaskCell::new(
fut,
local.remote.clone(),
Extras::single_level(),
));
local
.remote
.spawn(TaskCell::new(fut, Extras::single_level()));

local.handle_once();
assert_eq!(res_rx.recv().unwrap(), 1);
assert!(res_rx.try_recv().is_err());
local.handle_once();
assert_eq!(res_rx.recv().unwrap(), 2);
assert_eq!(res_rx.recv().unwrap(), 3);
assert!(res_rx.try_recv().is_err());

// `ForwardWaker` has not been notified yet, `handle_once` should
// handle nothing.
local.handle_once();
assert!(res_rx.try_recv().is_err());
let waker = waker_rx.try_recv().unwrap();
waker.wake();
local.handle_once();
assert_eq!(res_rx.try_recv().unwrap(), 4);
}
}