From 2ae7ab47ba19a22b7d98d31096cae605b26c3a8a Mon Sep 17 00:00:00 2001 From: Cameron Garnham Date: Tue, 13 Sep 2022 00:32:05 +0200 Subject: [PATCH 1/2] clock: add mockable clock --- Cargo.lock | 1 + Cargo.toml | 9 +- src/lib.rs | 11 ++ src/main.rs | 5 +- src/protocol/clock/clock.rs | 248 ++++++++++++++++++++++++++++++++++++ src/protocol/clock/mod.rs | 1 + src/protocol/mod.rs | 1 + 7 files changed, 274 insertions(+), 2 deletions(-) create mode 100644 src/protocol/clock/clock.rs create mode 100644 src/protocol/clock/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 279e4a67..1a4fe8b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2341,6 +2341,7 @@ dependencies = [ "fern", "futures", "hex", + "lazy_static", "log", "openssl", "percent-encoding", diff --git a/Cargo.toml b/Cargo.toml index 9d21ed7d..89fdffa9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,13 @@ lto = "fat" strip = true [dependencies] -tokio = { version = "1.7", features = ["rt-multi-thread", "net", "sync", "macros", "signal"] } +tokio = { version = "1.7", features = [ + "rt-multi-thread", + "net", + "sync", + "macros", + "signal", +] } serde = { version = "1.0", features = ["derive"] } serde_bencode = "^0.2.3" @@ -28,6 +34,7 @@ serde_with = "2.0.0" hex = "0.4.3" percent-encoding = "2.1.0" binascii = "0.1" +lazy_static = "1.4.0" openssl = { version = "0.10.41", features = ["vendored"] } diff --git a/src/lib.rs b/src/lib.rs index 6dcc7e6d..882e126b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,3 +16,14 @@ pub mod protocol; pub mod setup; pub mod tracker; pub mod udp; + +#[macro_use] +extern crate lazy_static; + +pub mod static_time { + use std::time::SystemTime; + + lazy_static! { + pub static ref TIME_AT_APP_START: SystemTime = SystemTime::now(); + } +} diff --git a/src/main.rs b/src/main.rs index 0b406c85..01121052 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,12 +2,15 @@ use std::sync::Arc; use log::info; use torrust_tracker::tracker::tracker::TorrentTracker; -use torrust_tracker::{logging, setup, Configuration}; +use torrust_tracker::{logging, setup, static_time, Configuration}; #[tokio::main] async fn main() { const CONFIG_PATH: &str = "config.toml"; + // Set the time of Torrust app starting + lazy_static::initialize(&static_time::TIME_AT_APP_START); + // Initialize Torrust config let config = match Configuration::load_from_file(CONFIG_PATH) { Ok(config) => Arc::new(config), diff --git a/src/protocol/clock/clock.rs b/src/protocol/clock/clock.rs new file mode 100644 index 00000000..db59170b --- /dev/null +++ b/src/protocol/clock/clock.rs @@ -0,0 +1,248 @@ +use std::num::IntErrorKind; +pub use std::time::Duration; + +pub type SinceUnixEpoch = Duration; + +#[derive(Debug)] +pub enum ClockType { + WorkingClock, + StoppedClock, +} + +#[derive(Debug)] +pub struct Clock; + +pub type WorkingClock = Clock<{ ClockType::WorkingClock as usize }>; +pub type StoppedClock = Clock<{ ClockType::StoppedClock as usize }>; + +#[cfg(not(test))] +pub type DefaultClock = WorkingClock; + +#[cfg(test)] +pub type DefaultClock = StoppedClock; + +pub trait Time: Sized { + fn now() -> SinceUnixEpoch; +} + +pub trait TimeNow: Time { + fn add(add_time: &Duration) -> Option { + Self::now().checked_add(*add_time) + } + fn sub(sub_time: &Duration) -> Option { + Self::now().checked_sub(*sub_time) + } +} + +#[cfg(test)] +mod tests { + use std::any::TypeId; + + use crate::protocol::clock::clock::{DefaultClock, StoppedClock, Time, WorkingClock}; + + #[test] + fn it_should_be_the_stopped_clock_as_default_when_testing() { + // We are testing, so we should default to the fixed time. + assert_eq!(TypeId::of::(), TypeId::of::()); + assert_eq!(StoppedClock::now(), DefaultClock::now()) + } + + #[test] + fn it_should_have_different_times() { + assert_ne!(TypeId::of::(), TypeId::of::()); + assert_ne!(StoppedClock::now(), WorkingClock::now()) + } +} + +mod working_clock { + use std::time::SystemTime; + + use super::{SinceUnixEpoch, Time, TimeNow, WorkingClock}; + + impl Time for WorkingClock { + fn now() -> SinceUnixEpoch { + SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap() + } + } + + impl TimeNow for WorkingClock {} +} + +pub trait StoppedTime: TimeNow { + fn local_set(unix_time: &SinceUnixEpoch); + fn local_set_to_unix_epoch() { + Self::local_set(&SinceUnixEpoch::ZERO) + } + fn local_set_to_app_start_time(); + fn local_set_to_system_time_now(); + fn local_add(duration: &Duration) -> Result<(), IntErrorKind>; + fn local_sub(duration: &Duration) -> Result<(), IntErrorKind>; + fn local_reset(); +} + +mod stopped_clock { + use std::num::IntErrorKind; + use std::time::Duration; + + use super::{SinceUnixEpoch, StoppedClock, StoppedTime, Time, TimeNow}; + + impl Time for StoppedClock { + fn now() -> SinceUnixEpoch { + detail::FIXED_TIME.with(|time| { + return *time.borrow(); + }) + } + } + + impl TimeNow for StoppedClock {} + + impl StoppedTime for StoppedClock { + fn local_set(unix_time: &SinceUnixEpoch) { + detail::FIXED_TIME.with(|time| { + *time.borrow_mut() = *unix_time; + }) + } + + fn local_set_to_app_start_time() { + Self::local_set(&detail::get_app_start_time()) + } + + fn local_set_to_system_time_now() { + Self::local_set(&detail::get_app_start_time()) + } + + fn local_add(duration: &Duration) -> Result<(), IntErrorKind> { + detail::FIXED_TIME.with(|time| { + let time_borrowed = *time.borrow(); + *time.borrow_mut() = match time_borrowed.checked_add(*duration) { + Some(time) => time, + None => { + return Err(IntErrorKind::PosOverflow); + } + }; + Ok(()) + }) + } + + fn local_sub(duration: &Duration) -> Result<(), IntErrorKind> { + detail::FIXED_TIME.with(|time| { + let time_borrowed = *time.borrow(); + *time.borrow_mut() = match time_borrowed.checked_sub(*duration) { + Some(time) => time, + None => { + return Err(IntErrorKind::NegOverflow); + } + }; + Ok(()) + }) + } + + fn local_reset() { + Self::local_set(&detail::get_default_fixed_time()) + } + } + + #[cfg(test)] + mod tests { + use std::thread; + use std::time::Duration; + + use crate::protocol::clock::clock::{SinceUnixEpoch, StoppedClock, StoppedTime, Time, TimeNow, WorkingClock}; + + #[test] + fn it_should_default_to_zero_when_testing() { + assert_eq!(StoppedClock::now(), SinceUnixEpoch::ZERO) + } + + #[test] + fn it_should_possible_to_set_the_time() { + // Check we start with ZERO. + assert_eq!(StoppedClock::now(), Duration::ZERO); + + // Set to Current Time and Check + let timestamp = WorkingClock::now(); + StoppedClock::local_set(×tamp); + assert_eq!(StoppedClock::now(), timestamp); + + // Elapse the Current Time and Check + StoppedClock::local_add(×tamp).unwrap(); + assert_eq!(StoppedClock::now(), timestamp + timestamp); + + // Reset to ZERO and Check + StoppedClock::local_reset(); + assert_eq!(StoppedClock::now(), Duration::ZERO); + } + + #[test] + fn it_should_default_to_zero_on_thread_exit() { + assert_eq!(StoppedClock::now(), Duration::ZERO); + let after5 = WorkingClock::add(&Duration::from_secs(5)).unwrap(); + StoppedClock::local_set(&after5); + assert_eq!(StoppedClock::now(), after5); + + let t = thread::spawn(move || { + // each thread starts out with the initial value of ZERO + assert_eq!(StoppedClock::now(), Duration::ZERO); + + // and gets set to the current time. + let timestamp = WorkingClock::now(); + StoppedClock::local_set(×tamp); + assert_eq!(StoppedClock::now(), timestamp); + }); + + // wait for the thread to complete and bail out on panic + t.join().unwrap(); + + // we retain our original value of current time + 5sec despite the child thread + assert_eq!(StoppedClock::now(), after5); + + // Reset to ZERO and Check + StoppedClock::local_reset(); + assert_eq!(StoppedClock::now(), Duration::ZERO); + } + } + + mod detail { + use std::cell::RefCell; + use std::time::SystemTime; + + use crate::protocol::clock::clock::SinceUnixEpoch; + use crate::static_time; + + pub fn get_app_start_time() -> SinceUnixEpoch { + (*static_time::TIME_AT_APP_START) + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + } + + #[cfg(not(test))] + pub fn get_default_fixed_time() -> SinceUnixEpoch { + get_app_start_time() + } + + #[cfg(test)] + pub fn get_default_fixed_time() -> SinceUnixEpoch { + SinceUnixEpoch::ZERO + } + + thread_local!(pub static FIXED_TIME: RefCell = RefCell::new(get_default_fixed_time())); + + #[cfg(test)] + mod tests { + use std::time::Duration; + + use crate::protocol::clock::clock::stopped_clock::detail::{get_app_start_time, get_default_fixed_time}; + + #[test] + fn it_should_get_the_zero_start_time_when_testing() { + assert_eq!(get_default_fixed_time(), Duration::ZERO); + } + + #[test] + fn it_should_get_app_start_time() { + const TIME_AT_WRITING_THIS_TEST: Duration = Duration::new(1662983731, 000022312); + assert!(get_app_start_time() > TIME_AT_WRITING_THIS_TEST); + } + } + } +} diff --git a/src/protocol/clock/mod.rs b/src/protocol/clock/mod.rs new file mode 100644 index 00000000..159730d2 --- /dev/null +++ b/src/protocol/clock/mod.rs @@ -0,0 +1 @@ +pub mod clock; diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 99cfd91e..fcb28b3b 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,2 +1,3 @@ +pub mod clock; pub mod common; pub mod utils; From cab093c026cf253ba6f7208211953a25b3d1c91b Mon Sep 17 00:00:00 2001 From: Cameron Garnham Date: Tue, 13 Sep 2022 00:44:12 +0200 Subject: [PATCH 2/2] clock: use mockable clock in project --- src/api/server.rs | 3 ++- src/databases/mysql.rs | 7 ++++--- src/databases/sqlite.rs | 7 ++++--- src/protocol/utils.rs | 14 ++++++-------- src/tracker/key.rs | 36 ++++++++++++++++++++++++++---------- src/tracker/peer.rs | 11 ++++++----- src/tracker/torrent.rs | 6 ++++-- src/tracker/tracker.rs | 5 +++-- src/udp/errors.rs | 3 +++ 9 files changed, 58 insertions(+), 34 deletions(-) diff --git a/src/api/server.rs b/src/api/server.rs index cc6c905e..5285c9b2 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -2,6 +2,7 @@ use std::cmp::min; use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use serde::{Deserialize, Serialize}; use warp::{filters, reply, serve, Filter}; @@ -268,7 +269,7 @@ pub fn start(socket_addr: SocketAddr, tracker: Arc) -> impl warp (seconds_valid, tracker) }) .and_then(|(seconds_valid, tracker): (u64, Arc)| async move { - match tracker.generate_auth_key(seconds_valid).await { + match tracker.generate_auth_key(Duration::from_secs(seconds_valid)).await { Ok(auth_key) => Ok(warp::reply::json(&auth_key)), Err(..) => Err(warp::reject::custom(ActionStatus::Err { reason: "failed to generate key".into(), diff --git a/src/databases/mysql.rs b/src/databases/mysql.rs index 882fb7bf..33287df6 100644 --- a/src/databases/mysql.rs +++ b/src/databases/mysql.rs @@ -1,4 +1,5 @@ use std::str::FromStr; +use std::time::Duration; use async_trait::async_trait; use log::debug; @@ -94,7 +95,7 @@ impl Database for MysqlDatabase { "SELECT `key`, valid_until FROM `keys`", |(key, valid_until): (String, i64)| AuthKey { key, - valid_until: Some(valid_until as u64), + valid_until: Some(Duration::from_secs(valid_until as u64)), }, ) .map_err(|_| database::Error::QueryReturnedNoRows)?; @@ -187,7 +188,7 @@ impl Database for MysqlDatabase { { Some((key, valid_until)) => Ok(AuthKey { key, - valid_until: Some(valid_until as u64), + valid_until: Some(Duration::from_secs(valid_until as u64)), }), None => Err(database::Error::InvalidQuery), } @@ -197,7 +198,7 @@ impl Database for MysqlDatabase { let mut conn = self.pool.get().map_err(|_| database::Error::DatabaseError)?; let key = auth_key.key.to_string(); - let valid_until = auth_key.valid_until.unwrap_or(0).to_string(); + let valid_until = auth_key.valid_until.unwrap_or(Duration::ZERO).as_secs().to_string(); match conn.exec_drop( "INSERT INTO `keys` (`key`, valid_until) VALUES (:key, :valid_until)", diff --git a/src/databases/sqlite.rs b/src/databases/sqlite.rs index 3aba3991..ff080306 100644 --- a/src/databases/sqlite.rs +++ b/src/databases/sqlite.rs @@ -7,6 +7,7 @@ use r2d2_sqlite::SqliteConnectionManager; use crate::databases::database; use crate::databases::database::{Database, Error}; +use crate::protocol::clock::clock::SinceUnixEpoch; use crate::tracker::key::AuthKey; use crate::InfoHash; @@ -85,7 +86,7 @@ impl Database for SqliteDatabase { Ok(AuthKey { key, - valid_until: Some(valid_until as u64), + valid_until: Some(SinceUnixEpoch::from_secs(valid_until as u64)), }) })?; @@ -192,7 +193,7 @@ impl Database for SqliteDatabase { Ok(AuthKey { key, - valid_until: Some(valid_until_i64 as u64), + valid_until: Some(SinceUnixEpoch::from_secs(valid_until_i64 as u64)), }) } else { Err(database::Error::QueryReturnedNoRows) @@ -204,7 +205,7 @@ impl Database for SqliteDatabase { match conn.execute( "INSERT INTO keys (key, valid_until) VALUES (?1, ?2)", - [auth_key.key.to_string(), auth_key.valid_until.unwrap().to_string()], + [auth_key.key.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()], ) { Ok(updated) => { if updated > 0 { diff --git a/src/protocol/utils.rs b/src/protocol/utils.rs index e50c8b03..f2a68fdb 100644 --- a/src/protocol/utils.rs +++ b/src/protocol/utils.rs @@ -1,19 +1,17 @@ use std::net::SocketAddr; -use std::time::SystemTime; use aquatic_udp_protocol::ConnectionId; +use super::clock::clock::{DefaultClock, SinceUnixEpoch, Time}; + pub fn get_connection_id(remote_address: &SocketAddr) -> ConnectionId { - match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { - Ok(duration) => ConnectionId(((duration.as_secs() / 3600) | ((remote_address.port() as u64) << 36)) as i64), - Err(_) => ConnectionId(0x7FFFFFFFFFFFFFFF), - } + ConnectionId(((current_time() / 3600) | ((remote_address.port() as u64) << 36)) as i64) } pub fn current_time() -> u64 { - SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + DefaultClock::now().as_secs() } -pub fn ser_instant(inst: &std::time::Instant, ser: S) -> Result { - ser.serialize_u64(inst.elapsed().as_millis() as u64) +pub fn ser_unix_time_value(unix_time_value: &SinceUnixEpoch, ser: S) -> Result { + ser.serialize_u64(unix_time_value.as_millis() as u64) } diff --git a/src/tracker/key.rs b/src/tracker/key.rs index f935dac0..8ba19ab1 100644 --- a/src/tracker/key.rs +++ b/src/tracker/key.rs @@ -1,29 +1,31 @@ +use std::time::Duration; + use derive_more::{Display, Error}; use log::debug; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; use serde::Serialize; -use crate::protocol::utils::current_time; +use crate::protocol::clock::clock::{DefaultClock, SinceUnixEpoch, Time, TimeNow}; use crate::AUTH_KEY_LENGTH; -pub fn generate_auth_key(seconds_valid: u64) -> AuthKey { +pub fn generate_auth_key(lifetime: Duration) -> AuthKey { let key: String = thread_rng() .sample_iter(&Alphanumeric) .take(AUTH_KEY_LENGTH) .map(char::from) .collect(); - debug!("Generated key: {}, valid for: {} seconds", key, seconds_valid); + debug!("Generated key: {}, valid for: {:?} seconds", key, lifetime); AuthKey { key, - valid_until: Some(current_time() + seconds_valid), + valid_until: Some(DefaultClock::add(&lifetime).unwrap()), } } pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> { - let current_time = current_time(); + let current_time: SinceUnixEpoch = DefaultClock::now(); if auth_key.valid_until.is_none() { return Err(Error::KeyInvalid); } @@ -37,7 +39,7 @@ pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> { #[derive(Serialize, Debug, Eq, PartialEq, Clone)] pub struct AuthKey { pub key: String, - pub valid_until: Option, + pub valid_until: Option, } impl AuthKey { @@ -81,6 +83,9 @@ impl From for Error { #[cfg(test)] mod tests { + use std::time::Duration; + + use crate::protocol::clock::clock::{DefaultClock, StoppedTime}; use crate::tracker::key; #[test] @@ -105,15 +110,26 @@ mod tests { #[test] fn generate_valid_auth_key() { - let auth_key = key::generate_auth_key(9999); + let auth_key = key::generate_auth_key(Duration::new(9999, 0)); assert!(key::verify_auth_key(&auth_key).is_ok()); } #[test] - fn generate_expired_auth_key() { - let mut auth_key = key::generate_auth_key(0); - auth_key.valid_until = Some(0); + fn generate_and_check_expired_auth_key() { + // Set the time to the current time. + DefaultClock::local_set_to_system_time_now(); + + // Make key that is valid for 19 seconds. + let auth_key = key::generate_auth_key(Duration::from_secs(19)); + + // Mock the time has passed 10 sec. + DefaultClock::local_add(&Duration::from_secs(10)).unwrap(); + + assert!(key::verify_auth_key(&auth_key).is_ok()); + + // Mock the time has passed another 10 sec. + DefaultClock::local_add(&Duration::from_secs(10)).unwrap(); assert!(key::verify_auth_key(&auth_key).is_err()); } diff --git a/src/tracker/peer.rs b/src/tracker/peer.rs index 0514f41e..b37090b8 100644 --- a/src/tracker/peer.rs +++ b/src/tracker/peer.rs @@ -5,16 +5,17 @@ use serde; use serde::Serialize; use crate::http::AnnounceRequest; +use crate::protocol::clock::clock::{DefaultClock, SinceUnixEpoch, Time}; use crate::protocol::common::{AnnounceEventDef, NumberOfBytesDef}; -use crate::protocol::utils::ser_instant; +use crate::protocol::utils::ser_unix_time_value; use crate::PeerId; #[derive(PartialEq, Eq, Debug, Clone, Serialize)] pub struct TorrentPeer { pub peer_id: PeerId, pub peer_addr: SocketAddr, - #[serde(serialize_with = "ser_instant")] - pub updated: std::time::Instant, + #[serde(serialize_with = "ser_unix_time_value")] + pub updated: SinceUnixEpoch, #[serde(with = "NumberOfBytesDef")] pub uploaded: NumberOfBytes, #[serde(with = "NumberOfBytesDef")] @@ -36,7 +37,7 @@ impl TorrentPeer { TorrentPeer { peer_id: PeerId(announce_request.peer_id.0), peer_addr, - updated: std::time::Instant::now(), + updated: DefaultClock::now(), uploaded: announce_request.bytes_uploaded, downloaded: announce_request.bytes_downloaded, left: announce_request.bytes_left, @@ -65,7 +66,7 @@ impl TorrentPeer { TorrentPeer { peer_id: announce_request.peer_id.clone(), peer_addr, - updated: std::time::Instant::now(), + updated: DefaultClock::now(), uploaded: NumberOfBytes(announce_request.uploaded as i64), downloaded: NumberOfBytes(announce_request.downloaded as i64), left: NumberOfBytes(announce_request.left as i64), diff --git a/src/tracker/torrent.rs b/src/tracker/torrent.rs index 7950ce9c..b08f0326 100644 --- a/src/tracker/torrent.rs +++ b/src/tracker/torrent.rs @@ -1,9 +1,11 @@ use std::net::{IpAddr, SocketAddr}; +use std::time::Duration; use aquatic_udp_protocol::AnnounceEvent; use serde::{Deserialize, Serialize}; use crate::peer::TorrentPeer; +use crate::protocol::clock::clock::{DefaultClock, TimeNow}; use crate::{PeerId, MAX_SCRAPE_TORRENTS}; #[derive(Serialize, Deserialize, Clone)] @@ -75,8 +77,8 @@ impl TorrentEntry { } pub fn remove_inactive_peers(&mut self, max_peer_timeout: u32) { - self.peers - .retain(|_, peer| peer.updated.elapsed() > std::time::Duration::from_secs(max_peer_timeout as u64)); + let current_cutoff = DefaultClock::sub(&Duration::from_secs(max_peer_timeout as u64)).unwrap_or_default(); + self.peers.retain(|_, peer| peer.updated > current_cutoff); } } diff --git a/src/tracker/tracker.rs b/src/tracker/tracker.rs index 51d7716f..9a242e41 100644 --- a/src/tracker/tracker.rs +++ b/src/tracker/tracker.rs @@ -2,6 +2,7 @@ use std::collections::btree_map::Entry; use std::collections::BTreeMap; use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use tokio::sync::mpsc::error::SendError; use tokio::sync::{RwLock, RwLockReadGuard}; @@ -60,8 +61,8 @@ impl TorrentTracker { self.mode == TrackerMode::Listed || self.mode == TrackerMode::PrivateListed } - pub async fn generate_auth_key(&self, seconds_valid: u64) -> Result { - let auth_key = key::generate_auth_key(seconds_valid); + pub async fn generate_auth_key(&self, lifetime: Duration) -> Result { + let auth_key = key::generate_auth_key(lifetime); self.database.add_key_to_keys(&auth_key).await?; self.keys.write().await.insert(auth_key.key.clone(), auth_key.clone()); Ok(auth_key) diff --git a/src/udp/errors.rs b/src/udp/errors.rs index fb29e969..8d7b04b4 100644 --- a/src/udp/errors.rs +++ b/src/udp/errors.rs @@ -8,6 +8,9 @@ pub enum ServerError { #[error("info_hash is either missing or invalid")] InvalidInfoHash, + #[error("connection id could not be verified")] + InvalidConnectionId, + #[error("could not find remote address")] AddressNotFound,