diff --git a/src/api/server.rs b/src/api/server.rs index cc6c905e4..5285c9b2b 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -2,6 +2,7 @@ use std::cmp::min; use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use serde::{Deserialize, Serialize}; use warp::{filters, reply, serve, Filter}; @@ -268,7 +269,7 @@ pub fn start(socket_addr: SocketAddr, tracker: Arc) -> impl warp (seconds_valid, tracker) }) .and_then(|(seconds_valid, tracker): (u64, Arc)| async move { - match tracker.generate_auth_key(seconds_valid).await { + match tracker.generate_auth_key(Duration::from_secs(seconds_valid)).await { Ok(auth_key) => Ok(warp::reply::json(&auth_key)), Err(..) => Err(warp::reject::custom(ActionStatus::Err { reason: "failed to generate key".into(), diff --git a/src/databases/mysql.rs b/src/databases/mysql.rs index 882fb7bf4..33287df6d 100644 --- a/src/databases/mysql.rs +++ b/src/databases/mysql.rs @@ -1,4 +1,5 @@ use std::str::FromStr; +use std::time::Duration; use async_trait::async_trait; use log::debug; @@ -94,7 +95,7 @@ impl Database for MysqlDatabase { "SELECT `key`, valid_until FROM `keys`", |(key, valid_until): (String, i64)| AuthKey { key, - valid_until: Some(valid_until as u64), + valid_until: Some(Duration::from_secs(valid_until as u64)), }, ) .map_err(|_| database::Error::QueryReturnedNoRows)?; @@ -187,7 +188,7 @@ impl Database for MysqlDatabase { { Some((key, valid_until)) => Ok(AuthKey { key, - valid_until: Some(valid_until as u64), + valid_until: Some(Duration::from_secs(valid_until as u64)), }), None => Err(database::Error::InvalidQuery), } @@ -197,7 +198,7 @@ impl Database for MysqlDatabase { let mut conn = self.pool.get().map_err(|_| database::Error::DatabaseError)?; let key = auth_key.key.to_string(); - let valid_until = auth_key.valid_until.unwrap_or(0).to_string(); + let valid_until = auth_key.valid_until.unwrap_or(Duration::ZERO).as_secs().to_string(); match conn.exec_drop( "INSERT INTO `keys` (`key`, valid_until) VALUES (:key, :valid_until)", diff --git a/src/databases/sqlite.rs b/src/databases/sqlite.rs index 3aba39919..25e3aae14 100644 --- a/src/databases/sqlite.rs +++ b/src/databases/sqlite.rs @@ -7,6 +7,7 @@ use r2d2_sqlite::SqliteConnectionManager; use crate::databases::database; use crate::databases::database::{Database, Error}; +use crate::protocol::clock::SinceUnixEpoch; use crate::tracker::key::AuthKey; use crate::InfoHash; @@ -85,7 +86,7 @@ impl Database for SqliteDatabase { Ok(AuthKey { key, - valid_until: Some(valid_until as u64), + valid_until: Some(SinceUnixEpoch::from_secs(valid_until as u64)), }) })?; @@ -192,7 +193,7 @@ impl Database for SqliteDatabase { Ok(AuthKey { key, - valid_until: Some(valid_until_i64 as u64), + valid_until: Some(SinceUnixEpoch::from_secs(valid_until_i64 as u64)), }) } else { Err(database::Error::QueryReturnedNoRows) @@ -204,7 +205,7 @@ impl Database for SqliteDatabase { match conn.execute( "INSERT INTO keys (key, valid_until) VALUES (?1, ?2)", - [auth_key.key.to_string(), auth_key.valid_until.unwrap().to_string()], + [auth_key.key.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()], ) { Ok(updated) => { if updated > 0 { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 99cfd91e4..fcb28b3b2 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,2 +1,3 @@ +pub mod clock; pub mod common; pub mod utils; diff --git a/src/protocol/utils.rs b/src/protocol/utils.rs index e50c8b036..cd015dc5e 100644 --- a/src/protocol/utils.rs +++ b/src/protocol/utils.rs @@ -1,19 +1,49 @@ +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use std::net::SocketAddr; -use std::time::SystemTime; +use std::ops::Mul; +use std::time::Duration; use aquatic_udp_protocol::ConnectionId; -pub fn get_connection_id(remote_address: &SocketAddr) -> ConnectionId { - match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { - Ok(duration) => ConnectionId(((duration.as_secs() / 3600) | ((remote_address.port() as u64) << 36)) as i64), - Err(_) => ConnectionId(0x7FFFFFFFFFFFFFFF), - } +use super::clock::{DefaultClock, SinceUnixEpoch, Time}; +use crate::udp::ServerError; + +pub fn make_connection_cookie(lifetime: &Duration, remote_address: &SocketAddr) -> ConnectionId { + let period = DefaultClock::now_add_periods(&lifetime, &lifetime); + + let mut hasher = DefaultHasher::new(); + + remote_address.hash(&mut hasher); + period.hash(&mut hasher); + + let connection_id_cookie = i64::from_le_bytes(hasher.finish().to_le_bytes()); + + ConnectionId(connection_id_cookie) } -pub fn current_time() -> u64 { - SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() +pub fn check_connection_cookie( + lifetime: &Duration, + remote_address: &SocketAddr, + connection_id: &ConnectionId, +) -> Result<(), ServerError> { + for n in 0..=1 { + let period = DefaultClock::now_add_periods(&lifetime.mul(n), &lifetime); + + let mut hasher = DefaultHasher::new(); + + remote_address.hash(&mut hasher); + period.hash(&mut hasher); + + let connection_id_cookie = i64::from_le_bytes(hasher.finish().to_le_bytes()); + + if (*connection_id).0 == connection_id_cookie { + return Ok(()); + } + } + Err(ServerError::InvalidConnectionId) } -pub fn ser_instant(inst: &std::time::Instant, ser: S) -> Result { - ser.serialize_u64(inst.elapsed().as_millis() as u64) +pub fn ser_unix_time_value(unix_time_value: &SinceUnixEpoch, ser: S) -> Result { + ser.serialize_u64(unix_time_value.as_millis() as u64) } diff --git a/src/tracker/key.rs b/src/tracker/key.rs index f935dac07..6276f19d3 100644 --- a/src/tracker/key.rs +++ b/src/tracker/key.rs @@ -1,29 +1,31 @@ +use std::time::Duration; + use derive_more::{Display, Error}; use log::debug; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; use serde::Serialize; -use crate::protocol::utils::current_time; +use crate::protocol::clock::{DefaultClock, SinceUnixEpoch, Time}; use crate::AUTH_KEY_LENGTH; -pub fn generate_auth_key(seconds_valid: u64) -> AuthKey { +pub fn generate_auth_key(lifetime: Duration) -> AuthKey { let key: String = thread_rng() .sample_iter(&Alphanumeric) .take(AUTH_KEY_LENGTH) .map(char::from) .collect(); - debug!("Generated key: {}, valid for: {} seconds", key, seconds_valid); + debug!("Generated key: {}, valid for: {:?} seconds", key, lifetime); AuthKey { key, - valid_until: Some(current_time() + seconds_valid), + valid_until: Some(DefaultClock::now_add(&lifetime).unwrap()), } } pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> { - let current_time = current_time(); + let current_time: SinceUnixEpoch = DefaultClock::now(); if auth_key.valid_until.is_none() { return Err(Error::KeyInvalid); } @@ -37,7 +39,7 @@ pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> { #[derive(Serialize, Debug, Eq, PartialEq, Clone)] pub struct AuthKey { pub key: String, - pub valid_until: Option, + pub valid_until: Option, } impl AuthKey { @@ -81,6 +83,9 @@ impl From for Error { #[cfg(test)] mod tests { + use std::time::Duration; + + use crate::protocol::clock::{DefaultClock, StoppedTime}; use crate::tracker::key; #[test] @@ -105,15 +110,26 @@ mod tests { #[test] fn generate_valid_auth_key() { - let auth_key = key::generate_auth_key(9999); + let auth_key = key::generate_auth_key(Duration::new(9999, 0)); assert!(key::verify_auth_key(&auth_key).is_ok()); } #[test] - fn generate_expired_auth_key() { - let mut auth_key = key::generate_auth_key(0); - auth_key.valid_until = Some(0); + fn generate_and_check_expired_auth_key() { + // Set the time to the current time. + DefaultClock::local_set_to_system_time_now(); + + // Make key that is valid for 19 seconds. + let auth_key = key::generate_auth_key(Duration::from_secs(19)); + + // Mock the time has passed 10 sec. + DefaultClock::local_add(&Duration::from_secs(10)).unwrap(); + + assert!(key::verify_auth_key(&auth_key).is_ok()); + + // Mock the time has passed another 10 sec. + DefaultClock::local_add(&Duration::from_secs(10)).unwrap(); assert!(key::verify_auth_key(&auth_key).is_err()); } diff --git a/src/tracker/peer.rs b/src/tracker/peer.rs index 0514f41ed..77200cd1e 100644 --- a/src/tracker/peer.rs +++ b/src/tracker/peer.rs @@ -5,16 +5,17 @@ use serde; use serde::Serialize; use crate::http::AnnounceRequest; +use crate::protocol::clock::{DefaultClock, SinceUnixEpoch, Time}; use crate::protocol::common::{AnnounceEventDef, NumberOfBytesDef}; -use crate::protocol::utils::ser_instant; +use crate::protocol::utils::ser_unix_time_value; use crate::PeerId; #[derive(PartialEq, Eq, Debug, Clone, Serialize)] pub struct TorrentPeer { pub peer_id: PeerId, pub peer_addr: SocketAddr, - #[serde(serialize_with = "ser_instant")] - pub updated: std::time::Instant, + #[serde(serialize_with = "ser_unix_time_value")] + pub updated: SinceUnixEpoch, #[serde(with = "NumberOfBytesDef")] pub uploaded: NumberOfBytes, #[serde(with = "NumberOfBytesDef")] @@ -36,7 +37,7 @@ impl TorrentPeer { TorrentPeer { peer_id: PeerId(announce_request.peer_id.0), peer_addr, - updated: std::time::Instant::now(), + updated: DefaultClock::now(), uploaded: announce_request.bytes_uploaded, downloaded: announce_request.bytes_downloaded, left: announce_request.bytes_left, @@ -65,7 +66,7 @@ impl TorrentPeer { TorrentPeer { peer_id: announce_request.peer_id.clone(), peer_addr, - updated: std::time::Instant::now(), + updated: DefaultClock::now(), uploaded: NumberOfBytes(announce_request.uploaded as i64), downloaded: NumberOfBytes(announce_request.downloaded as i64), left: NumberOfBytes(announce_request.left as i64), diff --git a/src/tracker/torrent.rs b/src/tracker/torrent.rs index 7950ce9c0..c41c01ee4 100644 --- a/src/tracker/torrent.rs +++ b/src/tracker/torrent.rs @@ -76,7 +76,7 @@ impl TorrentEntry { pub fn remove_inactive_peers(&mut self, max_peer_timeout: u32) { self.peers - .retain(|_, peer| peer.updated.elapsed() > std::time::Duration::from_secs(max_peer_timeout as u64)); + .retain(|_, peer| peer.updated > std::time::Duration::from_secs(max_peer_timeout as u64)); } } diff --git a/src/tracker/tracker.rs b/src/tracker/tracker.rs index 51d7716fb..9a242e41a 100644 --- a/src/tracker/tracker.rs +++ b/src/tracker/tracker.rs @@ -2,6 +2,7 @@ use std::collections::btree_map::Entry; use std::collections::BTreeMap; use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use tokio::sync::mpsc::error::SendError; use tokio::sync::{RwLock, RwLockReadGuard}; @@ -60,8 +61,8 @@ impl TorrentTracker { self.mode == TrackerMode::Listed || self.mode == TrackerMode::PrivateListed } - pub async fn generate_auth_key(&self, seconds_valid: u64) -> Result { - let auth_key = key::generate_auth_key(seconds_valid); + pub async fn generate_auth_key(&self, lifetime: Duration) -> Result { + let auth_key = key::generate_auth_key(lifetime); self.database.add_key_to_keys(&auth_key).await?; self.keys.write().await.insert(auth_key.key.clone(), auth_key.clone()); Ok(auth_key) diff --git a/src/udp/errors.rs b/src/udp/errors.rs index fb29e969e..8d7b04b4f 100644 --- a/src/udp/errors.rs +++ b/src/udp/errors.rs @@ -8,6 +8,9 @@ pub enum ServerError { #[error("info_hash is either missing or invalid")] InvalidInfoHash, + #[error("connection id could not be verified")] + InvalidConnectionId, + #[error("could not find remote address")] AddressNotFound, diff --git a/src/udp/handlers.rs b/src/udp/handlers.rs index 907dac0bc..9b578d8b6 100644 --- a/src/udp/handlers.rs +++ b/src/udp/handlers.rs @@ -1,5 +1,6 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; +use std::time::Duration; use aquatic_udp_protocol::{ AnnounceInterval, AnnounceRequest, AnnounceResponse, ConnectRequest, ConnectResponse, ErrorResponse, NumberOfDownloads, @@ -7,7 +8,7 @@ use aquatic_udp_protocol::{ }; use crate::peer::TorrentPeer; -use crate::protocol::utils::get_connection_id; +use crate::protocol::utils::{check_connection_cookie, make_connection_cookie}; use crate::tracker::statistics::TrackerStatisticsEvent; use crate::tracker::torrent::TorrentError; use crate::tracker::tracker::TorrentTracker; @@ -69,7 +70,7 @@ pub async fn handle_connect( request: &ConnectRequest, tracker: Arc, ) -> Result { - let connection_id = get_connection_id(&remote_addr); + let connection_id = make_connection_cookie(&Duration::from_secs(120), &remote_addr); let response = Response::from(ConnectResponse { transaction_id: request.transaction_id, @@ -94,6 +95,13 @@ pub async fn handle_announce( announce_request: &AnnounceRequest, tracker: Arc, ) -> Result { + match check_connection_cookie(&Duration::from_secs(120), &remote_addr, &announce_request.connection_id) { + Ok(_) => {} + Err(e) => { + return Err(e); + } + } + let wrapped_announce_request = AnnounceRequestWrapper::new(announce_request.clone()); authenticate(&wrapped_announce_request.info_hash, tracker.clone()).await?;