diff --git a/Cargo.toml b/Cargo.toml index 05975fc1ba..c9fbdcdb10 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,11 @@ runtime = [ "tokio-net", "tokio-timer", ] + +# unstable features +stream = [] + +# internal features used in CI nightly = [] __internal_flaky_tests = [] __internal_happy_eyeballs_tests = [] diff --git a/src/server/accept.rs b/src/server/accept.rs new file mode 100644 index 0000000000..9e1de03c78 --- /dev/null +++ b/src/server/accept.rs @@ -0,0 +1,99 @@ +//! The `Accept` trait and supporting types. +//! +//! This module contains: +//! +//! - The [`Accept`](Accept) trait used to asynchronously accept incoming +//! connections. +//! - Utilities like `poll_fn` to ease creating a custom `Accept`. + +#[cfg(feature = "stream")] +use futures_core::Stream; + +use crate::common::{Pin, task::{self, Poll}}; + +/// Asynchronously accept incoming connections. +pub trait Accept { + /// The connection type that can be accepted. + type Conn; + /// The error type that can occur when accepting a connection. + type Error; + + /// Poll to accept the next connection. + fn poll_accept(self: Pin<&mut Self>, cx: &mut task::Context<'_>) + -> Poll>>; +} + +/// Create an `Accept` with a polling function. +/// +/// # Example +/// +/// ``` +/// use std::task::Poll; +/// use hyper::server::{accept, Server}; +/// +/// # let mock_conn = (); +/// // If we created some mocked connection... +/// let mut conn = Some(mock_conn); +/// +/// // And accept just the mocked conn once... +/// let once = accept::poll_fn(move |cx| { +/// Poll::Ready(conn.take().map(Ok::<_, ()>)) +/// }); +/// +/// let builder = Server::builder(once); +/// ``` +pub fn poll_fn(func: F) -> impl Accept +where + F: FnMut(&mut task::Context<'_>) -> Poll>>, +{ + struct PollFn(F); + + impl Accept for PollFn + where + F: FnMut(&mut task::Context<'_>) -> Poll>>, + { + type Conn = IO; + type Error = E; + fn poll_accept(self: Pin<&mut Self>, cx: &mut task::Context<'_>) + -> Poll>> + { + unsafe { + (self.get_unchecked_mut().0)(cx) + } + } + } + + PollFn(func) +} + +/// Adapt a `Stream` of incoming connections into an `Accept`. +/// +/// # Unstable +/// +/// This function requires enabling the unstable `stream` feature in your +/// `Cargo.toml`. +#[cfg(feature = "stream")] +pub fn from_stream(stream: S) -> impl Accept +where + S: Stream>, +{ + struct FromStream(S); + + impl Accept for FromStream + where + S: Stream>, + { + type Conn = IO; + type Error = E; + fn poll_accept(self: Pin<&mut Self>, cx: &mut task::Context<'_>) + -> Poll>> + { + unsafe { + Pin::new_unchecked(&mut self.get_unchecked_mut().0) + .poll_next(cx) + } + } + } + + FromStream(stream) +} diff --git a/src/server/conn.rs b/src/server/conn.rs index d0789d0e4f..7a707860e5 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -28,6 +28,7 @@ use crate::error::{Kind, Parse}; use crate::proto; use crate::service::{MakeServiceRef, Service}; use crate::upgrade::Upgraded; +use super::Accept; pub(super) use self::spawn_all::NoopWatcher; use self::spawn_all::NewSvcTask; @@ -403,13 +404,10 @@ impl Http { } } - /// Bind the provided `addr` with the default `Handle` and return [`Serve`](Serve). - /// - /// This method will bind the `addr` provided with a new TCP listener ready - /// to accept connections. Each connection will be processed with the - /// `make_service` object provided, creating a new service per - /// connection. #[cfg(feature = "runtime")] + #[doc(hidden)] + #[deprecated] + #[allow(deprecated)] pub fn serve_addr(&self, addr: &SocketAddr, make_service: S) -> crate::Result> where S: MakeServiceRef< @@ -428,13 +426,10 @@ impl Http { Ok(self.serve_incoming(incoming, make_service)) } - /// Bind the provided `addr` with the `Handle` and return a [`Serve`](Serve) - /// - /// This method will bind the `addr` provided with a new TCP listener ready - /// to accept connections. Each connection will be processed with the - /// `make_service` object provided, creating a new service per - /// connection. #[cfg(feature = "runtime")] + #[doc(hidden)] + #[deprecated] + #[allow(deprecated)] pub fn serve_addr_handle(&self, addr: &SocketAddr, handle: &Handle, make_service: S) -> crate::Result> where S: MakeServiceRef< @@ -453,10 +448,11 @@ impl Http { Ok(self.serve_incoming(incoming, make_service)) } - /// Bind the provided stream of incoming IO objects with a `MakeService`. + #[doc(hidden)] + #[deprecated] pub fn serve_incoming(&self, incoming: I, make_service: S) -> Serve where - I: Stream>, + I: Accept, IE: Into>, IO: AsyncRead + AsyncWrite + Unpin, S: MakeServiceRef< @@ -678,13 +674,6 @@ where // ===== impl Serve ===== impl Serve { - /// Spawn all incoming connections onto the executor in `Http`. - pub(super) fn spawn_all(self) -> SpawnAll { - SpawnAll { - serve: self, - } - } - /// Get a reference to the incoming stream. #[inline] pub fn incoming_ref(&self) -> &I { @@ -696,22 +685,28 @@ impl Serve { pub fn incoming_mut(&mut self) -> &mut I { &mut self.incoming } + + /// Spawn all incoming connections onto the executor in `Http`. + pub(super) fn spawn_all(self) -> SpawnAll { + SpawnAll { + serve: self, + } + } } -impl Stream for Serve + + + +impl Serve where - I: Stream>, + I: Accept, IO: AsyncRead + AsyncWrite + Unpin, IE: Into>, S: MakeServiceRef, - //S::Error2: Into>, - //SME: Into>, B: Payload, E: H2Exec<>::Future, B>, { - type Item = crate::Result>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + fn poll_next_(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll>>> { match ready!(self.project().make_service.poll_ready_ref(cx)) { Ok(()) => (), Err(e) => { @@ -720,7 +715,7 @@ where } } - if let Some(item) = ready!(self.project().incoming.poll_next(cx)) { + if let Some(item) = ready!(self.project().incoming.poll_accept(cx)) { let io = item.map_err(crate::Error::new_accept)?; let new_fut = self.project().make_service.make_service_ref(&io); Poll::Ready(Some(Ok(Connecting { @@ -734,6 +729,23 @@ where } } +// deprecated +impl Stream for Serve +where + I: Accept, + IO: AsyncRead + AsyncWrite + Unpin, + IE: Into>, + S: MakeServiceRef, + B: Payload, + E: H2Exec<>::Future, B>, +{ + type Item = crate::Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.poll_next_(cx) + } +} + // ===== impl Connecting ===== @@ -772,7 +784,7 @@ impl SpawnAll { impl SpawnAll where - I: Stream>, + I: Accept, IE: Into>, IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: MakeServiceRef< @@ -789,7 +801,7 @@ where W: Watcher, { loop { - if let Some(connecting) = ready!(self.project().serve.poll_next(cx)?) { + if let Some(connecting) = ready!(self.project().serve.poll_next_(cx)?) { let fut = NewSvcTask::new(connecting, watcher.clone()); self.project().serve.project().protocol.exec.execute_new_svc(fut)?; } else { diff --git a/src/server/mod.rs b/src/server/mod.rs index ef306afa96..152f6733a6 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -48,6 +48,7 @@ //! # fn main() {} //! ``` +pub mod accept; pub mod conn; mod shutdown; #[cfg(feature = "runtime")] mod tcp; @@ -58,7 +59,6 @@ use std::fmt; #[cfg(feature = "runtime")] use std::time::Duration; -use futures_core::Stream; use tokio_io::{AsyncRead, AsyncWrite}; use pin_project::pin_project; @@ -66,6 +66,7 @@ use crate::body::{Body, Payload}; use crate::common::exec::{Exec, H2Exec, NewSvcExec}; use crate::common::{Future, Pin, Poll, Unpin, task}; use crate::service::{MakeServiceRef, Service}; +use self::accept::Accept; // Renamed `Http` as `Http_` for now so that people upgrading don't see an // error that `hyper::server::Http` is private... use self::conn::{Http as Http_, NoopWatcher, SpawnAll}; @@ -143,7 +144,7 @@ impl Server { impl Server where - I: Stream>, + I: Accept, IE: Into>, IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: MakeServiceRef, @@ -200,7 +201,7 @@ where impl Future for Server where - I: Stream>, + I: Accept, IE: Into>, IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: MakeServiceRef, @@ -380,17 +381,17 @@ impl Builder { /// // Finally, spawn `server` onto an Executor... /// # } /// ``` - pub fn serve(self, new_service: S) -> Server + pub fn serve(self, new_service: S) -> Server where - I: Stream>, - IE: Into>, - IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, - S: MakeServiceRef, + I: Accept, + I::Error: Into>, + I::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: MakeServiceRef, S::Error: Into>, S::Service: 'static, B: Payload, B::Data: Unpin, - E: NewSvcExec, + E: NewSvcExec, E: H2Exec<>::Future, B>, { let serve = self.protocol.serve_incoming(self.incoming, new_service); diff --git a/src/server/shutdown.rs b/src/server/shutdown.rs index 3080ae9b83..4591a1a0db 100644 --- a/src/server/shutdown.rs +++ b/src/server/shutdown.rs @@ -1,6 +1,5 @@ use std::error::Error as StdError; -use futures_core::Stream; use tokio_io::{AsyncRead, AsyncWrite}; use pin_project::{pin_project, project}; @@ -9,6 +8,7 @@ use crate::common::drain::{self, Draining, Signal, Watch, Watching}; use crate::common::exec::{H2Exec, NewSvcExec}; use crate::common::{Future, Pin, Poll, Unpin, task}; use crate::service::{MakeServiceRef, Service}; +use super::Accept; use super::conn::{SpawnAll, UpgradeableConnection, Watcher}; #[allow(missing_debug_implementations)] @@ -46,7 +46,7 @@ impl Graceful { impl Future for Graceful where - I: Stream>, + I: Accept, IE: Into>, IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: MakeServiceRef, diff --git a/src/server/tcp.rs b/src/server/tcp.rs index 47e064d746..17e9a14a3e 100644 --- a/src/server/tcp.rs +++ b/src/server/tcp.rs @@ -3,7 +3,6 @@ use std::io; use std::net::{SocketAddr, TcpListener as StdTcpListener}; use std::time::Duration; -use futures_core::Stream; use futures_util::FutureExt as _; use tokio_net::driver::Handle; use tokio_net::tcp::TcpListener; @@ -11,6 +10,7 @@ use tokio_timer::Delay; use crate::common::{Future, Pin, Poll, task}; +use super::Accept; pub use self::addr_stream::AddrStream; /// A stream of connections from binding to an address. @@ -156,6 +156,7 @@ impl AddrIncoming { } } +/* impl Stream for AddrIncoming { type Item = io::Result; @@ -164,6 +165,17 @@ impl Stream for AddrIncoming { Poll::Ready(Some(result)) } } +*/ + +impl Accept for AddrIncoming { + type Conn = AddrStream; + type Error = io::Error; + + fn poll_accept(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll>> { + let result = ready!(self.poll_next_(cx)); + Poll::Ready(Some(result)) + } +} /// This function defines errors that are per-connection. Which basically /// means that if we get this error from `accept()` system call it means diff --git a/tests/support/mod.rs b/tests/support/mod.rs index 601e699e95..858802f86b 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -322,7 +322,18 @@ pub fn __run_test(cfg: __TestConfig) { let serve_handles = Arc::new(Mutex::new( cfg.server_msgs )); + + let expected_connections = cfg.connections; + let mut cnt = 0; let new_service = make_service_fn(move |_| { + cnt += 1; + assert!( + cnt <= expected_connections, + "server expected {} connections, received {}", + expected_connections, + cnt + ); + // Move a clone into the service_fn let serve_handles = serve_handles.clone(); future::ok::<_, hyper::Error>(service_fn(move |req: Request| { @@ -352,36 +363,15 @@ pub fn __run_test(cfg: __TestConfig) { })) }); - let serve = hyper::server::conn::Http::new() + let server = hyper::Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))) .http2_only(cfg.server_version == 2) - .serve_addr( - &SocketAddr::from(([127, 0, 0, 1], 0)), - new_service, - ) - .expect("serve_addr"); + .serve(new_service); - let mut addr = serve.incoming_ref().local_addr(); - let expected_connections = cfg.connections; - let server = serve - .try_fold(0, move |cnt, connecting| { - let cnt = cnt + 1; - assert!( - cnt <= expected_connections, - "server expected {} connections, received {}", - expected_connections, - cnt - ); - let fut = connecting - .then(|res| res.expect("connecting")) - .map(|conn_res| conn_res.expect("server connection error")); - tokio::spawn(fut); - future::ok::<_, hyper::Error>(cnt) - }) - .map(|res| { - let _ = res.expect("serve error"); - }); + let mut addr = server.local_addr(); - rt.spawn(server); + rt.spawn(server.map(|result| { + let _ = result.expect("server error"); + })); if cfg.proxy { let (proxy_addr, proxy) = naive_proxy(ProxyConfig { @@ -393,7 +383,6 @@ pub fn __run_test(cfg: __TestConfig) { addr = proxy_addr; } - let make_request = Arc::new(move |client: &Client, creq: __CReq, cres: __CRes| { let uri = format!("http://{}{}", addr, creq.uri); let mut req = Request::builder()