diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index 51b741aa7..ce8a7a550 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -7,63 +7,32 @@ use syn::{Ident, Lit, LitStr}; pub(crate) fn generate(service: &Service, proto_path: &str) -> TokenStream { let methods = generate_methods(&service, proto_path); - let server_make_service = quote::format_ident!("{}Server", service.name); - let server_service = quote::format_ident!("{}ServerSvc", service.name); + let server_service = quote::format_ident!("{}Server", service.name); let server_trait = quote::format_ident!("{}", service.name); let generated_trait = generate_trait(service, proto_path, server_trait.clone()); let service_doc = generate_doc_comments(&service.comments.leading); - let server_new_doc = generate_doc_comment(&format!( - "Create a new {} from a type that implements {}.", - server_make_service, server_trait - )); + + // Transport based implementations + let path = format!("{}.{}", service.package, service.proto_name); + let transport = generate_transport(&server_service, &server_trait, &path); quote! { #generated_trait #service_doc - #[derive(Clone, Debug)] - pub struct #server_make_service { - inner: Arc, - } - - #[derive(Clone, Debug)] + #[derive(Debug)] #[doc(hidden)] pub struct #server_service { inner: Arc, } - impl #server_make_service { - #server_new_doc + impl #server_service { pub fn new(inner: T) -> Self { let inner = Arc::new(inner); - Self::from_shared(inner) - } - - pub fn from_shared(inner: Arc) -> Self { Self { inner } } } - impl #server_service { - pub fn new(inner: Arc) -> Self { - Self { inner } - } - } - - impl Service for #server_make_service { - type Response = #server_service; - type Error = Never; - type Future = Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, _: R) -> Self::Future { - ok(#server_service::new(self.inner.clone())) - } - } - impl Service> for #server_service { type Response = http::Response; type Error = Never; @@ -89,6 +58,15 @@ pub(crate) fn generate(service: &Service, proto_path: &str) -> TokenStream { } } } + + impl Clone for #server_service { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { inner } + } + } + + #transport } } @@ -181,6 +159,30 @@ fn generate_trait_methods(service: &Service, proto_path: &str) -> TokenStream { stream } +#[cfg(feature = "transport")] +fn generate_transport( + server_service: &syn::Ident, + server_trait: &syn::Ident, + service_name: &str, +) -> TokenStream { + let service_name = syn::LitStr::new(service_name, proc_macro2::Span::call_site()); + + quote! { + impl tonic::transport::ServiceName for #server_service { + const NAME: &'static str = #service_name; + } + } +} + +#[cfg(not(feature = "transport"))] +fn generate_transport( + _server_service: &syn::Ident, + _server_trait: &syn::Ident, + _service_name: &str, +) -> TokenStream { + TokenStream::new() +} + fn generate_methods(service: &Service, proto_path: &str) -> TokenStream { let mut stream = TokenStream::new(); diff --git a/tonic-examples/Cargo.toml b/tonic-examples/Cargo.toml index a7800dd51..e424e55f3 100644 --- a/tonic-examples/Cargo.toml +++ b/tonic-examples/Cargo.toml @@ -52,6 +52,14 @@ path = "src/tls_client_auth/server.rs" name = "tls-client-auth-client" path = "src/tls_client_auth/client.rs" +[[bin]] +name = "multiplex-server" +path = "src/multiplex/server.rs" + +[[bin]] +name = "multiplex-client" +path = "src/multiplex/client.rs" + [[bin]] name = "gcp-client" path = "src/gcp/client.rs" diff --git a/tonic-examples/helloworld-tutorial.md b/tonic-examples/helloworld-tutorial.md index 4ae80b8e2..88d504aec 100644 --- a/tonic-examples/helloworld-tutorial.md +++ b/tonic-examples/helloworld-tutorial.md @@ -191,7 +191,8 @@ async fn main() -> Result<(), Box> { let greeter = MyGreeter {}; Server::builder() - .serve(addr, GreeterServer::new(greeter)) + .add_service(GreeterServer::new(greeter)) + .serve(addr) .await?; Ok(()) @@ -236,7 +237,8 @@ async fn main() -> Result<(), Box> { let greeter = MyGreeter {}; Server::builder() - .serve(addr, GreeterServer::new(greeter)) + .add_service(GreeterServer::new(greeter)) + .serve(addr) .await?; Ok(()) diff --git a/tonic-examples/routeguide-tutorial.md b/tonic-examples/routeguide-tutorial.md index 19c836395..d6b75f68b 100644 --- a/tonic-examples/routeguide-tutorial.md +++ b/tonic-examples/routeguide-tutorial.md @@ -558,7 +558,10 @@ async fn main() -> Result<(), Box> { let svc = server::RouteGuideServer::new(route_guide); - Server::builder().serve(addr, svc).await?; + Server::builder() + .add_service(svc) + .serve(addr) + .await?; Ok(()) } diff --git a/tonic-examples/src/authentication/server.rs b/tonic-examples/src/authentication/server.rs index d63bc9c9b..cf49da0c0 100644 --- a/tonic-examples/src/authentication/server.rs +++ b/tonic-examples/src/authentication/server.rs @@ -78,8 +78,8 @@ async fn main() -> Result<(), Box> { } } }) - .clone() - .serve(addr, pb::server::EchoServer::new(server)) + .add_service(pb::server::EchoServer::new(server)) + .serve(addr) .await?; Ok(()) diff --git a/tonic-examples/src/helloworld/server.rs b/tonic-examples/src/helloworld/server.rs index 9d231fb0a..211c4e692 100644 --- a/tonic-examples/src/helloworld/server.rs +++ b/tonic-examples/src/helloworld/server.rs @@ -33,7 +33,8 @@ async fn main() -> Result<(), Box> { let greeter = MyGreeter::default(); Server::builder() - .serve(addr, GreeterServer::new(greeter)) + .add_service(GreeterServer::new(greeter)) + .serve(addr) .await?; Ok(()) diff --git a/tonic-examples/src/load_balance/server.rs b/tonic-examples/src/load_balance/server.rs index fb2a3d9e6..ee68be498 100644 --- a/tonic-examples/src/load_balance/server.rs +++ b/tonic-examples/src/load_balance/server.rs @@ -60,7 +60,9 @@ async fn main() -> Result<(), Box> { let mut tx = tx.clone(); let server = EchoServer { addr }; - let serve = Server::builder().serve(addr, pb::server::EchoServer::new(server)); + let serve = Server::builder() + .add_service(pb::server::EchoServer::new(server)) + .serve(addr); tokio::spawn(async move { if let Err(e) = serve.await { diff --git a/tonic-examples/src/multiplex/client.rs b/tonic-examples/src/multiplex/client.rs new file mode 100644 index 000000000..6e8f19aac --- /dev/null +++ b/tonic-examples/src/multiplex/client.rs @@ -0,0 +1,37 @@ +pub mod hello_world { + tonic::include_proto!("helloworld"); +} + +pub mod echo { + tonic::include_proto!("grpc.examples.echo"); +} + +use echo::{client::EchoClient, EchoRequest}; +use hello_world::{client::GreeterClient, HelloRequest}; +use tonic::transport::Endpoint; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let channel = Endpoint::from_static("http://[::1]:50051").channel(); + + let mut greeter_client = GreeterClient::new(channel.clone()); + let mut echo_client = EchoClient::new(channel); + + let request = tonic::Request::new(HelloRequest { + name: "Tonic".into(), + }); + + let response = greeter_client.say_hello(request).await?; + + println!("GREETER RESPONSE={:?}", response); + + let request = tonic::Request::new(EchoRequest { + message: "hello".into(), + }); + + let response = echo_client.unary_echo(request).await?; + + println!("ECHO RESPONSE={:?}", response); + + Ok(()) +} diff --git a/tonic-examples/src/multiplex/server.rs b/tonic-examples/src/multiplex/server.rs new file mode 100644 index 000000000..a8af6e4af --- /dev/null +++ b/tonic-examples/src/multiplex/server.rs @@ -0,0 +1,69 @@ +use std::collections::VecDeque; +use tonic::{transport::Server, Request, Response, Status}; + +pub mod hello_world { + tonic::include_proto!("helloworld"); +} + +pub mod echo { + tonic::include_proto!("grpc.examples.echo"); +} + +use hello_world::{ + server::{Greeter, GreeterServer}, + HelloReply, HelloRequest, +}; + +use echo::{ + server::{Echo, EchoServer}, + EchoRequest, EchoResponse, +}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = "[::1]:50051".parse().unwrap(); + + let greeter = GreeterServer::new(MyGreeter::default()); + let echo = EchoServer::new(MyEcho::default()); + + Server::builder() + .add_service(greeter) + .add_service(echo) + .serve(addr) + .await?; + + Ok(()) +} + +#[derive(Default)] +pub struct MyGreeter {} + +#[tonic::async_trait] +impl Greeter for MyGreeter { + async fn say_hello( + &self, + request: Request, + ) -> Result, Status> { + let reply = hello_world::HelloReply { + message: format!("Hello {}!", request.into_inner().name).into(), + }; + Ok(Response::new(reply)) + } +} + +#[derive(Default)] +pub struct MyEcho; + +#[tonic::async_trait] +impl Echo for MyEcho { + async fn unary_echo( + &self, + request: Request, + ) -> Result, Status> { + let message = request.into_inner().message; + Ok(Response::new(EchoResponse { message })) + } + + type ServerStreamingEchoStream = VecDeque>; + type BidirectionalStreamingEchoStream = VecDeque>; +} diff --git a/tonic-examples/src/routeguide/server.rs b/tonic-examples/src/routeguide/server.rs index 5732c52cc..b6ec19ba8 100644 --- a/tonic-examples/src/routeguide/server.rs +++ b/tonic-examples/src/routeguide/server.rs @@ -155,7 +155,7 @@ async fn main() -> Result<(), Box> { let svc = server::RouteGuideServer::new(route_guide); - Server::builder().serve(addr, svc).await?; + Server::builder().add_service(svc).serve(addr).await?; Ok(()) } diff --git a/tonic-examples/src/tls/server.rs b/tonic-examples/src/tls/server.rs index 0bac205aa..79b85451b 100644 --- a/tonic-examples/src/tls/server.rs +++ b/tonic-examples/src/tls/server.rs @@ -61,7 +61,8 @@ async fn main() -> Result<(), Box> { Server::builder() .tls_config(ServerTlsConfig::with_rustls().identity(identity)) .clone() - .serve(addr, pb::server::EchoServer::new(server)) + .add_service(pb::server::EchoServer::new(server)) + .serve(addr) .await?; Ok(()) diff --git a/tonic-examples/src/tls_client_auth/client.rs b/tonic-examples/src/tls_client_auth/client.rs index 8864e7738..2b13cd241 100644 --- a/tonic-examples/src/tls_client_auth/client.rs +++ b/tonic-examples/src/tls_client_auth/client.rs @@ -13,7 +13,7 @@ async fn main() -> Result<(), Box> { let client_key = tokio::fs::read("tonic-examples/data/tls/client1.key").await?; let client_identity = Identity::from_pem(client_cert, client_key); - let tls = ClientTlsConfig::with_openssl() + let tls = ClientTlsConfig::with_rustls() .domain_name("localhost") .ca_certificate(server_root_ca_cert) .identity(client_identity) diff --git a/tonic-examples/src/tls_client_auth/server.rs b/tonic-examples/src/tls_client_auth/server.rs index 1dddf98de..bff58ca26 100644 --- a/tonic-examples/src/tls_client_auth/server.rs +++ b/tonic-examples/src/tls_client_auth/server.rs @@ -44,8 +44,8 @@ async fn main() -> Result<(), Box> { Server::builder() .tls_config(&tls) - .clone() - .serve(addr, pb::server::EchoServer::new(server)) + .add_service(pb::server::EchoServer::new(server)) + .serve(addr) .await?; Ok(()) diff --git a/tonic-interop/src/bin/server.rs b/tonic-interop/src/bin/server.rs index dd0a7d052..2752d59f1 100644 --- a/tonic-interop/src/bin/server.rs +++ b/tonic-interop/src/bin/server.rs @@ -57,106 +57,15 @@ async fn main() -> std::result::Result<(), Box> { } }); + let test_service = server::TestServiceServer::new(server::TestService::default()); + let unimplemented_service = + server::UnimplementedServiceServer::new(server::UnimplementedService::default()); + builder - .serve( - addr, - router::Router { - test_service: std::sync::Arc::new(server::TestService), - unimplemented_service: std::sync::Arc::new(server::UnimplementedService), - }, - ) + .add_service(test_service) + .add_service(unimplemented_service) + .serve(addr) .await?; Ok(()) } - -mod router { - use futures_util::future; - use http::{Request, Response}; - use std::sync::Arc; - use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, - }; - use tonic::{body::BoxBody, transport::Body}; - use tonic_interop::server::{ - TestService, TestServiceServer, UnimplementedService, UnimplementedServiceServer, - }; - use tower::Service; - - #[derive(Clone)] - pub struct Router { - pub test_service: Arc, - pub unimplemented_service: Arc, - } - - impl Service<()> for Router { - type Response = Router; - type Error = Never; - type Future = future::Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Ok(()).into() - } - - fn call(&mut self, _req: ()) -> Self::Future { - future::ok(self.clone()) - } - } - - impl Service> for Router { - type Response = Response; - type Error = Never; - type Future = - Pin, Never>> + Send + 'static>>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Ok(()).into() - } - - fn call(&mut self, req: Request) -> Self::Future { - let mut segments = req.uri().path().split("/"); - segments.next(); - let service = segments.next().unwrap(); - - match service { - "grpc.testing.TestService" => { - let me = self.clone(); - Box::pin(async move { - let mut svc = TestServiceServer::from_shared(me.test_service); - let mut svc = svc.call(()).await.unwrap(); - - let res = svc.call(req).await.unwrap(); - Ok(res) - }) - } - - "grpc.testing.UnimplementedService" => { - let me = self.clone(); - Box::pin(async move { - let mut svc = - UnimplementedServiceServer::from_shared(me.unimplemented_service); - let mut svc = svc.call(()).await.unwrap(); - - let res = svc.call(req).await.unwrap(); - Ok(res) - }) - } - - _ => unimplemented!(), - } - } - } - - #[derive(Debug)] - pub enum Never {} - - impl std::fmt::Display for Never { - fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match *self {} - } - } - - impl std::error::Error for Never {} -} diff --git a/tonic-interop/src/server.rs b/tonic-interop/src/server.rs index a0d9c7360..86c10fd8b 100644 --- a/tonic-interop/src/server.rs +++ b/tonic-interop/src/server.rs @@ -150,6 +150,7 @@ impl pb::server::TestService for TestService { } } +#[derive(Default)] pub struct UnimplementedService; #[tonic::async_trait] diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 846719ef5..961d8ce2d 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -50,7 +50,23 @@ //! # use futures_util::future::{err, ok}; //! # #[cfg(feature = "rustls")] //! # async fn do_thing() -> Result<(), Box> { -//! # let my_svc = service_fn(|_| ok::<_, tonic::Status>(service_fn(|req| err(tonic::Status::unimplemented(""))))); +//! # #[derive(Clone)] +//! # pub struct Svc; +//! # impl Service> for Svc { +//! # type Response = hyper::Response; +//! # type Error = tonic::Status; +//! # type Future = futures_util::future::Ready>; +//! # fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { +//! # Ok(()).into() +//! # } +//! # fn call(&mut self, _req: hyper::Request) -> Self::Future { +//! # unimplemented!() +//! # } +//! # } +//! # impl tonic::transport::ServiceName for Svc { +//! # const NAME: &'static str = "some_svc"; +//! # } +//! # let my_svc = Svc; //! let cert = std::fs::read_to_string("server.pem")?; //! let key = std::fs::read_to_string("server.key")?; //! @@ -64,8 +80,8 @@ //! println!("Request: {:?}", req); //! svc.call(req) //! }) -//! .clone() -//! .serve(addr, my_svc) +//! .add_service(my_svc) +//! .serve(addr) //! .await?; //! //! # Ok(()) @@ -88,7 +104,7 @@ pub use self::channel::Channel; pub use self::endpoint::Endpoint; pub use self::error::Error; #[doc(inline)] -pub use self::server::Server; +pub use self::server::{Server, ServiceName}; pub use self::tls::{Certificate, Identity}; pub use hyper::Body; diff --git a/tonic/src/transport/server.rs b/tonic/src/transport/server.rs index 702571c4f..886a74b7e 100644 --- a/tonic/src/transport/server.rs +++ b/tonic/src/transport/server.rs @@ -1,6 +1,6 @@ //! Server implementation and builder. -use super::service::{layer_fn, BoxedIo, ServiceBuilderExt}; +use super::service::{layer_fn, BoxedIo, Or, Routes, ServiceBuilderExt}; #[cfg(feature = "tls")] use super::{ service::TlsAcceptor, @@ -9,7 +9,7 @@ use super::{ }; use crate::body::BoxBody; use futures_core::Stream; -use futures_util::{ready, try_future::MapErr, TryFutureExt, TryStreamExt}; +use futures_util::{future, ready, try_future::MapErr, TryFutureExt, TryStreamExt}; use http::{Request, Response}; use hyper::{ server::{accept::Accept, conn}, @@ -31,7 +31,6 @@ use tower::{ Service, ServiceBuilder, }; -use tower_make::MakeService; #[cfg(feature = "tls")] use tracing::error; @@ -58,6 +57,22 @@ pub struct Server { max_concurrent_streams: Option, } +/// A stack based `Service` router. +#[derive(Debug)] +pub struct Router { + server: Server, + routes: Routes>, +} + +/// A trait to provide a static reference to the service's +/// name. This is used for routing service's within the router. +pub trait ServiceName { + /// The `Service-Name` as described [here]. + /// + /// [here]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests + const NAME: &'static str; +} + impl Server { /// Create a new server builder that can configure a [`Server`]. pub fn builder() -> Self { @@ -149,14 +164,26 @@ impl Server { self } - /// Consume this [`Server`] creating a future that will execute the server - /// on [`tokio`]'s default executor. - pub async fn serve(self, addr: SocketAddr, svc: M) -> Result<(), super::Error> + /// Create a router with the `S` typed service as the first service. + /// + /// This will clone the `Server` builder and create a router that will + /// route around different services. + pub fn add_service(&mut self, svc: S) -> Router + where + S: Service, Response = Response> + + ServiceName + + Clone + + Send + + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + { + Router::new(self.clone(), svc) + } + + pub(crate) async fn serve(self, addr: SocketAddr, svc: S) -> Result<(), super::Error> where - M: Service<(), Response = S>, - M::Error: Into + Send + 'static, - M::Future: Send + 'static, - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, { @@ -210,6 +237,74 @@ impl Server { } } +impl Router { + pub(crate) fn new(server: Server, svc: S) -> Self + where + S: Service, Response = Response> + + ServiceName + + Clone + + Send + + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + { + let svc_name = ::NAME; + let svc_route = format!("/{}", svc_name); + let pred = move |req: &Request| { + let path = req.uri().path(); + + path.starts_with(&svc_route) + }; + Self { + server, + routes: Routes::new(pred, svc, Unimplemented::default()), + } + } +} + +impl Router +where + A: Service, Response = Response> + Clone + Send + 'static, + A::Future: Send + 'static, + A::Error: Into + Send, + B: Service, Response = Response> + Clone + Send + 'static, + B::Future: Send + 'static, + B::Error: Into + Send, +{ + /// Add a new service to this router. + pub fn add_service(self, svc: S) -> Router>> + where + S: Service, Response = Response> + + ServiceName + + Clone + + Send + + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + { + let Self { routes, server } = self; + + let svc_name = ::NAME; + let svc_route = format!("/{}", svc_name); + let pred = move |req: &Request| { + let path = req.uri().path(); + + path.starts_with(&svc_route) + }; + let routes = routes.push(pred, svc); + + Router { server, routes } + } + + /// Consume this [`Server`] creating a future that will execute the server + /// on [`tokio`]'s default executor. + /// + /// [`Server`]: struct.Server.html + pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> { + self.server.serve(addr, self.routes).await + } +} + fn map_err(e: impl Into) -> super::Error { super::Error::from_source(super::ErrorKind::Server, e.into()) } @@ -371,19 +466,16 @@ where } } -struct MakeSvc { +struct MakeSvc { interceptor: Option, concurrency_limit: Option, // timeout: Option, - inner: M, + inner: S, } -impl Service for MakeSvc +impl Service for MakeSvc where - M: Service<(), Response = S>, - M::Error: Into + Send, - M::Future: Send + 'static, - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, { @@ -392,19 +484,17 @@ where type Future = Pin> + Send + 'static>>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - MakeService::poll_ready(&mut self.inner, cx).map_err(Into::into) + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() } fn call(&mut self, _: T) -> Self::Future { let interceptor = self.interceptor.clone(); - let make = self.inner.make_service(()); + let svc = self.inner.clone(); let concurrency_limit = self.concurrency_limit; // let timeout = self.timeout.clone(); Box::pin(async move { - let svc = make.await.map_err(Into::into)?; - let svc = ServiceBuilder::new() .optional_layer(concurrency_limit.map(ConcurrencyLimitLayer::new)) // .optional_layer(timeout.map(TimeoutLayer::new)) @@ -421,3 +511,29 @@ where }) } } + +#[derive(Default, Clone, Debug)] +#[doc(hidden)] +pub struct Unimplemented { + _p: (), +} + +impl Service> for Unimplemented { + type Response = Response; + type Error = crate::Error; + type Future = future::Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + + fn call(&mut self, _req: Request) -> Self::Future { + future::ok( + http::Response::builder() + .status(200) + .header("grpc-status", "12") + .body(BoxBody::empty()) + .unwrap(), + ) + } +} diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 3ee2cba5b..1f3c7e9cd 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -5,6 +5,7 @@ mod discover; mod either; mod io; mod layer; +mod router; #[cfg(feature = "tls")] mod tls; @@ -14,5 +15,6 @@ pub(crate) use self::connector::connector; pub(crate) use self::discover::ServiceList; pub(crate) use self::io::BoxedIo; pub(crate) use self::layer::{layer_fn, ServiceBuilderExt}; +pub(crate) use self::router::{Or, Routes}; #[cfg(feature = "tls")] pub(crate) use self::tls::{TlsAcceptor, TlsConnector}; diff --git a/tonic/src/transport/service/router.rs b/tonic/src/transport/service/router.rs new file mode 100644 index 000000000..484a9a670 --- /dev/null +++ b/tonic/src/transport/service/router.rs @@ -0,0 +1,129 @@ +use futures_util::{ + future::Either, + try_future::{MapErr, TryFutureExt}, +}; +use std::{ + fmt, + sync::Arc, + task::{Context, Poll}, +}; +use tower_service::Service; + +#[derive(Debug)] +pub(crate) struct Routes { + routes: Or, +} + +impl Routes { + pub(crate) fn new( + predicate: impl Fn(&Request) -> bool + Send + Sync + 'static, + a: A, + b: B, + ) -> Self { + let routes = Or::new(predicate, a, b); + Self { routes } + } +} + +impl Routes { + pub(crate) fn push( + self, + predicate: impl Fn(&Request) -> bool + Send + Sync + 'static, + route: C, + ) -> Routes, Request> { + let routes = Or::new(predicate, route, self.routes); + Routes { routes } + } +} + +impl Service for Routes +where + A: Service, + A::Future: Send + 'static, + A::Error: Into, + B: Service, + B::Future: Send + 'static, + B::Error: Into, +{ + type Response = A::Response; + type Error = crate::Error; + type Future = as Service>::Future; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + + fn call(&mut self, req: Request) -> Self::Future { + self.routes.call(req) + } +} + +impl Clone for Routes { + fn clone(&self) -> Self { + Self { + routes: self.routes.clone(), + } + } +} + +#[doc(hidden)] +pub struct Or { + predicate: Arc bool + Send + Sync + 'static>, + a: A, + b: B, +} + +impl Or { + pub(crate) fn new(predicate: F, a: A, b: B) -> Self + where + F: Fn(&Request) -> bool + Send + Sync + 'static, + { + let predicate = Arc::new(predicate); + Self { predicate, a, b } + } +} + +impl Service for Or +where + A: Service, + A::Future: Send + 'static, + A::Error: Into, + B: Service, + B::Future: Send + 'static, + B::Error: Into, +{ + type Response = A::Response; + type Error = crate::Error; + type Future = Either< + MapErr crate::Error>, + MapErr crate::Error>, + >; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + + fn call(&mut self, req: Request) -> Self::Future { + if (self.predicate)(&req) { + Either::Left(self.a.call(req).map_err(|e| e.into())) + } else { + Either::Right(self.b.call(req).map_err(|e| e.into())) + } + } +} + +impl Clone for Or { + fn clone(&self) -> Self { + Self { + predicate: self.predicate.clone(), + a: self.a.clone(), + b: self.b.clone(), + } + } +} + +impl fmt::Debug for Or { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Or {{ .. }}") + } +} diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 22f4cb44e..2c8532be3 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -14,7 +14,6 @@ use tokio_rustls::{ webpki::DNSNameRef, TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector, }; -use tracing::trace; /// h2 alpn in wire format for openssl. #[cfg(feature = "openssl")] @@ -137,7 +136,7 @@ impl TlsConnector { let tls = tokio_openssl::connect(config, &self.domain, io).await?; match tls.ssl().selected_alpn_protocol() { - Some(b) if b == b"h2" => trace!("HTTP/2 succesfully negotiated."), + Some(b) if b == b"h2" => tracing::trace!("HTTP/2 succesfully negotiated."), _ => return Err(TlsError::H2NotNegotiated.into()), };