From 6dfc20e1db455be12b0a647533c65bbfd6ae78f2 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 18 Feb 2022 18:19:52 +0100 Subject: [PATCH] feat(transport): port router to axum (#830) --- tonic-build/src/server.rs | 2 +- tonic/Cargo.toml | 2 + tonic/src/body.rs | 9 ++ tonic/src/codegen.rs | 11 -- tonic/src/service/interceptor.rs | 37 +++-- tonic/src/transport/server/mod.rs | 221 +++++--------------------- tonic/src/transport/service/mod.rs | 2 +- tonic/src/transport/service/router.rs | 172 ++++++++------------ 8 files changed, 146 insertions(+), 310 deletions(-) diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index 4d405b374..ac432ad28 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -124,7 +124,7 @@ pub fn generate( B::Error: Into + Send + 'static, { type Response = http::Response; - type Error = Never; + type Error = std::convert::Infallible; type Future = BoxFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 4912df98b..2889ec03c 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -32,6 +32,7 @@ tls-roots = ["tls-roots-common", "rustls-native-certs"] tls-roots-common = ["tls"] tls-webpki-roots = ["tls-roots-common", "webpki-roots"] transport = [ + "axum", "h2", "hyper", "tokio", @@ -77,6 +78,7 @@ tokio = {version = "1.0.1", features = ["net"], optional = true} tokio-stream = "0.1" tower = {version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} tracing-futures = {version = "0.2", optional = true} +axum = {version = "0.4", default_features = false, optional = true} # rustls rustls-pemfile = { version = "0.2.1", optional = true } diff --git a/tonic/src/body.rs b/tonic/src/body.rs index d66a4fcb2..53a917ec6 100644 --- a/tonic/src/body.rs +++ b/tonic/src/body.rs @@ -5,6 +5,15 @@ use http_body::Body; /// A type erased HTTP body used for tonic services. pub type BoxBody = http_body::combinators::UnsyncBoxBody; +/// Convert a [`http_body::Body`] into a [`BoxBody`]. +pub(crate) fn boxed(body: B) -> BoxBody +where + B: http_body::Body + Send + 'static, + B::Error: Into, +{ + body.map_err(crate::Status::map_error).boxed_unsync() +} + // this also exists in `crate::codegen` but we need it here since `codegen` has // `#[cfg(feature = "codegen")]`. /// Create an empty `BoxBody` diff --git a/tonic/src/codegen.rs b/tonic/src/codegen.rs index 78a2c88a8..fae80235f 100644 --- a/tonic/src/codegen.rs +++ b/tonic/src/codegen.rs @@ -24,17 +24,6 @@ pub mod http { pub use http::*; } -#[derive(Debug)] -pub enum Never {} - -impl std::fmt::Display for Never { - fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match *self {} - } -} - -impl std::error::Error for Never {} - pub fn empty_body() -> crate::body::BoxBody { http_body::Empty::new() .map_err(|err| match err {}) diff --git a/tonic/src/service/interceptor.rs b/tonic/src/service/interceptor.rs index 441eb9163..83fc05820 100644 --- a/tonic/src/service/interceptor.rs +++ b/tonic/src/service/interceptor.rs @@ -2,7 +2,11 @@ //! //! See [`Interceptor`] for more details. -use crate::{request::SanitizeHeaders, Status}; +use crate::{ + body::{boxed, BoxBody}, + request::SanitizeHeaders, + Status, +}; use bytes::Bytes; use pin_project::pin_project; use std::{ @@ -145,15 +149,16 @@ where F: Interceptor, S: Service, Response = http::Response>, S::Error: Into, + ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { - type Response = http::Response; - type Error = crate::Error; + type Response = http::Response; + type Error = S::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx).map_err(Into::into) + self.inner.poll_ready(cx) } fn call(&mut self, req: http::Request) -> Self::Future { @@ -171,7 +176,7 @@ where let req = req.into_http(uri, SanitizeHeaders::No); ResponseFuture::future(self.inner.call(req)) } - Err(status) => ResponseFuture::error(status), + Err(status) => ResponseFuture::status(status), } } } @@ -200,9 +205,9 @@ impl ResponseFuture { } } - fn error(status: Status) -> Self { + fn status(status: Status) -> Self { Self { - kind: Kind::Error(Some(status)), + kind: Kind::Status(Some(status)), } } } @@ -211,7 +216,7 @@ impl ResponseFuture { #[derive(Debug)] enum Kind { Future(#[pin] F), - Error(Option), + Status(Option), } impl Future for ResponseFuture @@ -221,14 +226,20 @@ where B: Default + http_body::Body + Send + 'static, B::Error: Into, { - type Output = Result, crate::Error>; + type Output = Result, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project().kind.project() { - KindProj::Future(future) => future.poll(cx).map_err(Into::into), - KindProj::Error(status) => { - let response = status.take().unwrap().to_http().map(|_| B::default()); - + KindProj::Future(future) => future + .poll(cx) + .map(|result| result.map(|res| res.map(boxed))), + KindProj::Status(status) => { + let response = status + .take() + .unwrap() + .to_http() + .map(|_| B::default()) + .map(boxed); Poll::Ready(Ok(response)) } } diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 6a391cd62..5b869c47c 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -9,6 +9,7 @@ mod tls; #[cfg(unix)] mod unix; +pub use super::service::Routes; pub use conn::{Connected, TcpConnectInfo}; #[cfg(feature = "tls")] pub use tls::ServerTlsConfig; @@ -31,19 +32,17 @@ pub(crate) use tokio_rustls::server::TlsStream; use crate::transport::Error; use self::recover_error::RecoverError; -use super::service::{GrpcTimeout, Or, Routes, ServerIo}; +use super::service::{GrpcTimeout, ServerIo}; use crate::body::BoxBody; use bytes::Bytes; use futures_core::Stream; -use futures_util::{ - future::{self, MapErr}, - ready, TryFutureExt, -}; +use futures_util::{future, ready}; use http::{Request, Response}; use http_body::Body as _; use hyper::{server::accept, Body}; use pin_project::pin_project; use std::{ + convert::Infallible, fmt, future::Future, marker::PhantomData, @@ -94,41 +93,9 @@ pub struct Server { /// A stack based `Service` router. #[derive(Debug)] -pub struct Router { +pub struct Router { server: Server, - routes: Routes>, -} - -/// A service that is produced from a Tonic `Router`. -/// -/// This service implementation will route between multiple Tonic -/// gRPC endpoints and can be consumed with the rest of the `tower` -/// ecosystem. -#[derive(Debug, Clone)] -pub struct RouterService { - inner: S, -} - -impl Service> for RouterService -where - S: Service, Response = Response> + Clone + Send + 'static, - S::Future: Send + 'static, - S::Error: Into + Send, -{ - type Response = Response; - type Error = crate::Error; - - #[allow(clippy::type_complexity)] - type Future = MapErr crate::Error>; - - #[inline] - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - self.inner.call(req).map_err(Into::into) - } + routes: Routes, } /// A trait to provide a static reference to the service's @@ -330,18 +297,17 @@ impl Server { /// /// This will clone the `Server` builder and create a router that will /// route around different services. - pub fn add_service(&mut self, svc: S) -> Router + pub fn add_service(&mut self, svc: S) -> Router where - S: Service, Response = Response> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, S::Future: Send + 'static, - S::Error: Into + Send, L: Clone, { - Router::new(self.clone(), svc) + Router::new(self.clone(), Routes::new(svc)) } /// Create a router with the optional `S` typed service as the first service. @@ -352,25 +318,18 @@ impl Server { /// # Note /// Even when the argument given is `None` this will capture *all* requests to this service name. /// As a result, one cannot use this to toggle between two identically named implementations. - pub fn add_optional_service( - &mut self, - svc: Option, - ) -> Router, Unimplemented, L> + pub fn add_optional_service(&mut self, svc: Option) -> Router where - S: Service, Response = Response> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, S::Future: Send + 'static, - S::Error: Into + Send, L: Clone, { - let svc = match svc { - Some(some) => Either::A(some), - None => Either::B(Unimplemented::default()), - }; - Router::new(self.clone(), svc) + let routes = svc.map(Routes::new).unwrap_or_default(); + Router::new(self.clone(), routes) } /// Set the [Tower] [`Layer`] all services will be wrapped in. @@ -523,63 +482,25 @@ impl Server { } } -impl Router { - pub(crate) fn new(server: Server, svc: S) -> Self - where - S: Service, Response = Response> - + NamedService - + Clone - + Send - + 'static, - S::Future: Send + 'static, - S::Error: Into + Send, - { - let svc_name = ::NAME; - let svc_route = format!("/{}", svc_name); - let pred = move |req: &Request| { - let path = req.uri().path(); - - path.starts_with(&svc_route) - }; - Self { - server, - routes: Routes::new(pred, svc, Unimplemented::default()), - } +impl Router { + pub(crate) fn new(server: Server, routes: Routes) -> Self { + Self { server, routes } } } -impl Router -where - A: Service, Response = Response> + Clone + Send + 'static, - A::Future: Send + 'static, - A::Error: Into + Send, - B: Service, Response = Response> + Clone + Send + 'static, - B::Future: Send + 'static, - B::Error: Into + Send, -{ +impl Router { /// Add a new service to this router. - pub fn add_service(self, svc: S) -> Router>, L> + pub fn add_service(mut self, svc: S) -> Self where - S: Service, Response = Response> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, S::Future: Send + 'static, - S::Error: Into + Send, { - let Self { routes, server } = self; - - let svc_name = ::NAME; - let svc_route = format!("/{}", svc_name); - let pred = move |req: &Request| { - let path = req.uri().path(); - - path.starts_with(&svc_route) - }; - let routes = routes.push(pred, svc); - - Router { server, routes } + self.routes = self.routes.add_service(svc); + self } /// Add a new optional service to this router. @@ -588,35 +509,19 @@ where /// Even when the argument given is `None` this will capture *all* requests to this service name. /// As a result, one cannot use this to toggle between two identically named implementations. #[allow(clippy::type_complexity)] - pub fn add_optional_service( - self, - svc: Option, - ) -> Router, Or>, L> + pub fn add_optional_service(mut self, svc: Option) -> Self where - S: Service, Response = Response> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, S::Future: Send + 'static, - S::Error: Into + Send, { - let Self { routes, server } = self; - - let svc_name = ::NAME; - let svc_route = format!("/{}", svc_name); - let pred = move |req: &Request| { - let path = req.uri().path(); - - path.starts_with(&svc_route) - }; - let svc = match svc { - Some(some) => Either::A(some), - None => Either::B(Unimplemented::default()), - }; - let routes = routes.push(pred, svc); - - Router { server, routes } + if let Some(svc) = svc { + self.routes = self.routes.add_service(svc); + } + self } /// Consume this [`Server`] creating a future that will execute the server @@ -626,12 +531,10 @@ where /// [tokio]: https://docs.rs/tokio pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> where - L: Layer>>, + L: Layer, L::Service: Service, Response = Response> + Clone + Send + 'static, - <>>>::Service as Service>>::Future: - Send + 'static, - <>>>::Service as Service>>::Error: - Into + Send, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -658,12 +561,10 @@ where signal: F, ) -> Result<(), super::Error> where - L: Layer>>, + L: Layer, L::Service: Service, Response = Response> + Clone + Send + 'static, - <>>>::Service as Service>>::Future: - Send + 'static, - <>>>::Service as Service>>::Error: - Into + Send, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -687,12 +588,10 @@ where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, - L: Layer>>, + L: Layer, L::Service: Service, Response = Response> + Clone + Send + 'static, - <>>>::Service as Service>>::Future: - Send + 'static, - <>>>::Service as Service>>::Error: - Into + Send, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -722,12 +621,10 @@ where IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, F: Future, - L: Layer>>, + L: Layer, L::Service: Service, Response = Response> + Clone + Send + 'static, - <>>>::Service as Service>>::Future: - Send + 'static, - <>>>::Service as Service>>::Error: - Into + Send, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -737,19 +634,16 @@ where } /// Create a tower service out of a router. - pub fn into_service(self) -> RouterService + pub fn into_service(self) -> L::Service where - L: Layer>>, + L: Layer, L::Service: Service, Response = Response> + Clone + Send + 'static, - <>>>::Service as Service>>::Future: - Send + 'static, - <>>>::Service as Service>>::Error: - Into + Send, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { - let inner = self.server.layer.layer(self.routes); - RouterService { inner } + self.server.layer.layer(self.routes) } } @@ -905,30 +799,3 @@ where future::ready(Ok(svc)) } } - -#[derive(Default, Clone, Debug)] -#[doc(hidden)] -pub struct Unimplemented { - _p: (), -} - -impl Service> for Unimplemented { - type Response = Response; - type Error = crate::Error; - type Future = future::Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Ok(()).into() - } - - fn call(&mut self, _req: Request) -> Self::Future { - future::ok( - http::Response::builder() - .status(200) - .header("grpc-status", "12") - .header("content-type", "application/grpc") - .body(crate::body::empty_body()) - .unwrap(), - ) - } -} diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 4e1d89c0c..da7b46cca 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -16,9 +16,9 @@ pub(crate) use self::connector::connector; pub(crate) use self::discover::DynamicServiceStream; pub(crate) use self::grpc_timeout::GrpcTimeout; pub(crate) use self::io::ServerIo; -pub(crate) use self::router::{Or, Routes}; #[cfg(feature = "tls")] pub(crate) use self::tls::{TlsAcceptor, TlsConnector}; pub(crate) use self::user_agent::UserAgent; pub use self::grpc_timeout::TimeoutExpired; +pub use self::router::Routes; diff --git a/tonic/src/transport/service/router.rs b/tonic/src/transport/service/router.rs index 2eff8cafe..5051369ae 100644 --- a/tonic/src/transport/service/router.rs +++ b/tonic/src/transport/service/router.rs @@ -1,132 +1,90 @@ -use futures_util::{ - future::Either, - future::{MapErr, TryFutureExt}, +use crate::{ + body::{boxed, BoxBody}, + transport::NamedService, }; +use axum::handler::Handler; +use http::{Request, Response}; +use hyper::Body; +use pin_project::pin_project; use std::{ - fmt, - sync::Arc, + convert::Infallible, + future::Future, + pin::Pin, task::{Context, Poll}, }; +use tower::ServiceExt; use tower_service::Service; -#[doc(hidden)] -#[derive(Debug)] -pub struct Routes { - routes: Or, -} - -impl Routes { - pub(crate) fn new( - predicate: impl Fn(&Request) -> bool + Send + Sync + 'static, - a: A, - b: B, - ) -> Self { - let routes = Or::new(predicate, a, b); - Self { routes } - } +/// A [`Service`] router. +#[derive(Debug, Default, Clone)] +pub struct Routes { + router: axum::Router, } -impl Routes { - pub(crate) fn push( - self, - predicate: impl Fn(&Request) -> bool + Send + Sync + 'static, - route: C, - ) -> Routes, Request> { - let routes = Or::new(predicate, route, self.routes); - Routes { routes } - } -} - -impl Service for Routes -where - A: Service, - A::Future: Send + 'static, - A::Error: Into, - B: Service, - B::Future: Send + 'static, - B::Error: Into, -{ - type Response = A::Response; - type Error = crate::Error; - type Future = as Service>::Future; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Ok(()).into() - } - - fn call(&mut self, req: Request) -> Self::Future { - self.routes.call(req) - } -} - -impl Clone for Routes { - fn clone(&self) -> Self { - Self { - routes: self.routes.clone(), - } +impl Routes { + pub(crate) fn new(svc: S) -> Self + where + S: Service, Response = Response, Error = Infallible> + + NamedService + + Clone + + Send + + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + { + let router = axum::Router::new().fallback(unimplemented.into_service()); + Self { router }.add_service(svc) } -} - -#[doc(hidden)] -pub struct Or { - predicate: Arc bool + Send + Sync + 'static>, - a: A, - b: B, -} -impl Or { - pub(crate) fn new(predicate: F, a: A, b: B) -> Self + pub(crate) fn add_service(mut self, svc: S) -> Self where - F: Fn(&Request) -> bool + Send + Sync + 'static, + S: Service, Response = Response, Error = Infallible> + + NamedService + + Clone + + Send + + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, { - let predicate = Arc::new(predicate); - Self { predicate, a, b } + let svc = svc.map_response(|res| res.map(axum::body::boxed)); + self.router = self.router.route(&format!("/{}/*rest", S::NAME), svc); + self } } -impl Service for Or -where - A: Service, - A::Future: Send + 'static, - A::Error: Into, - B: Service, - B::Future: Send + 'static, - B::Error: Into, -{ - type Response = A::Response; - type Error = crate::Error; +async fn unimplemented() -> impl axum::response::IntoResponse { + let status = http::StatusCode::OK; + let headers = + axum::response::Headers([("grpc-status", "12"), ("content-type", "application/grpc")]); + (status, headers) +} - #[allow(clippy::type_complexity)] - type Future = Either< - MapErr crate::Error>, - MapErr crate::Error>, - >; +impl Service> for Routes { + type Response = Response; + type Error = crate::Error; + type Future = RoutesFuture; - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Ok(()).into() + #[inline] + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { - if (self.predicate)(&req) { - Either::Left(self.a.call(req).map_err(|e| e.into())) - } else { - Either::Right(self.b.call(req).map_err(|e| e.into())) - } + fn call(&mut self, req: Request) -> Self::Future { + RoutesFuture(self.router.call(req)) } } -impl Clone for Or { - fn clone(&self) -> Self { - Self { - predicate: self.predicate.clone(), - a: self.a.clone(), - b: self.b.clone(), - } - } -} +#[pin_project] +#[derive(Debug)] +pub struct RoutesFuture(#[pin] axum::routing::future::RouterFuture); + +impl Future for RoutesFuture { + type Output = Result, crate::Error>; -impl fmt::Debug for Or { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Or {{ .. }}") + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match futures_util::ready!(self.project().0.poll(cx)) { + Ok(res) => Ok(res.map(boxed)).into(), + Err(err) => match err {}, + } } }