Skip to content

Commit

Permalink
clock: implement the new clock, basic connection_cookie
Browse files Browse the repository at this point in the history
  • Loading branch information
da2ce7 committed Sep 10, 2022
1 parent d70b51a commit b3b6fc1
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 66 deletions.
3 changes: 2 additions & 1 deletion src/api/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::cmp::min;
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use serde::{Deserialize, Serialize};
use warp::{filters, reply, serve, Filter};
Expand Down Expand Up @@ -268,7 +269,7 @@ pub fn start(socket_addr: SocketAddr, tracker: Arc<TorrentTracker>) -> impl warp
(seconds_valid, tracker)
})
.and_then(|(seconds_valid, tracker): (u64, Arc<TorrentTracker>)| async move {
match tracker.generate_auth_key(seconds_valid).await {
match tracker.generate_auth_key(Duration::from_secs(seconds_valid)).await {
Ok(auth_key) => Ok(warp::reply::json(&auth_key)),
Err(..) => Err(warp::reject::custom(ActionStatus::Err {
reason: "failed to generate key".into(),
Expand Down
7 changes: 4 additions & 3 deletions src/databases/mysql.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::str::FromStr;
use std::time::Duration;

use async_trait::async_trait;
use log::debug;
Expand Down Expand Up @@ -94,7 +95,7 @@ impl Database for MysqlDatabase {
"SELECT `key`, valid_until FROM `keys`",
|(key, valid_until): (String, i64)| AuthKey {
key,
valid_until: Some(valid_until as u64),
valid_until: Some(Duration::from_secs(valid_until as u64)),
},
)
.map_err(|_| database::Error::QueryReturnedNoRows)?;
Expand Down Expand Up @@ -187,7 +188,7 @@ impl Database for MysqlDatabase {
{
Some((key, valid_until)) => Ok(AuthKey {
key,
valid_until: Some(valid_until as u64),
valid_until: Some(Duration::from_secs(valid_until as u64)),
}),
None => Err(database::Error::InvalidQuery),
}
Expand All @@ -197,7 +198,7 @@ impl Database for MysqlDatabase {
let mut conn = self.pool.get().map_err(|_| database::Error::DatabaseError)?;

let key = auth_key.key.to_string();
let valid_until = auth_key.valid_until.unwrap_or(0).to_string();
let valid_until = auth_key.valid_until.unwrap_or(Duration::ZERO).as_secs().to_string();

match conn.exec_drop(
"INSERT INTO `keys` (`key`, valid_until) VALUES (:key, :valid_until)",
Expand Down
7 changes: 4 additions & 3 deletions src/databases/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use r2d2_sqlite::SqliteConnectionManager;

use crate::databases::database;
use crate::databases::database::{Database, Error};
use crate::protocol::clock::UnixTimeValue;
use crate::tracker::key::AuthKey;
use crate::InfoHash;

Expand Down Expand Up @@ -85,7 +86,7 @@ impl Database for SqliteDatabase {

Ok(AuthKey {
key,
valid_until: Some(valid_until as u64),
valid_until: Some(UnixTimeValue::from_secs(valid_until as u64)),
})
})?;

Expand Down Expand Up @@ -192,7 +193,7 @@ impl Database for SqliteDatabase {

Ok(AuthKey {
key,
valid_until: Some(valid_until_i64 as u64),
valid_until: Some(UnixTimeValue::from_secs(valid_until_i64 as u64)),
})
} else {
Err(database::Error::QueryReturnedNoRows)
Expand All @@ -204,7 +205,7 @@ impl Database for SqliteDatabase {

match conn.execute(
"INSERT INTO keys (key, valid_until) VALUES (?1, ?2)",
[auth_key.key.to_string(), auth_key.valid_until.unwrap().to_string()],
[auth_key.key.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()],
) {
Ok(updated) => {
if updated > 0 {
Expand Down
74 changes: 45 additions & 29 deletions src/protocol/clock.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,47 @@
use std::{
cell::RefCell,
convert::TryInto,
ops::Div,
time::{Duration, SystemTime},
};

pub trait Time: Sized + Div + Into<u64> + TryInto<u32> {
use std::cell::RefCell;
use std::convert::TryInto;
use std::time::{Duration, SystemTime};

pub trait Time: Sized + Into<u128> + Into<u64> + TryInto<u32> {
fn now() -> Self;
fn after(period: &Duration) -> Self;
fn elapse(elapse: &Duration) -> Self;

fn elapse_sec(elapse: u64) -> Self {
Self::elapse(&Duration::new(elapse, 0))
}

fn after_sec(period: u64) -> Self {
Self::after(&Duration::new(period, 0))
fn periods_from_now(period: &Duration) -> u128 {
<Self as Into<u128>>::into(Self::now()) / period.as_nanos()
}

fn periods_after_elapse(elapse: &Duration, period: &Duration) -> u128 {
<Self as Into<u128>>::into(Self::elapse(elapse)) / period.as_nanos()
}
}
pub type UnixTimeValue = Duration;

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct UnixTime<const T: usize>(pub Duration);
pub struct UnixTime<const T: usize>(pub UnixTimeValue);

pub enum ClockType {
SystemTime,
FixedTime,
WorkingClock,
StoppedClock,
}

#[cfg(not(test))]
pub type DefaultTime = UnixTime<{ ClockType::SystemTime as usize }>;
pub type DefaultTime = UnixTime<{ ClockType::WorkingClock as usize }>;

#[cfg(test)]
pub type DefaultTime = UnixTime<{ ClockType::FixedTime as usize }>;
pub type DefaultTime = UnixTime<{ ClockType::StoppedClock as usize }>;

pub type CurrentTime = UnixTime<{ ClockType::SystemTime as usize }>;
pub type CurrentTime = UnixTime<{ ClockType::WorkingClock as usize }>;

impl Time for UnixTime<{ ClockType::SystemTime as usize }> {
impl Time for UnixTime<{ ClockType::WorkingClock as usize }> {
fn now() -> Self {
Self(SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap())
}

fn after(period: &Duration) -> Self {
fn elapse(period: &Duration) -> Self {
Self(
SystemTime::now()
.checked_add(*period)
Expand All @@ -48,7 +54,7 @@ impl Time for UnixTime<{ ClockType::SystemTime as usize }> {

thread_local!(static FIXED_TIME: RefCell<Duration> = RefCell::new(Duration::ZERO));

impl Time for UnixTime<{ ClockType::FixedTime as usize }> {
impl Time for UnixTime<{ ClockType::StoppedClock as usize }> {
fn now() -> Self {
let mut now: Duration = Duration::default();
FIXED_TIME.with(|time| {
Expand All @@ -58,7 +64,7 @@ impl Time for UnixTime<{ ClockType::FixedTime as usize }> {
Self(now)
}

fn after(period: &Duration) -> Self {
fn elapse(period: &Duration) -> Self {
let mut now: Duration = Duration::default();
FIXED_TIME.with(|time| {
now = *time.borrow();
Expand All @@ -68,21 +74,27 @@ impl Time for UnixTime<{ ClockType::FixedTime as usize }> {
}
}

impl UnixTime<{ ClockType::FixedTime as usize }> {
impl UnixTime<{ ClockType::StoppedClock as usize }> {
pub fn set_time(new_time: &Duration) {
FIXED_TIME.with(|time| {
*time.borrow_mut() = *new_time;
});
}

pub fn elapse_time(elapse: &Duration) {
FIXED_TIME.with(|time| {
*time.borrow_mut() += *elapse;
});
}

pub fn reset_time() {
Self::set_time(&Duration::ZERO)
}
}

impl<const T: usize> Div for UnixTime<{ T }> {
type Output = u128;
fn div(self, rhs: Self) -> Self::Output {
self.0.as_nanos() / rhs.0.as_nanos()
impl<const T: usize> Into<u128> for UnixTime<{ T }> {
fn into(self) -> u128 {
self.0.as_nanos()
}
}

Expand All @@ -106,7 +118,7 @@ mod tests {

use super::*;

type FixedTime = UnixTime<{ ClockType::FixedTime as usize }>;
type FixedTime = UnixTime<{ ClockType::StoppedClock as usize }>;

#[test]
fn fixed_time_and_default_time_should_be_the_same() {
Expand All @@ -125,7 +137,7 @@ mod tests {
}

#[test]
fn fixed_time_should_be_settable_and_resettable() {
fn fixed_time_should_be_settable_elapseable_and_resettable() {
// Check we start with ZERO.
assert_eq!(FixedTime::now().0, Duration::ZERO);

Expand All @@ -134,6 +146,10 @@ mod tests {
FixedTime::set_time(&timestamp.0);
assert_eq!(FixedTime::now().0, timestamp.0);

// Elapse the Current Time and Check
FixedTime::elapse_time(&timestamp.0);
assert_eq!(FixedTime::now().0, timestamp.0 + timestamp.0);

// Reset to ZERO and Check
FixedTime::reset_time();
assert_eq!(FixedTime::now().0, Duration::ZERO);
Expand All @@ -142,7 +158,7 @@ mod tests {
#[test]
fn fixed_time_should_default_to_zero_on_new_and_thread_exit() {
assert_eq!(FixedTime::now().0, Duration::ZERO);
let after5 = CurrentTime::after_sec(5);
let after5 = CurrentTime::elapse_sec(5);
FixedTime::set_time(&after5.0);
assert_eq!(FixedTime::now().0, after5.0);

Expand Down
49 changes: 39 additions & 10 deletions src/protocol/utils.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,48 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::net::SocketAddr;
use std::time::SystemTime;
use std::time::Duration;

use aquatic_udp_protocol::ConnectionId;

pub fn get_connection_id(remote_address: &SocketAddr) -> ConnectionId {
match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
Ok(duration) => ConnectionId(((duration.as_secs() / 3600) | ((remote_address.port() as u64) << 36)) as i64),
Err(_) => ConnectionId(0x7FFFFFFFFFFFFFFF),
}
use super::clock::{DefaultTime, Time, UnixTimeValue};
use crate::udp::ServerError;

pub fn make_connection_cookie(lifetime: &Duration, remote_address: &SocketAddr) -> ConnectionId {
let period = DefaultTime::periods_after_elapse(&lifetime, &lifetime);

let mut hasher = DefaultHasher::new();

remote_address.hash(&mut hasher);
hasher.write_u128(period);

let connection_id_cookie = i64::from_le_bytes(hasher.finish().to_le_bytes());

ConnectionId(connection_id_cookie)
}

pub fn current_time() -> u64 {
SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs()
pub fn check_connection_cookie(
lifetime: &Duration,
remote_address: &SocketAddr,
connection_id: &ConnectionId,
) -> Result<(), ServerError> {
let period = DefaultTime::periods_after_elapse(&lifetime, &lifetime);

for n in 0..=1 {
let mut hasher = DefaultHasher::new();

remote_address.hash(&mut hasher);
hasher.write_u128(period.saturating_sub(n));

let connection_id_cookie = i64::from_le_bytes(hasher.finish().to_le_bytes());

if (*connection_id).0 == connection_id_cookie {
return Ok(());
}
}
Err(ServerError::InvalidConnectionId)
}

pub fn ser_instant<S: serde::Serializer>(inst: &std::time::Instant, ser: S) -> Result<S::Ok, S::Error> {
ser.serialize_u64(inst.elapsed().as_millis() as u64)
pub fn ser_unix_time_value<S: serde::Serializer>(unix_time_value: &UnixTimeValue, ser: S) -> Result<S::Ok, S::Error> {
ser.serialize_u64(unix_time_value.as_millis() as u64)
}
36 changes: 26 additions & 10 deletions src/tracker/key.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
use std::time::Duration;

use derive_more::{Display, Error};
use log::debug;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use serde::Serialize;

use crate::protocol::utils::current_time;
use crate::protocol::clock::{DefaultTime, Time, UnixTimeValue};
use crate::AUTH_KEY_LENGTH;

pub fn generate_auth_key(seconds_valid: u64) -> AuthKey {
pub fn generate_auth_key(lifetime: Duration) -> AuthKey {
let key: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(AUTH_KEY_LENGTH)
.map(char::from)
.collect();

debug!("Generated key: {}, valid for: {} seconds", key, seconds_valid);
debug!("Generated key: {}, valid for: {:?} seconds", key, lifetime);

AuthKey {
key,
valid_until: Some(current_time() + seconds_valid),
valid_until: Some(DefaultTime::elapse(&lifetime).0),
}
}

pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> {
let current_time = current_time();
let current_time: UnixTimeValue = DefaultTime::now().0;
if auth_key.valid_until.is_none() {
return Err(Error::KeyInvalid);
}
Expand All @@ -37,7 +39,7 @@ pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> {
#[derive(Serialize, Debug, Eq, PartialEq, Clone)]
pub struct AuthKey {
pub key: String,
pub valid_until: Option<u64>,
pub valid_until: Option<UnixTimeValue>,
}

impl AuthKey {
Expand Down Expand Up @@ -81,6 +83,9 @@ impl From<r2d2_sqlite::rusqlite::Error> for Error {

#[cfg(test)]
mod tests {
use std::time::Duration;

use crate::protocol::clock::{CurrentTime, DefaultTime, Time, UnixTimeValue};
use crate::tracker::key;

#[test]
Expand All @@ -105,15 +110,26 @@ mod tests {

#[test]
fn generate_valid_auth_key() {
let auth_key = key::generate_auth_key(9999);
let auth_key = key::generate_auth_key(Duration::new(9999, 0));

assert!(key::verify_auth_key(&auth_key).is_ok());
}

#[test]
fn generate_expired_auth_key() {
let mut auth_key = key::generate_auth_key(0);
auth_key.valid_until = Some(0);
fn generate_and_check_expired_auth_key() {
// Set the time to the current time.
DefaultTime::set_time(&CurrentTime::now().0);

// Make key that is valid for 19 seconds.
let auth_key = key::generate_auth_key(Duration::from_secs(19));

// Mock the time has passed 10 sec.
DefaultTime::elapse_time(&UnixTimeValue::from_secs(10));

assert!(key::verify_auth_key(&auth_key).is_ok());

// Mock the time has passed another 10 sec.
DefaultTime::elapse_time(&UnixTimeValue::from_secs(10));

assert!(key::verify_auth_key(&auth_key).is_err());
}
Expand Down
Loading

0 comments on commit b3b6fc1

Please sign in to comment.