Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow reads concurrent to pending writes. #112

Merged
merged 15 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ readme = "README.md"
edition = "2018"

[dependencies]
futures = { version = "0.3.4", default-features = false, features = ["std"] }
futures = { version = "0.3.12", default-features = false, features = ["std"] }
log = "0.4.8"
nohash-hasher = "0.2"
parking_lot = "0.11"
Expand Down
143 changes: 101 additions & 42 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ use futures::{
channel::{mpsc, oneshot},
future::{self, Either},
prelude::*,
stream::{Fuse, FusedStream}
stream::{Fuse, FusedStream},
sink::SinkExt,
};
use nohash_hasher::IntMap;
use std::{fmt, sync::Arc, task::{Context, Poll}};
use std::{fmt, io, sync::Arc, task::{Context, Poll}};

pub use control::Control;
pub use stream::{Packet, State, Stream};
Expand Down Expand Up @@ -370,10 +371,47 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
loop {
self.garbage_collect().await?;

// For each channel and the socket we create a future that gets
// the next item. We will poll each future and if any one of them
// yields an item, we return the tuple of poll results which are
// then all processed.
// Wait for the frame sink to be ready or, if there is a pending
// write, for an incoming frame. I.e. as long as there is a pending
// write, we only read, unless a read results in needing to send a
// frame, in which case we must wait for the pending write to
// complete. When the frame sink is ready, we can proceed with
// waiting for a new stream or control command or another inbound
// frame.
let next_io_event = if self.socket.is_terminated() {
Either::Left(future::pending())
} else {
let socket = &mut self.socket;
let io = future::poll_fn(move |cx| {
if let Poll::Ready(res) = socket.poll_ready_unpin(cx) {
res.or(Err(ConnectionError::Closed))?;
return Poll::Ready(Result::Ok(IoEvent::OutboundReady))
}

// At this point we know the socket sink has a pending
// write, so we try to read the next frame instead.
let next_frame = futures::ready!(socket.poll_next_unpin(cx))
.transpose()
.map_err(ConnectionError::from);
Poll::Ready(Ok(IoEvent::Inbound(next_frame)))
});
Either::Right(io)
};

if let IoEvent::Inbound(frame) = next_io_event.await? {
if let Some(stream) = self.on_frame(frame).await? {
self.flush_nowait().await.or(Err(ConnectionError::Closed))?;
return Ok(Some(stream))
}
continue // The socket sink still has a pending write.
}

// Getting this far implies that the socket is ready to accept
// a new frame, so we can now listen for new commands while waiting
// for the next inbound frame. To that end, for each channel and the
// socket we create a future that gets the next item. We will poll
// each future and if any one of them yields an item, we return the
// tuple of poll results which are then all processed.
//
// For terminated sources we create non-finishing futures.
// This guarantees that if the remaining futures are pending
Expand All @@ -382,13 +420,12 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {

let mut num_terminated = 0;

let mut next_inbound_frame =
if self.socket.is_terminated() {
num_terminated += 1;
Either::Left(future::pending())
} else {
Either::Right(self.socket.try_next().err_into())
};
let mut next_frame = if self.socket.is_terminated() {
num_terminated += 1;
Either::Left(future::pending())
} else {
Either::Right(self.socket.next())
};

let mut next_stream_command =
if self.stream_receiver.is_terminated() {
Expand All @@ -415,14 +452,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
future::poll_fn(move |cx: &mut Context| {
let a = next_stream_command.poll_unpin(cx);
let b = next_control_command.poll_unpin(cx);
let c = next_inbound_frame.poll_unpin(cx);
let c = next_frame.poll_unpin(cx);
if a.is_pending() && b.is_pending() && c.is_pending() {
return Poll::Pending
}
Poll::Ready((a, b, c))
});

let (stream_command, control_command, inbound_frame) = next_item.await;
let (stream_command, control_command, frame) = next_item.await;

if let Poll::Ready(cmd) = control_command {
self.on_control_command(cmd).await?
Expand All @@ -432,14 +469,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
self.on_stream_command(cmd).await?
}

if let Poll::Ready(frame) = inbound_frame {
if let Some(stream) = self.on_frame(frame).await? {
self.socket.get_mut().flush().await.or(Err(ConnectionError::Closed))?;
if let Poll::Ready(frame) = frame {
if let Some(stream) = self.on_frame(frame.transpose().map_err(Into::into)).await? {
self.flush_nowait().await.or(Err(ConnectionError::Closed))?;
return Ok(Some(stream))
}
}

self.socket.get_mut().flush().await.or(Err(ConnectionError::Closed))?
self.flush_nowait().await.or(Err(ConnectionError::Closed))?;
}
}

Expand Down Expand Up @@ -467,8 +504,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
if extra_credit > 0 {
let mut frame = Frame::window_update(id, extra_credit);
frame.header_mut().syn();
log::trace!("{}: sending initial {}", self.id, frame.header());
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
log::trace!("{}/{}: sending initial {}", self.id, id, frame.header());
self.socket.feed(frame.into()).await.or(Err(ConnectionError::Closed))?
}
let stream = {
let config = self.config.clone();
Expand All @@ -489,7 +526,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
let mut header = Header::data(id, 0);
header.rst();
let frame = Frame::new(header);
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
log::trace!("{}/{}: sending reset", self.id, id);
self.socket.feed(frame.into()).await.or(Err(ConnectionError::Closed))?
}
}
}
Expand All @@ -499,7 +537,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
let _ = reply.send(());
return Ok(())
}
// Handle initial close command.
// Handle initial close command by pausing the control command
// receiver and closing the stream command receiver. I.e. we
// wait for the stream commands to drain.
debug_assert!(self.shutdown.has_not_started());
self.shutdown = Shutdown::InProgress(reply);
log::trace!("{}: shutting down connection", self.id);
Expand All @@ -522,27 +562,27 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
async fn on_stream_command(&mut self, cmd: Option<StreamCommand>) -> Result<()> {
match cmd {
Some(StreamCommand::SendFrame(frame)) => {
log::trace!("{}: sending: {}", self.id, frame.header());
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
log::trace!("{}/{}: sending: {}", self.id, frame.header().stream_id(), frame.header());
self.socket.feed(frame.into()).await.or(Err(ConnectionError::Closed))?
}
Some(StreamCommand::CloseStream { id, ack }) => {
log::trace!("{}: closing stream {} of {}", self.id, id, self);
log::trace!("{}/{}: sending close", self.id, id);
let mut header = Header::data(id, 0);
header.fin();
if ack { header.ack() }
let frame = Frame::new(header);
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
self.socket.feed(frame.into()).await.or(Err(ConnectionError::Closed))?
}
None => {
// We only get to this point when `self.stream_receiver`
// was closed which only happens in response to a close control
// command. Now that we are at the end of the stream command queue,
// we send the final term frame to the remote and complete the
// closure.
// closure by closing the already paused control command receiver.
debug_assert!(self.shutdown.is_in_progress());
log::debug!("{}: closing {}", self.id, self);
log::debug!("{}: sending term", self.id);
let frame = Frame::term();
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?;
self.socket.feed(frame.into()).await.or(Err(ConnectionError::Closed))?;
let shutdown = std::mem::replace(&mut self.shutdown, Shutdown::Complete);
if let Shutdown::InProgress(tx) = shutdown {
// Inform the `Control` that initiated the shutdown.
Expand Down Expand Up @@ -578,25 +618,25 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
log::trace!("{}: new inbound {} of {}", self.id, stream, self);
if let Some(f) = update {
log::trace!("{}/{}: sending update", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
self.socket.feed(f.into()).await.or(Err(ConnectionError::Closed))?
}
return Ok(Some(stream))
}
Action::Update(f) => {
log::trace!("{}/{}: sending update", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
log::trace!("{}: sending update: {:?}", self.id, f.header());
self.socket.feed(f.into()).await.or(Err(ConnectionError::Closed))?
}
Action::Ping(f) => {
log::trace!("{}/{}: pong", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
self.socket.feed(f.into()).await.or(Err(ConnectionError::Closed))?
}
Action::Reset(f) => {
log::trace!("{}/{}: sending reset", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
self.socket.feed(f.into()).await.or(Err(ConnectionError::Closed))?
}
Action::Terminate(f) => {
log::trace!("{}: sending term", self.id);
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
self.socket.feed(f.into()).await.or(Err(ConnectionError::Closed))?
}
}
Ok(None)
Expand All @@ -605,7 +645,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
log::debug!("{}: socket eof", self.id);
Err(ConnectionError::Closed)
}
Err(e) if e.io_kind() == Some(std::io::ErrorKind::ConnectionReset) => {
Err(e) if e.io_kind() == Some(io::ErrorKind::ConnectionReset) => {
log::debug!("{}: connection reset", self.id);
Err(ConnectionError::Closed)
}
Expand Down Expand Up @@ -712,7 +752,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
}
}
} else {
log::debug!("{}/{}: data for unknown stream, ignoring", self.id, stream_id);
log::trace!("{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}",
self.id, stream_id, frame);
// We do not consider this a protocol violation and thus do not send a stream reset
// because we may still be processing pending `StreamCommand`s of this stream that were
// sent before it has been dropped and "garbage collected". Such a stream reset would
Expand Down Expand Up @@ -782,7 +823,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
w.wake()
}
} else {
log::debug!("{}/{}: window update for unknown stream, ignoring", self.id, stream_id);
log::trace!("{}/{}: window update for unknown stream, possibly dropped earlier: {:?}",
self.id, stream_id, frame);
// We do not consider this a protocol violation and thus do not send a stream reset
// because we may still be processing pending `StreamCommand`s of this stream that were
// sent before it has been dropped and "garbage collected". Such a stream reset would
Expand All @@ -805,7 +847,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
hdr.ack();
return Action::Ping(Frame::new(hdr))
}
log::debug!("{}/{}: ping for unknown stream", self.id, stream_id);
log::trace!("{}/{}: ping for unknown stream, possibly dropped earlier: {:?}",
self.id, stream_id, frame);
// We do not consider this a protocol violation and thus do not send a stream reset because
// we may still be processing pending `StreamCommand`s of this stream that were sent before
// it has been dropped and "garbage collected". Such a stream reset would interfere with the
Expand Down Expand Up @@ -837,6 +880,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
}
}

