From d280012b7a430ab91fddc4dc382e81e6eca548e0 Mon Sep 17 00:00:00 2001 From: jerryz920 Date: Wed, 8 May 2019 21:00:35 -0500 Subject: [PATCH] #33 don't do read if EOF is hit (#36) --- src/client.rs | 13 ++++++++----- src/common/mod.rs | 34 +++++++++++++++++++++++++--------- src/common/test_stream.rs | 14 +++++++++++++- src/lib.rs | 7 +++++++ src/server.rs | 16 ++++++++++------ 5 files changed, 63 insertions(+), 21 deletions(-) diff --git a/src/client.rs b/src/client.rs index 616c151..4ba6950 100644 --- a/src/client.rs +++ b/src/client.rs @@ -48,8 +48,9 @@ where #[inline] fn poll(&mut self) -> Poll { if let MidHandshake::Handshaking(stream) = self { + let state = stream.state; let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session); + let mut stream = Stream::new(io, session).set_eof(!state.readable()); if stream.session.is_handshaking() { try_nb!(stream.complete_io()); @@ -102,7 +103,7 @@ where self.read(buf) } TlsState::Stream | TlsState::WriteShutdown => { - let mut stream = Stream::new(&mut self.io, &mut self.session); + let mut stream = Stream::new(&mut self.io, &mut self.session).set_eof(!self.state.readable()); match stream.read(buf) { Ok(0) => { @@ -131,7 +132,7 @@ where IO: AsyncRead + AsyncWrite, { fn write(&mut self, buf: &[u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); + let mut stream = Stream::new(&mut self.io, &mut self.session).set_eof(!self.state.readable()); match self.state { #[cfg(feature = "early-data")] @@ -168,7 +169,9 @@ where } fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.io, &mut self.session).flush()?; + Stream::new(&mut self.io, &mut self.session) + .set_eof(!self.state.readable()) + .flush()?; self.io.flush() } } @@ -192,7 +195,7 @@ where self.state.shutdown_write(); } - let mut stream = Stream::new(&mut self.io, &mut self.session); + let mut stream = Stream::new(&mut self.io, &mut self.session).set_eof(!self.state.readable()); try_nb!(stream.flush()); stream.io.shutdown() } diff --git a/src/common/mod.rs b/src/common/mod.rs index 14d2f71..99ef14b 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -5,10 +5,10 @@ use rustls::Session; use rustls::WriteV; use tokio_io::{ AsyncRead, AsyncWrite }; - pub struct Stream<'a, IO: 'a, S: 'a> { pub io: &'a mut IO, - pub session: &'a mut S + pub session: &'a mut S, + pub eof: bool, } pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write { @@ -24,7 +24,18 @@ enum Focus { impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { - Stream { io, session } + Stream { + io, + session, + // The state so far is only used to detect EOF, so either Stream + // or EarlyData state should both be all right. + eof: false, + } + } + + pub fn set_eof(mut self, eof: bool) -> Self { + self.eof = eof; + self } pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { @@ -54,7 +65,6 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { fn complete_inner_io(&mut self, focus: Focus) -> io::Result<(usize, usize)> { let mut wrlen = 0; let mut rdlen = 0; - let mut eof = false; loop { let mut write_would_block = false; @@ -71,12 +81,14 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { } } - if !eof && self.session.wants_read() { + if !self.eof && self.session.wants_read() { match self.complete_read_io() { - Ok(0) => eof = true, + Ok(0) => self.eof = true, Ok(n) => rdlen += n, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => read_would_block = true, - Err(err) => return Err(err) + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + read_would_block = true + } + Err(err) => return Err(err), } } @@ -86,7 +98,11 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { Focus::Writable => write_would_block, }; - match (eof, self.session.is_handshaking(), would_block) { + match ( + self.eof, + self.session.is_handshaking(), + would_block, + ) { (true, true, _) => return Err(io::ErrorKind::UnexpectedEof.into()), (_, false, true) => { let would_block = match focus { diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 744758a..6c2cb3c 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -11,7 +11,6 @@ use futures::{ Async, Poll }; use tokio_io::{ AsyncRead, AsyncWrite }; use super::Stream; - struct Good<'a>(&'a mut Session); impl<'a> Read for Good<'a> { @@ -149,6 +148,19 @@ fn stream_handshake_eof() -> io::Result<()> { Ok(()) } +#[test] +fn stream_eof() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + do_handshake(&mut client, &mut server); + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client).set_eof(true); + let (r, _) = stream.complete_io()?; + assert!(r == 0); + } + Ok(()) +} + fn make_pair() -> (ServerSession, ClientSession) { const CERT: &str = include_str!("../../tests/end.cert"); const CHAIN: &str = include_str!("../../tests/end.chain"); diff --git a/src/lib.rs b/src/lib.rs index 04d7421..5193576 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,6 +51,13 @@ impl TlsState { _ => true, } } + + pub(crate) fn readable(self) -> bool { + match self { + TlsState::ReadShutdown | TlsState::FullyShutdown => false, + _ => true, + } + } } /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. diff --git a/src/server.rs b/src/server.rs index 1568414..91b9604 100644 --- a/src/server.rs +++ b/src/server.rs @@ -42,8 +42,9 @@ where #[inline] fn poll(&mut self) -> Poll { if let MidHandshake::Handshaking(stream) = self { + let state = stream.state; let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session); + let mut stream = Stream::new(io, session).set_eof(!state.readable()); if stream.session.is_handshaking() { try_nb!(stream.complete_io()); @@ -66,7 +67,7 @@ where IO: AsyncRead + AsyncWrite, { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); + let mut stream = Stream::new(&mut self.io, &mut self.session).set_eof(!self.state.readable()); match self.state { TlsState::Stream | TlsState::WriteShutdown => match stream.read(buf) { @@ -97,12 +98,15 @@ where IO: AsyncRead + AsyncWrite, { fn write(&mut self, buf: &[u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); - stream.write(buf) + Stream::new(&mut self.io, &mut self.session) + .set_eof(!self.state.readable()) + .write(buf) } fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.io, &mut self.session).flush()?; + Stream::new(&mut self.io, &mut self.session) + .set_eof(!self.state.readable()) + .flush()?; self.io.flush() } } @@ -126,7 +130,7 @@ where self.state.shutdown_write(); } - let mut stream = Stream::new(&mut self.io, &mut self.session); + let mut stream = Stream::new(&mut self.io, &mut self.session).set_eof(!self.state.readable()); try_nb!(stream.complete_io()); stream.io.shutdown() }