From b90c3408001f762a32409f7e2cf688ebae39d89e Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Fri, 13 Dec 2019 17:14:49 -0500 Subject: [PATCH] feat(transport): Allow custom IO and UDS example (#184) Closes #136 --- examples/Cargo.toml | 10 ++- examples/src/uds/client.rs | 39 ++++++++++ examples/src/uds/server.rs | 48 ++++++++++++ tonic/src/transport/channel/endpoint.rs | 32 +++++++- tonic/src/transport/channel/mod.rs | 12 ++- tonic/src/transport/server/incoming.rs | 75 ++++++++++++++++++ tonic/src/transport/server/mod.rs | 93 +++++++++-------------- tonic/src/transport/service/connection.rs | 23 +++--- tonic/src/transport/service/connector.rs | 60 ++++++--------- tonic/src/transport/service/discover.rs | 6 +- tonic/src/transport/service/io.rs | 11 ++- tonic/src/transport/service/tls.rs | 12 ++- 12 files changed, 306 insertions(+), 115 deletions(-) create mode 100644 examples/src/uds/client.rs create mode 100644 examples/src/uds/server.rs create mode 100644 tonic/src/transport/server/incoming.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 282a0a46a..992bd5d01 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -78,12 +78,20 @@ path = "src/tracing/client.rs" name = "tracing-server" path = "src/tracing/server.rs" +[[bin]] +name = "uds-client" +path = "src/uds/client.rs" + +[[bin]] +name = "uds-server" +path = "src/uds/server.rs" + [dependencies] tonic = { path = "../tonic", features = ["tls"] } bytes = "0.4" prost = "0.5" -tokio = { version = "0.2", features = ["rt-threaded", "time", "stream", "fs", "macros"] } +tokio = { version = "0.2", features = ["rt-threaded", "time", "stream", "fs", "macros", "uds"] } futures = { version = "0.3", default-features = false, features = ["alloc"]} async-stream = "0.2" http = "0.2" diff --git a/examples/src/uds/client.rs b/examples/src/uds/client.rs new file mode 100644 index 000000000..0b30d024b --- /dev/null +++ b/examples/src/uds/client.rs @@ -0,0 +1,39 @@ +#[cfg(unix)] + +pub mod hello_world { + tonic::include_proto!("helloworld"); +} + +use hello_world::{greeter_client::GreeterClient, HelloRequest}; +use http::Uri; +use std::convert::TryFrom; +use tokio::net::UnixStream; +use tonic::transport::Endpoint; +use tower::service_fn; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // We will ignore this uri because uds do not use it + // if your connector does use the uri it will be provided + // as the request to the `MakeConnection`. + let channel = Endpoint::try_from("lttp://[::]:50051")? + .connect_with_connector(service_fn(|_: Uri| { + let path = "/tmp/tonic/helloworld"; + + // Connect to a Uds socket + UnixStream::connect(path) + })) + .await?; + + let mut client = GreeterClient::new(channel); + + let request = tonic::Request::new(HelloRequest { + name: "Tonic".into(), + }); + + let response = client.say_hello(request).await?; + + println!("RESPONSE={:?}", response); + + Ok(()) +} diff --git a/examples/src/uds/server.rs b/examples/src/uds/server.rs new file mode 100644 index 000000000..36258ab49 --- /dev/null +++ b/examples/src/uds/server.rs @@ -0,0 +1,48 @@ +use std::path::Path; +use tokio::net::UnixListener; +use tonic::{transport::Server, Request, Response, Status}; + +pub mod hello_world { + tonic::include_proto!("helloworld"); +} + +use hello_world::{ + greeter_server::{Greeter, GreeterServer}, + HelloReply, HelloRequest, +}; + +#[derive(Default)] +pub struct MyGreeter {} + +#[tonic::async_trait] +impl Greeter for MyGreeter { + async fn say_hello( + &self, + request: Request, + ) -> Result, Status> { + println!("Got a request: {:?}", request); + + let reply = hello_world::HelloReply { + message: format!("Hello {}!", request.into_inner().name).into(), + }; + Ok(Response::new(reply)) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let path = "/tmp/tonic/helloworld"; + + tokio::fs::create_dir_all(Path::new(path).parent().unwrap()).await?; + + let mut uds = UnixListener::bind(path)?; + + let greeter = MyGreeter::default(); + + Server::builder() + .add_service(GreeterServer::new(greeter)) + .serve_with_incoming(uds.incoming()) + .await?; + + Ok(()) +} diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index cf518d2b4..eb04f3314 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -1,3 +1,4 @@ +use super::super::service; use super::Channel; #[cfg(feature = "tls")] use super::ClientTlsConfig; @@ -12,6 +13,7 @@ use std::{ sync::Arc, time::Duration, }; +use tower_make::MakeConnection; /// Channel builder. /// @@ -182,7 +184,35 @@ impl Endpoint { /// Create a channel from this config. pub async fn connect(&self) -> Result { - Channel::connect(self.clone()).await + let mut http = hyper::client::connect::HttpConnector::new(); + http.enforce_http(false); + http.set_nodelay(self.tcp_nodelay); + http.set_keepalive(self.tcp_keepalive); + + #[cfg(feature = "tls")] + let connector = service::connector(http, self.tls.clone()); + + #[cfg(not(feature = "tls"))] + let connector = service::connector(http); + + Channel::connect(connector, self.clone()).await + } + + /// Connect with a custom connector. + pub async fn connect_with_connector(&self, connector: C) -> Result + where + C: MakeConnection + Send + 'static, + C::Connection: Unpin + Send + 'static, + C::Future: Send + 'static, + crate::Error: From + Send + 'static, + { + #[cfg(feature = "tls")] + let connector = service::connector(connector, self.tls.clone()); + + #[cfg(not(feature = "tls"))] + let connector = service::connector(connector); + + Channel::connect(connector, self.clone()).await } } diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 586e1777e..cb1ec7f00 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -15,6 +15,7 @@ use http::{ uri::{InvalidUri, Uri}, Request, Response, }; +use hyper::client::connect::Connection as HyperConnection; use std::{ fmt, future::Future, @@ -22,6 +23,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; +use tokio::io::{AsyncRead, AsyncWrite}; use tower::{ buffer::{self, Buffer}, discover::Discover, @@ -121,11 +123,17 @@ impl Channel { Self::balance(discover, buffer_size, interceptor_headers) } - pub(crate) async fn connect(endpoint: Endpoint) -> Result { + pub(crate) async fn connect(connector: C, endpoint: Endpoint) -> Result + where + C: Service + Send + 'static, + C::Error: Into + Send, + C::Future: Unpin + Send, + C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + { let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE); let interceptor_headers = endpoint.interceptor_headers.clone(); - let svc = Connection::new(endpoint) + let svc = Connection::new(connector, endpoint) .await .map_err(|e| super::Error::from_source(super::ErrorKind::Client, e))?; diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs new file mode 100644 index 000000000..a6c78ce89 --- /dev/null +++ b/tonic/src/transport/server/incoming.rs @@ -0,0 +1,75 @@ +use super::Server; +use crate::transport::service::BoxedIo; +use futures_core::Stream; +use futures_util::stream::TryStreamExt; +use hyper::server::{ + accept::Accept, + conn::{AddrIncoming, AddrStream}, +}; +use std::{ + net::SocketAddr, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tokio::io::{AsyncRead, AsyncWrite}; +#[cfg(feature = "tls")] +use tracing::error; + +#[cfg_attr(not(feature = "tls"), allow(unused_variables))] +pub(crate) fn tcp_incoming( + incoming: impl Stream>, + server: Server, +) -> impl Stream> +where + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into, +{ + async_stream::try_stream! { + futures_util::pin_mut!(incoming); + + while let Some(stream) = incoming.try_next().await? { + #[cfg(feature = "tls")] + { + if let Some(tls) = &server.tls { + let io = match tls.accept(stream).await { + Ok(io) => io, + Err(error) => { + error!(message = "Unable to accept incoming connection.", %error); + continue + }, + }; + yield BoxedIo::new(io); + continue; + } + } + + yield BoxedIo::new(stream); + } + } +} + +pub(crate) struct TcpIncoming { + inner: AddrIncoming, +} + +impl TcpIncoming { + pub(crate) fn new( + addr: SocketAddr, + nodelay: bool, + keepalive: Option, + ) -> Result { + let mut inner = AddrIncoming::bind(&addr)?; + inner.set_nodelay(nodelay); + inner.set_keepalive(keepalive); + Ok(TcpIncoming { inner }) + } +} + +impl Stream for TcpIncoming { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_accept(cx) + } +} diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index b2c89c186..e1bd03b8f 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -1,5 +1,6 @@ //! Server implementation and builder. +mod incoming; #[cfg(feature = "tls")] mod tls; @@ -9,18 +10,17 @@ pub use tls::ServerTlsConfig; #[cfg(feature = "tls")] use super::service::TlsAcceptor; -use super::service::{layer_fn, BoxedIo, Or, Routes, ServiceBuilderExt}; +use incoming::TcpIncoming; + +use super::service::{layer_fn, Or, Routes, ServiceBuilderExt}; use crate::body::BoxBody; +use futures_core::Stream; use futures_util::{ - future::{self, poll_fn, MapErr}, + future::{self, MapErr}, TryFutureExt, }; use http::{HeaderMap, Request, Response}; -use hyper::{ - server::{accept::Accept, conn}, - Body, -}; -use std::time::Duration; +use hyper::{server::accept, Body}; use std::{ fmt, future::Future, @@ -28,8 +28,9 @@ use std::{ pin::Pin, sync::Arc, task::{Context, Poll}, - // time::Duration, + time::Duration, }; +use tokio::io::{AsyncRead, AsyncWrite}; use tower::{ layer::{Layer, Stack}, limit::concurrency::ConcurrencyLimitLayer, @@ -37,8 +38,6 @@ use tower::{ Service, ServiceBuilder, }; -#[cfg(feature = "tls")] -use tracing::error; use tracing_futures::{Instrument, Instrumented}; type BoxService = tower::util::BoxService, Response, crate::Error>; @@ -242,16 +241,19 @@ impl Server { Router::new(self.clone(), svc) } - pub(crate) async fn serve_with_shutdown( + pub(crate) async fn serve_with_shutdown( self, - addr: SocketAddr, svc: S, + incoming: I, signal: Option, ) -> Result<(), super::Error> where S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, + I: Stream>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into, F: Future, { let interceptor = self.interceptor.clone(); @@ -262,35 +264,8 @@ impl Server { let max_concurrent_streams = self.max_concurrent_streams; // let timeout = self.timeout.clone(); - let incoming = hyper::server::accept::from_stream::<_, _, crate::Error>( - async_stream::try_stream! { - let mut incoming = conn::AddrIncoming::bind(&addr)?; - - incoming.set_nodelay(self.tcp_nodelay); - incoming.set_keepalive(self.tcp_keepalive); - - - - while let Some(stream) = next_accept(&mut incoming).await? { - #[cfg(feature = "tls")] - { - if let Some(tls) = &self.tls { - let io = match tls.connect(stream.into_inner()).await { - Ok(io) => io, - Err(error) => { - error!(message = "Unable to accept incoming connection.", %error); - continue - }, - }; - yield BoxedIo::new(io); - continue; - } - } - - yield BoxedIo::new(stream); - } - }, - ); + let tcp = incoming::tcp_incoming(incoming, self); + let incoming = accept::from_stream::<_, _, crate::Error>(tcp); let svc = MakeSvc { inner: svc, @@ -384,8 +359,10 @@ where /// /// [`Server`]: struct.Server.html pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> { + let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) + .map_err(map_err)?; self.server - .serve_with_shutdown::<_, future::Ready<()>>(addr, self.routes, None) + .serve_with_shutdown::<_, _, future::Ready<()>, _, _>(self.routes, incoming, None) .await } @@ -399,8 +376,25 @@ where addr: SocketAddr, f: F, ) -> Result<(), super::Error> { + let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) + .map_err(map_err)?; self.server - .serve_with_shutdown(addr, self.routes, Some(f)) + .serve_with_shutdown(self.routes, incoming, Some(f)) + .await + } + + /// Consume this [`Server`] creating a future that will execute the server on + /// the provided incoming stream of `AsyncRead + AsyncWrite`. + /// + /// [`Server`]: struct.Server.html + pub async fn serve_with_incoming(self, incoming: I) -> Result<(), super::Error> + where + I: Stream>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into, + { + self.server + .serve_with_shutdown::<_, _, future::Ready<()>, _, _>(self.routes, incoming, None) .await } } @@ -523,16 +517,3 @@ impl Service> for Unimplemented { ) } } - -// Implement try_next for `Accept::poll_accept`. -async fn next_accept( - incoming: &mut conn::AddrIncoming, -) -> Result, crate::Error> { - let res = poll_fn(|cx| Pin::new(&mut *incoming).poll_accept(cx)).await; - - if let Some(res) = res { - Ok(Some(res?)) - } else { - return Ok(None); - } -} diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 2dd83174b..b90c1885d 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,6 +1,8 @@ -use super::{connector, layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin}; +use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin}; use crate::{body::BoxBody, transport::Endpoint}; +use http::Uri; use hyper::client::conn::Builder; +use hyper::client::connect::Connection as HyperConnection; use hyper::client::service::Connect as HyperConnect; use std::{ fmt, @@ -8,6 +10,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; +use tokio::io::{AsyncRead, AsyncWrite}; use tower::{ layer::Layer, limit::{concurrency::ConcurrencyLimitLayer, rate::RateLimitLayer}, @@ -26,17 +29,13 @@ pub(crate) struct Connection { } impl Connection { - pub(crate) async fn new(endpoint: Endpoint) -> Result { - #[cfg(feature = "tls")] - let connector = connector(endpoint.tls.clone()) - .set_keepalive(endpoint.tcp_keepalive) - .set_nodelay(endpoint.tcp_nodelay); - - #[cfg(not(feature = "tls"))] - let connector = connector() - .set_keepalive(endpoint.tcp_keepalive) - .set_nodelay(endpoint.tcp_nodelay); - + pub(crate) async fn new(connector: C, endpoint: Endpoint) -> Result + where + C: Service + Send + 'static, + C::Error: Into + Send, + C::Future: Unpin + Send, + C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + { let settings = Builder::new() .http2_initial_stream_window_size(endpoint.init_stream_window_size) .http2_initial_connection_window_size(endpoint.init_connection_window_size) diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index 552700b4d..c7a9f3c37 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -2,64 +2,50 @@ use super::io::BoxedIo; #[cfg(feature = "tls")] use super::tls::TlsConnector; use http::Uri; -use hyper::client::connect::HttpConnector; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use std::time::Duration; use tower_make::MakeConnection; use tower_service::Service; #[cfg(not(feature = "tls"))] -pub(crate) fn connector() -> Connector { - Connector::new() +pub(crate) fn connector(inner: C) -> Connector { + Connector::new(inner) } #[cfg(feature = "tls")] -pub(crate) fn connector(tls: Option) -> Connector { - Connector::new(tls) +pub(crate) fn connector(inner: C, tls: Option) -> Connector { + Connector::new(inner, tls) } -pub(crate) struct Connector { - http: HttpConnector, +pub(crate) struct Connector { + inner: C, #[cfg(feature = "tls")] tls: Option, + #[cfg(not(feature = "tls"))] + #[allow(dead_code)] + tls: Option<()>, } -impl Connector { +impl Connector { #[cfg(not(feature = "tls"))] - pub(crate) fn new() -> Self { - Self { - http: Self::http_connector(), - } + pub(crate) fn new(inner: C) -> Self { + Self { inner, tls: None } } #[cfg(feature = "tls")] - fn new(tls: Option) -> Self { - Self { - http: Self::http_connector(), - tls, - } - } - - pub(crate) fn set_nodelay(mut self, enabled: bool) -> Self { - self.http.set_nodelay(enabled); - self - } - - pub(crate) fn set_keepalive(mut self, duration: Option) -> Self { - self.http.set_keepalive(duration); - self - } - - fn http_connector() -> HttpConnector { - let mut http = HttpConnector::new(); - http.enforce_http(false); - http + fn new(inner: C, tls: Option) -> Self { + Self { inner, tls } } } -impl Service for Connector { +impl Service for Connector +where + C: MakeConnection, + C::Connection: Unpin + Send + 'static, + C::Future: Send + 'static, + crate::Error: From + Send + 'static, +{ type Response = BoxedIo; type Error = crate::Error; @@ -67,11 +53,11 @@ impl Service for Connector { Pin> + Send + 'static>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - MakeConnection::poll_ready(&mut self.http, cx).map_err(Into::into) + MakeConnection::poll_ready(self, cx).map_err(Into::into) } fn call(&mut self, uri: Uri) -> Self::Future { - let connect = MakeConnection::make_connection(&mut self.http, uri); + let connect = self.inner.make_connection(uri); #[cfg(feature = "tls")] let tls = self.tls.clone(); diff --git a/tonic/src/transport/service/discover.rs b/tonic/src/transport/service/discover.rs index 4635d8617..534e26c45 100644 --- a/tonic/src/transport/service/discover.rs +++ b/tonic/src/transport/service/discover.rs @@ -49,7 +49,11 @@ impl Discover for ServiceList { } if let Some(endpoint) = self.list.pop_front() { - let fut = Connection::new(endpoint); + let mut http = hyper::client::connect::HttpConnector::new(); + http.set_nodelay(endpoint.tcp_nodelay); + http.set_keepalive(endpoint.tcp_keepalive); + + let fut = Connection::new(http, endpoint); self.connecting = Some(Box::pin(fut)); } else { return Poll::Pending; diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index b6ba75229..e001f580f 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,14 +1,15 @@ +use hyper::client::connect::{Connected, Connection}; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite}; pub(in crate::transport) trait Io: - AsyncRead + AsyncWrite + Send + Unpin + 'static + AsyncRead + AsyncWrite + Send + 'static { } -impl Io for T where T: AsyncRead + AsyncWrite + Send + Unpin + 'static {} +impl Io for T where T: AsyncRead + AsyncWrite + Send + 'static {} pub(crate) struct BoxedIo(Pin>); @@ -18,6 +19,12 @@ impl BoxedIo { } } +impl Connection for BoxedIo { + fn connected(&self) -> Connected { + Connected::new() + } +} + impl AsyncRead for BoxedIo { fn poll_read( mut self: Pin<&mut Self>, diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index c619a4433..876ef1631 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -3,7 +3,7 @@ use crate::transport::{Certificate, Identity}; #[cfg(feature = "tls-roots")] use rustls_native_certs; use std::{fmt, sync::Arc}; -use tokio::net::TcpStream; +use tokio::io::{AsyncRead, AsyncWrite}; #[cfg(feature = "tls")] use tokio_rustls::{ rustls::{ClientConfig, NoClientAuth, ServerConfig, Session}, @@ -80,7 +80,10 @@ impl TlsConnector { }) } - pub(crate) async fn connect(&self, io: TcpStream) -> Result { + pub(crate) async fn connect(&self, io: I) -> Result + where + I: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { let tls_io = { let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str())?.to_owned(); @@ -154,7 +157,10 @@ impl TlsAcceptor { }) } - pub(crate) async fn connect(&self, io: TcpStream) -> Result { + pub(crate) async fn accept(&self, io: IO) -> Result + where + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { let io = { let acceptor = RustlsAcceptor::from(self.inner.clone()); let tls = acceptor.accept(io).await?;