diff --git a/src/apis/v1/context/auth_key/handlers.rs b/src/apis/v1/context/auth_key/handlers.rs index d21f0829..cb1cd111 100644 --- a/src/apis/v1/context/auth_key/handlers.rs +++ b/src/apis/v1/context/auth_key/handlers.rs @@ -31,7 +31,7 @@ pub async fn delete_auth_key_handler( ) -> Response { match Key::from_str(&seconds_valid_or_key.0) { Err(_) => invalid_auth_key_param_response(&seconds_valid_or_key.0), - Ok(key) => match tracker.remove_auth_key(&key.to_string()).await { + Ok(key) => match tracker.remove_auth_key(&key).await { Ok(_) => ok_response(), Err(e) => failed_to_delete_key_response(e), }, diff --git a/src/databases/mod.rs b/src/databases/mod.rs index 247f571d..0af6f572 100644 --- a/src/databases/mod.rs +++ b/src/databases/mod.rs @@ -9,7 +9,7 @@ use async_trait::async_trait; use self::error::Error; use crate::protocol::info_hash::InfoHash; -use crate::tracker::auth; +use crate::tracker::auth::{self, Key}; pub(self) struct Builder where @@ -63,25 +63,19 @@ pub trait Database: Sync + Send { async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error>; - // todo: replace type `&str` with `&InfoHash` - async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result, Error>; + async fn get_info_hash_from_whitelist(&self, info_hash: &InfoHash) -> Result, Error>; async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result; async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result; - // todo: replace type `&str` with `&Key` - async fn get_key_from_keys(&self, key: &str) -> Result, Error>; + async fn get_key_from_keys(&self, key: &Key) -> Result, Error>; async fn add_key_to_keys(&self, auth_key: &auth::ExpiringKey) -> Result; - // todo: replace type `&str` with `&Key` - async fn remove_key_from_keys(&self, key: &str) -> Result; + async fn remove_key_from_keys(&self, key: &Key) -> Result; async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result { - Ok(self - .get_info_hash_from_whitelist(&info_hash.clone().to_string()) - .await? - .is_some()) + Ok(self.get_info_hash_from_whitelist(info_hash).await?.is_some()) } } diff --git a/src/databases/mysql.rs b/src/databases/mysql.rs index f0c7ec1d..f6918974 100644 --- a/src/databases/mysql.rs +++ b/src/databases/mysql.rs @@ -147,12 +147,12 @@ impl Database for Mysql { Ok(conn.exec_drop(COMMAND, params! { info_hash_str, completed })?) } - async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result, Error> { + async fn get_info_hash_from_whitelist(&self, info_hash: &InfoHash) -> Result, Error> { let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let select = conn.exec_first::( "SELECT info_hash FROM whitelist WHERE info_hash = :info_hash", - params! { info_hash }, + params! { "info_hash" => info_hash.to_hex_string() }, )?; let info_hash = select.map(|f| InfoHash::from_str(&f).expect("Failed to decode InfoHash String from DB!")); @@ -183,11 +183,13 @@ impl Database for Mysql { Ok(1) } - async fn get_key_from_keys(&self, key: &str) -> Result, Error> { + async fn get_key_from_keys(&self, key: &Key) -> Result, Error> { let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; - let query = - conn.exec_first::<(String, i64), _, _>("SELECT `key`, valid_until FROM `keys` WHERE `key` = :key", params! { key }); + let query = conn.exec_first::<(String, i64), _, _>( + "SELECT `key`, valid_until FROM `keys` WHERE `key` = :key", + params! { "key" => key.to_string() }, + ); let key = query?; @@ -211,10 +213,10 @@ impl Database for Mysql { Ok(1) } - async fn remove_key_from_keys(&self, key: &str) -> Result { + async fn remove_key_from_keys(&self, key: &Key) -> Result { let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; - conn.exec_drop("DELETE FROM `keys` WHERE key = :key", params! { key })?; + conn.exec_drop("DELETE FROM `keys` WHERE key = :key", params! { "key" => key.to_string() })?; Ok(1) } diff --git a/src/databases/sqlite.rs b/src/databases/sqlite.rs index 4bf2931d..adb201de 100644 --- a/src/databases/sqlite.rs +++ b/src/databases/sqlite.rs @@ -156,12 +156,12 @@ impl Database for Sqlite { } } - async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result, Error> { + async fn get_info_hash_from_whitelist(&self, info_hash: &InfoHash) -> Result, Error> { let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT info_hash FROM whitelist WHERE info_hash = ?")?; - let mut rows = stmt.query([info_hash])?; + let mut rows = stmt.query([info_hash.to_hex_string()])?; let query = rows.next()?; @@ -200,7 +200,7 @@ impl Database for Sqlite { } } - async fn get_key_from_keys(&self, key: &str) -> Result, Error> { + async fn get_key_from_keys(&self, key: &Key) -> Result, Error> { let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT key, valid_until FROM keys WHERE key = ?")?; @@ -211,9 +211,9 @@ impl Database for Sqlite { Ok(key.map(|f| { let expiry: i64 = f.get(1).unwrap(); - let id: String = f.get(0).unwrap(); + let key: String = f.get(0).unwrap(); auth::ExpiringKey { - key: id.parse::().unwrap(), + key: key.parse::().unwrap(), valid_until: DurationSinceUnixEpoch::from_secs(expiry.unsigned_abs()), } })) @@ -237,10 +237,10 @@ impl Database for Sqlite { } } - async fn remove_key_from_keys(&self, key: &str) -> Result { + async fn remove_key_from_keys(&self, key: &Key) -> Result { let conn = self.pool.get().map_err(|e| (e, DRIVER))?; - let deleted = conn.execute("DELETE FROM keys WHERE key = ?", [key])?; + let deleted = conn.execute("DELETE FROM keys WHERE key = ?", [key.to_string()])?; if deleted == 1 { // should only remove a single record. diff --git a/src/protocol/info_hash.rs b/src/protocol/info_hash.rs index 32063672..fd7602cd 100644 --- a/src/protocol/info_hash.rs +++ b/src/protocol/info_hash.rs @@ -24,6 +24,11 @@ impl InfoHash { pub fn bytes(&self) -> [u8; 20] { self.0 } + + #[must_use] + pub fn to_hex_string(&self) -> String { + self.to_string() + } } impl std::fmt::Display for InfoHash { @@ -197,6 +202,13 @@ mod tests { assert_eq!(output, "ffffffffffffffffffffffffffffffffffffffff"); } + #[test] + fn an_info_hash_should_return_its_a_40_utf8_lowercased_char_hex_representations_as_string() { + let info_hash = InfoHash::from_str("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF").unwrap(); + + assert_eq!(info_hash.to_hex_string(), "ffffffffffffffffffffffffffffffffffffffff"); + } + #[test] fn an_info_hash_can_be_created_from_a_valid_20_byte_array_slice() { let info_hash: InfoHash = [255u8; 20].as_slice().into(); diff --git a/src/tracker/mod.rs b/src/tracker/mod.rs index 326afbf0..8a973979 100644 --- a/src/tracker/mod.rs +++ b/src/tracker/mod.rs @@ -202,10 +202,9 @@ impl Tracker { /// # Panics /// /// Will panic if key cannot be converted into a valid `Key`. - pub async fn remove_auth_key(&self, key: &str) -> Result<(), databases::error::Error> { - // todo: change argument `key: &str` to `key: &Key` + pub async fn remove_auth_key(&self, key: &Key) -> Result<(), databases::error::Error> { self.database.remove_key_from_keys(key).await?; - self.keys.write().await.remove(&key.parse::().unwrap()); + self.keys.write().await.remove(key); Ok(()) } @@ -1175,12 +1174,12 @@ mod tests { async fn it_should_remove_an_authentication_key() { let tracker = private_tracker(); - let key = tracker.generate_auth_key(Duration::from_secs(100)).await.unwrap(); + let expiring_key = tracker.generate_auth_key(Duration::from_secs(100)).await.unwrap(); - let result = tracker.remove_auth_key(&key.id().to_string()).await; + let result = tracker.remove_auth_key(&expiring_key.id()).await; assert!(result.is_ok()); - assert!(tracker.verify_auth_key(&key.id()).await.is_err()); + assert!(tracker.verify_auth_key(&expiring_key.id()).await.is_err()); } #[tokio::test]