From 12e8c6219d76ad0ee19f02758f734bdd4fb87c0e Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 29 Dec 2023 12:17:58 +0100 Subject: [PATCH] Support graceful shutdown on `serve` (#2398) --- .../fail/argument_not_extractor.stderr | 2 +- .../fail/parts_extracting_body.stderr | 2 +- axum/CHANGELOG.md | 2 + axum/Cargo.toml | 6 +- axum/src/macros.rs | 12 + axum/src/serve.rs | 249 +++++++++++++++--- examples/graceful-shutdown/Cargo.toml | 2 +- examples/graceful-shutdown/src/main.rs | 113 +------- 8 files changed, 249 insertions(+), 139 deletions(-) diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index f33529e1a5..d946782586 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -13,8 +13,8 @@ error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied > > > - as FromRequestParts> > + as FromRequestParts> and $N others = note: required for `bool` to implement `FromRequest<(), axum_core::extract::private::ViaParts>` note: required by a bound in `__axum_macros_check_handler_0_from_request_check` diff --git a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr index 3fb190a2a0..fbd58ea013 100644 --- a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr +++ b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr @@ -14,5 +14,5 @@ error[E0277]: the trait bound `String: FromRequestParts` is not satisfied > > > - as FromRequestParts> + > and $N others diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 1eb5499543..61bc8d8a9b 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -10,10 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** `Body` implements `From<()>` now ([#2411]) - **change:** Update version of multer used internally for multipart ([#2433]) - **change:** Update tokio-tungstenite to 0.21 ([#2435]) +- **added:** Support graceful shutdown on `serve` ([#2398]) [#2411]: https://github.com/tokio-rs/axum/pull/2411 [#2433]: https://github.com/tokio-rs/axum/pull/2433 [#2435]: https://github.com/tokio-rs/axum/pull/2435 +[#2398]: https://github.com/tokio-rs/axum/pull/2398 # 0.7.2 (03. December, 2023) diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 3001c97f60..2e341db924 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -22,7 +22,7 @@ matched-path = [] multipart = ["dep:multer"] original-uri = [] query = ["dep:serde_urlencoded"] -tokio = ["dep:hyper-util", "dep:tokio", "tokio/net", "tokio/rt", "tower/make"] +tokio = ["dep:hyper-util", "dep:tokio", "tokio/net", "tokio/rt", "tower/make", "tokio/macros"] tower-log = ["tower/log"] tracing = ["dep:tracing", "axum-core/tracing"] ws = ["dep:hyper", "tokio", "dep:tokio-tungstenite", "dep:sha1", "dep:base64"] @@ -53,8 +53,8 @@ tower-service = "0.3" # optional dependencies axum-macros = { path = "../axum-macros", version = "0.4.0", optional = true } base64 = { version = "0.21.0", optional = true } -hyper = { version = "1.0.0", optional = true } -hyper-util = { version = "0.1.1", features = ["tokio", "server", "server-auto"], optional = true } +hyper = { version = "1.1.0", optional = true } +hyper-util = { version = "0.1.2", features = ["tokio", "server", "server-auto"], optional = true } multer = { version = "3.0.0", optional = true } serde_json = { version = "1.0", features = ["raw_value"], optional = true } serde_path_to_error = { version = "0.1.8", optional = true } diff --git a/axum/src/macros.rs b/axum/src/macros.rs index 7024f77998..5b8a335ef4 100644 --- a/axum/src/macros.rs +++ b/axum/src/macros.rs @@ -67,6 +67,13 @@ macro_rules! all_the_tuples { }; } +#[cfg(feature = "tracing")] +macro_rules! trace { + ($($tt:tt)*) => { + tracing::trace!($($tt)*) + } +} + #[cfg(feature = "tracing")] macro_rules! error { ($($tt:tt)*) => { @@ -74,6 +81,11 @@ macro_rules! error { }; } +#[cfg(not(feature = "tracing"))] +macro_rules! trace { + ($($tt:tt)*) => {}; +} + #[cfg(not(feature = "tracing"))] macro_rules! error { ($($tt:tt)*) => {}; diff --git a/axum/src/serve.rs b/axum/src/serve.rs index f9b48c41e2..9850af2787 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -2,24 +2,29 @@ use std::{ convert::Infallible, - future::{Future, IntoFuture}, + fmt::Debug, + future::{poll_fn, Future, IntoFuture}, io, marker::PhantomData, net::SocketAddr, pin::Pin, + sync::Arc, task::{Context, Poll}, time::Duration, }; use axum_core::{body::Body, extract::Request, response::Response}; -use futures_util::future::poll_fn; +use futures_util::{pin_mut, FutureExt}; use hyper::body::Incoming; use hyper_util::{ rt::{TokioExecutor, TokioIo}, server::conn::auto::Builder, }; use pin_project_lite::pin_project; -use tokio::net::{TcpListener, TcpStream}; +use tokio::{ + net::{TcpListener, TcpStream}, + sync::watch, +}; use tower::util::{Oneshot, ServiceExt}; use tower_service::Service; @@ -110,9 +115,45 @@ pub struct Serve { } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl std::fmt::Debug for Serve +impl Serve { + /// Prepares a server to handle graceful shutdown when the provided future completes. + /// + /// # Example + /// + /// ``` + /// use axum::{Router, routing::get}; + /// + /// # async { + /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); + /// + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + /// axum::serve(listener, router) + /// .with_graceful_shutdown(shutdown_signal()) + /// .await + /// .unwrap(); + /// # }; + /// + /// async fn shutdown_signal() { + /// // ... + /// } + /// ``` + pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown + where + F: Future + Send + 'static, + { + WithGracefulShutdown { + tcp_listener: self.tcp_listener, + make_service: self.make_service, + signal, + _marker: PhantomData, + } + } +} + +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl Debug for Serve where - M: std::fmt::Debug, + M: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { @@ -148,30 +189,9 @@ where } = self; loop { - let (tcp_stream, remote_addr) = match tcp_listener.accept().await { - Ok(conn) => conn, - Err(e) => { - // Connection errors can be ignored directly, continue - // by accepting the next request. - if is_connection_error(&e) { - continue; - } - - // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) - // - // > A possible scenario is that the process has hit the max open files - // > allowed, and so trying to accept a new connection will fail with - // > `EMFILE`. In some cases, it's preferable to just wait for some time, if - // > the application will likely close some files (or connections), and try - // > to accept the connection again. If this option is `true`, the error - // > will be logged at the `error` level, since it is still a big deal, - // > and then the listener will sleep for 1 second. - // - // hyper allowed customizing this but axum does not. - error!("accept error: {e}"); - tokio::time::sleep(Duration::from_secs(1)).await; - continue; - } + let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await { + Some(conn) => conn, + None => continue, }; let tcp_stream = TokioIo::new(tcp_stream); @@ -191,7 +211,7 @@ where service: tower_service, }; - tokio::task::spawn(async move { + tokio::spawn(async move { match Builder::new(TokioExecutor::new()) // upgrades needed for websockets .serve_connection_with_upgrades(tcp_stream, hyper_service) @@ -212,6 +232,149 @@ where } } +/// Serve future with graceful shutdown enabled. +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +pub struct WithGracefulShutdown { + tcp_listener: TcpListener, + make_service: M, + signal: F, + _marker: PhantomData, +} + +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl Debug for WithGracefulShutdown +where + M: Debug, + S: Debug, + F: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Self { + tcp_listener, + make_service, + signal, + _marker: _, + } = self; + + f.debug_struct("WithGracefulShutdown") + .field("tcp_listener", tcp_listener) + .field("make_service", make_service) + .field("signal", signal) + .finish() + } +} + +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl IntoFuture for WithGracefulShutdown +where + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, + S: Service + Clone + Send + 'static, + S::Future: Send, + F: Future + Send + 'static, +{ + type Output = io::Result<()>; + type IntoFuture = private::ServeFuture; + + fn into_future(self) -> Self::IntoFuture { + let Self { + tcp_listener, + mut make_service, + signal, + _marker: _, + } = self; + + let (signal_tx, signal_rx) = watch::channel(()); + let signal_tx = Arc::new(signal_tx); + tokio::spawn(async move { + signal.await; + trace!("received graceful shutdown signal. Telling tasks to shutdown"); + drop(signal_rx); + }); + + let (close_tx, close_rx) = watch::channel(()); + + private::ServeFuture(Box::pin(async move { + loop { + let (tcp_stream, remote_addr) = tokio::select! { + conn = tcp_accept(&tcp_listener) => { + match conn { + Some(conn) => conn, + None => continue, + } + } + _ = signal_tx.closed() => { + trace!("signal received, not accepting new connections"); + break; + } + }; + let tcp_stream = TokioIo::new(tcp_stream); + + trace!("connection {remote_addr} accepted"); + + poll_fn(|cx| make_service.poll_ready(cx)) + .await + .unwrap_or_else(|err| match err {}); + + let tower_service = make_service + .call(IncomingStream { + tcp_stream: &tcp_stream, + remote_addr, + }) + .await + .unwrap_or_else(|err| match err {}); + + let hyper_service = TowerToHyperService { + service: tower_service, + }; + + let signal_tx = Arc::clone(&signal_tx); + + let close_rx = close_rx.clone(); + + tokio::spawn(async move { + let builder = Builder::new(TokioExecutor::new()); + let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service); + pin_mut!(conn); + + let signal_closed = signal_tx.closed().fuse(); + pin_mut!(signal_closed); + + loop { + tokio::select! { + result = conn.as_mut() => { + if let Err(_err) = result { + trace!("failed to serve connection: {_err:#}"); + } + break; + } + _ = &mut signal_closed => { + trace!("signal received in task, starting graceful shutdown"); + conn.as_mut().graceful_shutdown(); + } + } + } + + trace!("connection {remote_addr} closed"); + + drop(close_rx); + }); + } + + drop(close_rx); + drop(tcp_listener); + + trace!( + "waiting for {} task(s) to finish", + close_tx.receiver_count() + ); + close_tx.closed().await; + + Ok(()) + })) + } +} + fn is_connection_error(e: &io::Error) -> bool { matches!( e.kind(), @@ -221,6 +384,32 @@ fn is_connection_error(e: &io::Error) -> bool { ) } +async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> { + match listener.accept().await { + Ok(conn) => Some(conn), + Err(e) => { + if is_connection_error(&e) { + return None; + } + + // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) + // + // > A possible scenario is that the process has hit the max open files + // > allowed, and so trying to accept a new connection will fail with + // > `EMFILE`. In some cases, it's preferable to just wait for some time, if + // > the application will likely close some files (or connections), and try + // > to accept the connection again. If this option is `true`, the error + // > will be logged at the `error` level, since it is still a big deal, + // > and then the listener will sleep for 1 second. + // + // hyper allowed customizing this but axum does not. + error!("accept error: {e}"); + tokio::time::sleep(Duration::from_secs(1)).await; + None + } + } +} + mod private { use std::{ future::Future, diff --git a/examples/graceful-shutdown/Cargo.toml b/examples/graceful-shutdown/Cargo.toml index fd9c09e35f..86dfd52763 100644 --- a/examples/graceful-shutdown/Cargo.toml +++ b/examples/graceful-shutdown/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" publish = false [dependencies] -axum = { path = "../../axum" } +axum = { path = "../../axum", features = ["tracing"] } hyper = { version = "1.0", features = [] } hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] } tokio = { version = "1.0", features = ["full"] } diff --git a/examples/graceful-shutdown/src/main.rs b/examples/graceful-shutdown/src/main.rs index 8932b54bee..984330715e 100644 --- a/examples/graceful-shutdown/src/main.rs +++ b/examples/graceful-shutdown/src/main.rs @@ -10,17 +10,12 @@ use std::time::Duration; -use axum::{extract::Request, routing::get, Router}; -use hyper::body::Incoming; -use hyper_util::rt::TokioIo; +use axum::{routing::get, Router}; use tokio::net::TcpListener; use tokio::signal; -use tokio::sync::watch; use tokio::time::sleep; -use tower::Service; use tower_http::timeout::TimeoutLayer; use tower_http::trace::TraceLayer; -use tracing::debug; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] @@ -28,10 +23,11 @@ async fn main() { // Enable tracing. tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_graceful_shutdown=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + "example_graceful_shutdown=debug,tower_http=debug,axum=trace".into() + }), ) - .with(tracing_subscriber::fmt::layer()) + .with(tracing_subscriber::fmt::layer().without_time()) .init(); // Create a regular axum app. @@ -48,100 +44,11 @@ async fn main() { // Create a `TcpListener` using tokio. let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap(); - // Create a watch channel to track tasks that are handling connections and wait for them to - // complete. - let (close_tx, close_rx) = watch::channel(()); - - // Continuously accept new connections. - loop { - let (socket, remote_addr) = tokio::select! { - // Either accept a new connection... - result = listener.accept() => { - result.unwrap() - } - // ...or wait to receive a shutdown signal and stop the accept loop. - _ = shutdown_signal() => { - debug!("signal received, not accepting new connections"); - break; - } - }; - - debug!("connection {remote_addr} accepted"); - - // We don't need to call `poll_ready` because `Router` is always ready. - let tower_service = app.clone(); - - // Clone the watch receiver and move it into the task. - let close_rx = close_rx.clone(); - - // Spawn a task to handle the connection. That way we can serve multiple connections - // concurrently. - tokio::spawn(async move { - // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. - // `TokioIo` converts between them. - let socket = TokioIo::new(socket); - - // Hyper also has its own `Service` trait and doesn't use tower. We can use - // `hyper::service::service_fn` to create a hyper `Service` that calls our app through - // `tower::Service::call`. - let hyper_service = hyper::service::service_fn(move |request: Request| { - // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas - // tower's `Service` requires `&mut self`. - // - // We don't need to call `poll_ready` since `Router` is always ready. - tower_service.clone().call(request) - }); - - // `hyper_util::server::conn::auto::Builder` supports both http1 and http2 but doesn't - // support graceful so we have to use hyper directly and unfortunately pick between - // http1 and http2. - let conn = hyper::server::conn::http1::Builder::new() - .serve_connection(socket, hyper_service) - // `with_upgrades` is required for websockets. - .with_upgrades(); - - // `graceful_shutdown` requires a pinned connection. - let mut conn = std::pin::pin!(conn); - - loop { - tokio::select! { - // Poll the connection. This completes when the client has closed the - // connection, graceful shutdown has completed, or we encounter a TCP error. - result = conn.as_mut() => { - if let Err(err) = result { - debug!("failed to serve connection: {err:#}"); - } - break; - } - // Start graceful shutdown when we receive a shutdown signal. - // - // We use a loop to continue polling the connection to allow requests to finish - // after starting graceful shutdown. Our `Router` has `TimeoutLayer` so - // requests will finish after at most 10 seconds. - _ = shutdown_signal() => { - debug!("signal received, starting graceful shutdown"); - conn.as_mut().graceful_shutdown(); - } - } - } - - debug!("connection {remote_addr} closed"); - - // Drop the watch receiver to signal to `main` that this task is done. - drop(close_rx); - }); - } - - // We only care about the watch receivers that were moved into the tasks so close the residual - // receiver. - drop(close_rx); - - // Close the listener to stop accepting new connections. - drop(listener); - - // Wait for all tasks to complete. - debug!("waiting for {} tasks to finish", close_tx.receiver_count()); - close_tx.closed().await; + // Run the server with graceful shutdown + axum::serve(listener, app) + .with_graceful_shutdown(shutdown_signal()) + .await + .unwrap(); } async fn shutdown_signal() {