Skip to content

Commit

Permalink
Allow reads concurrent to pending writes.
Browse files Browse the repository at this point in the history
Thereby propagate pending write back-pressure via the
paused command channels, rather than by "blocking"
`Connection::next()` itself until a write completes.

Notably, no new buffers are introduced. When a frame write
cannot complete, command channels are paused and I/O
reads can continue while the write is pending. The paused
command channels exercise write back-pressure on the streams
and API and ensure that the only frames we still try
to send are those as a result of reading a frame - these
then indeed wait for completion of the prior pending
send operation, if any, but since it is done as a result
of reading a frame, the remote can in turn write again,
should it have been waiting to be able to do so before
it in turn can read again. Unexpected write deadlocks
of peers which otherwise read & write concurrently
from substreams can thus be avoided.
  • Loading branch information
Roman S. Borschel committed Feb 11, 2021
1 parent 4241ea4 commit 24d5464
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 55 deletions.
174 changes: 141 additions & 33 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ use futures::{
stream::{Fuse, FusedStream}
};
use nohash_hasher::IntMap;
use std::{fmt, sync::Arc, task::{Context, Poll}};
use std::{fmt, io, sync::Arc, task::{Context, Poll}};

pub use control::Control;
pub use stream::{Packet, State, Stream};
Expand Down Expand Up @@ -168,7 +168,7 @@ pub struct Connection<T> {
control_sender: mpsc::Sender<ControlCommand>,
control_receiver: Pausable<mpsc::Receiver<ControlCommand>>,
stream_sender: mpsc::Sender<StreamCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
stream_receiver: Pausable<mpsc::Receiver<StreamCommand>>,
garbage: Vec<StreamId>, // see `Connection::garbage_collect()`
shutdown: Shutdown,
is_closed: bool
Expand Down Expand Up @@ -282,7 +282,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
control_sender,
control_receiver: Pausable::new(control_receiver),
stream_sender,
stream_receiver,
stream_receiver: Pausable::new(stream_receiver),
next_id: match mode {
Mode::Client => 1,
Mode::Server => 2
Expand Down Expand Up @@ -348,7 +348,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {

// Close and drain the stream command receiver.
if !self.stream_receiver.is_terminated() {
self.stream_receiver.close();
self.stream_receiver.stream().close();
while let Some(_cmd) = self.stream_receiver.next().await {
// drop it
}
Expand Down Expand Up @@ -382,12 +382,37 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {

let mut num_terminated = 0;

let mut next_inbound_frame =
// Determine if the stream command receiver is paused, which
// is only the case if we have pending writes, i.e. back-pressure
// from the underlying connection.
let stream_receiver_paused = self.stream_receiver.is_paused();

let mut next_io_event =
if self.socket.is_terminated() {
num_terminated += 1;
Either::Left(future::pending())
} else {
Either::Right(self.socket.try_next().err_into())
let socket = &mut self.socket;
let io = future::poll_fn(move |cx| {
// Progress writing.
match socket.get_mut().poll_send::<()>(cx, None) {
frame::PollSend::Pending(_) => {}
frame::PollSend::Ready(res) => {
res.or(Err(ConnectionError::Closed))?;
if stream_receiver_paused {
return Poll::Ready(Result::Ok(IoEvent::OutboundReady))
}
}
}
// Progress reading.
let next_frame = match futures::ready!(socket.poll_next_unpin(cx)) {
None => Ok(None),
Some(Err(e)) => Err(e.into()),
Some(Ok(f)) => Ok(Some(f))
};
Poll::Ready(Ok(IoEvent::Inbound(next_frame)))
});
Either::Right(io)
};

let mut next_stream_command =
Expand Down Expand Up @@ -415,14 +440,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
future::poll_fn(move |cx: &mut Context| {
let a = next_stream_command.poll_unpin(cx);
let b = next_control_command.poll_unpin(cx);
let c = next_inbound_frame.poll_unpin(cx);
let c = next_io_event.poll_unpin(cx);
if a.is_pending() && b.is_pending() && c.is_pending() {
return Poll::Pending
}
Poll::Ready((a, b, c))
});

let (stream_command, control_command, inbound_frame) = next_item.await;
let (stream_command, control_command, io_event) = next_item.await;

if let Poll::Ready(cmd) = control_command {
self.on_control_command(cmd).await?
Expand All @@ -432,14 +457,25 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
self.on_stream_command(cmd).await?
}

if let Poll::Ready(frame) = inbound_frame {
if let Some(stream) = self.on_frame(frame).await? {
self.socket.get_mut().flush().await.or(Err(ConnectionError::Closed))?;
return Ok(Some(stream))
match io_event? {
Poll::Ready(IoEvent::OutboundReady) => {
self.stream_receiver.unpause();
// Only unpause the control command receiver if we're not
// shutting down already.
if let Shutdown::NotStarted = self.shutdown {
self.control_receiver.unpause();
}
}
Poll::Ready(IoEvent::Inbound(frame)) => {
if let Some(stream) = self.on_frame(frame).await? {
self.flush_nowait().await.or(Err(ConnectionError::Closed))?;
return Ok(Some(stream))
}
}
Poll::Pending => {}
}

self.socket.get_mut().flush().await.or(Err(ConnectionError::Closed))?
self.flush_nowait().await.or(Err(ConnectionError::Closed))?;
}
}

Expand Down Expand Up @@ -467,8 +503,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
if extra_credit > 0 {
let mut frame = Frame::window_update(id, extra_credit);
frame.header_mut().syn();
log::trace!("{}: sending initial {}", self.id, frame.header());
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
log::trace!("{}/{}: sending initial {}", self.id, id, frame.header());
self.send(frame).await.or(Err(ConnectionError::Closed))?
}
let stream = {
let config = self.config.clone();
Expand All @@ -489,7 +525,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
let mut header = Header::data(id, 0);
header.rst();
let frame = Frame::new(header);
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
log::trace!("{}/{}: sending reset", self.id, id);
self.send(frame).await.or(Err(ConnectionError::Closed))?
}
}
}
Expand All @@ -499,12 +536,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
let _ = reply.send(());
return Ok(())
}
// Handle initial close command.
// Handle initial close command by pausing the control command
// receiver and closing the stream command receiver. I.e. we
// wait for the stream commands to drain.
debug_assert!(self.shutdown.has_not_started());
self.shutdown = Shutdown::InProgress(reply);
log::trace!("{}: shutting down connection", self.id);
self.control_receiver.pause();
self.stream_receiver.close()
self.stream_receiver.stream().close()
}
None => {
// We only get here after the whole connection shutdown is complete.
Expand All @@ -518,31 +557,35 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
Ok(())
}

async fn send(&mut self, f: impl Into<Frame<()>>) -> Result<()> {
send(self.id, &mut self.socket, &mut self.stream_receiver, &mut self.control_receiver, f).await
}

/// Process a command from one of our `Stream`s.
async fn on_stream_command(&mut self, cmd: Option<StreamCommand>) -> Result<()> {
match cmd {
Some(StreamCommand::SendFrame(frame)) => {
log::trace!("{}: sending: {}", self.id, frame.header());
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
log::trace!("{}/{}: sending: {}", self.id, frame.header().stream_id(), frame.header());
self.send(frame).await.or(Err(ConnectionError::Closed))?
}
Some(StreamCommand::CloseStream { id, ack }) => {
log::trace!("{}: closing stream {} of {}", self.id, id, self);
log::trace!("{}/{}: sending close", self.id, id);
let mut header = Header::data(id, 0);
header.fin();
if ack { header.ack() }
let frame = Frame::new(header);
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
self.send(frame).await.or(Err(ConnectionError::Closed))?
}
None => {
// We only get to this point when `self.stream_receiver`
// was closed which only happens in response to a close control
// command. Now that we are at the end of the stream command queue,
// we send the final term frame to the remote and complete the
// closure.
// closure by closing the already paused control command receiver.
debug_assert!(self.shutdown.is_in_progress());
log::debug!("{}: closing {}", self.id, self);
log::debug!("{}: sending term", self.id);
let frame = Frame::term();
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?;
self.send(frame).await.or(Err(ConnectionError::Closed))?;
let shutdown = std::mem::replace(&mut self.shutdown, Shutdown::Complete);
if let Shutdown::InProgress(tx) = shutdown {
// Inform the `Control` that initiated the shutdown.
Expand Down Expand Up @@ -578,25 +621,25 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
log::trace!("{}: new inbound {} of {}", self.id, stream, self);
if let Some(f) = update {
log::trace!("{}/{}: sending update", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
self.send(f).await.or(Err(ConnectionError::Closed))?
}
return Ok(Some(stream))
}
Action::Update(f) => {
log::trace!("{}/{}: sending update", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
self.send(f).await.or(Err(ConnectionError::Closed))?
}
Action::Ping(f) => {
log::trace!("{}/{}: pong", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
self.send(f).await.or(Err(ConnectionError::Closed))?
}
Action::Reset(f) => {
log::trace!("{}/{}: sending reset", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
self.send(f).await.or(Err(ConnectionError::Closed))?
}
Action::Terminate(f) => {
log::trace!("{}: sending term", self.id);
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
self.send(f).await.or(Err(ConnectionError::Closed))?
}
}
Ok(None)
Expand All @@ -605,7 +648,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
log::debug!("{}: socket eof", self.id);
Err(ConnectionError::Closed)
}
Err(e) if e.io_kind() == Some(std::io::ErrorKind::ConnectionReset) => {
Err(e) if e.io_kind() == Some(io::ErrorKind::ConnectionReset) => {
log::debug!("{}: connection reset", self.id);
Err(ConnectionError::Closed)
}
Expand Down Expand Up @@ -837,6 +880,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
}
}

/// Try to flush the underlying I/O stream, without waiting for it.
async fn flush_nowait(&mut self) -> Result<()> {
future::poll_fn(|cx| {
let _ = self.socket.get_mut().poll_flush(cx)?;
Poll::Ready(Ok(()))
}).await
}

/// Remove stale streams and send necessary messages to the remote.
///
/// If we ever get async destructors we can replace this with streams
Expand Down Expand Up @@ -902,8 +953,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
frame
};
if let Some(f) = frame {
log::trace!("{}: sending: {}", self.id, f.header());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
log::trace!("{}/{}: sending final frame: {}", self.id, stream_id, f.header());
send(conn_id, &mut self.socket, &mut self.stream_receiver,
&mut self.control_receiver, f).await.or(Err(ConnectionError::Closed))?
}
self.garbage.push(stream_id)
}
Expand Down Expand Up @@ -936,6 +988,62 @@ impl<T> Drop for Connection<T> {
}
}

/// Events related to reading from or writing to the underlying socket.
enum IoEvent {
/// A new inbound frame arrived.
Inbound(Result<Option<Frame<()>>>),
/// We can continue sending frames after having to pause earlier.
OutboundReady,
}

/// Sends a frame on the given `io` stream.
///
/// If the frame is taken by `io` but cannot be fully sent, the command
/// receivers are paused, without waiting for completion.
///
/// If a prior send operation is still pending, waits for its completion.
async fn send<T: AsyncRead + AsyncWrite + Unpin>(
id: Id,
io: &mut Fuse<frame::Io<T>>,
stream_receiver: &mut Pausable<mpsc::Receiver<StreamCommand>>,
control_receiver: &mut Pausable<mpsc::Receiver<ControlCommand>>,
frame: impl Into<Frame<()>>
) -> Result<()> {
let mut frame = Some(frame.into());
future::poll_fn(move |cx| {
let next = frame.take().expect("Frame has not yet been taken by `io`.");
match io.get_mut().poll_send(cx, Some(next)) {
frame::PollSend::Pending(Some(f)) => {
debug_assert!(stream_receiver.is_paused());
log::debug!("{}: send: Prior write pending. Waiting.", id);
frame = Some(f);
return Poll::Pending
}
frame::PollSend::Pending(None) => {
// The frame could not yet fully be written to the underlying
// socket, so we pause the processing of commands in order to
// pause writing while still being able to read from the socket.
// The only frames that may still be sent while commands are paused
// are as a reaction to frames being read, which in turn allows
// the remote to make progress eventually, if it should
// currently be blocked on writing. In this way unnecessary
// deadlocks between peers blocked on writing are avoided.
log::trace!("{}: send: Write pending. Continuing with paused command streams.", id);
stream_receiver.pause();
control_receiver.pause();
return Poll::Ready(Ok(()))
}
frame::PollSend::Ready(Err(e)) => {
return Poll::Ready(Err(e.into()))
}
frame::PollSend::Ready(Ok(())) => {
// Note: We leave the unpausing of the command streams to `Connection::next()`.
return Poll::Ready(Ok(()))
}
}
}).await
}

/// Turn a Yamux [`Connection`] into a [`futures::Stream`].
pub fn into_stream<T>(c: Connection<T>) -> impl futures::stream::Stream<Item = Result<Stream>>
where
Expand Down
12 changes: 10 additions & 2 deletions src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use futures::future::Either;
use header::{Header, StreamId, Data, WindowUpdate, GoAway, Ping};
use std::{convert::TryInto, num::TryFromIntError};

pub(crate) use io::Io;
pub(crate) use io::{Io, PollSend};
pub use io::FrameDecodeError;

/// A Yamux message frame consisting of header and body.
Expand Down Expand Up @@ -49,6 +49,15 @@ impl<T> Frame<T> {
}
}

impl<A: header::private::Sealed> From<Frame<A>> for Frame<()> {
fn from(f: Frame<A>) -> Frame<()> {
Frame {
header: f.header.into(),
body: f.body
}
}
}

impl Frame<()> {
pub(crate) fn into_data(self) -> Frame<Data> {
Frame { header: self.header.into_data(), body: self.body }
Expand Down Expand Up @@ -117,4 +126,3 @@ impl Frame<GoAway> {
}
}
}

9 changes: 8 additions & 1 deletion src/frame/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ impl<T> Header<T> {
}
}

impl<A: private::Sealed> From<Header<A>> for Header<()> {
fn from(h: Header<A>) -> Header<()> {
h.cast()
}
}

impl Header<()> {
pub(crate) fn into_data(self) -> Header<Data> {
debug_assert_eq!(self.tag, Tag::Data);
Expand Down Expand Up @@ -242,12 +248,13 @@ pub trait HasRst: private::Sealed {}
impl HasRst for Data {}
impl HasRst for WindowUpdate {}

mod private {
pub(super) mod private {
pub trait Sealed {}

impl Sealed for super::Data {}
impl Sealed for super::WindowUpdate {}
impl Sealed for super::Ping {}
impl Sealed for super::GoAway {}
impl<A: Sealed, B: Sealed> Sealed for super::Either<A, B> {}
}

Expand Down
Loading

0 comments on commit 24d5464

Please sign in to comment.