/// Try to flush the underlying I/O stream, without waiting for it.
async fn flush_nowait(&mut self) -> Result<()> {
future::poll_fn(|cx| {
let _ = self.socket.get_mut().poll_flush_unpin(cx)?;
Poll::Ready(Ok(()))
}).await
}

/// Remove stale streams and send necessary messages to the remote.
///
/// If we ever get async destructors we can replace this with streams
Expand Down Expand Up @@ -902,8 +953,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
frame
};
if let Some(f) = frame {
log::trace!("{}: sending: {}", self.id, f.header());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
self.socket.feed(f.into()).await.or(Err(ConnectionError::Closed))?
}
self.garbage.push(stream_id)
}
Expand Down Expand Up @@ -936,6 +987,14 @@ impl<T> Drop for Connection<T> {
}
}

/// Events related to reading from or writing to the underlying socket.
enum IoEvent {
/// A new inbound frame arrived.
Inbound(Result<Option<Frame<()>>>),
/// We can continue sending frames.
OutboundReady,
}

/// Turn a Yamux [`Connection`] into a [`futures::Stream`].
pub fn into_stream<T>(c: Connection<T>) -> impl futures::stream::Stream<Item = Result<Stream>>
where
Expand Down
10 changes: 9 additions & 1 deletion src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ impl<T> Frame<T> {
}
}

