Skip to content

Commit

Permalink
clock: use the clock in the current implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
da2ce7 committed Sep 11, 2022
1 parent 2a58a9e commit c123da0
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 37 deletions.
3 changes: 2 additions & 1 deletion src/api/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::cmp::min;
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use serde::{Deserialize, Serialize};
use warp::{filters, reply, serve, Filter};
Expand Down Expand Up @@ -268,7 +269,7 @@ pub fn start(socket_addr: SocketAddr, tracker: Arc<TorrentTracker>) -> impl warp
(seconds_valid, tracker)
})
.and_then(|(seconds_valid, tracker): (u64, Arc<TorrentTracker>)| async move {
match tracker.generate_auth_key(seconds_valid).await {
match tracker.generate_auth_key(Duration::from_secs(seconds_valid)).await {
Ok(auth_key) => Ok(warp::reply::json(&auth_key)),
Err(..) => Err(warp::reject::custom(ActionStatus::Err {
reason: "failed to generate key".into(),
Expand Down
7 changes: 4 additions & 3 deletions src/databases/mysql.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::str::FromStr;
use std::time::Duration;

use async_trait::async_trait;
use log::debug;
Expand Down Expand Up @@ -94,7 +95,7 @@ impl Database for MysqlDatabase {
"SELECT `key`, valid_until FROM `keys`",
|(key, valid_until): (String, i64)| AuthKey {
key,
valid_until: Some(valid_until as u64),
valid_until: Some(Duration::from_secs(valid_until as u64)),
},
)
.map_err(|_| database::Error::QueryReturnedNoRows)?;
Expand Down Expand Up @@ -187,7 +188,7 @@ impl Database for MysqlDatabase {
{
Some((key, valid_until)) => Ok(AuthKey {
key,
valid_until: Some(valid_until as u64),
valid_until: Some(Duration::from_secs(valid_until as u64)),
}),
None => Err(database::Error::InvalidQuery),
}
Expand All @@ -197,7 +198,7 @@ impl Database for MysqlDatabase {
let mut conn = self.pool.get().map_err(|_| database::Error::DatabaseError)?;

let key = auth_key.key.to_string();
let valid_until = auth_key.valid_until.unwrap_or(0).to_string();
let valid_until = auth_key.valid_until.unwrap_or(Duration::ZERO).as_secs().to_string();

match conn.exec_drop(
"INSERT INTO `keys` (`key`, valid_until) VALUES (:key, :valid_until)",
Expand Down
7 changes: 4 additions & 3 deletions src/databases/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use r2d2_sqlite::SqliteConnectionManager;

use crate::databases::database;
use crate::databases::database::{Database, Error};
use crate::protocol::clock::SinceUnixEpoch;
use crate::tracker::key::AuthKey;
use crate::InfoHash;

Expand Down Expand Up @@ -85,7 +86,7 @@ impl Database for SqliteDatabase {

Ok(AuthKey {
key,
valid_until: Some(valid_until as u64),
valid_until: Some(SinceUnixEpoch::from_secs(valid_until as u64)),
})
})?;

Expand Down Expand Up @@ -192,7 +193,7 @@ impl Database for SqliteDatabase {

Ok(AuthKey {
key,
valid_until: Some(valid_until_i64 as u64),
valid_until: Some(SinceUnixEpoch::from_secs(valid_until_i64 as u64)),
})
} else {
Err(database::Error::QueryReturnedNoRows)
Expand All @@ -204,7 +205,7 @@ impl Database for SqliteDatabase {

match conn.execute(
"INSERT INTO keys (key, valid_until) VALUES (?1, ?2)",
[auth_key.key.to_string(), auth_key.valid_until.unwrap().to_string()],
[auth_key.key.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()],
) {
Ok(updated) => {
if updated > 0 {
Expand Down
1 change: 1 addition & 0 deletions src/protocol/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod clock;
pub mod common;
pub mod utils;
50 changes: 40 additions & 10 deletions src/protocol/utils.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,49 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::net::SocketAddr;
use std::time::SystemTime;
use std::ops::Mul;
use std::time::Duration;

use aquatic_udp_protocol::ConnectionId;

pub fn get_connection_id(remote_address: &SocketAddr) -> ConnectionId {
match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
Ok(duration) => ConnectionId(((duration.as_secs() / 3600) | ((remote_address.port() as u64) << 36)) as i64),
Err(_) => ConnectionId(0x7FFFFFFFFFFFFFFF),
}
use super::clock::{DefaultClock, SinceUnixEpoch, Time};
use crate::udp::ServerError;

pub fn make_connection_cookie(lifetime: &Duration, remote_address: &SocketAddr) -> ConnectionId {
let period = DefaultClock::now_add_periods(&lifetime, &lifetime);

let mut hasher = DefaultHasher::new();

remote_address.hash(&mut hasher);
period.hash(&mut hasher);

let connection_id_cookie = i64::from_le_bytes(hasher.finish().to_le_bytes());

ConnectionId(connection_id_cookie)
}

pub fn current_time() -> u64 {
SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs()
pub fn check_connection_cookie(
lifetime: &Duration,
remote_address: &SocketAddr,
connection_id: &ConnectionId,
) -> Result<(), ServerError> {
for n in 0..=1 {
let period = DefaultClock::now_add_periods(&lifetime.mul(n), &lifetime);

let mut hasher = DefaultHasher::new();

remote_address.hash(&mut hasher);
period.hash(&mut hasher);

let connection_id_cookie = i64::from_le_bytes(hasher.finish().to_le_bytes());

if (*connection_id).0 == connection_id_cookie {
return Ok(());
}
}
Err(ServerError::InvalidConnectionId)
}

pub fn ser_instant<S: serde::Serializer>(inst: &std::time::Instant, ser: S) -> Result<S::Ok, S::Error> {
ser.serialize_u64(inst.elapsed().as_millis() as u64)
pub fn ser_unix_time_value<S: serde::Serializer>(unix_time_value: &SinceUnixEpoch, ser: S) -> Result<S::Ok, S::Error> {
ser.serialize_u64(unix_time_value.as_millis() as u64)
}
36 changes: 26 additions & 10 deletions src/tracker/key.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
use std::time::Duration;

use derive_more::{Display, Error};
use log::debug;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use serde::Serialize;

use crate::protocol::utils::current_time;
use crate::protocol::clock::{DefaultClock, SinceUnixEpoch, Time};
use crate::AUTH_KEY_LENGTH;

pub fn generate_auth_key(seconds_valid: u64) -> AuthKey {
pub fn generate_auth_key(lifetime: Duration) -> AuthKey {
let key: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(AUTH_KEY_LENGTH)
.map(char::from)
.collect();

debug!("Generated key: {}, valid for: {} seconds", key, seconds_valid);
debug!("Generated key: {}, valid for: {:?} seconds", key, lifetime);

AuthKey {
key,
valid_until: Some(current_time() + seconds_valid),
valid_until: Some(DefaultClock::now_add(&lifetime).unwrap()),
}
}

pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> {
let current_time = current_time();
let current_time: SinceUnixEpoch = DefaultClock::now();
if auth_key.valid_until.is_none() {
return Err(Error::KeyInvalid);
}
Expand All @@ -37,7 +39,7 @@ pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> {
#[derive(Serialize, Debug, Eq, PartialEq, Clone)]
pub struct AuthKey {
pub key: String,
pub valid_until: Option<u64>,
pub valid_until: Option<SinceUnixEpoch>,
}

impl AuthKey {
Expand Down Expand Up @@ -81,6 +83,9 @@ impl From<r2d2_sqlite::rusqlite::Error> for Error {

#[cfg(test)]
mod tests {
use std::time::Duration;

use crate::protocol::clock::{DefaultClock, StoppedTime};
use crate::tracker::key;

#[test]
Expand All @@ -105,15 +110,26 @@ mod tests {

#[test]
fn generate_valid_auth_key() {
let auth_key = key::generate_auth_key(9999);
let auth_key = key::generate_auth_key(Duration::new(9999, 0));

assert!(key::verify_auth_key(&auth_key).is_ok());
}

#[test]
fn generate_expired_auth_key() {
let mut auth_key = key::generate_auth_key(0);
auth_key.valid_until = Some(0);
fn generate_and_check_expired_auth_key() {
// Set the time to the current time.
DefaultClock::local_set_to_system_time_now();

// Make key that is valid for 19 seconds.
let auth_key = key::generate_auth_key(Duration::from_secs(19));

// Mock the time has passed 10 sec.
DefaultClock::local_add(&Duration::from_secs(10)).unwrap();

assert!(key::verify_auth_key(&auth_key).is_ok());

// Mock the time has passed another 10 sec.
DefaultClock::local_add(&Duration::from_secs(10)).unwrap();

assert!(key::verify_auth_key(&auth_key).is_err());
}
Expand Down
11 changes: 6 additions & 5 deletions src/tracker/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ use serde;
use serde::Serialize;

use crate::http::AnnounceRequest;
use crate::protocol::clock::{DefaultClock, SinceUnixEpoch, Time};
use crate::protocol::common::{AnnounceEventDef, NumberOfBytesDef};
use crate::protocol::utils::ser_instant;
use crate::protocol::utils::ser_unix_time_value;
use crate::PeerId;

#[derive(PartialEq, Eq, Debug, Clone, Serialize)]
pub struct TorrentPeer {
pub peer_id: PeerId,
pub peer_addr: SocketAddr,
#[serde(serialize_with = "ser_instant")]
pub updated: std::time::Instant,
#[serde(serialize_with = "ser_unix_time_value")]
pub updated: SinceUnixEpoch,
#[serde(with = "NumberOfBytesDef")]
pub uploaded: NumberOfBytes,
#[serde(with = "NumberOfBytesDef")]
Expand All @@ -36,7 +37,7 @@ impl TorrentPeer {
TorrentPeer {
peer_id: PeerId(announce_request.peer_id.0),
peer_addr,
updated: std::time::Instant::now(),
updated: DefaultClock::now(),
uploaded: announce_request.bytes_uploaded,
downloaded: announce_request.bytes_downloaded,
left: announce_request.bytes_left,
Expand Down Expand Up @@ -65,7 +66,7 @@ impl TorrentPeer {
TorrentPeer {
peer_id: announce_request.peer_id.clone(),
peer_addr,
updated: std::time::Instant::now(),
updated: DefaultClock::now(),
uploaded: NumberOfBytes(announce_request.uploaded as i64),
downloaded: NumberOfBytes(announce_request.downloaded as i64),
left: NumberOfBytes(announce_request.left as i64),
Expand Down
2 changes: 1 addition & 1 deletion src/tracker/torrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl TorrentEntry {

pub fn remove_inactive_peers(&mut self, max_peer_timeout: u32) {
self.peers
.retain(|_, peer| peer.updated.elapsed() > std::time::Duration::from_secs(max_peer_timeout as u64));
.retain(|_, peer| peer.updated > std::time::Duration::from_secs(max_peer_timeout as u64));
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/tracker/tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::btree_map::Entry;
use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use tokio::sync::mpsc::error::SendError;
use tokio::sync::{RwLock, RwLockReadGuard};
Expand Down Expand Up @@ -60,8 +61,8 @@ impl TorrentTracker {
self.mode == TrackerMode::Listed || self.mode == TrackerMode::PrivateListed
}

pub async fn generate_auth_key(&self, seconds_valid: u64) -> Result<AuthKey, database::Error> {
let auth_key = key::generate_auth_key(seconds_valid);
pub async fn generate_auth_key(&self, lifetime: Duration) -> Result<AuthKey, database::Error> {
let auth_key = key::generate_auth_key(lifetime);
self.database.add_key_to_keys(&auth_key).await?;
self.keys.write().await.insert(auth_key.key.clone(), auth_key.clone());
Ok(auth_key)
Expand Down
3 changes: 3 additions & 0 deletions src/udp/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ pub enum ServerError {
#[error("info_hash is either missing or invalid")]
InvalidInfoHash,

#[error("connection id could not be verified")]
InvalidConnectionId,

#[error("could not find remote address")]
AddressNotFound,

Expand Down
12 changes: 10 additions & 2 deletions src/udp/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;

use aquatic_udp_protocol::{
AnnounceInterval, AnnounceRequest, AnnounceResponse, ConnectRequest, ConnectResponse, ErrorResponse, NumberOfDownloads,
NumberOfPeers, Port, Request, Response, ResponsePeer, ScrapeRequest, ScrapeResponse, TorrentScrapeStatistics, TransactionId,
};

use crate::peer::TorrentPeer;
use crate::protocol::utils::get_connection_id;
use crate::protocol::utils::{check_connection_cookie, make_connection_cookie};
use crate::tracker::statistics::TrackerStatisticsEvent;
use crate::tracker::torrent::TorrentError;
use crate::tracker::tracker::TorrentTracker;
Expand Down Expand Up @@ -69,7 +70,7 @@ pub async fn handle_connect(
request: &ConnectRequest,
tracker: Arc<TorrentTracker>,
) -> Result<Response, ServerError> {
let connection_id = get_connection_id(&remote_addr);
let connection_id = make_connection_cookie(&Duration::from_secs(120), &remote_addr);

let response = Response::from(ConnectResponse {
transaction_id: request.transaction_id,
Expand All @@ -94,6 +95,13 @@ pub async fn handle_announce(
announce_request: &AnnounceRequest,
tracker: Arc<TorrentTracker>,
) -> Result<Response, ServerError> {
match check_connection_cookie(&Duration::from_secs(120), &remote_addr, &announce_request.connection_id) {
Ok(_) => {}
Err(e) => {
return Err(e);
}
}

let wrapped_announce_request = AnnounceRequestWrapper::new(announce_request.clone());

authenticate(&wrapped_announce_request.info_hash, tracker.clone()).await?;
Expand Down

0 comments on commit c123da0

Please sign in to comment.