Skip to content

Commit

Permalink
refactor: tracker_api launching and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mickvandijke authored and josecelano committed Mar 9, 2023
1 parent 5b95b5d commit d020c5a
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 271 deletions.
2 changes: 1 addition & 1 deletion src/apis/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::handlers::{
use super::middlewares::auth::auth;
use crate::tracker::Tracker;

pub fn router(tracker: &Arc<Tracker>) -> Router {
pub fn router(tracker: Arc<Tracker>) -> Router {
Router::new()
// Stats
.route("/api/stats", get(get_stats_handler).with_state(tracker.clone()))
Expand Down
212 changes: 115 additions & 97 deletions src/apis/server.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use std::net::{SocketAddr, TcpListener};
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;

use axum_server::tls_rustls::RustlsConfig;
use axum_server::Handle;
use futures::future::BoxFuture;
use futures::Future;
use log::info;
use tokio::task::JoinHandle;
use warp::hyper;

use super::routes::router;
use crate::signals::shutdown_signal_with_message;
use crate::signals::shutdown_signal;
use crate::tracker::Tracker;

#[derive(Debug)]
Expand All @@ -25,133 +26,150 @@ pub type RunningApiServer = ApiServer<Running>;
#[allow(clippy::module_name_repetitions)]
pub struct ApiServer<S> {
pub cfg: torrust_tracker_configuration::HttpApi,
pub tracker: Arc<Tracker>,
pub state: S,
}

pub struct Stopped;

pub struct Running {
pub bind_address: SocketAddr,
stop_job_sender: tokio::sync::oneshot::Sender<u8>,
job: JoinHandle<()>,
pub bind_addr: SocketAddr,
task_killer: tokio::sync::oneshot::Sender<u8>,
task: tokio::task::JoinHandle<()>,
}

impl ApiServer<Stopped> {
pub fn new(cfg: torrust_tracker_configuration::HttpApi, tracker: Arc<Tracker>) -> Self {
Self {
cfg,
tracker,
state: Stopped {},
}
pub fn new(cfg: torrust_tracker_configuration::HttpApi) -> Self {
Self { cfg, state: Stopped {} }
}

/// # Errors
///
/// Will return `Err` if `TcpListener` can not bind to `bind_address`.
pub fn start(self) -> Result<ApiServer<Running>, Error> {
let listener = TcpListener::bind(&self.cfg.bind_address).map_err(|e| Error::Error(e.to_string()))?;

let bind_address = listener.local_addr().map_err(|e| Error::Error(e.to_string()))?;

let cfg = self.cfg.clone();
let tracker = self.tracker.clone();

let (sender, receiver) = tokio::sync::oneshot::channel::<u8>();

let job = tokio::spawn(async move {
if let (true, Some(ssl_cert_path), Some(ssl_key_path)) = (cfg.ssl_enabled, cfg.ssl_cert_path, cfg.ssl_key_path) {
let tls_config = RustlsConfig::from_pem_file(ssl_cert_path, ssl_key_path)
.await
.expect("Could not read ssl cert and/or key.");

start_tls_from_tcp_listener_with_graceful_shutdown(listener, tls_config, &tracker, receiver)
.await
.expect("Could not start from tcp listener with tls.");
} else {
start_from_tcp_listener_with_graceful_shutdown(listener, &tracker, receiver)
.await
.expect("Could not start from tcp listener.");
}
pub async fn start(self, tracker: Arc<Tracker>) -> Result<ApiServer<Running>, Error> {
let (shutdown_sender, shutdown_receiver) = tokio::sync::oneshot::channel::<u8>();
let (addr_sender, addr_receiver) = tokio::sync::oneshot::channel::<SocketAddr>();

let configuration = self.cfg.clone();

let task = tokio::spawn(async move {
let (bind_addr, server) = Launcher::start(&configuration, tracker, shutdown_signal(shutdown_receiver));

addr_sender.send(bind_addr).unwrap();

server.await;
});

let running_api_server: ApiServer<Running> = ApiServer {
let bind_address = addr_receiver.await.expect("Could not receive bind_address.");

Ok(ApiServer {
cfg: self.cfg,
tracker: self.tracker,
state: Running {
bind_address,
stop_job_sender: sender,
job,
bind_addr: bind_address,
task_killer: shutdown_sender,
task,
},
};

Ok(running_api_server)
})
}
}

impl ApiServer<Running> {
/// # Errors
///
/// Will return `Err` if the oneshot channel to send the stop signal
/// has already been called once.
pub async fn stop(self) -> Result<ApiServer<Stopped>, Error> {
self.state.stop_job_sender.send(1).map_err(|e| Error::Error(e.to_string()))?;
self.state.task_killer.send(0).unwrap();

let _ = self.state.job.await;
let _ = self.state.task.await;

let stopped_api_server: ApiServer<Stopped> = ApiServer {
Ok(ApiServer {
cfg: self.cfg,
tracker: self.tracker,
state: Stopped {},
};

Ok(stopped_api_server)
})
}
}

pub fn start_from_tcp_listener_with_graceful_shutdown(
tcp_listener: TcpListener,
tracker: &Arc<Tracker>,
shutdown_signal: tokio::sync::oneshot::Receiver<u8>,
) -> impl Future<Output = hyper::Result<()>> {
let app = router(tracker);

let context = tcp_listener.local_addr().expect("Could not get context.");

axum::Server::from_tcp(tcp_listener)
.expect("Could not bind to tcp listener.")
.serve(app.into_make_service())
.with_graceful_shutdown(shutdown_signal_with_message(
shutdown_signal,
format!("Shutting down {context}.."),
))
}
struct Launcher;

impl Launcher {
pub fn start<F>(
cfg: &torrust_tracker_configuration::HttpApi,
tracker: Arc<Tracker>,
shutdown_signal: F,
) -> (SocketAddr, BoxFuture<'static, ()>)
where
F: Future<Output = ()> + Send + 'static,
{
let addr = SocketAddr::from_str(&cfg.bind_address).expect("bind_address is not a valid SocketAddr.");
let tcp_listener = std::net::TcpListener::bind(addr).expect("Could not bind tcp_listener to address.");
let bind_addr = tcp_listener
.local_addr()
.expect("Could not get local_addr from tcp_listener.");

if let (true, Some(ssl_cert_path), Some(ssl_key_path)) = (&cfg.ssl_enabled, &cfg.ssl_cert_path, &cfg.ssl_key_path) {
let server = Self::start_tls_with_graceful_shutdown(
tcp_listener,
(ssl_cert_path.to_string(), ssl_key_path.to_string()),
tracker,
shutdown_signal,
);

(bind_addr, server)
} else {
let server = Self::start_with_graceful_shutdown(tcp_listener, tracker, shutdown_signal);

(bind_addr, server)
}
}

pub fn start_tls_from_tcp_listener_with_graceful_shutdown(
tcp_listener: TcpListener,
tls_config: RustlsConfig,
tracker: &Arc<Tracker>,
shutdown_signal: tokio::sync::oneshot::Receiver<u8>,
) -> impl Future<Output = Result<(), std::io::Error>> {
let app = router(tracker);
pub fn start_with_graceful_shutdown<F>(
tcp_listener: std::net::TcpListener,
tracker: Arc<Tracker>,
shutdown_signal: F,
) -> BoxFuture<'static, ()>
where
F: Future<Output = ()> + Send + 'static,
{
let app = router(tracker);

Box::pin(async {
axum::Server::from_tcp(tcp_listener)
.expect("Could not bind to tcp listener.")
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.with_graceful_shutdown(shutdown_signal)
.await
.expect("Axum server crashed.");
})
}

let context = tcp_listener.local_addr().expect("Could not get context.");
pub fn start_tls_with_graceful_shutdown<F>(
tcp_listener: std::net::TcpListener,
(ssl_cert_path, ssl_key_path): (String, String),
tracker: Arc<Tracker>,
shutdown_signal: F,
) -> BoxFuture<'static, ()>
where
F: Future<Output = ()> + Send + 'static,
{
let app = router(tracker);

let handle = Handle::new();
let handle = Handle::new();

let cloned_handle = handle.clone();
let cloned_handle = handle.clone();

tokio::spawn(async move {
shutdown_signal_with_message(shutdown_signal, format!("Shutting down {context}..")).await;
cloned_handle.shutdown();
});
tokio::task::spawn_local(async move {
shutdown_signal.await;
cloned_handle.shutdown();
});

axum_server::from_tcp_rustls(tcp_listener, tls_config)
.handle(handle)
.serve(app.into_make_service())
Box::pin(async {
let tls_config = RustlsConfig::from_pem_file(ssl_cert_path, ssl_key_path)
.await
.expect("Could not read tls cert.");

axum_server::from_tcp_rustls(tcp_listener, tls_config)
.handle(handle)
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await
.expect("Axum server crashed.");
})
}
}

pub fn start(socket_addr: SocketAddr, tracker: &Arc<Tracker>) -> impl Future<Output = hyper::Result<()>> {
pub fn start(socket_addr: SocketAddr, tracker: Arc<Tracker>) -> impl Future<Output = hyper::Result<()>> {
let app = router(tracker);

let server = axum::Server::bind(&socket_addr).serve(app.into_make_service());
Expand All @@ -165,7 +183,7 @@ pub fn start(socket_addr: SocketAddr, tracker: &Arc<Tracker>) -> impl Future<Out
pub fn start_tls(
socket_addr: SocketAddr,
ssl_config: RustlsConfig,
tracker: &Arc<Tracker>,
tracker: Arc<Tracker>,
) -> impl Future<Output = Result<(), std::io::Error>> {
let app = router(tracker);

Expand Down Expand Up @@ -204,9 +222,9 @@ mod tests {

let tracker = Arc::new(tracker::Tracker::new(cfg.clone(), None, statistics::Repo::new()).unwrap());

let stopped_api_server = ApiServer::new(cfg.http_api.clone(), tracker);
let stopped_api_server = ApiServer::new(cfg.http_api.clone());

let running_api_server_result = stopped_api_server.start();
let running_api_server_result = stopped_api_server.start(tracker).await;

assert!(running_api_server_result.is_ok());

Expand Down
4 changes: 2 additions & 2 deletions src/jobs/tracker_apis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub async fn start_job(config: &HttpApi, tracker: Arc<tracker::Tracker>) -> Join
if !ssl_enabled {
info!("Starting Torrust APIs server on: http://{}", bind_addr);

let handle = server::start(bind_addr, &tracker);
let handle = server::start(bind_addr, tracker);

tx.send(ApiServerJobStarted()).expect("the API server should not be dropped");

Expand All @@ -45,7 +45,7 @@ pub async fn start_job(config: &HttpApi, tracker: Arc<tracker::Tracker>) -> Join
.await
.unwrap();

let handle = server::start_tls(bind_addr, ssl_config, &tracker);
let handle = server::start_tls(bind_addr, ssl_config, tracker);

tx.send(ApiServerJobStarted()).expect("the API server should not be dropped");

Expand Down
4 changes: 2 additions & 2 deletions src/jobs/udp_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ pub fn start_job(config: &UdpTracker, tracker: Arc<tracker::Tracker>) -> JoinHan
let bind_addr = config.bind_address.clone();

tokio::spawn(async move {
match Udp::new(tracker, &bind_addr).await {
match Udp::new(&bind_addr).await {
Ok(udp_server) => {
info!("Starting UDP server on: udp://{}", bind_addr);
udp_server.start().await;
udp_server.start(tracker).await;
}
Err(e) => {
warn!("Could not start UDP tracker on: udp://{}", bind_addr);
Expand Down
Loading

0 comments on commit d020c5a

Please sign in to comment.