impl<A: header::private::Sealed> From<Frame<A>> for Frame<()> {
fn from(f: Frame<A>) -> Frame<()> {
Frame {
header: f.header.into(),
body: f.body
}
}
}

impl Frame<()> {
pub(crate) fn into_data(self) -> Frame<Data> {
Frame { header: self.header.into_data(), body: self.body }
Expand Down Expand Up @@ -117,4 +126,3 @@ impl Frame<GoAway> {
}
}
}

9 changes: 8 additions & 1 deletion src/frame/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ impl<T> Header<T> {
}
}

impl<A: private::Sealed> From<Header<A>> for Header<()> {
fn from(h: Header<A>) -> Header<()> {
h.cast()
}
}

impl Header<()> {
pub(crate) fn into_data(self) -> Header<Data> {
debug_assert_eq!(self.tag, Tag::Data);
Expand Down Expand Up @@ -242,12 +248,13 @@ pub trait HasRst: private::Sealed {}
impl HasRst for Data {}
impl HasRst for WindowUpdate {}

mod private {
pub(super) mod private {
pub trait Sealed {}

impl Sealed for super::Data {}
impl Sealed for super::WindowUpdate {}
impl Sealed for super::Ping {}
impl Sealed for super::GoAway {}
impl<A: Sealed, B: Sealed> Sealed for super::Either<A, B> {}
}

Expand Down
Loading