Skip to content

Commit

Permalink
feat(server): add Http::max_buf_size() option
Browse files Browse the repository at this point in the history
The internal connection's read and write bufs will be restricted from
growing bigger than the configured `max_buf_size`.

Closes #1368
  • Loading branch information
seanmonstar committed Jan 24, 2018
1 parent 7cb72d2 commit d22deb6
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 17 deletions.
6 changes: 5 additions & 1 deletion src/proto/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ where I: AsyncRead + AsyncWrite,
self.io.set_flush_pipeline(enabled);
}

pub fn set_max_buf_size(&mut self, max: usize) {
self.io.set_max_buf_size(max);
}

#[cfg(feature = "tokio-proto")]
fn poll_incoming(&mut self) -> Poll<Option<Frame<super::MessageHead<T::Incoming>, super::Chunk, ::Error>>, io::Error> {
trace!("Conn::poll_incoming()");
Expand Down Expand Up @@ -1221,7 +1225,7 @@ mod tests {
let _: Result<(), ()> = future::lazy(|| {
let io = AsyncIo::new_buf(vec![], 0);
let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default());
let max = ::proto::io::MAX_BUFFER_SIZE + 4096;
let max = ::proto::io::DEFAULT_MAX_BUFFER_SIZE + 4096;
conn.state.writing = Writing::Body(Encoder::length((max * 2) as u64), None);

assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; 1024 * 8].into()) }).unwrap().is_ready());
Expand Down
45 changes: 29 additions & 16 deletions src/proto/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ use super::{Http1Transaction, MessageHead};
use bytes::{BytesMut, Bytes};

const INIT_BUFFER_SIZE: usize = 8192;
pub const MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100;
pub const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100;

