Skip to content

Commit

Permalink
io: make copy cooperative (#6265)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rustin170506 committed Jan 6, 2024
1 parent 9780bf4 commit 3275cfb
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 2 deletions.
75 changes: 73 additions & 2 deletions tokio/src/io/util/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ impl CopyBuffer {
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
ready!(crate::trace::trace_leaf(cx));
#[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
// Keep track of task budget
let coop = ready!(crate::runtime::coop::poll_proceed(cx));
loop {
// If our buffer is empty, then we need to read some data to
// continue.
Expand All @@ -90,13 +103,49 @@ impl CopyBuffer {
self.cap = 0;

match self.poll_fill_buf(cx, reader.as_mut()) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Ready(Ok(())) => {
#[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
coop.made_progress();
}
Poll::Ready(Err(err)) => {
#[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
coop.made_progress();
return Poll::Ready(Err(err));
}
Poll::Pending => {
// Try flushing when the reader has no progress to avoid deadlock
// when the reader depends on buffered writer.
if self.need_flush {
ready!(writer.as_mut().poll_flush(cx))?;
#[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
coop.made_progress();
self.need_flush = false;
}

Expand All @@ -108,6 +157,17 @@ impl CopyBuffer {
// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
#[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
coop.made_progress();
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
Expand All @@ -132,6 +192,17 @@ impl CopyBuffer {
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?;
#[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
coop.made_progress();
return Poll::Ready(Ok(self.amt));
}
}
Expand Down
15 changes: 15 additions & 0 deletions tokio/tests/io_copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,18 @@ async fn proxy() {

assert_eq!(n, 1024);
}

#[tokio::test]
async fn copy_is_cooperative() {
tokio::select! {
biased;
_ = async {
loop {
let mut reader: &[u8] = b"hello";
let mut writer: Vec<u8> = vec![];
let _ = io::copy(&mut reader, &mut writer).await;
}
} => {},
_ = tokio::task::yield_now() => {}
}
}
25 changes: 25 additions & 0 deletions tokio/tests/io_copy_bidirectional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,28 @@ async fn immediate_exit_on_read_error() {

assert!(copy_bidirectional(&mut a, &mut b).await.is_err());
}

#[tokio::test]
async fn copy_bidirectional_is_cooperative() {
tokio::select! {
biased;
_ = async {
loop {
let payload = b"here, take this";

let mut a = tokio_test::io::Builder::new()
.read(payload)
.write(payload)
.build();

let mut b = tokio_test::io::Builder::new()
.read(payload)
.write(payload)
.build();

let _ = copy_bidirectional(&mut a, &mut b).await;
}
} => {},
_ = tokio::task::yield_now() => {}
}
}

0 comments on commit 3275cfb

Please sign in to comment.