Skip to content

Commit

Permalink
#33 don't do read if EOF is hit (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaiyan920 authored and quininer committed May 9, 2019
1 parent 00f1022 commit d280012
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 21 deletions.
13 changes: 8 additions & 5 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ where
#[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let MidHandshake::Handshaking(stream) = self {
let state = stream.state;
let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session);
let mut stream = Stream::new(io, session).set_eof(!state.readable());

if stream.session.is_handshaking() {
try_nb!(stream.complete_io());
Expand Down Expand Up @@ -102,7 +103,7 @@ where
self.read(buf)
}
TlsState::Stream | TlsState::WriteShutdown => {
let mut stream = Stream::new(&mut self.io, &mut self.session);
let mut stream = Stream::new(&mut self.io, &mut self.session).set_eof(!self.state.readable());

match stream.read(buf) {
Ok(0) => {
Expand Down Expand Up @@ -131,7 +132,7 @@ where
IO: AsyncRead + AsyncWrite,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
let mut stream = Stream::new(&mut self.io, &mut self.session).set_eof(!self.state.readable());

match self.state {
#[cfg(feature = "early-data")]
Expand Down Expand Up @@ -168,7 +169,9 @@ where
}

fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.io, &mut self.session).flush()?;
Stream::new(&mut self.io, &mut self.session)
.set_eof(!self.state.readable())
.flush()?;
self.io.flush()
}
}
Expand All @@ -192,7 +195,7 @@ where
self.state.shutdown_write();
}

let mut stream = Stream::new(&mut self.io, &mut self.session);
let mut stream = Stream::new(&mut self.io, &mut self.session).set_eof(!self.state.readable());
try_nb!(stream.flush());
stream.io.shutdown()
}
Expand Down
34 changes: 25 additions & 9 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use rustls::Session;
use rustls::WriteV;
use tokio_io::{ AsyncRead, AsyncWrite };


pub struct Stream<'a, IO: 'a, S: 'a> {
pub io: &'a mut IO,
pub session: &'a mut S
pub session: &'a mut S,
pub eof: bool,
}

pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write {
Expand All @@ -24,7 +24,18 @@ enum Focus {

impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> {
pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
Stream { io, session }
Stream {
io,
session,
// The state so far is only used to detect EOF, so either Stream
// or EarlyData state should both be all right.
eof: false,
}
}

pub fn set_eof(mut self, eof: bool) -> Self {
self.eof = eof;
self
}

pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
Expand Down Expand Up @@ -54,7 +65,6 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> {
fn complete_inner_io(&mut self, focus: Focus) -> io::Result<(usize, usize)> {
let mut wrlen = 0;
let mut rdlen = 0;
let mut eof = false;

loop {
let mut write_would_block = false;
Expand All @@ -71,12 +81,14 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> {
}
}

if !eof && self.session.wants_read() {
if !self.eof && self.session.wants_read() {
match self.complete_read_io() {
Ok(0) => eof = true,
Ok(0) => self.eof = true,
Ok(n) => rdlen += n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => read_would_block = true,
Err(err) => return Err(err)
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
read_would_block = true
}
Err(err) => return Err(err),
}
}

Expand All @@ -86,7 +98,11 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> {
Focus::Writable => write_would_block,
};

match (eof, self.session.is_handshaking(), would_block) {
match (
self.eof,
self.session.is_handshaking(),
would_block,
) {
(true, true, _) => return Err(io::ErrorKind::UnexpectedEof.into()),
(_, false, true) => {
let would_block = match focus {
Expand Down
14 changes: 13 additions & 1 deletion src/common/test_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use futures::{ Async, Poll };
use tokio_io::{ AsyncRead, AsyncWrite };
use super::Stream;


struct Good<'a>(&'a mut Session);

impl<'a> Read for Good<'a> {
Expand Down Expand Up @@ -149,6 +148,19 @@ fn stream_handshake_eof() -> io::Result<()> {
Ok(())
}

#[test]
fn stream_eof() -> io::Result<()> {
let (mut server, mut client) = make_pair();
do_handshake(&mut client, &mut server);
{
let mut good = Good(&mut server);
let mut stream = Stream::new(&mut good, &mut client).set_eof(true);
let (r, _) = stream.complete_io()?;
assert!(r == 0);
}
Ok(())
}

fn make_pair() -> (ServerSession, ClientSession) {
const CERT: &str = include_str!("../../tests/end.cert");
const CHAIN: &str = include_str!("../../tests/end.chain");
Expand Down
7 changes: 7 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ impl TlsState {
_ => true,
}
}

pub(crate) fn readable(self) -> bool {
match self {
TlsState::ReadShutdown | TlsState::FullyShutdown => false,
_ => true,
}
}
}

/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
Expand Down
16 changes: 10 additions & 6 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ where
#[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let MidHandshake::Handshaking(stream) = self {
let state = stream.state;
let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session);
let mut stream = Stream::new(io, session).set_eof(!state.readable());

if stream.session.is_handshaking() {
try_nb!(stream.complete_io());
Expand All @@ -66,7 +67,7 @@ where
IO: AsyncRead + AsyncWrite,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
let mut stream = Stream::new(&mut self.io, &mut self.session).set_eof(!self.state.readable());

match self.state {
TlsState::Stream | TlsState::WriteShutdown => match stream.read(buf) {
Expand Down Expand Up @@ -97,12 +98,15 @@ where
IO: AsyncRead + AsyncWrite,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
stream.write(buf)
Stream::new(&mut self.io, &mut self.session)
.set_eof(!self.state.readable())
.write(buf)
}

fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.io, &mut self.session).flush()?;
Stream::new(&mut self.io, &mut self.session)
.set_eof(!self.state.readable())
.flush()?;
self.io.flush()
}
}
Expand All @@ -126,7 +130,7 @@ where
self.state.shutdown_write();
}

let mut stream = Stream::new(&mut self.io, &mut self.session);
let mut stream = Stream::new(&mut self.io, &mut self.session).set_eof(!self.state.readable());
try_nb!(stream.complete_io());
stream.io.shutdown()
}
Expand Down

0 comments on commit d280012

Please sign in to comment.