diff --git a/Cargo.lock b/Cargo.lock index d1f677ff1d..730b9ef7f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4746,6 +4746,7 @@ dependencies = [ "serde_derive", "serde_repr", "tari_common", + "tari_common_sqlite", "tari_comms", "tari_comms_rpc_macros", "tari_crypto", diff --git a/base_layer/core/tests/base_node_rpc.rs b/base_layer/core/tests/base_node_rpc.rs index 7643b5ca28..1b928bf911 100644 --- a/base_layer/core/tests/base_node_rpc.rs +++ b/base_layer/core/tests/base_node_rpc.rs @@ -45,6 +45,7 @@ use std::{convert::TryFrom, sync::Arc, time::Duration}; use randomx_rs::RandomXFlag; +use tari_crypto::tari_utilities::epoch_time::EpochTime; use tempfile::{tempdir, TempDir}; use tari_common::configuration::Network; @@ -303,8 +304,7 @@ async fn test_get_height_at_time() { let (_, service, base_node, request_mock, consensus_manager, block0, _utxo0, _temp_dir) = setup().await; let mut prev_block = block0.clone(); - let mut times = Vec::new(); - times.push(prev_block.header().timestamp); + let mut times: Vec = vec![prev_block.header().timestamp]; for _ in 0..10 { tokio::time::sleep(Duration::from_secs(2)).await; let new_block = base_node diff --git a/base_layer/key_manager/src/mnemonic.rs b/base_layer/key_manager/src/mnemonic.rs index 9338e82217..8df3f90734 100644 --- a/base_layer/key_manager/src/mnemonic.rs +++ b/base_layer/key_manager/src/mnemonic.rs @@ -350,7 +350,7 @@ mod test { "abandon".to_string(), "tipico".to_string(), ]; - assert_eq!(MnemonicLanguage::detect_language(&words2).is_err(), true); + assert!(MnemonicLanguage::detect_language(&words2).is_err()); // bounds check (last word is invalid) let words3 = vec![ @@ -360,7 +360,7 @@ mod test { "abandon".to_string(), "topazio".to_string(), ]; - assert_eq!(MnemonicLanguage::detect_language(&words3).is_err(), true); + assert!(MnemonicLanguage::detect_language(&words3).is_err()); // building up a word list: English/French + French -> French let mut words = Vec::with_capacity(3); diff --git a/base_layer/wallet/src/contacts_service/error.rs b/base_layer/wallet/src/contacts_service/error.rs index b06a02f9c0..3d9dd19987 100644 --- a/base_layer/wallet/src/contacts_service/error.rs +++ b/base_layer/wallet/src/contacts_service/error.rs @@ -25,7 +25,7 @@ use diesel::result::Error as DieselError; use tari_service_framework::reply_channel::TransportChannelError; use thiserror::Error; -#[derive(Debug, Error, PartialEq)] +#[derive(Debug, Error)] #[allow(clippy::large_enum_variant)] pub enum ContactsServiceError { #[error("Contact is not found")] @@ -38,7 +38,7 @@ pub enum ContactsServiceError { TransportChannelError(#[from] TransportChannelError), } -#[derive(Debug, Error, PartialEq)] +#[derive(Debug, Error)] pub enum ContactsServiceStorageError { #[error("This write operation is not supported for provided DbKey")] OperationNotSupported, diff --git a/base_layer/wallet/src/error.rs b/base_layer/wallet/src/error.rs index 61c879570f..7aba0ca426 100644 --- a/base_layer/wallet/src/error.rs +++ b/base_layer/wallet/src/error.rs @@ -170,9 +170,3 @@ impl From for ExitCodes { } } } - -impl PartialEq for WalletStorageError { - fn eq(&self, other: &Self) -> bool { - self == other - } -} diff --git a/base_layer/wallet/src/output_manager_service/error.rs b/base_layer/wallet/src/output_manager_service/error.rs index 1e7eeb84e8..afef1156c3 100644 --- a/base_layer/wallet/src/output_manager_service/error.rs +++ b/base_layer/wallet/src/output_manager_service/error.rs @@ -120,7 +120,7 @@ pub enum OutputManagerError { InvalidMessageError(String), } -#[derive(Debug, Error, PartialEq)] +#[derive(Debug, Error)] pub enum OutputManagerStorageError { #[error("Tried to insert an output that already exists in the database")] DuplicateOutput, diff --git a/base_layer/wallet/tests/contacts_service/mod.rs b/base_layer/wallet/tests/contacts_service/mod.rs index ed5ad5033c..b96aa8e8a7 100644 --- a/base_layer/wallet/tests/contacts_service/mod.rs +++ b/base_layer/wallet/tests/contacts_service/mod.rs @@ -87,18 +87,21 @@ pub fn test_contacts_service() { let (_secret_key, public_key) = PublicKey::random_keypair(&mut OsRng); let contact = runtime.block_on(contacts_service.get_contact(public_key.clone())); - assert_eq!( - contact, - Err(ContactsServiceError::ContactsServiceStorageError( - ContactsServiceStorageError::ValueNotFound(DbKey::Contact(public_key.clone())) - )) - ); - assert_eq!( - runtime.block_on(contacts_service.remove_contact(public_key.clone())), - Err(ContactsServiceError::ContactsServiceStorageError( - ContactsServiceStorageError::ValueNotFound(DbKey::Contact(public_key)) - )) - ); + match contact { + Ok(_) => panic!("There should be an error here"), + Err(ContactsServiceError::ContactsServiceStorageError(ContactsServiceStorageError::ValueNotFound(val))) => { + assert_eq!(val, DbKey::Contact(public_key.clone())) + }, + _ => panic!("There should be a specific error here"), + } + let result = runtime.block_on(contacts_service.remove_contact(public_key.clone())); + match result { + Ok(_) => panic!("There should be an error here"), + Err(ContactsServiceError::ContactsServiceStorageError(ContactsServiceStorageError::ValueNotFound(val))) => { + assert_eq!(val, DbKey::Contact(public_key)) + }, + _ => panic!("There should be a specific error here"), + } let _ = runtime .block_on(contacts_service.remove_contact(contacts[0].public_key.clone())) diff --git a/base_layer/wallet/tests/wallet/mod.rs b/base_layer/wallet/tests/wallet/mod.rs index d54c307c01..c33481f3c2 100644 --- a/base_layer/wallet/tests/wallet/mod.rs +++ b/base_layer/wallet/tests/wallet/mod.rs @@ -69,7 +69,6 @@ use tari_wallet::{ handle::TransactionEvent, storage::sqlite_db::TransactionServiceSqliteDatabase, }, - utxo_scanner_service::utxo_scanning::UtxoScannerService, Wallet, WalletConfig, WalletSqlite, diff --git a/common_sqlite/src/error.rs b/common_sqlite/src/error.rs index c47859ce16..cd2b8e967a 100644 --- a/common_sqlite/src/error.rs +++ b/common_sqlite/src/error.rs @@ -27,9 +27,3 @@ pub enum SqliteStorageError { #[error("Diesel R2d2 error")] DieselR2d2Error(String), } - -impl PartialEq for SqliteStorageError { - fn eq(&self, other: &Self) -> bool { - self == other - } -} diff --git a/comms/dht/Cargo.toml b/comms/dht/Cargo.toml index 5b15421607..6119216687 100644 --- a/comms/dht/Cargo.toml +++ b/comms/dht/Cargo.toml @@ -16,6 +16,7 @@ tari_crypto = { git = "https://github.com/tari-project/tari-crypto.git", branch tari_utilities = { version = "^0.3" } tari_shutdown = { version = "^0.21", path = "../../infrastructure/shutdown" } tari_storage = { version = "^0.21", path = "../../infrastructure/storage" } +tari_common_sqlite = { path = "../../common_sqlite" } anyhow = "1.0.32" bitflags = "1.2.0" diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 3d93e3abc8..b1e16f18e8 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -254,7 +254,6 @@ impl DhtActor { let offline_ts = self .database .get_metadata_value::>(DhtMetadataKey::OfflineTimestamp) - .await .ok() .flatten(); info!( @@ -284,25 +283,24 @@ impl DhtActor { }, _ = dedup_cache_trim_ticker.tick() => { - if let Err(err) = self.msg_hash_dedup_cache.trim_entries().await { + if let Err(err) = self.msg_hash_dedup_cache.trim_entries() { error!(target: LOG_TARGET, "Error when trimming message dedup cache: {:?}", err); } }, _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "DhtActor is shutting down because it received a shutdown signal."); - self.mark_shutdown_time().await; + self.mark_shutdown_time(); break Ok(()); }, } } } - async fn mark_shutdown_time(&self) { + fn mark_shutdown_time(&self) { if let Err(err) = self .database .set_metadata_value(DhtMetadataKey::OfflineTimestamp, Utc::now()) - .await { warn!(target: LOG_TARGET, "Failed to mark offline time: {:?}", err); } @@ -323,7 +321,7 @@ impl DhtActor { } => { let msg_hash_cache = self.msg_hash_dedup_cache.clone(); Box::pin(async move { - match msg_hash_cache.add_body_hash(message_hash, received_from).await { + match msg_hash_cache.add_body_hash(message_hash, received_from) { Ok(hit_count) => { let _ = reply_tx.send(hit_count); }, @@ -341,7 +339,7 @@ impl DhtActor { GetMsgHashHitCount(hash, reply_tx) => { let msg_hash_cache = self.msg_hash_dedup_cache.clone(); Box::pin(async move { - let hit_count = msg_hash_cache.get_hit_count(hash).await?; + let hit_count = msg_hash_cache.get_hit_count(hash)?; let _ = reply_tx.send(hit_count); Ok(()) }) @@ -366,14 +364,14 @@ impl DhtActor { GetMetadata(key, reply_tx) => { let db = self.database.clone(); Box::pin(async move { - let _ = reply_tx.send(db.get_metadata_value_bytes(key).await.map_err(Into::into)); + let _ = reply_tx.send(db.get_metadata_value_bytes(key).map_err(Into::into)); Ok(()) }) }, SetMetadata(key, value, reply_tx) => { let db = self.database.clone(); Box::pin(async move { - match db.set_metadata_value_bytes(key, value).await { + match db.set_metadata_value_bytes(key, value) { Ok(_) => { debug!(target: LOG_TARGET, "Dht metadata '{}' set", key); let _ = reply_tx.send(Ok(())); @@ -727,8 +725,8 @@ mod test { use tari_test_utils::random; async fn db_connection() -> DbConnection { - let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); - conn.migrate().await.unwrap(); + let conn = DbConnection::connect_memory(random::string(8)).unwrap(); + conn.migrate().unwrap(); conn } @@ -838,7 +836,6 @@ mod test { let num_hits = actor .msg_hash_dedup_cache .add_body_hash(key.clone(), CommsPublicKey::default()) - .await .unwrap(); assert_eq!(num_hits, 1); } @@ -847,7 +844,6 @@ mod test { let num_hits = actor .msg_hash_dedup_cache .add_body_hash(key.clone(), CommsPublicKey::default()) - .await .unwrap(); assert_eq!(num_hits, 2); } @@ -855,7 +851,7 @@ mod test { let dedup_cache_db = actor.msg_hash_dedup_cache.clone(); // The cleanup ticker starts when the actor is spawned; the first cleanup event will fire fairly soon after the // task is running on a thread. To remove this race condition, we trim the cache in the test. - let num_trimmed = dedup_cache_db.trim_entries().await.unwrap(); + let num_trimmed = dedup_cache_db.trim_entries().unwrap(); assert_eq!(num_trimmed, 10); actor.spawn(); @@ -877,7 +873,7 @@ mod test { } // Trim the database of excess entries - dedup_cache_db.trim_entries().await.unwrap(); + dedup_cache_db.trim_entries().unwrap(); // Verify that the last half of the signatures have been removed and can be re-inserted into cache for key in signatures.iter().take(capacity * 2).skip(capacity) { diff --git a/comms/dht/src/dedup/dedup_cache.rs b/comms/dht/src/dedup/dedup_cache.rs index 2b21fd38bb..b6b2d9338c 100644 --- a/comms/dht/src/dedup/dedup_cache.rs +++ b/comms/dht/src/dedup/dedup_cache.rs @@ -58,10 +58,8 @@ impl DedupCacheDatabase { /// Adds the body hash to the cache, returning the number of hits (inclusive) that have been recorded for this body /// hash - pub async fn add_body_hash(&self, body_hash: Vec, public_key: CommsPublicKey) -> Result { - let hit_count = self - .insert_body_hash_or_update_stats(body_hash.to_hex(), public_key.to_hex()) - .await?; + pub fn add_body_hash(&self, body_hash: Vec, public_key: CommsPublicKey) -> Result { + let hit_count = self.insert_body_hash_or_update_stats(body_hash.to_hex(), public_key.to_hex())?; if hit_count == 0 { warn!( @@ -72,96 +70,80 @@ impl DedupCacheDatabase { Ok(hit_count) } - pub async fn get_hit_count(&self, body_hash: Vec) -> Result { - let hit_count = self - .connection - .with_connection_async(move |conn| { - dedup_cache::table - .select(dedup_cache::number_of_hits) - .filter(dedup_cache::body_hash.eq(&body_hash.to_hex())) - .get_result::(conn) - .optional() - .map_err(Into::into) - }) - .await?; + pub fn get_hit_count(&self, body_hash: Vec) -> Result { + let conn = self.connection.get_pooled_connection()?; + let hit_count = dedup_cache::table + .select(dedup_cache::number_of_hits) + .filter(dedup_cache::body_hash.eq(&body_hash.to_hex())) + .get_result::(&conn) + .optional()?; Ok(hit_count.unwrap_or(0) as u32) } /// Trims the dedup cache to the configured limit by removing the oldest entries - pub async fn trim_entries(&self) -> Result { + pub fn trim_entries(&self) -> Result { let capacity = self.capacity as i64; - self.connection - .with_connection_async(move |conn| { - let mut num_removed = 0; - let msg_count = dedup_cache::table - .select(dsl::count(dedup_cache::id)) - .first::(conn)?; - // Hysteresis added to minimize database impact - if msg_count > capacity { - let remove_count = msg_count - capacity; - num_removed = diesel::sql_query( - "DELETE FROM dedup_cache WHERE id IN (SELECT id FROM dedup_cache ORDER BY last_hit_at ASC \ - LIMIT $1)", - ) - .bind::(remove_count) - .execute(conn)?; - } - debug!( - target: LOG_TARGET, - "Message dedup cache: count {}, capacity {}, removed {}", msg_count, capacity, num_removed, - ); - Ok(num_removed) - }) - .await + let mut num_removed = 0; + let conn = self.connection.get_pooled_connection()?; + let msg_count = dedup_cache::table + .select(dsl::count(dedup_cache::id)) + .first::(&conn)?; + // Hysteresis added to minimize database impact + if msg_count > capacity { + let remove_count = msg_count - capacity; + num_removed = diesel::sql_query( + "DELETE FROM dedup_cache WHERE id IN (SELECT id FROM dedup_cache ORDER BY last_hit_at ASC LIMIT $1)", + ) + .bind::(remove_count) + .execute(&conn)?; + } + debug!( + target: LOG_TARGET, + "Message dedup cache: count {}, capacity {}, removed {}", msg_count, capacity, num_removed, + ); + Ok(num_removed) } /// Insert new row into the table or updates an existing row. Returns the number of hits for this body hash. - async fn insert_body_hash_or_update_stats( - &self, - body_hash: String, - public_key: String, - ) -> Result { - self.connection - .with_connection_async(move |conn| { - let insert_result = diesel::insert_into(dedup_cache::table) - .values(( - dedup_cache::body_hash.eq(&body_hash), - dedup_cache::sender_public_key.eq(&public_key), - dedup_cache::number_of_hits.eq(1), - dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), - )) - .execute(conn); - match insert_result { - Ok(1) => Ok(1), - Ok(n) => Err(StorageError::UnexpectedResult(format!( - "Expected exactly one row to be inserted. Got {}", - n - ))), - Err(diesel::result::Error::DatabaseError(kind, e_info)) => match kind { - DatabaseErrorKind::UniqueViolation => { - // Update hit stats for the message - diesel::update(dedup_cache::table.filter(dedup_cache::body_hash.eq(&body_hash))) - .set(( - dedup_cache::sender_public_key.eq(&public_key), - dedup_cache::number_of_hits.eq(dedup_cache::number_of_hits + 1), - dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), - )) - .execute(conn)?; - // TODO: Diesel support for RETURNING statements would remove this query, but is not - // available for Diesel + SQLite yet - let hits = dedup_cache::table - .select(dedup_cache::number_of_hits) - .filter(dedup_cache::body_hash.eq(&body_hash)) - .get_result::(conn)?; + fn insert_body_hash_or_update_stats(&self, body_hash: String, public_key: String) -> Result { + let conn = self.connection.get_pooled_connection()?; + let insert_result = diesel::insert_into(dedup_cache::table) + .values(( + dedup_cache::body_hash.eq(&body_hash), + dedup_cache::sender_public_key.eq(&public_key), + dedup_cache::number_of_hits.eq(1), + dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), + )) + .execute(&conn); + match insert_result { + Ok(1) => Ok(1), + Ok(n) => Err(StorageError::UnexpectedResult(format!( + "Expected exactly one row to be inserted. Got {}", + n + ))), + Err(diesel::result::Error::DatabaseError(kind, e_info)) => match kind { + DatabaseErrorKind::UniqueViolation => { + // Update hit stats for the message + diesel::update(dedup_cache::table.filter(dedup_cache::body_hash.eq(&body_hash))) + .set(( + dedup_cache::sender_public_key.eq(&public_key), + dedup_cache::number_of_hits.eq(dedup_cache::number_of_hits + 1), + dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), + )) + .execute(&conn)?; + // TODO: Diesel support for RETURNING statements would remove this query, but is not + // TODO: available for Diesel + SQLite yet + let hits = dedup_cache::table + .select(dedup_cache::number_of_hits) + .filter(dedup_cache::body_hash.eq(&body_hash)) + .get_result::(&conn)?; - Ok(hits as u32) - }, - _ => Err(diesel::result::Error::DatabaseError(kind, e_info).into()), - }, - Err(e) => Err(e.into()), - } - }) - .await + Ok(hits as u32) + }, + _ => Err(diesel::result::Error::DatabaseError(kind, e_info).into()), + }, + Err(e) => Err(e.into()), + } } } diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 54f049a3da..06741eb256 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -132,11 +132,10 @@ impl Dht { saf_response_signal_sender, connectivity, discovery_sender, - event_publisher: event_publisher.clone(), + event_publisher, }; let conn = DbConnection::connect_and_migrate(dht.config.database_url.clone()) - .await .map_err(DhtInitializationError::DatabaseMigrationFailed)?; dht.network_discovery_service(shutdown_signal.clone()).spawn(); diff --git a/comms/dht/src/storage/connection.rs b/comms/dht/src/storage/connection.rs index ee99f8b560..3a37647ec6 100644 --- a/comms/dht/src/storage/connection.rs +++ b/comms/dht/src/storage/connection.rs @@ -21,16 +21,16 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::storage::error::StorageError; -use diesel::{Connection, SqliteConnection}; -use log::*; -use std::{ - io, - path::PathBuf, - sync::{Arc, Mutex}, +use diesel::{ + r2d2::{ConnectionManager, PooledConnection}, + SqliteConnection, }; -use tokio::task; +use log::*; +use std::{io, path::PathBuf, time::Duration}; +use tari_common_sqlite::sqlite_connection_pool::SqliteConnectionPool; const LOG_TARGET: &str = "comms::dht::storage::connection"; +const SQLITE_POOL_SIZE: usize = 16; #[derive(Clone, Debug)] pub enum DbConnectionUrl { @@ -58,64 +58,53 @@ impl DbConnectionUrl { #[derive(Clone)] pub struct DbConnection { - inner: Arc>, + pool: SqliteConnectionPool, } impl DbConnection { #[cfg(test)] - pub async fn connect_memory(name: String) -> Result { - Self::connect_url(DbConnectionUrl::MemoryShared(name)).await + pub fn connect_memory(name: String) -> Result { + Self::connect_url(DbConnectionUrl::MemoryShared(name)) } - pub async fn connect_url(db_url: DbConnectionUrl) -> Result { + pub fn connect_url(db_url: DbConnectionUrl) -> Result { debug!(target: LOG_TARGET, "Connecting to database using '{:?}'", db_url); - let conn = task::spawn_blocking(move || { - let conn = SqliteConnection::establish(&db_url.to_url_string())?; - conn.execute("PRAGMA foreign_keys = ON; PRAGMA busy_timeout = 60000;")?; - Result::<_, StorageError>::Ok(conn) - }) - .await??; - - Ok(Self::new(conn)) + + let mut pool = SqliteConnectionPool::new( + db_url.to_url_string(), + SQLITE_POOL_SIZE, + true, + true, + Duration::from_secs(60), + ); + pool.create_pool()?; + + Ok(Self::new(pool)) } - pub async fn connect_and_migrate(db_url: DbConnectionUrl) -> Result { - let conn = Self::connect_url(db_url).await?; - let output = conn.migrate().await?; + pub fn connect_and_migrate(db_url: DbConnectionUrl) -> Result { + let conn = Self::connect_url(db_url)?; + let output = conn.migrate()?; info!(target: LOG_TARGET, "DHT database migration: {}", output.trim()); Ok(conn) } - fn new(conn: SqliteConnection) -> Self { - Self { - inner: Arc::new(Mutex::new(conn)), - } + fn new(pool: SqliteConnectionPool) -> Self { + Self { pool } } - pub async fn migrate(&self) -> Result { - embed_migrations!("./migrations"); - - self.with_connection_async(|conn| { - let mut buf = io::Cursor::new(Vec::new()); - embedded_migrations::run_with_output(conn, &mut buf) - .map_err(|err| StorageError::DatabaseMigrationFailed(format!("Database migration failed {}", err)))?; - Ok(String::from_utf8_lossy(&buf.into_inner()).to_string()) - }) - .await + pub fn get_pooled_connection(&self) -> Result>, StorageError> { + self.pool.get_pooled_connection().map_err(StorageError::DieselR2d2Error) } - pub async fn with_connection_async(&self, f: F) -> Result - where - F: FnOnce(&SqliteConnection) -> Result + Send + 'static, - R: Send + 'static, - { - let conn_mutex = self.inner.clone(); - let ret = task::spawn_blocking(move || { - let lock = acquire_lock!(conn_mutex); - f(&*lock) - }) - .await??; - Ok(ret) + pub fn migrate(&self) -> Result { + embed_migrations!("./migrations"); + + let mut buf = io::Cursor::new(Vec::new()); + let conn = self.get_pooled_connection()?; + embedded_migrations::run_with_output(&conn, &mut buf) + .map_err(|err| StorageError::DatabaseMigrationFailed(format!("Database migration failed {}", err)))?; + Ok(String::from_utf8_lossy(&buf.into_inner()).to_string()) } } @@ -128,24 +117,20 @@ mod test { #[runtime::test] async fn connect_and_migrate() { - let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); - let output = conn.migrate().await.unwrap(); + let conn = DbConnection::connect_memory(random::string(8)).unwrap(); + let output = conn.migrate().unwrap(); assert!(output.starts_with("Running migration")); } #[runtime::test] async fn memory_connections() { let id = random::string(8); - let conn = DbConnection::connect_memory(id.clone()).await.unwrap(); - conn.migrate().await.unwrap(); - let conn = DbConnection::connect_memory(id).await.unwrap(); - let count: i32 = conn - .with_connection_async(|c| { - sql::("SELECT COUNT(*) FROM stored_messages") - .get_result(c) - .map_err(Into::into) - }) - .await + let conn = DbConnection::connect_memory(id.clone()).unwrap(); + conn.migrate().unwrap(); + let conn = DbConnection::connect_memory(id).unwrap(); + let conn = conn.get_pooled_connection().unwrap(); + let count: i32 = sql::("SELECT COUNT(*) FROM stored_messages") + .get_result(&conn) .unwrap(); assert_eq!(count, 0); } diff --git a/comms/dht/src/storage/database.rs b/comms/dht/src/storage/database.rs index c035eebb9a..94704de3b8 100644 --- a/comms/dht/src/storage/database.rs +++ b/comms/dht/src/storage/database.rs @@ -38,49 +38,39 @@ impl DhtDatabase { Self { connection } } - pub async fn get_metadata_value(&self, key: DhtMetadataKey) -> Result, StorageError> { - match self.get_metadata_value_bytes(key).await? { + pub fn get_metadata_value(&self, key: DhtMetadataKey) -> Result, StorageError> { + match self.get_metadata_value_bytes(key)? { Some(bytes) => T::from_binary(&bytes).map(Some).map_err(Into::into), None => Ok(None), } } - pub async fn get_metadata_value_bytes(&self, key: DhtMetadataKey) -> Result>, StorageError> { - self.connection - .with_connection_async(move |conn| { - dht_metadata::table - .filter(dht_metadata::key.eq(key.to_string())) - .first(conn) - .map(|rec: DhtMetadataEntry| Some(rec.value)) - .or_else(|err| match err { - diesel::result::Error::NotFound => Ok(None), - err => Err(err.into()), - }) + pub fn get_metadata_value_bytes(&self, key: DhtMetadataKey) -> Result>, StorageError> { + let conn = self.connection.get_pooled_connection()?; + dht_metadata::table + .filter(dht_metadata::key.eq(key.to_string())) + .first(&conn) + .map(|rec: DhtMetadataEntry| Some(rec.value)) + .or_else(|err| match err { + diesel::result::Error::NotFound => Ok(None), + err => Err(err.into()), }) - .await } - pub async fn set_metadata_value( - &self, - key: DhtMetadataKey, - value: T, - ) -> Result<(), StorageError> { + pub fn set_metadata_value(&self, key: DhtMetadataKey, value: T) -> Result<(), StorageError> { let bytes = value.to_binary()?; - self.set_metadata_value_bytes(key, bytes).await + self.set_metadata_value_bytes(key, bytes) } - pub async fn set_metadata_value_bytes(&self, key: DhtMetadataKey, value: Vec) -> Result<(), StorageError> { - self.connection - .with_connection_async(move |conn| { - diesel::replace_into(dht_metadata::table) - .values(NewDhtMetadataEntry { - key: key.to_string(), - value, - }) - .execute(conn) - .map(|_| ()) - .map_err(Into::into) + pub fn set_metadata_value_bytes(&self, key: DhtMetadataKey, value: Vec) -> Result<(), StorageError> { + let conn = self.connection.get_pooled_connection()?; + diesel::replace_into(dht_metadata::table) + .values(NewDhtMetadataEntry { + key: key.to_string(), + value, }) - .await + .execute(&conn) + .map(|_| ()) + .map_err(Into::into) } } diff --git a/comms/dht/src/storage/error.rs b/comms/dht/src/storage/error.rs index f5bf4f0596..5e89ab494d 100644 --- a/comms/dht/src/storage/error.rs +++ b/comms/dht/src/storage/error.rs @@ -20,6 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use tari_common_sqlite::error::SqliteStorageError; use tari_utilities::message_format::MessageFormatError; use thiserror::Error; use tokio::task; @@ -42,4 +43,6 @@ pub enum StorageError { MessageFormatError(#[from] MessageFormatError), #[error("Unexpected result: {0}")] UnexpectedResult(String), + #[error("Diesel R2d2 error: `{0}`")] + DieselR2d2Error(#[from] SqliteStorageError), } diff --git a/comms/dht/src/store_forward/database/mod.rs b/comms/dht/src/store_forward/database/mod.rs index 58ee06eb9c..74a004fb97 100644 --- a/comms/dht/src/store_forward/database/mod.rs +++ b/comms/dht/src/store_forward/database/mod.rs @@ -44,36 +44,30 @@ impl StoreAndForwardDatabase { } /// Inserts and returns Ok(true) if the item already existed and Ok(false) if it didn't - pub async fn insert_message_if_unique(&self, message: NewStoredMessage) -> Result { - self.connection - .with_connection_async(move |conn| { - match diesel::insert_into(stored_messages::table) - .values(message) - .execute(conn) - { - Ok(_) => Ok(false), - Err(diesel::result::Error::DatabaseError(kind, e_info)) => match kind { - DatabaseErrorKind::UniqueViolation => Ok(true), - _ => Err(diesel::result::Error::DatabaseError(kind, e_info).into()), - }, - Err(e) => Err(e.into()), - } - }) - .await + pub fn insert_message_if_unique(&self, message: NewStoredMessage) -> Result { + let conn = self.connection.get_pooled_connection()?; + match diesel::insert_into(stored_messages::table) + .values(message) + .execute(&conn) + { + Ok(_) => Ok(false), + Err(diesel::result::Error::DatabaseError(kind, e_info)) => match kind { + DatabaseErrorKind::UniqueViolation => Ok(true), + _ => Err(diesel::result::Error::DatabaseError(kind, e_info).into()), + }, + Err(e) => Err(e.into()), + } } - pub async fn remove_message(&self, message_ids: Vec) -> Result { - self.connection - .with_connection_async(move |conn| { - diesel::delete(stored_messages::table) - .filter(stored_messages::id.eq_any(message_ids)) - .execute(conn) - .map_err(Into::into) - }) - .await + pub fn remove_message(&self, message_ids: Vec) -> Result { + let conn = self.connection.get_pooled_connection()?; + diesel::delete(stored_messages::table) + .filter(stored_messages::id.eq_any(message_ids)) + .execute(&conn) + .map_err(Into::into) } - pub async fn find_messages_for_peer( + pub fn find_messages_for_peer( &self, public_key: &CommsPublicKey, node_id: &NodeId, @@ -82,85 +76,76 @@ impl StoreAndForwardDatabase { ) -> Result, StorageError> { let pk_hex = public_key.to_hex(); let node_id_hex = node_id.to_hex(); - self.connection - .with_connection_async::<_, Vec>(move |conn| { - let mut query = stored_messages::table - .select(stored_messages::all_columns) - .filter( - stored_messages::destination_pubkey - .eq(pk_hex) - .or(stored_messages::destination_node_id.eq(node_id_hex)), - ) - .filter(stored_messages::message_type.eq(DhtMessageType::None as i32)) - .into_boxed(); + let conn = self.connection.get_pooled_connection()?; + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter( + stored_messages::destination_pubkey + .eq(pk_hex) + .or(stored_messages::destination_node_id.eq(node_id_hex)), + ) + .filter(stored_messages::message_type.eq(DhtMessageType::None as i32)) + .into_boxed(); - if let Some(since) = since { - query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); - } + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); + } - query - .order_by(stored_messages::stored_at.desc()) - .limit(limit) - .get_results(conn) - .map_err(Into::into) - }) - .await + query + .order_by(stored_messages::stored_at.desc()) + .limit(limit) + .get_results(&conn) + .map_err(Into::into) } - pub async fn find_anonymous_messages( + pub fn find_anonymous_messages( &self, since: Option>, limit: i64, ) -> Result, StorageError> { - self.connection - .with_connection_async(move |conn| { - let mut query = stored_messages::table - .select(stored_messages::all_columns) - .filter(stored_messages::origin_pubkey.is_null()) - .filter(stored_messages::destination_pubkey.is_null()) - .filter(stored_messages::is_encrypted.eq(true)) - .filter(stored_messages::message_type.eq(DhtMessageType::None as i32)) - .into_boxed(); + let conn = self.connection.get_pooled_connection()?; + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter(stored_messages::origin_pubkey.is_null()) + .filter(stored_messages::destination_pubkey.is_null()) + .filter(stored_messages::is_encrypted.eq(true)) + .filter(stored_messages::message_type.eq(DhtMessageType::None as i32)) + .into_boxed(); - if let Some(since) = since { - query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); - } + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); + } - query - .order_by(stored_messages::stored_at.desc()) - .limit(limit) - .get_results(conn) - .map_err(Into::into) - }) - .await + query + .order_by(stored_messages::stored_at.desc()) + .limit(limit) + .get_results(&conn) + .map_err(Into::into) } - pub async fn find_join_messages( + pub fn find_join_messages( &self, since: Option>, limit: i64, ) -> Result, StorageError> { - self.connection - .with_connection_async(move |conn| { - let mut query = stored_messages::table - .select(stored_messages::all_columns) - .filter(stored_messages::message_type.eq(DhtMessageType::Join as i32)) - .into_boxed(); + let conn = self.connection.get_pooled_connection()?; + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter(stored_messages::message_type.eq(DhtMessageType::Join as i32)) + .into_boxed(); - if let Some(since) = since { - query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); - } + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); + } - query - .order_by(stored_messages::stored_at.desc()) - .limit(limit) - .get_results(conn) - .map_err(Into::into) - }) - .await + query + .order_by(stored_messages::stored_at.desc()) + .limit(limit) + .get_results(&conn) + .map_err(Into::into) } - pub async fn find_messages_of_type_for_pubkey( + pub fn find_messages_of_type_for_pubkey( &self, public_key: &CommsPublicKey, message_type: DhtMessageType, @@ -168,87 +153,72 @@ impl StoreAndForwardDatabase { limit: i64, ) -> Result, StorageError> { let pk_hex = public_key.to_hex(); - self.connection - .with_connection_async(move |conn| { - let mut query = stored_messages::table - .select(stored_messages::all_columns) - .filter(stored_messages::destination_pubkey.eq(pk_hex)) - .filter(stored_messages::message_type.eq(message_type as i32)) - .into_boxed(); + let conn = self.connection.get_pooled_connection()?; + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter(stored_messages::destination_pubkey.eq(pk_hex)) + .filter(stored_messages::message_type.eq(message_type as i32)) + .into_boxed(); - if let Some(since) = since { - query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); - } + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); + } - query - .order_by(stored_messages::stored_at.desc()) - .limit(limit) - .get_results(conn) - .map_err(Into::into) - }) - .await + query + .order_by(stored_messages::stored_at.desc()) + .limit(limit) + .get_results(&conn) + .map_err(Into::into) } #[cfg(test)] - pub(crate) async fn get_all_messages(&self) -> Result, StorageError> { - self.connection - .with_connection_async(|conn| { - stored_messages::table - .select(stored_messages::all_columns) - .get_results(conn) - .map_err(Into::into) - }) - .await + pub(crate) fn get_all_messages(&self) -> Result, StorageError> { + let conn = self.connection.get_pooled_connection()?; + stored_messages::table + .select(stored_messages::all_columns) + .get_results(&conn) + .map_err(Into::into) } - pub(crate) async fn delete_messages_with_priority_older_than( + pub(crate) fn delete_messages_with_priority_older_than( &self, priority: StoredMessagePriority, since: NaiveDateTime, ) -> Result { - self.connection - .with_connection_async(move |conn| { - diesel::delete(stored_messages::table) - .filter(stored_messages::stored_at.lt(since)) - .filter(stored_messages::priority.eq(priority as i32)) - .execute(conn) - .map_err(Into::into) - }) - .await + let conn = self.connection.get_pooled_connection()?; + diesel::delete(stored_messages::table) + .filter(stored_messages::stored_at.lt(since)) + .filter(stored_messages::priority.eq(priority as i32)) + .execute(&conn) + .map_err(Into::into) } - pub(crate) async fn delete_messages_older_than(&self, since: NaiveDateTime) -> Result { - self.connection - .with_connection_async(move |conn| { - diesel::delete(stored_messages::table) - .filter(stored_messages::stored_at.lt(since)) - .execute(conn) - .map_err(Into::into) - }) - .await + pub(crate) fn delete_messages_older_than(&self, since: NaiveDateTime) -> Result { + let conn = self.connection.get_pooled_connection()?; + diesel::delete(stored_messages::table) + .filter(stored_messages::stored_at.lt(since)) + .execute(&conn) + .map_err(Into::into) } - pub(crate) async fn truncate_messages(&self, max_size: usize) -> Result { - self.connection - .with_connection_async(move |conn| { - let mut num_removed = 0; - let msg_count = stored_messages::table - .select(dsl::count(stored_messages::id)) - .first::(conn)? as usize; - if msg_count > max_size { - let remove_count = msg_count - max_size; - let message_ids: Vec = stored_messages::table - .select(stored_messages::id) - .order_by(stored_messages::stored_at.asc()) - .limit(remove_count as i64) - .get_results(conn)?; - num_removed = diesel::delete(stored_messages::table) - .filter(stored_messages::id.eq_any(message_ids)) - .execute(conn)?; - } - Ok(num_removed) - }) - .await + pub(crate) fn truncate_messages(&self, max_size: usize) -> Result { + let mut num_removed = 0; + let conn = self.connection.get_pooled_connection()?; + let msg_count = stored_messages::table + .select(dsl::count(stored_messages::id)) + .first::(&conn)? as usize; + if msg_count > max_size { + let remove_count = msg_count - max_size; + let message_ids: Vec = stored_messages::table + .select(stored_messages::id) + .order_by(stored_messages::stored_at.asc()) + .limit(remove_count as i64) + .get_results(&conn)?; + num_removed = diesel::delete(stored_messages::table) + .filter(stored_messages::id.eq_any(message_ids)) + .execute(&conn)?; + } + Ok(num_removed) } } @@ -260,8 +230,8 @@ mod test { #[runtime::test] async fn insert_messages() { - let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); - conn.migrate().await.unwrap(); + let conn = DbConnection::connect_memory(random::string(8)).unwrap(); + conn.migrate().unwrap(); let db = StoreAndForwardDatabase::new(conn); let mut msg1 = NewStoredMessage::default(); msg1.body_hash.push('1'); @@ -269,10 +239,10 @@ mod test { msg2.body_hash.push('2'); let mut msg3 = NewStoredMessage::default(); msg3.body_hash.push('2'); // Duplicate message - db.insert_message_if_unique(msg1.clone()).await.unwrap(); - db.insert_message_if_unique(msg2.clone()).await.unwrap(); - db.insert_message_if_unique(msg3.clone()).await.unwrap(); - let messages = db.get_all_messages().await.unwrap(); + db.insert_message_if_unique(msg1.clone()).unwrap(); + db.insert_message_if_unique(msg2.clone()).unwrap(); + db.insert_message_if_unique(msg3.clone()).unwrap(); + let messages = db.get_all_messages().unwrap(); assert_eq!(messages.len(), 2); assert_eq!(messages[0].body_hash, msg1.body_hash); assert_eq!(messages[1].body_hash, msg2.body_hash); @@ -280,8 +250,8 @@ mod test { #[runtime::test] async fn remove_messages() { - let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); - conn.migrate().await.unwrap(); + let conn = DbConnection::connect_memory(random::string(8)).unwrap(); + conn.migrate().unwrap(); let db = StoreAndForwardDatabase::new(conn); // Create 3 unique messages let mut msg1 = NewStoredMessage::default(); @@ -290,25 +260,25 @@ mod test { msg2.body_hash.push('2'); let mut msg3 = NewStoredMessage::default(); msg3.body_hash.push('3'); - db.insert_message_if_unique(msg1.clone()).await.unwrap(); - db.insert_message_if_unique(msg2.clone()).await.unwrap(); - db.insert_message_if_unique(msg3.clone()).await.unwrap(); - let messages = db.get_all_messages().await.unwrap(); + db.insert_message_if_unique(msg1.clone()).unwrap(); + db.insert_message_if_unique(msg2.clone()).unwrap(); + db.insert_message_if_unique(msg3.clone()).unwrap(); + let messages = db.get_all_messages().unwrap(); assert_eq!(messages.len(), 3); let msg1_id = messages[0].id; let msg2_id = messages[1].id; let msg3_id = messages[2].id; - db.remove_message(vec![msg1_id, msg3_id]).await.unwrap(); - let messages = db.get_all_messages().await.unwrap(); + db.remove_message(vec![msg1_id, msg3_id]).unwrap(); + let messages = db.get_all_messages().unwrap(); assert_eq!(messages.len(), 1); assert_eq!(messages[0].id, msg2_id); } #[runtime::test] async fn truncate_messages() { - let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); - conn.migrate().await.unwrap(); + let conn = DbConnection::connect_memory(random::string(8)).unwrap(); + conn.migrate().unwrap(); let db = StoreAndForwardDatabase::new(conn); let mut msg1 = NewStoredMessage::default(); msg1.body_hash.push('1'); @@ -318,13 +288,13 @@ mod test { msg3.body_hash.push('3'); let mut msg4 = NewStoredMessage::default(); msg4.body_hash.push('4'); - db.insert_message_if_unique(msg1.clone()).await.unwrap(); - db.insert_message_if_unique(msg2.clone()).await.unwrap(); - db.insert_message_if_unique(msg3.clone()).await.unwrap(); - db.insert_message_if_unique(msg4.clone()).await.unwrap(); - let num_removed = db.truncate_messages(2).await.unwrap(); + db.insert_message_if_unique(msg1.clone()).unwrap(); + db.insert_message_if_unique(msg2.clone()).unwrap(); + db.insert_message_if_unique(msg3.clone()).unwrap(); + db.insert_message_if_unique(msg4.clone()).unwrap(); + let num_removed = db.truncate_messages(2).unwrap(); assert_eq!(num_removed, 2); - let messages = db.get_all_messages().await.unwrap(); + let messages = db.get_all_messages().unwrap(); assert_eq!(messages.len(), 2); assert_eq!(messages[0].body_hash, msg3.body_hash); assert_eq!(messages[1].body_hash, msg4.body_hash); diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs index 91bff5037c..3e2e26797b 100644 --- a/comms/dht/src/store_forward/service.rs +++ b/comms/dht/src/store_forward/service.rs @@ -243,7 +243,7 @@ impl StoreAndForwardService { }, _ = cleanup_ticker.tick() => { - if let Err(err) = self.cleanup().await { + if let Err(err) = self.cleanup() { error!(target: LOG_TARGET, "Error when performing store and forward cleanup: {:?}", err); } }, @@ -267,7 +267,7 @@ impl StoreAndForwardService { use StoreAndForwardRequest::*; trace!(target: LOG_TARGET, "Request: {:?}", request); match request { - FetchMessages(query, reply_tx) => match self.handle_fetch_message_query(query).await { + FetchMessages(query, reply_tx) => match self.handle_fetch_message_query(query) { Ok(messages) => { let _ = reply_tx.send(Ok(messages)); }, @@ -282,7 +282,7 @@ impl StoreAndForwardService { InsertMessage(msg, reply_tx) => { let public_key = msg.destination_pubkey.clone(); let node_id = msg.destination_node_id.clone(); - match self.database.insert_message_if_unique(msg).await { + match self.database.insert_message_if_unique(msg) { Ok(existed) => { let pub_key = public_key .map(|p| format!("public key '{}'", p)) @@ -301,7 +301,7 @@ impl StoreAndForwardService { }, } }, - RemoveMessages(message_ids) => match self.database.remove_message(message_ids.clone()).await { + RemoveMessages(message_ids) => match self.database.remove_message(message_ids.clone()) { Ok(_) => trace!(target: LOG_TARGET, "Removed messages: {:?}", message_ids), Err(err) => error!(target: LOG_TARGET, "RemoveMessage failed because '{:?}'", err), }, @@ -319,7 +319,7 @@ impl StoreAndForwardService { } }, RemoveMessagesOlderThan(threshold) => { - match self.database.delete_messages_older_than(threshold.naive_utc()).await { + match self.database.delete_messages_older_than(threshold.naive_utc()) { Ok(_) => trace!(target: LOG_TARGET, "Removed messages older than {}", threshold), Err(err) => error!(target: LOG_TARGET, "RemoveMessage failed because '{:?}'", err), } @@ -453,7 +453,7 @@ impl StoreAndForwardService { } } - async fn handle_fetch_message_query(&self, query: FetchStoredMessageQuery) -> SafResult> { + fn handle_fetch_message_query(&self, query: FetchStoredMessageQuery) -> SafResult> { use SafResponseType::*; let limit = i64::try_from(self.config.max_returned_messages) .ok() @@ -461,47 +461,34 @@ impl StoreAndForwardService { .unwrap(); let db = &self.database; let messages = match query.response_type { - ForMe => { - db.find_messages_for_peer(&query.public_key, &query.node_id, query.since, limit) - .await? - }, - Join => db.find_join_messages(query.since, limit).await?, + ForMe => db.find_messages_for_peer(&query.public_key, &query.node_id, query.since, limit)?, + Join => db.find_join_messages(query.since, limit)?, Discovery => { - db.find_messages_of_type_for_pubkey(&query.public_key, DhtMessageType::Discovery, query.since, limit) - .await? + db.find_messages_of_type_for_pubkey(&query.public_key, DhtMessageType::Discovery, query.since, limit)? }, - Anonymous => db.find_anonymous_messages(query.since, limit).await?, + Anonymous => db.find_anonymous_messages(query.since, limit)?, }; Ok(messages) } - async fn cleanup(&mut self) -> SafResult<()> { + fn cleanup(&mut self) -> SafResult<()> { self.local_state .garbage_collect(self.config.max_inflight_request_age * 2); - let num_removed = self - .database - .delete_messages_with_priority_older_than( - StoredMessagePriority::Low, - since(self.config.low_priority_msg_storage_ttl), - ) - .await?; + let num_removed = self.database.delete_messages_with_priority_older_than( + StoredMessagePriority::Low, + since(self.config.low_priority_msg_storage_ttl), + )?; debug!(target: LOG_TARGET, "Cleaned {} old low priority messages", num_removed); - let num_removed = self - .database - .delete_messages_with_priority_older_than( - StoredMessagePriority::High, - since(self.config.high_priority_msg_storage_ttl), - ) - .await?; + let num_removed = self.database.delete_messages_with_priority_older_than( + StoredMessagePriority::High, + since(self.config.high_priority_msg_storage_ttl), + )?; debug!(target: LOG_TARGET, "Cleaned {} old high priority messages", num_removed); - let num_removed = self - .database - .truncate_messages(self.config.msg_storage_capacity) - .await?; + let num_removed = self.database.truncate_messages(self.config.msg_storage_capacity)?; if num_removed > 0 { debug!( target: LOG_TARGET,