From ed3e7e95a5401b9b224640e17908c2182286197d Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Fri, 4 Oct 2019 19:17:41 -0400 Subject: [PATCH] fix(codec): Fix buffer decode panic on full (#43) * fix(codec): Fix buffer decode panic on full This is a naive fix for the buffer growing beyond capacity and producing a panic. Ideally we should do a better job of not having to allocate for new messages by using a link list. * fmt --- tonic/src/codec/decode.rs | 16 +++++- tonic/src/codec/encode.rs | 4 +- tonic/src/codec/mod.rs | 3 + tonic/src/codec/prost.rs | 4 +- tonic/src/codec/tests.rs | 112 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 134 insertions(+), 5 deletions(-) create mode 100644 tonic/src/codec/tests.rs diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 6b657fd0b..413afdb45 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -12,6 +12,8 @@ use std::{ }; use tracing::{debug, trace}; +const BUFFER_SIZE: usize = 8 * 1024; + /// Streaming requests and responses. /// /// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface @@ -70,6 +72,7 @@ impl Streaming { { Self::new(decoder, body, Direction::Request) } + fn new(decoder: D, body: B, direction: Direction) -> Self where B: Body + Send + 'static, @@ -82,8 +85,7 @@ impl Streaming { body: BoxBody::map_from(body), state: State::ReadHeader, direction, - // FIXME: update this with a reasonable size - buf: BytesMut::with_capacity(1024 * 1024), + buf: BytesMut::with_capacity(BUFFER_SIZE), trailers: None, } } @@ -234,6 +236,16 @@ impl Stream for Streaming { }; if let Some(data) = chunk { + if data.remaining() > self.buf.remaining_mut() { + let amt = if data.remaining() > BUFFER_SIZE { + data.remaining() + } else { + BUFFER_SIZE + }; + + self.buf.reserve(amt); + } + self.buf.put(data); } else { // FIXME: improve buf usage. diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index bfa269a2b..c157595f8 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -9,6 +9,8 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tokio_codec::Encoder; +const BUFFER_SIZE: usize = 8 * 1024; + pub(crate) fn encode_server( encoder: T, source: U, @@ -39,7 +41,7 @@ where U: Stream>, { async_stream::stream! { - let mut buf = BytesMut::with_capacity(1024 * 1024); + let mut buf = BytesMut::with_capacity(BUFFER_SIZE); futures_util::pin_mut!(source); loop { diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index 5859223c8..c32c729af 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -8,6 +8,9 @@ mod encode; #[cfg(feature = "prost")] mod prost; +#[cfg(test)] +mod tests; + pub use self::decode::Streaming; pub(crate) use self::encode::{encode_client, encode_server}; #[cfg(feature = "prost")] diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 776c907d7..495dc11f5 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -40,7 +40,7 @@ where } /// A [`Encoder`] that knows how to encode `T`. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct ProstEncoder(PhantomData); impl Encoder for ProstEncoder { @@ -60,7 +60,7 @@ impl Encoder for ProstEncoder { } /// A [`Decoder`] that knows how to decode `U`. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct ProstDecoder(PhantomData); impl Decoder for ProstDecoder { diff --git a/tonic/src/codec/tests.rs b/tonic/src/codec/tests.rs new file mode 100644 index 000000000..a61d063ab --- /dev/null +++ b/tonic/src/codec/tests.rs @@ -0,0 +1,112 @@ +use super::{ + encode_server, + prost::{ProstDecoder, ProstEncoder}, + Streaming, +}; +use crate::Status; +use bytes::{Buf, BufMut, Bytes, BytesMut, IntoBuf}; +use http_body::Body; +use prost::Message; +use std::{ + io::Cursor, + pin::Pin, + task::{Context, Poll}, +}; + +#[derive(Clone, PartialEq, prost::Message)] +struct Msg { + #[prost(bytes, tag = "1")] + data: Vec, +} + +#[tokio::test] +async fn decode() { + let decoder = ProstDecoder::::default(); + + let data = Vec::from(&[0u8; 1024][..]); + let msg = Msg { data }; + + let mut buf = BytesMut::new(); + let len = msg.encoded_len(); + + buf.reserve(len + 5); + buf.put_u8(0); + buf.put_u32_be(len as u32); + msg.encode(&mut buf).unwrap(); + + let body = MockBody(buf.freeze(), 0, 100); + + let mut stream = Streaming::new_request(decoder, body); + + while let Some(_) = stream.message().await.unwrap() {} +} + +#[tokio::test] +async fn encode() { + let encoder = ProstEncoder::::default(); + + let data = Vec::from(&[0u8; 1024][..]); + let msg = Msg { data }; + + let messages = std::iter::repeat(Ok::<_, Status>(msg)).take(10000); + let source = futures_util::stream::iter(messages); + + let body = encode_server(encoder, source); + + futures_util::pin_mut!(body); + + while let Some(r) = body.next().await { + r.unwrap(); + } +} + +#[derive(Debug)] +struct MockBody(Bytes, usize, usize); + +impl Body for MockBody { + type Data = Data; + type Error = Status; + + fn poll_data( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.1 > self.2 { + self.1 += 1; + let data = Data(self.0.clone().into_buf()); + Poll::Ready(Some(Ok(data))) + } else { + Poll::Ready(None) + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + drop(cx); + Poll::Ready(Ok(None)) + } +} + +struct Data(Cursor); + +impl Into for Data { + fn into(self) -> Bytes { + self.0.into_inner() + } +} + +impl Buf for Data { + fn remaining(&self) -> usize { + self.0.remaining() + } + + fn bytes(&self) -> &[u8] { + self.0.bytes() + } + + fn advance(&mut self, cnt: usize) { + self.0.advance(cnt) + } +}