From 7cdd63ee42b4868734038280c6a4f83e07c511ad Mon Sep 17 00:00:00 2001 From: Jose Celano Date: Mon, 27 Feb 2023 18:52:25 +0000 Subject: [PATCH] refactor: [#171] use KeyId in auth:Key The struct `KeyId` was extracted to wrap the primitive type but it was not being used in the `auth::Key` struct. --- src/apis/resources/auth_key.rs | 14 +++---- src/databases/mod.rs | 3 ++ src/databases/mysql.rs | 8 ++-- src/databases/sqlite.rs | 11 +++--- src/http/warp_implementation/filters.rs | 16 ++++++-- src/http/warp_implementation/handlers.rs | 13 ++++--- src/http/warp_implementation/routes.rs | 6 +-- src/tracker/auth.rs | 49 +++++++++++++----------- src/tracker/error.rs | 4 +- src/tracker/mod.rs | 24 +++++++----- tests/tracker_api.rs | 12 +++--- 11 files changed, 91 insertions(+), 69 deletions(-) diff --git a/src/apis/resources/auth_key.rs b/src/apis/resources/auth_key.rs index d5c08f49..207a0c48 100644 --- a/src/apis/resources/auth_key.rs +++ b/src/apis/resources/auth_key.rs @@ -3,18 +3,18 @@ use std::convert::From; use serde::{Deserialize, Serialize}; use crate::protocol::clock::DurationSinceUnixEpoch; -use crate::tracker::auth; +use crate::tracker::auth::{self, KeyId}; #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct AuthKey { - pub key: String, + pub key: String, // todo: rename to `id` pub valid_until: Option, } impl From for auth::Key { fn from(auth_key_resource: AuthKey) -> Self { auth::Key { - key: auth_key_resource.key, + id: auth_key_resource.key.parse::().unwrap(), valid_until: auth_key_resource .valid_until .map(|valid_until| DurationSinceUnixEpoch::new(valid_until, 0)), @@ -25,7 +25,7 @@ impl From for auth::Key { impl From for AuthKey { fn from(auth_key: auth::Key) -> Self { AuthKey { - key: auth_key.key, + key: auth_key.id.to_string(), valid_until: auth_key.valid_until.map(|valid_until| valid_until.as_secs()), } } @@ -37,7 +37,7 @@ mod tests { use super::AuthKey; use crate::protocol::clock::{Current, TimeNow}; - use crate::tracker::auth; + use crate::tracker::auth::{self, KeyId}; #[test] fn it_should_be_convertible_into_an_auth_key() { @@ -51,7 +51,7 @@ mod tests { assert_eq!( auth::Key::from(auth_key_resource), auth::Key { - key: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".to_string(), // cspell:disable-line + id: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".parse::().unwrap(), // cspell:disable-line valid_until: Some(Current::add(&Duration::new(duration_in_secs, 0)).unwrap()) } ); @@ -62,7 +62,7 @@ mod tests { let duration_in_secs = 60; let auth_key = auth::Key { - key: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".to_string(), // cspell:disable-line + id: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".parse::().unwrap(), // cspell:disable-line valid_until: Some(Current::add(&Duration::new(duration_in_secs, 0)).unwrap()), }; diff --git a/src/databases/mod.rs b/src/databases/mod.rs index 809decc2..70cc9eb7 100644 --- a/src/databases/mod.rs +++ b/src/databases/mod.rs @@ -63,16 +63,19 @@ pub trait Database: Sync + Send { async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error>; + // todo: replace type `&str` with `&InfoHash` async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result, Error>; async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result; async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result; + // todo: replace type `&str` with `&KeyId` async fn get_key_from_keys(&self, key: &str) -> Result, Error>; async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result; + // todo: replace type `&str` with `&KeyId` async fn remove_key_from_keys(&self, key: &str) -> Result; async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result { diff --git a/src/databases/mysql.rs b/src/databases/mysql.rs index ac54ebb8..532ba1dc 100644 --- a/src/databases/mysql.rs +++ b/src/databases/mysql.rs @@ -12,7 +12,7 @@ use super::driver::Driver; use crate::databases::{Database, Error}; use crate::protocol::common::AUTH_KEY_LENGTH; use crate::protocol::info_hash::InfoHash; -use crate::tracker::auth; +use crate::tracker::auth::{self, KeyId}; const DRIVER: Driver = Driver::MySQL; @@ -117,7 +117,7 @@ impl Database for Mysql { let keys = conn.query_map( "SELECT `key`, valid_until FROM `keys`", |(key, valid_until): (String, i64)| auth::Key { - key, + id: key.parse::().unwrap(), valid_until: Some(Duration::from_secs(valid_until.unsigned_abs())), }, )?; @@ -192,7 +192,7 @@ impl Database for Mysql { let key = query?; Ok(key.map(|(key, expiry)| auth::Key { - key, + id: key.parse::().unwrap(), valid_until: Some(Duration::from_secs(expiry.unsigned_abs())), })) } @@ -200,7 +200,7 @@ impl Database for Mysql { async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result { let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; - let key = auth_key.key.to_string(); + let key = auth_key.id.to_string(); let valid_until = auth_key.valid_until.unwrap_or(Duration::ZERO).as_secs().to_string(); conn.exec_drop( diff --git a/src/databases/sqlite.rs b/src/databases/sqlite.rs index 3425b15c..d6915c85 100644 --- a/src/databases/sqlite.rs +++ b/src/databases/sqlite.rs @@ -9,7 +9,7 @@ use super::driver::Driver; use crate::databases::{Database, Error}; use crate::protocol::clock::DurationSinceUnixEpoch; use crate::protocol::info_hash::InfoHash; -use crate::tracker::auth; +use crate::tracker::auth::{self, KeyId}; const DRIVER: Driver = Driver::Sqlite3; @@ -108,11 +108,11 @@ impl Database for Sqlite { let mut stmt = conn.prepare("SELECT key, valid_until FROM keys")?; let keys_iter = stmt.query_map([], |row| { - let key = row.get(0)?; + let key: String = row.get(0)?; let valid_until: i64 = row.get(1)?; Ok(auth::Key { - key, + id: key.parse::().unwrap(), valid_until: Some(DurationSinceUnixEpoch::from_secs(valid_until.unsigned_abs())), }) })?; @@ -211,8 +211,9 @@ impl Database for Sqlite { Ok(key.map(|f| { let expiry: i64 = f.get(1).unwrap(); + let id: String = f.get(0).unwrap(); auth::Key { - key: f.get(0).unwrap(), + id: id.parse::().unwrap(), valid_until: Some(DurationSinceUnixEpoch::from_secs(expiry.unsigned_abs())), } })) @@ -223,7 +224,7 @@ impl Database for Sqlite { let insert = conn.execute( "INSERT INTO keys (key, valid_until) VALUES (?1, ?2)", - [auth_key.key.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()], + [auth_key.id.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()], )?; if insert == 0 { diff --git a/src/http/warp_implementation/filters.rs b/src/http/warp_implementation/filters.rs index fc8ef20b..eb7abcd4 100644 --- a/src/http/warp_implementation/filters.rs +++ b/src/http/warp_implementation/filters.rs @@ -1,6 +1,7 @@ use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; use std::panic::Location; +use std::str::FromStr; use std::sync::Arc; use warp::{reject, Filter, Rejection}; @@ -11,7 +12,8 @@ use super::{request, WebResult}; use crate::http::percent_encoding::{percent_decode_info_hash, percent_decode_peer_id}; use crate::protocol::common::MAX_SCRAPE_TORRENTS; use crate::protocol::info_hash::InfoHash; -use crate::tracker::{self, auth, peer}; +use crate::tracker::auth::KeyId; +use crate::tracker::{self, peer}; /// Pass Arc along #[must_use] @@ -35,10 +37,16 @@ pub fn with_peer_id() -> impl Filter + /// Pass Arc along #[must_use] -pub fn with_auth_key() -> impl Filter,), Error = Infallible> + Clone { +pub fn with_auth_key_id() -> impl Filter,), Error = Infallible> + Clone { warp::path::param::() - .map(|key: String| auth::Key::from_string(&key)) - .or_else(|_| async { Ok::<(Option,), Infallible>((None,)) }) + .map(|key: String| { + let key_id = KeyId::from_str(&key); + match key_id { + Ok(id) => Some(id), + Err(_) => None, + } + }) + .or_else(|_| async { Ok::<(Option,), Infallible>((None,)) }) } /// Check for `PeerAddress` diff --git a/src/http/warp_implementation/handlers.rs b/src/http/warp_implementation/handlers.rs index 400cc576..6019bf01 100644 --- a/src/http/warp_implementation/handlers.rs +++ b/src/http/warp_implementation/handlers.rs @@ -12,6 +12,7 @@ use super::error::Error; use super::{request, response, WebResult}; use crate::http::warp_implementation::peer_builder; use crate::protocol::info_hash::InfoHash; +use crate::tracker::auth::KeyId; use crate::tracker::{self, auth, peer, statistics, torrent}; /// Authenticate `InfoHash` using optional `auth::Key` @@ -21,11 +22,11 @@ use crate::tracker::{self, auth, peer, statistics, torrent}; /// Will return `ServerError` that wraps the `tracker::error::Error` if unable to `authenticate_request`. pub async fn authenticate( info_hash: &InfoHash, - auth_key: &Option, + auth_key_id: &Option, tracker: Arc, ) -> Result<(), Error> { tracker - .authenticate_request(info_hash, auth_key) + .authenticate_request(info_hash, auth_key_id) .await .map_err(|e| Error::TrackerError { source: (Arc::new(e) as Arc).into(), @@ -37,7 +38,7 @@ pub async fn authenticate( /// Will return `warp::Rejection` that wraps the `ServerError` if unable to `send_announce_response`. pub async fn handle_announce( announce_request: request::Announce, - auth_key: Option, + auth_key_id: Option, tracker: Arc, ) -> WebResult { debug!("http announce request: {:#?}", announce_request); @@ -45,7 +46,7 @@ pub async fn handle_announce( let info_hash = announce_request.info_hash; let remote_client_ip = announce_request.peer_addr; - authenticate(&info_hash, &auth_key, tracker.clone()).await?; + authenticate(&info_hash, &auth_key_id, tracker.clone()).await?; let mut peer = peer_builder::from_request(&announce_request, &remote_client_ip); @@ -77,7 +78,7 @@ pub async fn handle_announce( /// Will return `warp::Rejection` that wraps the `ServerError` if unable to `send_scrape_response`. pub async fn handle_scrape( scrape_request: request::Scrape, - auth_key: Option, + auth_key_id: Option, tracker: Arc, ) -> WebResult { let mut files: HashMap = HashMap::new(); @@ -86,7 +87,7 @@ pub async fn handle_scrape( for info_hash in &scrape_request.info_hashes { let scrape_entry = match db.get(info_hash) { Some(torrent_info) => { - if authenticate(info_hash, &auth_key, tracker.clone()).await.is_ok() { + if authenticate(info_hash, &auth_key_id, tracker.clone()).await.is_ok() { let (seeders, completed, leechers) = torrent_info.get_stats(); response::ScrapeEntry { complete: seeders, diff --git a/src/http/warp_implementation/routes.rs b/src/http/warp_implementation/routes.rs index c46c502e..2ee60e8c 100644 --- a/src/http/warp_implementation/routes.rs +++ b/src/http/warp_implementation/routes.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use warp::{Filter, Rejection}; -use super::filters::{with_announce_request, with_auth_key, with_scrape_request, with_tracker}; +use super::filters::{with_announce_request, with_auth_key_id, with_scrape_request, with_tracker}; use super::handlers::{handle_announce, handle_scrape, send_error}; use crate::tracker; @@ -20,7 +20,7 @@ fn announce(tracker: Arc) -> impl Filter) -> impl Filter Key { - let key: String = thread_rng() + let random_id: String = thread_rng() .sample_iter(&Alphanumeric) .take(AUTH_KEY_LENGTH) .map(char::from) .collect(); - debug!("Generated key: {}, valid for: {:?} seconds", key, lifetime); + debug!("Generated key: {}, valid for: {:?} seconds", random_id, lifetime); Key { - key, + id: random_id.parse::().unwrap(), valid_until: Some(Current::add(&lifetime).unwrap()), } } @@ -54,16 +54,14 @@ pub fn verify(auth_key: &Key) -> Result<(), Error> { } None => Err(Error::UnableToReadKey { location: Location::caller(), - key: Box::new(auth_key.clone()), + key_id: Box::new(auth_key.id.clone()), }), } } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] pub struct Key { - // todo: replace key field definition with: - // pub key: KeyId, - pub key: String, + pub id: KeyId, pub valid_until: Option, } @@ -72,7 +70,7 @@ impl std::fmt::Display for Key { write!( f, "key: `{}`, valid until `{}`", - self.key, + self.id, match self.valid_until { Some(duration) => format!( "{}", @@ -91,20 +89,29 @@ impl std::fmt::Display for Key { } impl Key { + /// # Panics + /// + /// Will panic if bytes cannot be converted into a valid `KeyId`. #[must_use] pub fn from_buffer(key_buffer: [u8; AUTH_KEY_LENGTH]) -> Option { if let Ok(key) = String::from_utf8(Vec::from(key_buffer)) { - Some(Key { key, valid_until: None }) + Some(Key { + id: key.parse::().unwrap(), + valid_until: None, + }) } else { None } } + /// # Panics + /// + /// Will panic if string cannot be converted into a valid `KeyId`. #[must_use] pub fn from_string(key: &str) -> Option { if key.len() == AUTH_KEY_LENGTH { Some(Key { - key: key.to_string(), + id: key.parse::().unwrap(), valid_until: None, }) } else { @@ -112,18 +119,13 @@ impl Key { } } - /// # Panics - /// - /// Will fail if the key id is not a valid key id. #[must_use] pub fn id(&self) -> KeyId { - // todo: replace the type of field `key` with type `KeyId`. - // The constructor should fail if an invalid KeyId is provided. - KeyId::from_str(&self.key).unwrap() + self.id.clone() } } -#[derive(Debug, Display, PartialEq, Clone)] +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone, Display, Hash)] pub struct KeyId(String); #[derive(Debug, PartialEq, Eq)] @@ -148,10 +150,10 @@ pub enum Error { KeyVerificationError { source: LocatedError<'static, dyn std::error::Error + Send + Sync>, }, - #[error("Failed to read key: {key}, {location}")] + #[error("Failed to read key: {key_id}, {location}")] UnableToReadKey { location: &'static Location<'static>, - key: Box, + key_id: Box, }, #[error("Key has expired, {location}")] KeyExpired { location: &'static Location<'static> }, @@ -171,7 +173,7 @@ mod tests { use std::time::Duration; use crate::protocol::clock::{Current, StoppedTime}; - use crate::tracker::auth; + use crate::tracker::auth::{self, KeyId}; #[test] fn auth_key_from_buffer() { @@ -181,7 +183,10 @@ mod tests { ]); assert!(auth_key.is_some()); - assert_eq!(auth_key.unwrap().key, "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ"); + assert_eq!( + auth_key.unwrap().id, + "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ".parse::().unwrap() + ); } #[test] @@ -190,7 +195,7 @@ mod tests { let auth_key = auth::Key::from_string(key_string); assert!(auth_key.is_some()); - assert_eq!(auth_key.unwrap().key, key_string); + assert_eq!(auth_key.unwrap().id, key_string.parse::().unwrap()); } #[test] diff --git a/src/tracker/error.rs b/src/tracker/error.rs index 51bcbf3b..acc85a1c 100644 --- a/src/tracker/error.rs +++ b/src/tracker/error.rs @@ -4,9 +4,9 @@ use crate::located_error::LocatedError; #[derive(thiserror::Error, Debug, Clone)] pub enum Error { - #[error("The supplied key: {key:?}, is not valid: {source}")] + #[error("The supplied key: {key_id:?}, is not valid: {source}")] PeerKeyNotValid { - key: super::auth::Key, + key_id: super::auth::KeyId, source: LocatedError<'static, dyn std::error::Error + Send + Sync>, }, #[error("The peer is not authenticated, {location}")] diff --git a/src/tracker/mod.rs b/src/tracker/mod.rs index 3e5e9743..147c889a 100644 --- a/src/tracker/mod.rs +++ b/src/tracker/mod.rs @@ -16,6 +16,7 @@ use std::time::Duration; use tokio::sync::mpsc::error::SendError; use tokio::sync::{RwLock, RwLockReadGuard}; +use self::auth::KeyId; use self::error::Error; use self::peer::Peer; use self::torrent::{SwamStats, SwarmMetadata}; @@ -27,7 +28,7 @@ use crate::protocol::info_hash::InfoHash; pub struct Tracker { pub config: Arc, mode: mode::Mode, - keys: RwLock>, + keys: RwLock>, whitelist: RwLock>, torrents: RwLock>, stats_event_sender: Option>, @@ -155,28 +156,31 @@ impl Tracker { pub async fn generate_auth_key(&self, lifetime: Duration) -> Result { let auth_key = auth::generate(lifetime); self.database.add_key_to_keys(&auth_key).await?; - self.keys.write().await.insert(auth_key.key.clone(), auth_key.clone()); + self.keys.write().await.insert(auth_key.id.clone(), auth_key.clone()); Ok(auth_key) } /// # Errors /// /// Will return a `database::Error` if unable to remove the `key` to the database. + /// + /// # Panics + /// + /// Will panic if key cannot be converted into a valid `KeyId`. pub async fn remove_auth_key(&self, key: &str) -> Result<(), databases::error::Error> { self.database.remove_key_from_keys(key).await?; - self.keys.write().await.remove(key); + self.keys.write().await.remove(&key.parse::().unwrap()); Ok(()) } /// # Errors /// /// Will return a `key::Error` if unable to get any `auth_key`. - pub async fn verify_auth_key(&self, auth_key: &auth::Key) -> Result<(), auth::Error> { - // todo: use auth::KeyId for the function argument `auth_key` - match self.keys.read().await.get(&auth_key.key) { + pub async fn verify_auth_key(&self, key_id: &KeyId) -> Result<(), auth::Error> { + match self.keys.read().await.get(key_id) { None => Err(auth::Error::UnableToReadKey { location: Location::caller(), - key: Box::new(auth_key.clone()), + key_id: Box::new(key_id.clone()), }), Some(key) => auth::verify(key), } @@ -192,7 +196,7 @@ impl Tracker { keys.clear(); for key in keys_from_database { - keys.insert(key.key.clone(), key); + keys.insert(key.id.clone(), key); } Ok(()) @@ -283,7 +287,7 @@ impl Tracker { /// Will return a `torrent::Error::PeerNotAuthenticated` if the `key` is `None`. /// /// Will return a `torrent::Error::TorrentNotWhitelisted` if the the Tracker is in listed mode and the `info_hash` is not whitelisted. - pub async fn authenticate_request(&self, info_hash: &InfoHash, key: &Option) -> Result<(), Error> { + pub async fn authenticate_request(&self, info_hash: &InfoHash, key: &Option) -> Result<(), Error> { // no authentication needed in public mode if self.is_public() { return Ok(()); @@ -295,7 +299,7 @@ impl Tracker { Some(key) => { if let Err(e) = self.verify_auth_key(key).await { return Err(Error::PeerKeyNotValid { - key: key.clone(), + key_id: key.clone(), source: (Arc::new(e) as Arc).into(), }); } diff --git a/tests/tracker_api.rs b/tests/tracker_api.rs index 193c6487..bec22e2b 100644 --- a/tests/tracker_api.rs +++ b/tests/tracker_api.rs @@ -638,7 +638,7 @@ mod tracker_apis { mod for_key_resources { use std::time::Duration; - use torrust_tracker::tracker::auth::Key; + use torrust_tracker::tracker::auth::KeyId; use crate::api::asserts::{ assert_auth_key_utf8, assert_failed_to_delete_key, assert_failed_to_generate_key, assert_failed_to_reload_keys, @@ -665,7 +665,7 @@ mod tracker_apis { // Verify the key with the tracker assert!(api_server .tracker - .verify_auth_key(&Key::from(auth_key_resource)) + .verify_auth_key(&auth_key_resource.key.parse::().unwrap()) .await .is_ok()); } @@ -734,7 +734,7 @@ mod tracker_apis { .unwrap(); let response = Client::new(api_server.get_connection_info()) - .delete_auth_key(&auth_key.key) + .delete_auth_key(&auth_key.id.to_string()) .await; assert_ok(response).await; @@ -777,7 +777,7 @@ mod tracker_apis { force_database_error(&api_server.tracker); let response = Client::new(api_server.get_connection_info()) - .delete_auth_key(&auth_key.key) + .delete_auth_key(&auth_key.id.to_string()) .await; assert_failed_to_delete_key(response).await; @@ -797,7 +797,7 @@ mod tracker_apis { .unwrap(); let response = Client::new(connection_with_invalid_token(&api_server.get_bind_address())) - .delete_auth_key(&auth_key.key) + .delete_auth_key(&auth_key.id.to_string()) .await; assert_token_not_valid(response).await; @@ -810,7 +810,7 @@ mod tracker_apis { .unwrap(); let response = Client::new(connection_with_no_token(&api_server.get_bind_address())) - .delete_auth_key(&auth_key.key) + .delete_auth_key(&auth_key.id.to_string()) .await; assert_unauthorized(response).await;