Skip to content

Commit

Permalink
proxy: re-enabled vectored writes through our dynamic Io trait object. (
Browse files Browse the repository at this point in the history
#1167)

This adds `Io::write_buf_erased` that doesn't required `Self: Sized`, so
it can be called on trait objects. By using this method, specialized
methods of `TcpStream` (and others) can use their `write_buf` to do
vectored writes.

Since it can be easy to forget to call `Io::write_buf_erased` instead of
`Io::write_buf`, the concept of making a `Box<Io>` has been made
private. A new type, `BoxedIo`, implements all the super traits of `Io`,
while making the `Io` trait private to the `transport` module. Anything
hoping to use a `Box<Io>` can use a `BoxedIo` instead, and know that
the write buf erase dance is taken care of.

Adds a test to `transport::io` checking that the dance we've done does
indeed call the underlying specialized `write_buf` method.

Closes #1162
  • Loading branch information
seanmonstar committed Jun 20, 2018
1 parent ad65987 commit 8dcb95d
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 19 deletions.
8 changes: 4 additions & 4 deletions proxy/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tokio::{
use conditional::Conditional;
use ctx::transport::TlsStatus;
use config::Addr;
use transport::{GetOriginalDst, Io, tls};
use transport::{AddrInfo, BoxedIo, GetOriginalDst, tls};

pub struct BoundPort {
inner: std::net::TcpListener,
Expand Down Expand Up @@ -47,7 +47,7 @@ pub enum Connecting {
/// subverted.
#[derive(Debug)]
pub struct Connection {
io: Box<Io>,
io: BoxedIo,
/// This buffer gets filled up when "peeking" bytes on this Connection.
///
/// This is used instead of MSG_PEEK in order to support TLS streams.
Expand Down Expand Up @@ -213,15 +213,15 @@ impl Future for Connecting {
impl Connection {
fn plain(io: TcpStream, why_no_tls: tls::ReasonForNoTls) -> Self {
Connection {
io: Box::new(io),
io: BoxedIo::new(io),
peek_buf: BytesMut::new(),
tls_status: Conditional::None(why_no_tls),
}
}

fn tls<S: tls::Session + std::fmt::Debug + 'static>(tls: tls::Connection<S>) -> Self {
Connection {
io: Box::new(tls),
io: BoxedIo::new(tls),
peek_buf: BytesMut::new(),
tls_status: Conditional::Some(()),
}
Expand Down
169 changes: 169 additions & 0 deletions proxy/src/transport/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use std::io;
use std::net::{Shutdown, SocketAddr};

use bytes::Buf;
use futures::Poll;
use tokio::io::{AsyncRead, AsyncWrite};

use super::AddrInfo;
use self::internal::Io;

/// A public wrapper around a `Box<Io>`.
///
/// This type ensures that the proper write_buf method is called,
/// to allow vectored writes to occur.
#[derive(Debug)]
pub struct BoxedIo(Box<Io>);

impl BoxedIo {
pub fn new<T: Io + 'static>(io: T) -> Self {
BoxedIo(Box::new(io))
}

/// Since `Io` isn't publicly exported, but `Connection` wants
/// this method, it's just an inherent method.
pub fn shutdown_write(&mut self) -> Result<(), io::Error> {
self.0.shutdown_write()
}
}

impl io::Read for BoxedIo {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}

impl io::Write for BoxedIo {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}

fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}

impl AsyncRead for BoxedIo {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
}
}

impl AsyncWrite for BoxedIo {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.0.shutdown()
}

fn write_buf<B: Buf>(&mut self, mut buf: &mut B) -> Poll<usize, io::Error> {
// A trait object of AsyncWrite would use the default write_buf,
// which doesn't allow vectored writes. Going through this method
// allows the trait object to call the specialized write_buf method.
self.0.write_buf_erased(&mut buf)
}
}

impl AddrInfo for BoxedIo {
fn local_addr(&self) -> Result<SocketAddr, io::Error> {
self.0.local_addr()
}

fn get_original_dst(&self) -> Option<SocketAddr> {
self.0.get_original_dst()
}
}

pub(super) mod internal {
use std::io;
use tokio::net::TcpStream;
use super::{AddrInfo, AsyncRead, AsyncWrite, Buf, Poll, Shutdown};

/// This trait is private, since it's purpose is for creating a dynamic
/// trait object, but doing so without care can lead not getting vectored
/// writes.
///
/// Instead, used the concrete `BoxedIo` type.
pub trait Io: AddrInfo + AsyncRead + AsyncWrite + Send {
fn shutdown_write(&mut self) -> Result<(), io::Error>;

/// This method is to allow using `Async::write_buf` even through a
/// trait object.
fn write_buf_erased(&mut self, buf: &mut Buf) -> Poll<usize, io::Error>;
}

impl Io for TcpStream {
fn shutdown_write(&mut self) -> Result<(), io::Error> {
TcpStream::shutdown(self, Shutdown::Write)
}

fn write_buf_erased(&mut self, mut buf: &mut Buf) -> Poll<usize, io::Error> {
self.write_buf(&mut buf)
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[derive(Debug)]
struct WriteBufDetector;

impl io::Read for WriteBufDetector {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
unimplemented!()
}
}

impl io::Write for WriteBufDetector {
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
panic!("BoxedIo called wrong write_buf method");
}
fn flush(&mut self) -> io::Result<()> {
unimplemented!()
}
}

impl AsyncRead for WriteBufDetector {}

impl AsyncWrite for WriteBufDetector {
fn shutdown(&mut self) -> Poll<(), io::Error> {
unimplemented!()
}

fn write_buf<B: Buf>(&mut self, _: &mut B) -> Poll<usize, io::Error> {
Ok(0.into())
}
}

impl AddrInfo for WriteBufDetector {
fn local_addr(&self) -> Result<SocketAddr, io::Error> {
unimplemented!()
}

fn get_original_dst(&self) -> Option<SocketAddr> {
unimplemented!()
}
}

impl Io for WriteBufDetector {
fn shutdown_write(&mut self) -> Result<(), io::Error> {
unimplemented!()
}

fn write_buf_erased(&mut self, mut buf: &mut Buf) -> Poll<usize, io::Error> {
self.write_buf(&mut buf)
}
}


#[test]
fn boxed_io_uses_vectored_io() {
use bytes::IntoBuf;
let mut io = BoxedIo::new(WriteBufDetector);

// This method will trigger the panic in WriteBufDetector::write IFF
// BoxedIo doesn't call write_buf_erased, but write_buf, and triggering
// a regular write.
io.write_buf(&mut "hello".into_buf()).expect("write_buf");
}
}
16 changes: 2 additions & 14 deletions proxy/src/transport/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
use std::io;
use std::net::Shutdown;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;

mod connect;
mod addr_info;
mod io;
pub mod tls;

pub use self::connect::{
Expand All @@ -13,13 +9,5 @@ pub use self::connect::{
LookupAddressAndConnect,
};
pub use self::addr_info::{AddrInfo, GetOriginalDst, SoOriginalDst};
pub use self::io::BoxedIo;

pub trait Io: AddrInfo + AsyncRead + AsyncWrite + Send {
fn shutdown_write(&mut self) -> Result<(), io::Error>;
}

impl Io for TcpStream {
fn shutdown_write(&mut self) -> Result<(), io::Error> {
TcpStream::shutdown(self, Shutdown::Write)
}
}
6 changes: 5 additions & 1 deletion proxy/src/transport/tls/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use futures::Future;
use tokio::prelude::*;
use tokio::net::TcpStream;

use transport::{AddrInfo, Io};
use transport::{AddrInfo, io::internal::Io};

use super::{
identity::Identity,
Expand Down Expand Up @@ -110,4 +110,8 @@ impl<S: Session + Debug> Io for Connection<S> {
fn shutdown_write(&mut self) -> Result<(), io::Error> {
self.0.get_mut().0.shutdown_write()
}

fn write_buf_erased(&mut self, mut buf: &mut Buf) -> Poll<usize, io::Error> {
self.0.write_buf(&mut buf)
}
}

0 comments on commit 8dcb95d

Please sign in to comment.