pub struct Buffered<T> {
flush_pipeline: bool,
io: T,
max_buf_size: usize,
read_blocked: bool,
read_buf: BytesMut,
write_buf: WriteBuf,
Expand All @@ -34,6 +35,7 @@ impl<T: AsyncRead + AsyncWrite> Buffered<T> {
Buffered {
flush_pipeline: false,
io: io,
max_buf_size: DEFAULT_MAX_BUFFER_SIZE,
read_buf: BytesMut::with_capacity(0),
write_buf: WriteBuf::new(),
read_blocked: false,
Expand All @@ -44,14 +46,19 @@ impl<T: AsyncRead + AsyncWrite> Buffered<T> {
self.flush_pipeline = enabled;
}

pub fn set_max_buf_size(&mut self, max: usize) {
self.max_buf_size = max;
self.write_buf.max_buf_size = max;
}

pub fn read_buf(&self) -> &[u8] {
self.read_buf.as_ref()
}

pub fn write_buf_mut(&mut self) -> &mut Vec<u8> {
self.write_buf.maybe_reset();
self.write_buf.maybe_reserve(0);
&mut self.write_buf.0.bytes
&mut self.write_buf.buf.bytes
}

pub fn consume_leading_lines(&mut self) {
Expand All @@ -75,8 +82,8 @@ impl<T: AsyncRead + AsyncWrite> Buffered<T> {
return Ok(Async::Ready(head))
},
None => {
if self.read_buf.capacity() >= MAX_BUFFER_SIZE {
debug!("MAX_BUFFER_SIZE reached, closing");
if self.read_buf.capacity() >= self.max_buf_size {
debug!("max_buf_size ({}) reached, closing", self.max_buf_size);
return Err(::Error::TooLarge);
}
},
Expand Down Expand Up @@ -259,22 +266,28 @@ impl<T: Write> AtomicWrite for T {

// an internal buffer to collect writes before flushes
#[derive(Debug)]
struct WriteBuf(Cursor<Vec<u8>>);
struct WriteBuf{
buf: Cursor<Vec<u8>>,
max_buf_size: usize,
}

impl WriteBuf {
fn new() -> WriteBuf {
WriteBuf(Cursor::new(Vec::new()))
WriteBuf {
buf: Cursor::new(Vec::new()),
max_buf_size: DEFAULT_MAX_BUFFER_SIZE,
}
}

fn write_into<W: Write>(&mut self, w: &mut W) -> io::Result<usize> {
self.0.write_to(w)
self.buf.write_to(w)
}

fn buffer(&mut self, data: &[u8]) -> usize {
trace!("WriteBuf::buffer() len = {:?}", data.len());
self.maybe_reset();
self.maybe_reserve(data.len());
let vec = &mut self.0.bytes;
let vec = &mut self.buf.bytes;
let len = cmp::min(vec.capacity() - vec.len(), data.len());
assert!(vec.capacity() - vec.len() >= len);
unsafe {
Expand All @@ -291,28 +304,28 @@ impl WriteBuf {
}

fn remaining(&self) -> usize {
self.0.remaining()
self.buf.remaining()
}

#[inline]
fn maybe_reserve(&mut self, needed: usize) {
let vec = &mut self.0.bytes;
let vec = &mut self.buf.bytes;
let cap = vec.capacity();
if cap == 0 {
let init = cmp::min(MAX_BUFFER_SIZE, cmp::max(INIT_BUFFER_SIZE, needed));
let init = cmp::min(self.max_buf_size, cmp::max(INIT_BUFFER_SIZE, needed));
trace!("WriteBuf reserving initial {}", init);
vec.reserve(init);
} else if cap < MAX_BUFFER_SIZE {
vec.reserve(cmp::min(needed, MAX_BUFFER_SIZE - cap));
} else if cap < self.max_buf_size {
vec.reserve(cmp::min(needed, self.max_buf_size - cap));
trace!("WriteBuf reserved {}", vec.capacity() - cap);
}
}

fn maybe_reset(&mut self) {
if self.0.pos != 0 && self.0.remaining() == 0 {
self.0.pos = 0;
if self.buf.pos != 0 && self.buf.remaining() == 0 {
self.buf.pos = 0;
unsafe {
self.0.bytes.set_len(0);
self.buf.bytes.set_len(0);
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub use self::service::{const_service, service_fn};
/// which handle a connection to an HTTP server. Each instance of `Http` can be
/// configured with various protocol-level options such as keepalive.
pub struct Http<B = ::Chunk> {
max_buf_size: Option<usize>,
keep_alive: bool,
pipeline: bool,
_marker: PhantomData<B>,
Expand Down Expand Up @@ -129,6 +130,7 @@ impl<B: AsRef<[u8]> + 'static> Http<B> {
pub fn new() -> Http<B> {
Http {
keep_alive: true,
max_buf_size: None,
pipeline: false,
_marker: PhantomData,
}
Expand All @@ -142,6 +144,12 @@ impl<B: AsRef<[u8]> + 'static> Http<B> {
self
}

/// Set the maximum buffer size for the connection.
pub fn max_buf_size(&mut self, max: usize) -> &mut Self {
self.max_buf_size = Some(max);
self
}

/// Aggregates flushes to better support pipelined responses.
///
/// Experimental, may be have bugs.
Expand Down Expand Up @@ -226,6 +234,7 @@ impl<B: AsRef<[u8]> + 'static> Http<B> {
new_service: new_service,
protocol: Http {
keep_alive: self.keep_alive,
max_buf_size: self.max_buf_size,
pipeline: self.pipeline,
_marker: PhantomData,
},
Expand All @@ -250,6 +259,9 @@ impl<B: AsRef<[u8]> + 'static> Http<B> {
};
let mut conn = proto::Conn::new(io, ka);
conn.set_flush_pipeline(self.pipeline);
if let Some(max) = self.max_buf_size {
conn.set_max_buf_size(max);
}
Connection {
conn: proto::dispatch::Dispatcher::new(proto::dispatch::Server::new(service), conn),
}
Expand Down
3 changes: 3 additions & 0 deletions src/server/server_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ impl<T, B> ServerProto<T> for Http<B>
};
let mut conn = proto::Conn::new(io, ka);
conn.set_flush_pipeline(self.pipeline);
if let Some(max) = self.max_buf_size {
conn.set_max_buf_size(max);
}
__ProtoBindTransport {
inner: future::ok(conn),
}
Expand Down
35 changes: 35 additions & 0 deletions tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,41 @@ fn illegal_request_length_returns_400_response() {
core.run(fut).unwrap_err();
}

#[test]
fn max_buf_size() {
let _ = pretty_env_logger::try_init();
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();

const MAX: usize = 16_000;

thread::spawn(move || {
let mut tcp = connect(&addr);
tcp.write_all(b"POST /").expect("write 1");
tcp.write_all(&vec![b'a'; MAX]).expect("write 2");
tcp.write_all(b" HTTP/1.1\r\n\r\n").expect("write 3");
let mut buf = [0; 256];
tcp.read(&mut buf).expect("read 1");

let expected = "HTTP/1.1 400 ";
assert_eq!(s(&buf[..expected.len()]), expected);
});

let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new()
.max_buf_size(MAX)
.serve_connection(socket, HelloWorld)
.map(|_| ())
});

core.run(fut).unwrap_err();
}

#[test]
fn remote_addr() {
let server = serve();
Expand Down

0 comments on commit d22deb6

Please sign in to comment.