diff --git a/src/apis/routes.rs b/src/apis/routes.rs index 281979aa..6e321860 100644 --- a/src/apis/routes.rs +++ b/src/apis/routes.rs @@ -10,7 +10,7 @@ use super::handlers::{ use super::middlewares::auth::auth; use crate::tracker::Tracker; -pub fn router(tracker: &Arc) -> Router { +pub fn router(tracker: Arc) -> Router { Router::new() // Stats .route("/api/stats", get(get_stats_handler).with_state(tracker.clone())) diff --git a/src/apis/server.rs b/src/apis/server.rs index 4c8fbaad..0a501316 100644 --- a/src/apis/server.rs +++ b/src/apis/server.rs @@ -1,15 +1,16 @@ -use std::net::{SocketAddr, TcpListener}; +use std::net::SocketAddr; +use std::str::FromStr; use std::sync::Arc; use axum_server::tls_rustls::RustlsConfig; use axum_server::Handle; +use futures::future::BoxFuture; use futures::Future; use log::info; -use tokio::task::JoinHandle; use warp::hyper; use super::routes::router; -use crate::signals::shutdown_signal_with_message; +use crate::signals::shutdown_signal; use crate::tracker::Tracker; #[derive(Debug)] @@ -25,133 +26,150 @@ pub type RunningApiServer = ApiServer; #[allow(clippy::module_name_repetitions)] pub struct ApiServer { pub cfg: torrust_tracker_configuration::HttpApi, - pub tracker: Arc, pub state: S, } pub struct Stopped; pub struct Running { - pub bind_address: SocketAddr, - stop_job_sender: tokio::sync::oneshot::Sender, - job: JoinHandle<()>, + pub bind_addr: SocketAddr, + task_killer: tokio::sync::oneshot::Sender, + task: tokio::task::JoinHandle<()>, } impl ApiServer { - pub fn new(cfg: torrust_tracker_configuration::HttpApi, tracker: Arc) -> Self { - Self { - cfg, - tracker, - state: Stopped {}, - } + pub fn new(cfg: torrust_tracker_configuration::HttpApi) -> Self { + Self { cfg, state: Stopped {} } } - /// # Errors - /// - /// Will return `Err` if `TcpListener` can not bind to `bind_address`. - pub fn start(self) -> Result, Error> { - let listener = TcpListener::bind(&self.cfg.bind_address).map_err(|e| Error::Error(e.to_string()))?; - - let bind_address = listener.local_addr().map_err(|e| Error::Error(e.to_string()))?; - - let cfg = self.cfg.clone(); - let tracker = self.tracker.clone(); - - let (sender, receiver) = tokio::sync::oneshot::channel::(); - - let job = tokio::spawn(async move { - if let (true, Some(ssl_cert_path), Some(ssl_key_path)) = (cfg.ssl_enabled, cfg.ssl_cert_path, cfg.ssl_key_path) { - let tls_config = RustlsConfig::from_pem_file(ssl_cert_path, ssl_key_path) - .await - .expect("Could not read ssl cert and/or key."); - - start_tls_from_tcp_listener_with_graceful_shutdown(listener, tls_config, &tracker, receiver) - .await - .expect("Could not start from tcp listener with tls."); - } else { - start_from_tcp_listener_with_graceful_shutdown(listener, &tracker, receiver) - .await - .expect("Could not start from tcp listener."); - } + pub async fn start(self, tracker: Arc) -> Result, Error> { + let (shutdown_sender, shutdown_receiver) = tokio::sync::oneshot::channel::(); + let (addr_sender, addr_receiver) = tokio::sync::oneshot::channel::(); + + let configuration = self.cfg.clone(); + + let task = tokio::spawn(async move { + let (bind_addr, server) = Launcher::start(&configuration, tracker, shutdown_signal(shutdown_receiver)); + + addr_sender.send(bind_addr).unwrap(); + + server.await; }); - let running_api_server: ApiServer = ApiServer { + let bind_address = addr_receiver.await.expect("Could not receive bind_address."); + + Ok(ApiServer { cfg: self.cfg, - tracker: self.tracker, state: Running { - bind_address, - stop_job_sender: sender, - job, + bind_addr: bind_address, + task_killer: shutdown_sender, + task, }, - }; - - Ok(running_api_server) + }) } } impl ApiServer { - /// # Errors - /// - /// Will return `Err` if the oneshot channel to send the stop signal - /// has already been called once. pub async fn stop(self) -> Result, Error> { - self.state.stop_job_sender.send(1).map_err(|e| Error::Error(e.to_string()))?; + self.state.task_killer.send(0).unwrap(); - let _ = self.state.job.await; + let _ = self.state.task.await; - let stopped_api_server: ApiServer = ApiServer { + Ok(ApiServer { cfg: self.cfg, - tracker: self.tracker, state: Stopped {}, - }; - - Ok(stopped_api_server) + }) } } -pub fn start_from_tcp_listener_with_graceful_shutdown( - tcp_listener: TcpListener, - tracker: &Arc, - shutdown_signal: tokio::sync::oneshot::Receiver, -) -> impl Future> { - let app = router(tracker); - - let context = tcp_listener.local_addr().expect("Could not get context."); - - axum::Server::from_tcp(tcp_listener) - .expect("Could not bind to tcp listener.") - .serve(app.into_make_service()) - .with_graceful_shutdown(shutdown_signal_with_message( - shutdown_signal, - format!("Shutting down {context}.."), - )) -} +struct Launcher; + +impl Launcher { + pub fn start( + cfg: &torrust_tracker_configuration::HttpApi, + tracker: Arc, + shutdown_signal: F, + ) -> (SocketAddr, BoxFuture<'static, ()>) + where + F: Future + Send + 'static, + { + let addr = SocketAddr::from_str(&cfg.bind_address).expect("bind_address is not a valid SocketAddr."); + let tcp_listener = std::net::TcpListener::bind(addr).expect("Could not bind tcp_listener to address."); + let bind_addr = tcp_listener + .local_addr() + .expect("Could not get local_addr from tcp_listener."); + + if let (true, Some(ssl_cert_path), Some(ssl_key_path)) = (&cfg.ssl_enabled, &cfg.ssl_cert_path, &cfg.ssl_key_path) { + let server = Self::start_tls_with_graceful_shutdown( + tcp_listener, + (ssl_cert_path.to_string(), ssl_key_path.to_string()), + tracker, + shutdown_signal, + ); + + (bind_addr, server) + } else { + let server = Self::start_with_graceful_shutdown(tcp_listener, tracker, shutdown_signal); + + (bind_addr, server) + } + } -pub fn start_tls_from_tcp_listener_with_graceful_shutdown( - tcp_listener: TcpListener, - tls_config: RustlsConfig, - tracker: &Arc, - shutdown_signal: tokio::sync::oneshot::Receiver, -) -> impl Future> { - let app = router(tracker); + pub fn start_with_graceful_shutdown( + tcp_listener: std::net::TcpListener, + tracker: Arc, + shutdown_signal: F, + ) -> BoxFuture<'static, ()> + where + F: Future + Send + 'static, + { + let app = router(tracker); + + Box::pin(async { + axum::Server::from_tcp(tcp_listener) + .expect("Could not bind to tcp listener.") + .serve(app.into_make_service_with_connect_info::()) + .with_graceful_shutdown(shutdown_signal) + .await + .expect("Axum server crashed."); + }) + } - let context = tcp_listener.local_addr().expect("Could not get context."); + pub fn start_tls_with_graceful_shutdown( + tcp_listener: std::net::TcpListener, + (ssl_cert_path, ssl_key_path): (String, String), + tracker: Arc, + shutdown_signal: F, + ) -> BoxFuture<'static, ()> + where + F: Future + Send + 'static, + { + let app = router(tracker); - let handle = Handle::new(); + let handle = Handle::new(); - let cloned_handle = handle.clone(); + let cloned_handle = handle.clone(); - tokio::spawn(async move { - shutdown_signal_with_message(shutdown_signal, format!("Shutting down {context}..")).await; - cloned_handle.shutdown(); - }); + tokio::task::spawn_local(async move { + shutdown_signal.await; + cloned_handle.shutdown(); + }); - axum_server::from_tcp_rustls(tcp_listener, tls_config) - .handle(handle) - .serve(app.into_make_service()) + Box::pin(async { + let tls_config = RustlsConfig::from_pem_file(ssl_cert_path, ssl_key_path) + .await + .expect("Could not read tls cert."); + + axum_server::from_tcp_rustls(tcp_listener, tls_config) + .handle(handle) + .serve(app.into_make_service_with_connect_info::()) + .await + .expect("Axum server crashed."); + }) + } } -pub fn start(socket_addr: SocketAddr, tracker: &Arc) -> impl Future> { +pub fn start(socket_addr: SocketAddr, tracker: Arc) -> impl Future> { let app = router(tracker); let server = axum::Server::bind(&socket_addr).serve(app.into_make_service()); @@ -165,7 +183,7 @@ pub fn start(socket_addr: SocketAddr, tracker: &Arc) -> impl Future, + tracker: Arc, ) -> impl Future> { let app = router(tracker); @@ -204,9 +222,9 @@ mod tests { let tracker = Arc::new(tracker::Tracker::new(cfg.clone(), None, statistics::Repo::new()).unwrap()); - let stopped_api_server = ApiServer::new(cfg.http_api.clone(), tracker); + let stopped_api_server = ApiServer::new(cfg.http_api.clone()); - let running_api_server_result = stopped_api_server.start(); + let running_api_server_result = stopped_api_server.start(tracker).await; assert!(running_api_server_result.is_ok()); diff --git a/src/jobs/tracker_apis.rs b/src/jobs/tracker_apis.rs index 85bb1b59..939b5863 100644 --- a/src/jobs/tracker_apis.rs +++ b/src/jobs/tracker_apis.rs @@ -31,7 +31,7 @@ pub async fn start_job(config: &HttpApi, tracker: Arc) -> Join if !ssl_enabled { info!("Starting Torrust APIs server on: http://{}", bind_addr); - let handle = server::start(bind_addr, &tracker); + let handle = server::start(bind_addr, tracker); tx.send(ApiServerJobStarted()).expect("the API server should not be dropped"); @@ -45,7 +45,7 @@ pub async fn start_job(config: &HttpApi, tracker: Arc) -> Join .await .unwrap(); - let handle = server::start_tls(bind_addr, ssl_config, &tracker); + let handle = server::start_tls(bind_addr, ssl_config, tracker); tx.send(ApiServerJobStarted()).expect("the API server should not be dropped"); diff --git a/src/jobs/udp_tracker.rs b/src/jobs/udp_tracker.rs index 468f6dbb..57232855 100644 --- a/src/jobs/udp_tracker.rs +++ b/src/jobs/udp_tracker.rs @@ -12,10 +12,10 @@ pub fn start_job(config: &UdpTracker, tracker: Arc) -> JoinHan let bind_addr = config.bind_address.clone(); tokio::spawn(async move { - match Udp::new(tracker, &bind_addr).await { + match Udp::new(&bind_addr).await { Ok(udp_server) => { info!("Starting UDP server on: udp://{}", bind_addr); - udp_server.start().await; + udp_server.start(tracker).await; } Err(e) => { warn!("Could not start UDP tracker on: udp://{}", bind_addr); diff --git a/src/udp/handlers.rs b/src/udp/handlers.rs index 211a0d1b..e47a89dd 100644 --- a/src/udp/handlers.rs +++ b/src/udp/handlers.rs @@ -11,12 +11,12 @@ use log::debug; use super::connection_cookie::{check, from_connection_id, into_connection_id, make}; use crate::protocol::common::MAX_SCRAPE_TORRENTS; use crate::protocol::info_hash::InfoHash; -use crate::tracker::{self, statistics}; +use crate::tracker::{statistics, Tracker}; use crate::udp::error::Error; use crate::udp::peer_builder; use crate::udp::request::AnnounceWrapper; -pub async fn handle_packet(remote_addr: SocketAddr, payload: Vec, tracker: Arc) -> Response { +pub async fn handle_packet(remote_addr: SocketAddr, payload: Vec, tracker: &Tracker) -> Response { match Request::from_bytes(&payload[..payload.len()], MAX_SCRAPE_TORRENTS).map_err(|e| Error::InternalServer { message: format!("{e:?}"), location: Location::caller(), @@ -46,11 +46,7 @@ pub async fn handle_packet(remote_addr: SocketAddr, payload: Vec, tracker: A /// # Errors /// /// If a error happens in the `handle_request` function, it will just return the `ServerError`. -pub async fn handle_request( - request: Request, - remote_addr: SocketAddr, - tracker: Arc, -) -> Result { +pub async fn handle_request(request: Request, remote_addr: SocketAddr, tracker: &Tracker) -> Result { match request { Request::Connect(connect_request) => handle_connect(remote_addr, &connect_request, tracker).await, Request::Announce(announce_request) => handle_announce(remote_addr, &announce_request, tracker).await, @@ -61,11 +57,7 @@ pub async fn handle_request( /// # Errors /// /// This function dose not ever return an error. -pub async fn handle_connect( - remote_addr: SocketAddr, - request: &ConnectRequest, - tracker: Arc, -) -> Result { +pub async fn handle_connect(remote_addr: SocketAddr, request: &ConnectRequest, tracker: &Tracker) -> Result { let connection_cookie = make(&remote_addr); let connection_id = into_connection_id(&connection_cookie); @@ -90,7 +82,7 @@ pub async fn handle_connect( /// # Errors /// /// Will return `Error` if unable to `authenticate_request`. -pub async fn authenticate(info_hash: &InfoHash, tracker: Arc) -> Result<(), Error> { +pub async fn authenticate(info_hash: &InfoHash, tracker: &Tracker) -> Result<(), Error> { tracker .authenticate_request(info_hash, &None) .await @@ -105,7 +97,7 @@ pub async fn authenticate(info_hash: &InfoHash, tracker: Arc) pub async fn handle_announce( remote_addr: SocketAddr, announce_request: &AnnounceRequest, - tracker: Arc, + tracker: &Tracker, ) -> Result { debug!("udp announce request: {:#?}", announce_request); @@ -116,7 +108,7 @@ pub async fn handle_announce( let info_hash = wrapped_announce_request.info_hash; let remote_client_ip = remote_addr.ip(); - authenticate(&info_hash, tracker.clone()).await?; + authenticate(&info_hash, tracker).await?; let mut peer = peer_builder::from_request(&wrapped_announce_request, &remote_client_ip); @@ -182,11 +174,7 @@ pub async fn handle_announce( /// # Errors /// /// This function dose not ever return an error. -pub async fn handle_scrape( - remote_addr: SocketAddr, - request: &ScrapeRequest, - tracker: Arc, -) -> Result { +pub async fn handle_scrape(remote_addr: SocketAddr, request: &ScrapeRequest, tracker: &Tracker) -> Result { // Convert from aquatic infohashes let mut info_hashes = vec![]; for info_hash in &request.info_hashes { @@ -392,7 +380,7 @@ mod tests { transaction_id: TransactionId(0i32), }; - let response = handle_connect(sample_ipv4_remote_addr(), &request, initialized_public_tracker()) + let response = handle_connect(sample_ipv4_remote_addr(), &request, &initialized_public_tracker()) .await .unwrap(); @@ -411,7 +399,7 @@ mod tests { transaction_id: TransactionId(0i32), }; - let response = handle_connect(sample_ipv4_remote_addr(), &request, initialized_public_tracker()) + let response = handle_connect(sample_ipv4_remote_addr(), &request, &initialized_public_tracker()) .await .unwrap(); @@ -439,7 +427,7 @@ mod tests { let torrent_tracker = Arc::new( tracker::Tracker::new(tracker_configuration(), Some(stats_event_sender), statistics::Repo::new()).unwrap(), ); - handle_connect(client_socket_address, &sample_connect_request(), torrent_tracker) + handle_connect(client_socket_address, &sample_connect_request(), &torrent_tracker) .await .unwrap(); } @@ -457,7 +445,7 @@ mod tests { let torrent_tracker = Arc::new( tracker::Tracker::new(tracker_configuration(), Some(stats_event_sender), statistics::Repo::new()).unwrap(), ); - handle_connect(sample_ipv6_remote_addr(), &sample_connect_request(), torrent_tracker) + handle_connect(sample_ipv6_remote_addr(), &sample_connect_request(), &torrent_tracker) .await .unwrap(); } @@ -573,7 +561,7 @@ mod tests { .with_port(client_port) .into(); - handle_announce(remote_addr, &request, tracker.clone()).await.unwrap(); + handle_announce(remote_addr, &request, &tracker).await.unwrap(); let peers = tracker.get_all_torrent_peers(&info_hash.0.into()).await; @@ -593,11 +581,11 @@ mod tests { .with_connection_id(into_connection_id(&make(&remote_addr))) .into(); - let response = handle_announce(remote_addr, &request, initialized_public_tracker()) + let response = handle_announce(remote_addr, &request, &initialized_public_tracker()) .await .unwrap(); - let empty_peer_vector: Vec> = vec![]; + let empty_peer_vector: Vec> = vec![]; assert_eq!( response, Response::from(AnnounceResponse { @@ -636,7 +624,7 @@ mod tests { .with_port(client_port) .into(); - handle_announce(remote_addr, &request, tracker.clone()).await.unwrap(); + handle_announce(remote_addr, &request, &tracker).await.unwrap(); let peers = tracker.get_all_torrent_peers(&info_hash.0.into()).await; @@ -667,7 +655,7 @@ mod tests { .with_connection_id(into_connection_id(&make(&remote_addr))) .into(); - handle_announce(remote_addr, &request, tracker.clone()).await.unwrap() + handle_announce(remote_addr, &request, &tracker).await.unwrap() } #[tokio::test] @@ -704,7 +692,7 @@ mod tests { handle_announce( sample_ipv4_socket_address(), &AnnounceRequestBuilder::default().into(), - tracker.clone(), + &tracker, ) .await .unwrap(); @@ -740,7 +728,7 @@ mod tests { .with_port(client_port) .into(); - handle_announce(remote_addr, &request, tracker.clone()).await.unwrap(); + handle_announce(remote_addr, &request, &tracker).await.unwrap(); let peers = tracker.get_all_torrent_peers(&info_hash.0.into()).await; @@ -797,7 +785,7 @@ mod tests { .with_port(client_port) .into(); - handle_announce(remote_addr, &request, tracker.clone()).await.unwrap(); + handle_announce(remote_addr, &request, &tracker).await.unwrap(); let peers = tracker.get_all_torrent_peers(&info_hash.0.into()).await; @@ -820,11 +808,11 @@ mod tests { .with_connection_id(into_connection_id(&make(&remote_addr))) .into(); - let response = handle_announce(remote_addr, &request, initialized_public_tracker()) + let response = handle_announce(remote_addr, &request, &initialized_public_tracker()) .await .unwrap(); - let empty_peer_vector: Vec> = vec![]; + let empty_peer_vector: Vec> = vec![]; assert_eq!( response, Response::from(AnnounceResponse { @@ -863,7 +851,7 @@ mod tests { .with_port(client_port) .into(); - handle_announce(remote_addr, &request, tracker.clone()).await.unwrap(); + handle_announce(remote_addr, &request, &tracker).await.unwrap(); let peers = tracker.get_all_torrent_peers(&info_hash.0.into()).await; @@ -897,7 +885,7 @@ mod tests { .with_connection_id(into_connection_id(&make(&remote_addr))) .into(); - handle_announce(remote_addr, &request, tracker.clone()).await.unwrap() + handle_announce(remote_addr, &request, &tracker).await.unwrap() } #[tokio::test] @@ -937,9 +925,7 @@ mod tests { .with_connection_id(into_connection_id(&make(&remote_addr))) .into(); - handle_announce(remote_addr, &announce_request, tracker.clone()) - .await - .unwrap(); + handle_announce(remote_addr, &announce_request, &tracker).await.unwrap(); } mod from_a_loopback_ip { @@ -982,7 +968,7 @@ mod tests { .with_port(client_port) .into(); - handle_announce(remote_addr, &request, tracker.clone()).await.unwrap(); + handle_announce(remote_addr, &request, &tracker).await.unwrap(); let peers = tracker.get_all_torrent_peers(&info_hash.0.into()).await; @@ -1036,7 +1022,7 @@ mod tests { info_hashes, }; - let response = handle_scrape(remote_addr, &request, initialized_public_tracker()) + let response = handle_scrape(remote_addr, &request, &initialized_public_tracker()) .await .unwrap(); @@ -1083,7 +1069,7 @@ mod tests { let request = build_scrape_request(&remote_addr, &info_hash); - handle_scrape(remote_addr, &request, tracker.clone()).await.unwrap() + handle_scrape(remote_addr, &request, &tracker).await.unwrap() } fn match_scrape_response(response: Response) -> Option { @@ -1134,8 +1120,7 @@ mod tests { let request = build_scrape_request(&remote_addr, &non_existing_info_hash); - let torrent_stats = - match_scrape_response(handle_scrape(remote_addr, &request, tracker.clone()).await.unwrap()).unwrap(); + let torrent_stats = match_scrape_response(handle_scrape(remote_addr, &request, &tracker).await.unwrap()).unwrap(); let expected_torrent_stats = vec![zeroed_torrent_statistics()]; @@ -1177,8 +1162,7 @@ mod tests { let request = build_scrape_request(&remote_addr, &info_hash); - let torrent_stats = - match_scrape_response(handle_scrape(remote_addr, &request, tracker.clone()).await.unwrap()).unwrap(); + let torrent_stats = match_scrape_response(handle_scrape(remote_addr, &request, &tracker).await.unwrap()).unwrap(); let expected_torrent_stats = vec![TorrentScrapeStatistics { seeders: NumberOfPeers(1), @@ -1200,8 +1184,7 @@ mod tests { let request = build_scrape_request(&remote_addr, &info_hash); - let torrent_stats = - match_scrape_response(handle_scrape(remote_addr, &request, tracker.clone()).await.unwrap()).unwrap(); + let torrent_stats = match_scrape_response(handle_scrape(remote_addr, &request, &tracker).await.unwrap()).unwrap(); let expected_torrent_stats = vec![zeroed_torrent_statistics()]; @@ -1246,7 +1229,7 @@ mod tests { tracker::Tracker::new(tracker_configuration(), Some(stats_event_sender), statistics::Repo::new()).unwrap(), ); - handle_scrape(remote_addr, &sample_scrape_request(&remote_addr), tracker.clone()) + handle_scrape(remote_addr, &sample_scrape_request(&remote_addr), &tracker) .await .unwrap(); } @@ -1278,7 +1261,7 @@ mod tests { tracker::Tracker::new(tracker_configuration(), Some(stats_event_sender), statistics::Repo::new()).unwrap(), ); - handle_scrape(remote_addr, &sample_scrape_request(&remote_addr), tracker.clone()) + handle_scrape(remote_addr, &sample_scrape_request(&remote_addr), &tracker) .await .unwrap(); } diff --git a/src/udp/server.rs b/src/udp/server.rs index f7446818..f3f90362 100644 --- a/src/udp/server.rs +++ b/src/udp/server.rs @@ -27,7 +27,6 @@ pub type RunningUdpServer = UdpServer; #[allow(clippy::module_name_repetitions)] pub struct UdpServer { pub cfg: torrust_tracker_configuration::UdpTracker, - pub tracker: Arc, pub state: S, } @@ -40,19 +39,15 @@ pub struct Running { } impl UdpServer { - pub fn new(cfg: torrust_tracker_configuration::UdpTracker, tracker: Arc) -> Self { - Self { - cfg, - tracker, - state: Stopped {}, - } + pub fn new(cfg: torrust_tracker_configuration::UdpTracker) -> Self { + Self { cfg, state: Stopped {} } } /// # Errors /// /// Will return `Err` if UDP can't bind to given bind address. - pub async fn start(self) -> Result, Error> { - let udp = Udp::new(self.tracker.clone(), &self.cfg.bind_address) + pub async fn start(self, tracker: Arc) -> Result, Error> { + let udp = Udp::new(&self.cfg.bind_address) .await .map_err(|e| Error::Error(e.to_string()))?; @@ -61,12 +56,11 @@ impl UdpServer { let (sender, receiver) = tokio::sync::oneshot::channel::(); let job = tokio::spawn(async move { - udp.start_with_graceful_shutdown(shutdown_signal(receiver)).await; + udp.start_with_graceful_shutdown(tracker, shutdown_signal(receiver)).await; }); let running_udp_server: UdpServer = UdpServer { cfg: self.cfg, - tracker: self.tracker, state: Running { bind_address, stop_job_sender: sender, @@ -90,7 +84,6 @@ impl UdpServer { let stopped_api_server: UdpServer = UdpServer { cfg: self.cfg, - tracker: self.tracker, state: Stopped {}, }; @@ -100,30 +93,27 @@ impl UdpServer { pub struct Udp { socket: Arc, - tracker: Arc, } impl Udp { /// # Errors /// /// Will return `Err` unable to bind to the supplied `bind_address`. - pub async fn new(tracker: Arc, bind_address: &str) -> tokio::io::Result { + pub async fn new(bind_address: &str) -> tokio::io::Result { let socket = UdpSocket::bind(bind_address).await?; Ok(Udp { socket: Arc::new(socket), - tracker, }) } /// # Panics /// /// It would panic if unable to resolve the `local_addr` from the supplied ´socket´. - pub async fn start(&self) { + pub async fn start(&self, tracker: Arc) { loop { let mut data = [0; MAX_PACKET_SIZE]; let socket = self.socket.clone(); - let tracker = self.tracker.clone(); tokio::select! { _ = tokio::signal::ctrl_c() => { @@ -137,7 +127,7 @@ impl Udp { debug!("From: {}", &remote_addr); debug!("Payload: {:?}", payload); - let response = handle_packet(remote_addr, payload, tracker).await; + let response = handle_packet(remote_addr, payload, &tracker).await; Udp::send_response(socket, remote_addr, response).await; } @@ -148,7 +138,7 @@ impl Udp { /// # Panics /// /// It would panic if unable to resolve the `local_addr` from the supplied ´socket´. - async fn start_with_graceful_shutdown(&self, shutdown_signal: F) + async fn start_with_graceful_shutdown(&self, tracker: Arc, shutdown_signal: F) where F: Future, { @@ -158,7 +148,6 @@ impl Udp { loop { let mut data = [0; MAX_PACKET_SIZE]; let socket = self.socket.clone(); - let tracker = self.tracker.clone(); tokio::select! { _ = &mut shutdown_signal => { @@ -172,7 +161,7 @@ impl Udp { debug!("From: {}", &remote_addr); debug!("Payload: {:?}", payload); - let response = handle_packet(remote_addr, payload, tracker).await; + let response = handle_packet(remote_addr, payload, &tracker).await; Udp::send_response(socket, remote_addr, response).await; } diff --git a/tests/api/test_environment.rs b/tests/api/test_environment.rs index 1f870865..b6f5ca99 100644 --- a/tests/api/test_environment.rs +++ b/tests/api/test_environment.rs @@ -4,7 +4,6 @@ use torrust_tracker::apis::server::{ApiServer, RunningApiServer, StoppedApiServe use torrust_tracker::protocol::info_hash::InfoHash; use torrust_tracker::tracker::peer::Peer; use torrust_tracker::tracker::Tracker; -use torrust_tracker_test_helpers::configuration; use super::connection_info::ConnectionInfo; use crate::common::tracker::new_tracker; @@ -15,6 +14,7 @@ pub type StoppedTestEnvironment = TestEnvironment; pub type RunningTestEnvironment = TestEnvironment; pub struct TestEnvironment { + pub cfg: Arc, pub tracker: Arc, pub state: S, } @@ -36,39 +36,45 @@ impl TestEnvironment { } impl TestEnvironment { - #[allow(dead_code)] - pub fn new_stopped() -> Self { - let api_server = api_server(); + pub fn new_stopped(cfg: torrust_tracker_configuration::Configuration) -> Self { + let cfg = Arc::new(cfg); + + let tracker = new_tracker(cfg.clone()); + + let api_server = api_server(cfg.http_api.clone()); Self { - tracker: api_server.tracker.clone(), + cfg, + tracker, state: Stopped { api_server }, } } - #[allow(dead_code)] - pub fn start(self) -> TestEnvironment { + pub async fn start(self) -> TestEnvironment { TestEnvironment { - tracker: self.tracker, + cfg: self.cfg, + tracker: self.tracker.clone(), state: Running { - api_server: self.state.api_server.start().unwrap(), + api_server: self.state.api_server.start(self.tracker).await.unwrap(), }, } } + + pub fn config_mut(&mut self) -> &mut torrust_tracker_configuration::HttpApi { + &mut self.state.api_server.cfg + } } impl TestEnvironment { - pub fn new_running() -> Self { - let api_server = running_api_server(); + pub async fn new_running(cfg: torrust_tracker_configuration::Configuration) -> Self { + let test_env = StoppedTestEnvironment::new_stopped(cfg); - Self { - tracker: api_server.tracker.clone(), - state: Running { api_server }, - } + test_env.start().await } pub async fn stop(self) -> TestEnvironment { TestEnvironment { + cfg: self.cfg, tracker: self.tracker, state: Stopped { api_server: self.state.api_server.stop().await.unwrap(), @@ -78,25 +84,22 @@ impl TestEnvironment { pub fn get_connection_info(&self) -> ConnectionInfo { ConnectionInfo { - bind_address: self.state.api_server.state.bind_address.to_string(), + bind_address: self.state.api_server.state.bind_addr.to_string(), api_token: self.state.api_server.cfg.access_tokens.get("admin").cloned(), } } } #[allow(clippy::module_name_repetitions)] -pub fn running_test_environment() -> RunningTestEnvironment { - TestEnvironment::new_running() +pub fn stopped_test_environment(cfg: torrust_tracker_configuration::Configuration) -> StoppedTestEnvironment { + TestEnvironment::new_stopped(cfg) } -pub fn api_server() -> StoppedApiServer { - let config = Arc::new(configuration::ephemeral()); - - let tracker = new_tracker(config.clone()); - - ApiServer::new(config.http_api.clone(), tracker) +#[allow(clippy::module_name_repetitions)] +pub async fn running_test_environment(cfg: torrust_tracker_configuration::Configuration) -> RunningTestEnvironment { + TestEnvironment::new_running(cfg).await } -pub fn running_api_server() -> RunningApiServer { - api_server().start().unwrap() +pub fn api_server(cfg: torrust_tracker_configuration::HttpApi) -> StoppedApiServer { + ApiServer::new(cfg) } diff --git a/tests/tracker_api.rs b/tests/tracker_api.rs index ccdcded5..d00c7d68 100644 --- a/tests/tracker_api.rs +++ b/tests/tracker_api.rs @@ -9,7 +9,6 @@ mod api; mod common; mod tracker_apis { - use crate::common::fixtures::invalid_info_hashes; // When these infohashes are used in URL path params @@ -24,7 +23,29 @@ mod tracker_apis { [String::new(), " ".to_string()].to_vec() } + mod configuration { + use torrust_tracker_test_helpers::configuration; + + use crate::api::test_environment::stopped_test_environment; + + #[tokio::test] + #[should_panic] + async fn should_fail_with_ssl_enabled_and_bad_ssl_config() { + let mut test_env = stopped_test_environment(configuration::ephemeral()); + + let cfg = test_env.config_mut(); + + cfg.ssl_enabled = true; + cfg.ssl_key_path = Some("bad key path".to_string()); + cfg.ssl_cert_path = Some("bad cert path".to_string()); + + test_env.start().await; + } + } + mod authentication { + use torrust_tracker_test_helpers::configuration; + use crate::api::asserts::{assert_token_not_valid, assert_unauthorized}; use crate::api::client::Client; use crate::api::test_environment::running_test_environment; @@ -32,7 +53,7 @@ mod tracker_apis { #[tokio::test] async fn should_authenticate_requests_by_using_a_token_query_param() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let token = test_env.get_connection_info().api_token.unwrap(); @@ -47,7 +68,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_authenticate_requests_when_the_token_is_missing() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let response = Client::new(test_env.get_connection_info()) .get_request_with_query("stats", Query::default()) @@ -60,7 +81,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_authenticate_requests_when_the_token_is_empty() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let response = Client::new(test_env.get_connection_info()) .get_request_with_query("stats", Query::params([QueryParam::new("token", "")].to_vec())) @@ -73,7 +94,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_authenticate_requests_when_the_token_is_invalid() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let response = Client::new(test_env.get_connection_info()) .get_request_with_query("stats", Query::params([QueryParam::new("token", "INVALID TOKEN")].to_vec())) @@ -86,7 +107,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_the_token_query_param_to_be_at_any_position_in_the_url_query() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let token = test_env.get_connection_info().api_token.unwrap(); @@ -113,6 +134,7 @@ mod tracker_apis { use torrust_tracker::apis::resources::stats::Stats; use torrust_tracker::protocol::info_hash::InfoHash; + use torrust_tracker_test_helpers::configuration; use crate::api::asserts::{assert_stats, assert_token_not_valid, assert_unauthorized}; use crate::api::client::Client; @@ -122,7 +144,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_getting_tracker_statistics() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; test_env .add_torrent_peer( @@ -161,7 +183,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_allow_getting_tracker_statistics_for_unauthenticated_users() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let response = Client::new(connection_with_invalid_token( test_env.get_connection_info().bind_address.as_str(), @@ -187,6 +209,7 @@ mod tracker_apis { use torrust_tracker::apis::resources::torrent::Torrent; use torrust_tracker::apis::resources::{self, torrent}; use torrust_tracker::protocol::info_hash::InfoHash; + use torrust_tracker_test_helpers::configuration; use super::{invalid_infohashes_returning_bad_request, invalid_infohashes_returning_not_found}; use crate::api::asserts::{ @@ -201,7 +224,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_getting_torrents() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let info_hash = InfoHash::from_str("9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d").unwrap(); @@ -226,7 +249,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_limiting_the_torrents_in_the_result() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; // torrents are ordered alphabetically by infohashes let info_hash_1 = InfoHash::from_str("9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d").unwrap(); @@ -256,7 +279,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_the_torrents_result_pagination() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; // torrents are ordered alphabetically by infohashes let info_hash_1 = InfoHash::from_str("9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d").unwrap(); @@ -286,7 +309,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_getting_torrents_when_the_offset_query_parameter_cannot_be_parsed() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let invalid_offsets = [" ", "-1", "1.1", "INVALID OFFSET"]; @@ -303,7 +326,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_getting_torrents_when_the_limit_query_parameter_cannot_be_parsed() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let invalid_limits = [" ", "-1", "1.1", "INVALID LIMIT"]; @@ -320,7 +343,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_allow_getting_torrents_for_unauthenticated_users() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let response = Client::new(connection_with_invalid_token( test_env.get_connection_info().bind_address.as_str(), @@ -341,7 +364,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_getting_a_torrent_info() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let info_hash = InfoHash::from_str("9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d").unwrap(); @@ -370,7 +393,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_while_getting_a_torrent_info_when_the_torrent_does_not_exist() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let info_hash = InfoHash::from_str("9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d").unwrap(); @@ -385,7 +408,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_getting_a_torrent_info_when_the_provided_infohash_is_invalid() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; for invalid_infohash in &invalid_infohashes_returning_bad_request() { let response = Client::new(test_env.get_connection_info()) @@ -408,7 +431,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_allow_getting_a_torrent_info_for_unauthenticated_users() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let info_hash = InfoHash::from_str("9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d").unwrap(); @@ -436,6 +459,7 @@ mod tracker_apis { use std::str::FromStr; use torrust_tracker::protocol::info_hash::InfoHash; + use torrust_tracker_test_helpers::configuration; use super::{invalid_infohashes_returning_bad_request, invalid_infohashes_returning_not_found}; use crate::api::asserts::{ @@ -450,7 +474,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_whitelisting_a_torrent() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let info_hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); @@ -471,7 +495,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_whitelisting_a_torrent_that_has_been_already_whitelisted() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let info_hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); @@ -488,7 +512,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_allow_whitelisting_a_torrent_for_unauthenticated_users() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let info_hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); @@ -511,7 +535,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_when_the_torrent_cannot_be_whitelisted() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let info_hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); @@ -528,7 +552,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_whitelisting_a_torrent_when_the_provided_infohash_is_invalid() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; for invalid_infohash in &invalid_infohashes_returning_bad_request() { let response = Client::new(test_env.get_connection_info()) @@ -551,7 +575,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_removing_a_torrent_from_the_whitelist() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); let info_hash = InfoHash::from_str(&hash).unwrap(); @@ -569,7 +593,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_fail_trying_to_remove_a_non_whitelisted_torrent_from_the_whitelist() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let non_whitelisted_torrent_hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); @@ -584,7 +608,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_removing_a_torrent_from_the_whitelist_when_the_provided_infohash_is_invalid() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; for invalid_infohash in &invalid_infohashes_returning_bad_request() { let response = Client::new(test_env.get_connection_info()) @@ -607,7 +631,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_when_the_torrent_cannot_be_removed_from_the_whitelist() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); let info_hash = InfoHash::from_str(&hash).unwrap(); @@ -626,7 +650,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_allow_removing_a_torrent_from_the_whitelist_for_unauthenticated_users() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); let info_hash = InfoHash::from_str(&hash).unwrap(); @@ -652,7 +676,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_reload_the_whitelist_from_the_database() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); let info_hash = InfoHash::from_str(&hash).unwrap(); @@ -677,7 +701,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_when_the_whitelist_cannot_be_reloaded_from_the_database() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); let info_hash = InfoHash::from_str(&hash).unwrap(); @@ -697,6 +721,7 @@ mod tracker_apis { use std::time::Duration; use torrust_tracker::tracker::auth::Key; + use torrust_tracker_test_helpers::configuration; use crate::api::asserts::{ assert_auth_key_utf8, assert_failed_to_delete_key, assert_failed_to_generate_key, assert_failed_to_reload_keys, @@ -710,7 +735,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_generating_a_new_auth_key() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let seconds_valid = 60; @@ -732,7 +757,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_allow_generating_a_new_auth_key_for_unauthenticated_users() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let seconds_valid = 60; @@ -755,7 +780,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_generating_a_new_auth_key_when_the_key_duration_is_invalid() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let invalid_key_durations = [ // "", it returns 404 @@ -763,9 +788,9 @@ mod tracker_apis { "-1", "text", ]; - for invalid_key_duration in &invalid_key_durations { + for invalid_key_duration in invalid_key_durations { let response = Client::new(test_env.get_connection_info()) - .post(&format!("key/{}", &invalid_key_duration)) + .post(&format!("key/{}", invalid_key_duration)) .await; assert_invalid_key_duration_param(response, invalid_key_duration).await; @@ -776,7 +801,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_when_the_auth_key_cannot_be_generated() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; force_database_error(&test_env.tracker); @@ -792,7 +817,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_deleting_an_auth_key() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let seconds_valid = 60; let auth_key = test_env @@ -812,7 +837,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_deleting_an_auth_key_when_the_key_id_is_invalid() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let invalid_auth_keys = [ // "", it returns a 404 @@ -837,7 +862,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_when_the_auth_key_cannot_be_deleted() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let seconds_valid = 60; let auth_key = test_env @@ -859,7 +884,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_allow_deleting_an_auth_key_for_unauthenticated_users() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let seconds_valid = 60; @@ -896,7 +921,7 @@ mod tracker_apis { #[tokio::test] async fn should_allow_reloading_keys() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let seconds_valid = 60; test_env @@ -914,7 +939,7 @@ mod tracker_apis { #[tokio::test] async fn should_fail_when_keys_cannot_be_reloaded() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let seconds_valid = 60; test_env @@ -934,7 +959,7 @@ mod tracker_apis { #[tokio::test] async fn should_not_allow_reloading_keys_for_unauthenticated_users() { - let test_env = running_test_environment(); + let test_env = running_test_environment(configuration::ephemeral()).await; let seconds_valid = 60; test_env diff --git a/tests/udp/test_environment.rs b/tests/udp/test_environment.rs index 02d51c4b..f729777a 100644 --- a/tests/udp/test_environment.rs +++ b/tests/udp/test_environment.rs @@ -5,7 +5,6 @@ use torrust_tracker::protocol::info_hash::InfoHash; use torrust_tracker::tracker::peer::Peer; use torrust_tracker::tracker::Tracker; use torrust_tracker::udp::server::{RunningUdpServer, StoppedUdpServer, UdpServer}; -use torrust_tracker_test_helpers::configuration; use crate::common::tracker::new_tracker; @@ -15,6 +14,7 @@ pub type StoppedTestEnvironment = TestEnvironment; pub type RunningTestEnvironment = TestEnvironment; pub struct TestEnvironment { + pub cfg: Arc, pub tracker: Arc, pub state: S, } @@ -38,39 +38,41 @@ impl TestEnvironment { impl TestEnvironment { #[allow(dead_code)] - pub fn new_stopped() -> Self { - let udp_server = udp_server(); + pub fn new_stopped(cfg: torrust_tracker_configuration::Configuration) -> Self { + let cfg = Arc::new(cfg); + + let tracker = new_tracker(cfg.clone()); + + let udp_server = udp_server(cfg.udp_trackers[0].clone()); Self { - tracker: udp_server.tracker.clone(), - state: Stopped { udp_server: udp_server }, + cfg, + tracker, + state: Stopped { udp_server }, } } #[allow(dead_code)] pub async fn start(self) -> TestEnvironment { TestEnvironment { - tracker: self.tracker, + cfg: self.cfg, + tracker: self.tracker.clone(), state: Running { - udp_server: self.state.udp_server.start().await.unwrap(), + udp_server: self.state.udp_server.start(self.tracker).await.unwrap(), }, } } } impl TestEnvironment { - pub async fn new_running() -> Self { - let udp_server = running_udp_server().await; - - Self { - tracker: udp_server.tracker.clone(), - state: Running { udp_server: udp_server }, - } + pub async fn new_running(cfg: torrust_tracker_configuration::Configuration) -> Self { + StoppedTestEnvironment::new_stopped(cfg).start().await } #[allow(dead_code)] pub async fn stop(self) -> TestEnvironment { TestEnvironment { + cfg: self.cfg, tracker: self.tracker, state: Stopped { udp_server: self.state.udp_server.stop().await.unwrap(), @@ -83,19 +85,16 @@ impl TestEnvironment { } } -#[allow(clippy::module_name_repetitions)] -pub async fn running_test_environment() -> RunningTestEnvironment { - TestEnvironment::new_running().await +#[allow(clippy::module_name_repetitions, dead_code)] +pub fn stopped_test_environment(cfg: torrust_tracker_configuration::Configuration) -> StoppedTestEnvironment { + TestEnvironment::new_stopped(cfg) } -pub fn udp_server() -> StoppedUdpServer { - let config = Arc::new(configuration::ephemeral()); - - let tracker = new_tracker(config.clone()); - - UdpServer::new(config.udp_trackers[0].clone(), tracker) +#[allow(clippy::module_name_repetitions)] +pub async fn running_test_environment(cfg: torrust_tracker_configuration::Configuration) -> RunningTestEnvironment { + TestEnvironment::new_running(cfg).await } -pub async fn running_udp_server() -> RunningUdpServer { - udp_server().start().await.unwrap() +pub fn udp_server(cfg: torrust_tracker_configuration::UdpTracker) -> StoppedUdpServer { + UdpServer::new(cfg) } diff --git a/tests/udp_tracker.rs b/tests/udp_tracker.rs index b7cc3bd6..0f9283a8 100644 --- a/tests/udp_tracker.rs +++ b/tests/udp_tracker.rs @@ -17,6 +17,7 @@ mod udp_tracker_server { use aquatic_udp_protocol::{ConnectRequest, ConnectionId, Response, TransactionId}; use torrust_tracker::udp::MAX_PACKET_SIZE; + use torrust_tracker_test_helpers::configuration; use crate::udp::asserts::is_error_response; use crate::udp::client::{new_udp_client_connected, UdpTrackerClient}; @@ -45,7 +46,7 @@ mod udp_tracker_server { #[tokio::test] async fn should_return_a_bad_request_response_when_the_client_sends_an_empty_request() { - let test_env = running_test_environment().await; + let test_env = running_test_environment(configuration::ephemeral()).await; let client = new_udp_client_connected(&test_env.bind_address().to_string()).await; @@ -60,6 +61,7 @@ mod udp_tracker_server { mod receiving_a_connection_request { use aquatic_udp_protocol::{ConnectRequest, TransactionId}; + use torrust_tracker_test_helpers::configuration; use crate::udp::asserts::is_connect_response; use crate::udp::client::new_udp_tracker_client_connected; @@ -67,7 +69,7 @@ mod udp_tracker_server { #[tokio::test] async fn should_return_a_connect_response() { - let test_env = running_test_environment().await; + let test_env = running_test_environment(configuration::ephemeral()).await; let client = new_udp_tracker_client_connected(&test_env.bind_address().to_string()).await; @@ -90,6 +92,7 @@ mod udp_tracker_server { AnnounceEvent, AnnounceRequest, ConnectionId, InfoHash, NumberOfBytes, NumberOfPeers, PeerId, PeerKey, Port, TransactionId, }; + use torrust_tracker_test_helpers::configuration; use crate::udp::asserts::is_ipv4_announce_response; use crate::udp::client::new_udp_tracker_client_connected; @@ -98,7 +101,7 @@ mod udp_tracker_server { #[tokio::test] async fn should_return_an_announce_response() { - let test_env = running_test_environment().await; + let test_env = running_test_environment(configuration::ephemeral()).await; let client = new_udp_tracker_client_connected(&test_env.bind_address().to_string()).await; @@ -131,6 +134,7 @@ mod udp_tracker_server { mod receiving_an_scrape_request { use aquatic_udp_protocol::{ConnectionId, InfoHash, ScrapeRequest, TransactionId}; + use torrust_tracker_test_helpers::configuration; use crate::udp::asserts::is_scrape_response; use crate::udp::client::new_udp_tracker_client_connected; @@ -139,7 +143,7 @@ mod udp_tracker_server { #[tokio::test] async fn should_return_a_scrape_response() { - let test_env = running_test_environment().await; + let test_env = running_test_environment(configuration::ephemeral()).await; let client = new_udp_tracker_client_connected(&test_env.bind_address().to_string()).await;