Skip to content

Commit

Permalink
Merge #241: Refactor: use domain struts in DB trait
Browse files Browse the repository at this point in the history
fd50bb0 refactor(tracker): use domain struts in DB trait (Jose Celano)

Pull request description:

  Instead of primitive types.

Top commit has no ACKs.

Tree-SHA512: 101c40551f04c0e538351bbe54dd7d030a4441f3061762f3751cbfe8c7452a05c29a69de7aad1db09d87578d3ee1b617e2758c6b2e9adc269d7376bd19dc35c6
  • Loading branch information
josecelano committed Mar 13, 2023
2 parents be110bc + fd50bb0 commit c3cd623
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/apis/v1/context/auth_key/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub async fn delete_auth_key_handler(
) -> Response {
match Key::from_str(&seconds_valid_or_key.0) {
Err(_) => invalid_auth_key_param_response(&seconds_valid_or_key.0),
Ok(key) => match tracker.remove_auth_key(&key.to_string()).await {
Ok(key) => match tracker.remove_auth_key(&key).await {
Ok(_) => ok_response(),
Err(e) => failed_to_delete_key_response(e),
},
Expand Down
16 changes: 5 additions & 11 deletions src/databases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use async_trait::async_trait;

use self::error::Error;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::auth;
use crate::tracker::auth::{self, Key};

pub(self) struct Builder<T>
where
Expand Down Expand Up @@ -63,25 +63,19 @@ pub trait Database: Sync + Send {

async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error>;

// todo: replace type `&str` with `&InfoHash`
async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result<Option<InfoHash>, Error>;
async fn get_info_hash_from_whitelist(&self, info_hash: &InfoHash) -> Result<Option<InfoHash>, Error>;

async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result<usize, Error>;

async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result<usize, Error>;

// todo: replace type `&str` with `&Key`
async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::ExpiringKey>, Error>;
async fn get_key_from_keys(&self, key: &Key) -> Result<Option<auth::ExpiringKey>, Error>;

async fn add_key_to_keys(&self, auth_key: &auth::ExpiringKey) -> Result<usize, Error>;

// todo: replace type `&str` with `&Key`
async fn remove_key_from_keys(&self, key: &str) -> Result<usize, Error>;
async fn remove_key_from_keys(&self, key: &Key) -> Result<usize, Error>;

async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result<bool, Error> {
Ok(self
.get_info_hash_from_whitelist(&info_hash.clone().to_string())
.await?
.is_some())
Ok(self.get_info_hash_from_whitelist(info_hash).await?.is_some())
}
}
16 changes: 9 additions & 7 deletions src/databases/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ impl Database for Mysql {
Ok(conn.exec_drop(COMMAND, params! { info_hash_str, completed })?)
}

async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result<Option<InfoHash>, Error> {
async fn get_info_hash_from_whitelist(&self, info_hash: &InfoHash) -> Result<Option<InfoHash>, Error> {
let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let select = conn.exec_first::<String, _, _>(
"SELECT info_hash FROM whitelist WHERE info_hash = :info_hash",
params! { info_hash },
params! { "info_hash" => info_hash.to_hex_string() },
)?;

let info_hash = select.map(|f| InfoHash::from_str(&f).expect("Failed to decode InfoHash String from DB!"));
Expand Down Expand Up @@ -183,11 +183,13 @@ impl Database for Mysql {
Ok(1)
}

async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::ExpiringKey>, Error> {
async fn get_key_from_keys(&self, key: &Key) -> Result<Option<auth::ExpiringKey>, Error> {
let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let query =
conn.exec_first::<(String, i64), _, _>("SELECT `key`, valid_until FROM `keys` WHERE `key` = :key", params! { key });
let query = conn.exec_first::<(String, i64), _, _>(
"SELECT `key`, valid_until FROM `keys` WHERE `key` = :key",
params! { "key" => key.to_string() },
);

let key = query?;

Expand All @@ -211,10 +213,10 @@ impl Database for Mysql {
Ok(1)
}

async fn remove_key_from_keys(&self, key: &str) -> Result<usize, Error> {
async fn remove_key_from_keys(&self, key: &Key) -> Result<usize, Error> {
let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?;

conn.exec_drop("DELETE FROM `keys` WHERE key = :key", params! { key })?;
conn.exec_drop("DELETE FROM `keys` WHERE key = :key", params! { "key" => key.to_string() })?;

Ok(1)
}
Expand Down
14 changes: 7 additions & 7 deletions src/databases/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ impl Database for Sqlite {
}
}

async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result<Option<InfoHash>, Error> {
async fn get_info_hash_from_whitelist(&self, info_hash: &InfoHash) -> Result<Option<InfoHash>, Error> {
let conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let mut stmt = conn.prepare("SELECT info_hash FROM whitelist WHERE info_hash = ?")?;

let mut rows = stmt.query([info_hash])?;
let mut rows = stmt.query([info_hash.to_hex_string()])?;

let query = rows.next()?;

Expand Down Expand Up @@ -200,7 +200,7 @@ impl Database for Sqlite {
}
}

async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::ExpiringKey>, Error> {
async fn get_key_from_keys(&self, key: &Key) -> Result<Option<auth::ExpiringKey>, Error> {
let conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let mut stmt = conn.prepare("SELECT key, valid_until FROM keys WHERE key = ?")?;
Expand All @@ -211,9 +211,9 @@ impl Database for Sqlite {

Ok(key.map(|f| {
let expiry: i64 = f.get(1).unwrap();
let id: String = f.get(0).unwrap();
let key: String = f.get(0).unwrap();
auth::ExpiringKey {
key: id.parse::<Key>().unwrap(),
key: key.parse::<Key>().unwrap(),
valid_until: DurationSinceUnixEpoch::from_secs(expiry.unsigned_abs()),
}
}))
Expand All @@ -237,10 +237,10 @@ impl Database for Sqlite {
}
}

async fn remove_key_from_keys(&self, key: &str) -> Result<usize, Error> {
async fn remove_key_from_keys(&self, key: &Key) -> Result<usize, Error> {
let conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let deleted = conn.execute("DELETE FROM keys WHERE key = ?", [key])?;
let deleted = conn.execute("DELETE FROM keys WHERE key = ?", [key.to_string()])?;

if deleted == 1 {
// should only remove a single record.
Expand Down
12 changes: 12 additions & 0 deletions src/protocol/info_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ impl InfoHash {
pub fn bytes(&self) -> [u8; 20] {
self.0
}

#[must_use]
pub fn to_hex_string(&self) -> String {
self.to_string()
}
}

impl std::fmt::Display for InfoHash {
Expand Down Expand Up @@ -197,6 +202,13 @@ mod tests {
assert_eq!(output, "ffffffffffffffffffffffffffffffffffffffff");
}

#[test]
fn an_info_hash_should_return_its_a_40_utf8_lowercased_char_hex_representations_as_string() {
let info_hash = InfoHash::from_str("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF").unwrap();

assert_eq!(info_hash.to_hex_string(), "ffffffffffffffffffffffffffffffffffffffff");
}

#[test]
fn an_info_hash_can_be_created_from_a_valid_20_byte_array_slice() {
let info_hash: InfoHash = [255u8; 20].as_slice().into();
Expand Down
11 changes: 5 additions & 6 deletions src/tracker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,9 @@ impl Tracker {
/// # Panics
///
/// Will panic if key cannot be converted into a valid `Key`.
pub async fn remove_auth_key(&self, key: &str) -> Result<(), databases::error::Error> {
// todo: change argument `key: &str` to `key: &Key`
pub async fn remove_auth_key(&self, key: &Key) -> Result<(), databases::error::Error> {
self.database.remove_key_from_keys(key).await?;
self.keys.write().await.remove(&key.parse::<Key>().unwrap());
self.keys.write().await.remove(key);
Ok(())
}

Expand Down Expand Up @@ -1175,12 +1174,12 @@ mod tests {
async fn it_should_remove_an_authentication_key() {
let tracker = private_tracker();

let key = tracker.generate_auth_key(Duration::from_secs(100)).await.unwrap();
let expiring_key = tracker.generate_auth_key(Duration::from_secs(100)).await.unwrap();

let result = tracker.remove_auth_key(&key.id().to_string()).await;
let result = tracker.remove_auth_key(&expiring_key.id()).await;

assert!(result.is_ok());
assert!(tracker.verify_auth_key(&key.id()).await.is_err());
assert!(tracker.verify_auth_key(&expiring_key.id()).await.is_err());
}

#[tokio::test]
Expand Down

0 comments on commit c3cd623

Please sign in to comment.