Skip to content

Commit

Permalink
refactor: [#171] use KeyId in auth:Key
Browse files Browse the repository at this point in the history
The struct `KeyId` was extracted to wrap the primitive type but it was not
being used in the `auth::Key` struct.
  • Loading branch information
josecelano committed Feb 27, 2023
1 parent 12a42b7 commit 7cdd63e
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 69 deletions.
14 changes: 7 additions & 7 deletions src/apis/resources/auth_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ use std::convert::From;
use serde::{Deserialize, Serialize};

use crate::protocol::clock::DurationSinceUnixEpoch;
use crate::tracker::auth;
use crate::tracker::auth::{self, KeyId};

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct AuthKey {
pub key: String,
pub key: String, // todo: rename to `id`
pub valid_until: Option<u64>,
}

impl From<AuthKey> for auth::Key {
fn from(auth_key_resource: AuthKey) -> Self {
auth::Key {
key: auth_key_resource.key,
id: auth_key_resource.key.parse::<KeyId>().unwrap(),
valid_until: auth_key_resource
.valid_until
.map(|valid_until| DurationSinceUnixEpoch::new(valid_until, 0)),
Expand All @@ -25,7 +25,7 @@ impl From<AuthKey> for auth::Key {
impl From<auth::Key> for AuthKey {
fn from(auth_key: auth::Key) -> Self {
AuthKey {
key: auth_key.key,
key: auth_key.id.to_string(),
valid_until: auth_key.valid_until.map(|valid_until| valid_until.as_secs()),
}
}
Expand All @@ -37,7 +37,7 @@ mod tests {

use super::AuthKey;
use crate::protocol::clock::{Current, TimeNow};
use crate::tracker::auth;
use crate::tracker::auth::{self, KeyId};

#[test]
fn it_should_be_convertible_into_an_auth_key() {
Expand All @@ -51,7 +51,7 @@ mod tests {
assert_eq!(
auth::Key::from(auth_key_resource),
auth::Key {
key: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".to_string(), // cspell:disable-line
id: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".parse::<KeyId>().unwrap(), // cspell:disable-line
valid_until: Some(Current::add(&Duration::new(duration_in_secs, 0)).unwrap())
}
);
Expand All @@ -62,7 +62,7 @@ mod tests {
let duration_in_secs = 60;

let auth_key = auth::Key {
key: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".to_string(), // cspell:disable-line
id: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".parse::<KeyId>().unwrap(), // cspell:disable-line
valid_until: Some(Current::add(&Duration::new(duration_in_secs, 0)).unwrap()),
};

Expand Down
3 changes: 3 additions & 0 deletions src/databases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,19 @@ pub trait Database: Sync + Send {

async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error>;

// todo: replace type `&str` with `&InfoHash`
async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result<Option<InfoHash>, Error>;

async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result<usize, Error>;

async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result<usize, Error>;

// todo: replace type `&str` with `&KeyId`
async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::Key>, Error>;

async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result<usize, Error>;

// todo: replace type `&str` with `&KeyId`
async fn remove_key_from_keys(&self, key: &str) -> Result<usize, Error>;

async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result<bool, Error> {
Expand Down
8 changes: 4 additions & 4 deletions src/databases/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::driver::Driver;
use crate::databases::{Database, Error};
use crate::protocol::common::AUTH_KEY_LENGTH;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::auth;
use crate::tracker::auth::{self, KeyId};

const DRIVER: Driver = Driver::MySQL;

Expand Down Expand Up @@ -117,7 +117,7 @@ impl Database for Mysql {
let keys = conn.query_map(
"SELECT `key`, valid_until FROM `keys`",
|(key, valid_until): (String, i64)| auth::Key {
key,
id: key.parse::<KeyId>().unwrap(),
valid_until: Some(Duration::from_secs(valid_until.unsigned_abs())),
},
)?;
Expand Down Expand Up @@ -192,15 +192,15 @@ impl Database for Mysql {
let key = query?;

Ok(key.map(|(key, expiry)| auth::Key {
key,
id: key.parse::<KeyId>().unwrap(),
valid_until: Some(Duration::from_secs(expiry.unsigned_abs())),
}))
}

async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result<usize, Error> {
let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let key = auth_key.key.to_string();
let key = auth_key.id.to_string();
let valid_until = auth_key.valid_until.unwrap_or(Duration::ZERO).as_secs().to_string();

conn.exec_drop(
Expand Down
11 changes: 6 additions & 5 deletions src/databases/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::driver::Driver;
use crate::databases::{Database, Error};
use crate::protocol::clock::DurationSinceUnixEpoch;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::auth;
use crate::tracker::auth::{self, KeyId};

const DRIVER: Driver = Driver::Sqlite3;

Expand Down Expand Up @@ -108,11 +108,11 @@ impl Database for Sqlite {
let mut stmt = conn.prepare("SELECT key, valid_until FROM keys")?;

let keys_iter = stmt.query_map([], |row| {
let key = row.get(0)?;
let key: String = row.get(0)?;
let valid_until: i64 = row.get(1)?;

Ok(auth::Key {
key,
id: key.parse::<KeyId>().unwrap(),
valid_until: Some(DurationSinceUnixEpoch::from_secs(valid_until.unsigned_abs())),
})
})?;
Expand Down Expand Up @@ -211,8 +211,9 @@ impl Database for Sqlite {

Ok(key.map(|f| {
let expiry: i64 = f.get(1).unwrap();
let id: String = f.get(0).unwrap();
auth::Key {
key: f.get(0).unwrap(),
id: id.parse::<KeyId>().unwrap(),
valid_until: Some(DurationSinceUnixEpoch::from_secs(expiry.unsigned_abs())),
}
}))
Expand All @@ -223,7 +224,7 @@ impl Database for Sqlite {

let insert = conn.execute(
"INSERT INTO keys (key, valid_until) VALUES (?1, ?2)",
[auth_key.key.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()],
[auth_key.id.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()],
)?;

if insert == 0 {
Expand Down
16 changes: 12 additions & 4 deletions src/http/warp_implementation/filters.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr};
use std::panic::Location;
use std::str::FromStr;
use std::sync::Arc;

use warp::{reject, Filter, Rejection};
Expand All @@ -11,7 +12,8 @@ use super::{request, WebResult};
use crate::http::percent_encoding::{percent_decode_info_hash, percent_decode_peer_id};
use crate::protocol::common::MAX_SCRAPE_TORRENTS;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::{self, auth, peer};
use crate::tracker::auth::KeyId;
use crate::tracker::{self, peer};

/// Pass Arc<tracker::TorrentTracker> along
#[must_use]
Expand All @@ -35,10 +37,16 @@ pub fn with_peer_id() -> impl Filter<Extract = (peer::Id,), Error = Rejection> +

/// Pass Arc<tracker::TorrentTracker> along
#[must_use]
pub fn with_auth_key() -> impl Filter<Extract = (Option<auth::Key>,), Error = Infallible> + Clone {
pub fn with_auth_key_id() -> impl Filter<Extract = (Option<KeyId>,), Error = Infallible> + Clone {
warp::path::param::<String>()
.map(|key: String| auth::Key::from_string(&key))
.or_else(|_| async { Ok::<(Option<auth::Key>,), Infallible>((None,)) })
.map(|key: String| {
let key_id = KeyId::from_str(&key);
match key_id {
Ok(id) => Some(id),
Err(_) => None,
}
})
.or_else(|_| async { Ok::<(Option<KeyId>,), Infallible>((None,)) })
}

/// Check for `PeerAddress`
Expand Down
13 changes: 7 additions & 6 deletions src/http/warp_implementation/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use super::error::Error;
use super::{request, response, WebResult};
use crate::http::warp_implementation::peer_builder;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::auth::KeyId;
use crate::tracker::{self, auth, peer, statistics, torrent};

/// Authenticate `InfoHash` using optional `auth::Key`
Expand All @@ -21,11 +22,11 @@ use crate::tracker::{self, auth, peer, statistics, torrent};
/// Will return `ServerError` that wraps the `tracker::error::Error` if unable to `authenticate_request`.
pub async fn authenticate(
info_hash: &InfoHash,
auth_key: &Option<auth::Key>,
auth_key_id: &Option<auth::KeyId>,
tracker: Arc<tracker::Tracker>,
) -> Result<(), Error> {
tracker
.authenticate_request(info_hash, auth_key)
.authenticate_request(info_hash, auth_key_id)
.await
.map_err(|e| Error::TrackerError {
source: (Arc::new(e) as Arc<dyn std::error::Error + Send + Sync>).into(),
Expand All @@ -37,15 +38,15 @@ pub async fn authenticate(
/// Will return `warp::Rejection` that wraps the `ServerError` if unable to `send_announce_response`.
pub async fn handle_announce(
announce_request: request::Announce,
auth_key: Option<auth::Key>,
auth_key_id: Option<KeyId>,
tracker: Arc<tracker::Tracker>,
) -> WebResult<impl Reply> {
debug!("http announce request: {:#?}", announce_request);

let info_hash = announce_request.info_hash;
let remote_client_ip = announce_request.peer_addr;

authenticate(&info_hash, &auth_key, tracker.clone()).await?;
authenticate(&info_hash, &auth_key_id, tracker.clone()).await?;

let mut peer = peer_builder::from_request(&announce_request, &remote_client_ip);

Expand Down Expand Up @@ -77,7 +78,7 @@ pub async fn handle_announce(
/// Will return `warp::Rejection` that wraps the `ServerError` if unable to `send_scrape_response`.
pub async fn handle_scrape(
scrape_request: request::Scrape,
auth_key: Option<auth::Key>,
auth_key_id: Option<KeyId>,
tracker: Arc<tracker::Tracker>,
) -> WebResult<impl Reply> {
let mut files: HashMap<InfoHash, response::ScrapeEntry> = HashMap::new();
Expand All @@ -86,7 +87,7 @@ pub async fn handle_scrape(
for info_hash in &scrape_request.info_hashes {
let scrape_entry = match db.get(info_hash) {
Some(torrent_info) => {
if authenticate(info_hash, &auth_key, tracker.clone()).await.is_ok() {
if authenticate(info_hash, &auth_key_id, tracker.clone()).await.is_ok() {
let (seeders, completed, leechers) = torrent_info.get_stats();
response::ScrapeEntry {
complete: seeders,
Expand Down
6 changes: 3 additions & 3 deletions src/http/warp_implementation/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use warp::{Filter, Rejection};

use super::filters::{with_announce_request, with_auth_key, with_scrape_request, with_tracker};
use super::filters::{with_announce_request, with_auth_key_id, with_scrape_request, with_tracker};
use super::handlers::{handle_announce, handle_scrape, send_error};
use crate::tracker;

Expand All @@ -20,7 +20,7 @@ fn announce(tracker: Arc<tracker::Tracker>) -> impl Filter<Extract = impl warp::
warp::path::path("announce")
.and(warp::filters::method::get())
.and(with_announce_request(tracker.config.on_reverse_proxy))
.and(with_auth_key())
.and(with_auth_key_id())
.and(with_tracker(tracker))
.and_then(handle_announce)
}
Expand All @@ -30,7 +30,7 @@ fn scrape(tracker: Arc<tracker::Tracker>) -> impl Filter<Extract = impl warp::Re
warp::path::path("scrape")
.and(warp::filters::method::get())
.and(with_scrape_request(tracker.config.on_reverse_proxy))
.and(with_auth_key())
.and(with_auth_key_id())
.and(with_tracker(tracker))
.and_then(handle_scrape)
}
Loading

0 comments on commit 7cdd63e

Please sign in to comment.