diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs index 87036b199c..e967a51044 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs @@ -688,7 +688,7 @@ where .outbound_message_service .closest_broadcast( NodeId::from_public_key(&self.dest_pubkey), - OutboundEncryption::EncryptFor(Box::new(self.dest_pubkey.clone())), + OutboundEncryption::encrypt_for(self.dest_pubkey.clone()), vec![], OutboundDomainMessage::new(TariMessageType::SenderPartialTransaction, proto_message), ) diff --git a/base_layer/wallet/src/transaction_service/tasks/send_finalized_transaction.rs b/base_layer/wallet/src/transaction_service/tasks/send_finalized_transaction.rs index cf723bbfad..6603457751 100644 --- a/base_layer/wallet/src/transaction_service/tasks/send_finalized_transaction.rs +++ b/base_layer/wallet/src/transaction_service/tasks/send_finalized_transaction.rs @@ -215,7 +215,7 @@ async fn send_transaction_finalized_message_store_and_forward( match outbound_message_service .closest_broadcast( NodeId::from_public_key(&destination_pubkey), - OutboundEncryption::EncryptFor(Box::new(destination_pubkey.clone())), + OutboundEncryption::encrypt_for(destination_pubkey.clone()), vec![], OutboundDomainMessage::new(TariMessageType::TransactionFinalized, msg.clone()), ) diff --git a/base_layer/wallet/src/transaction_service/tasks/send_transaction_cancelled.rs b/base_layer/wallet/src/transaction_service/tasks/send_transaction_cancelled.rs index f7049db69a..8f9138a5b0 100644 --- a/base_layer/wallet/src/transaction_service/tasks/send_transaction_cancelled.rs +++ b/base_layer/wallet/src/transaction_service/tasks/send_transaction_cancelled.rs @@ -48,7 +48,7 @@ pub async fn send_transaction_cancelled_message( let _ = outbound_message_service .closest_broadcast( NodeId::from_public_key(&destination_public_key), - OutboundEncryption::EncryptFor(Box::new(destination_public_key)), + OutboundEncryption::encrypt_for(destination_public_key), vec![], OutboundDomainMessage::new(TariMessageType::SenderPartialTransaction, proto_message), ) diff --git a/base_layer/wallet/src/transaction_service/tasks/send_transaction_reply.rs b/base_layer/wallet/src/transaction_service/tasks/send_transaction_reply.rs index adc26c4ab2..e8e09ad8c4 100644 --- a/base_layer/wallet/src/transaction_service/tasks/send_transaction_reply.rs +++ b/base_layer/wallet/src/transaction_service/tasks/send_transaction_reply.rs @@ -196,7 +196,7 @@ async fn send_transaction_reply_store_and_forward( match outbound_message_service .closest_broadcast( NodeId::from_public_key(&destination_pubkey), - OutboundEncryption::EncryptFor(Box::new(destination_pubkey.clone())), + OutboundEncryption::encrypt_for(destination_pubkey.clone()), vec![], OutboundDomainMessage::new(TariMessageType::ReceiverPartialTransactionReply, msg), ) diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index bb6cf8f55a..4b875675b8 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -436,7 +436,7 @@ pub async fn do_store_and_forward_message_propagation( .outbound_requester() .closest_broadcast( node_identity.node_id().clone(), - OutboundEncryption::EncryptFor(Box::new(node_identity.public_key().clone())), + OutboundEncryption::encrypt_for(node_identity.public_key().clone()), vec![], OutboundDomainMessage::new(123i32, secret_message.clone()), ) @@ -716,7 +716,7 @@ impl TestNode { loop { match conn_man_event_sub.recv().await { Ok(event) => { - events_tx.send(logger(event)).await.unwrap(); + let _ = events_tx.send(logger(event)).await; }, Err(broadcast::error::RecvError::Closed) => break, Err(err) => log::error!("{}", err), diff --git a/comms/dht/src/dedup/mod.rs b/comms/dht/src/dedup/mod.rs index 9c7ec3684d..f178b804dc 100644 --- a/comms/dht/src/dedup/mod.rs +++ b/comms/dht/src/dedup/mod.rs @@ -24,24 +24,19 @@ mod dedup_cache; pub use dedup_cache::DedupCacheDatabase; -use crate::{actor::DhtRequester, inbound::DhtInboundMessage}; -use digest::Digest; +use crate::{actor::DhtRequester, inbound::DecryptedDhtMessage}; use futures::{future::BoxFuture, task::Context}; use log::*; use std::task::Poll; -use tari_comms::{pipeline::PipelineError, types::Challenge}; +use tari_comms::pipeline::PipelineError; use tari_utilities::hex::Hex; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::dedup"; -fn hash_inbound_message(message: &DhtInboundMessage) -> Vec { - Challenge::new().chain(&message.body).finalize().to_vec() -} - /// # DHT Deduplication middleware /// -/// Takes in a `DhtInboundMessage` and checks the message signature cache for duplicates. +/// Takes in a `DecryptedDhtMessage` and checks the message signature cache for duplicates. /// If a duplicate message is detected, it is discarded. #[derive(Clone)] pub struct DedupMiddleware { @@ -60,9 +55,9 @@ impl DedupMiddleware { } } -impl Service for DedupMiddleware +impl Service for DedupMiddleware where - S: Service + Clone + Send + 'static, + S: Service + Clone + Send + 'static, S::Future: Send, { type Error = PipelineError; @@ -73,22 +68,21 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, mut message: DhtInboundMessage) -> Self::Future { + fn call(&mut self, mut message: DecryptedDhtMessage) -> Self::Future { let next_service = self.next_service.clone(); let mut dht_requester = self.dht_requester.clone(); let allowed_message_occurrences = self.allowed_message_occurrences; Box::pin(async move { - let hash = hash_inbound_message(&message); trace!( target: LOG_TARGET, "Inserting message hash {} for message {} (Trace: {})", - hash.to_hex(), + message.hash.to_hex(), message.tag, message.dht_header.message_tag ); message.dedup_hit_count = dht_requester - .add_message_to_dedup_cache(hash, message.source_peer.public_key.clone()) + .add_message_to_dedup_cache(message.hash.clone(), message.source_peer.public_key.clone()) .await?; if message.dedup_hit_count as usize > allowed_message_occurrences { @@ -144,6 +138,7 @@ mod test { envelope::DhtMessageFlags, test_utils::{create_dht_actor_mock, make_dht_inbound_message, make_node_identity, service_spy}, }; + use tari_comms::wrap_in_envelope_body; use tari_test_utils::panic_context; use tokio::runtime::Runtime; @@ -163,13 +158,14 @@ mod test { assert!(dedup.poll_ready(&mut cx).is_ready()); let node_identity = make_node_identity(); - let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false, false); + let inbound_message = make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty(), false, false); + let decrypted_msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, inbound_message); - rt.block_on(dedup.call(msg.clone())).unwrap(); + rt.block_on(dedup.call(decrypted_msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); mock_state.set_number_of_message_hits(4); - rt.block_on(dedup.call(msg)).unwrap(); + rt.block_on(dedup.call(decrypted_msg)).unwrap(); assert_eq!(spy.call_count(), 1); // Drop dedup so that the DhtMock will stop running drop(dedup); @@ -179,28 +175,29 @@ mod test { fn deterministic_hash() { const TEST_MSG: &[u8] = b"test123"; const EXPECTED_HASH: &str = "90cccd774db0ac8c6ea2deff0e26fc52768a827c91c737a2e050668d8c39c224"; + let node_identity = make_node_identity(); - let msg = make_dht_inbound_message( + let dht_message = make_dht_inbound_message( &node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false, false, ); - let hash1 = hash_inbound_message(&msg); + let decrypted1 = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, dht_message); let node_identity = make_node_identity(); - let msg = make_dht_inbound_message( + let dht_message = make_dht_inbound_message( &node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false, false, ); - let hash2 = hash_inbound_message(&msg); + let decrypted2 = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, dht_message); - assert_eq!(hash1, hash2); - let subjects = &[hash1, hash2]; + assert_eq!(decrypted1.hash, decrypted2.hash); + let subjects = &[decrypted1.hash, decrypted2.hash]; assert!(subjects.iter().all(|h| h.to_hex() == EXPECTED_HASH)); } } diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 1b762e02bc..cae6bbbfb3 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -295,21 +295,21 @@ impl Dht { ServiceBuilder::new() .layer(MetricsLayer::new(self.metrics_collector.clone())) .layer(inbound::DeserializeLayer::new(self.peer_manager.clone())) + .layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter())) + .layer(inbound::DecryptionLayer::new( + self.config.clone(), + self.node_identity.clone(), + self.connectivity.clone(), + )) .layer(DedupLayer::new( self.dht_requester(), self.config.dedup_allowed_message_occurrences, )) - .layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter())) + .layer(filter::FilterLayer::new(filter_messages_to_rebroadcast)) .layer(MessageLoggingLayer::new(format!( "Inbound [{}]", self.node_identity.node_id().short_str() ))) - .layer(inbound::DecryptionLayer::new( - self.config.clone(), - self.node_identity.clone(), - self.connectivity.clone(), - )) - .layer(filter::FilterLayer::new(filter_messages_to_rebroadcast)) .layer(store_forward::StoreLayer::new( self.config.clone(), Arc::clone(&self.peer_manager), diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index a2f75b755f..048b9c65e7 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::envelope::{DhtMessageFlags, DhtMessageHeader}; +use digest::Digest; use std::{ fmt, fmt::{Display, Formatter}, @@ -29,9 +30,17 @@ use std::{ use tari_comms::{ message::{EnvelopeBody, MessageTag}, peer_manager::Peer, - types::CommsPublicKey, + types::{Challenge, CommsPublicKey}, }; +fn hash_inbound_message(message: &DhtInboundMessage) -> Vec { + Challenge::new() + .chain(&message.dht_header.origin_mac) + .chain(&message.body) + .finalize() + .to_vec() +} + #[derive(Debug, Clone)] pub struct DhtInboundMessage { pub tag: MessageTag, @@ -84,6 +93,7 @@ pub struct DecryptedDhtMessage { pub is_already_forwarded: bool, pub decryption_result: Result>, pub dedup_hit_count: u32, + pub hash: Vec, } impl DecryptedDhtMessage { @@ -104,6 +114,7 @@ impl DecryptedDhtMessage { message: DhtInboundMessage, ) -> Self { Self { + hash: hash_inbound_message(&message), tag: message.tag, source_peer: message.source_peer, authenticated_origin, @@ -118,6 +129,7 @@ impl DecryptedDhtMessage { pub fn failed(message: DhtInboundMessage) -> Self { Self { + hash: hash_inbound_message(&message), tag: message.tag, source_peer: message.source_peer, authenticated_origin: None, diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index 74e8b8c720..dc8bd21e85 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -46,6 +46,10 @@ pub enum OutboundEncryption { } impl OutboundEncryption { + pub fn encrypt_for(public_key: CommsPublicKey) -> Self { + OutboundEncryption::EncryptFor(Box::new(public_key)) + } + /// Return the correct DHT flags for the encryption setting pub fn flags(&self) -> DhtMessageFlags { match self { diff --git a/comms/dht/src/outbound/message_params.rs b/comms/dht/src/outbound/message_params.rs index 3b38272c38..81d92ab19e 100644 --- a/comms/dht/src/outbound/message_params.rs +++ b/comms/dht/src/outbound/message_params.rs @@ -40,7 +40,7 @@ use tari_comms::{message::MessageTag, peer_manager::NodeId, types::CommsPublicKe /// let dest_public_key = CommsPublicKey::default(); /// let params = SendMessageParams::new() /// .random(5) -/// .with_encryption(OutboundEncryption::EncryptFor(Box::new(dest_public_key))) +/// .with_encryption(OutboundEncryption::encrypt_for(dest_public_key)) /// .finish(); /// ``` #[derive(Debug, Clone)] diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index cebedac101..9ae42cb296 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -403,13 +403,12 @@ async fn dht_store_forward() { .await .unwrap(); - let dest_public_key = Box::new(node_C_node_identity.public_key().clone()); let params = SendMessageParams::new() .broadcast(vec![]) - .with_encryption(OutboundEncryption::EncryptFor(dest_public_key)) - .with_destination(NodeDestination::NodeId(Box::new( - node_C_node_identity.node_id().clone(), - ))) + .with_encryption(OutboundEncryption::encrypt_for( + node_C_node_identity.public_key().clone(), + )) + .with_destination(node_C_node_identity.node_id().clone().into()) .finish(); let secret_msg1 = b"NCZW VUSX PNYM INHZ XMQX SFWX WLKJ AHSH"; @@ -570,7 +569,7 @@ async fn dht_propagate_dedup() { .outbound_requester() .propagate( NodeDestination::Unknown, - OutboundEncryption::EncryptFor(Box::new(node_D.node_identity().public_key().clone())), + OutboundEncryption::encrypt_for(node_D.node_identity().public_key().clone()), vec![], out_msg, ) @@ -623,6 +622,174 @@ async fn dht_propagate_dedup() { assert_eq!(count_messages_received(&received, &[&node_C_id]), 1); } +#[tokio::test] +#[allow(non_snake_case)] +async fn dht_do_not_store_invalid_message_in_dedup() { + let mut config = dht_config(); + config.dedup_allowed_message_occurrences = 1; + + // Node C receives messages from A and B + let mut node_C = make_node("node_B", PeerFeatures::COMMUNICATION_NODE, config.clone(), None).await; + + // Node B forwards a message from A but modifies it + let mut node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_C.to_peer()), + ) + .await; + + // Node A creates a message sends it to B, B modifies it, sends it to C; Node A sends message to C + let node_A = make_node("node_A", PeerFeatures::COMMUNICATION_NODE, config.clone(), [ + node_B.to_peer(), + node_C.to_peer(), + ]) + .await; + + log::info!( + "NodeA = {}, NodeB = {}, NodeC = {}", + node_A.node_identity().node_id().short_str(), + node_B.node_identity().node_id().short_str(), + node_C.node_identity().node_id().short_str(), + ); + + // Connect the peers that should be connected + node_A + .comms + .connectivity() + .dial_peer(node_B.node_identity().node_id().clone()) + .await + .unwrap(); + + node_A + .comms + .connectivity() + .dial_peer(node_C.node_identity().node_id().clone()) + .await + .unwrap(); + + node_B + .comms + .connectivity() + .dial_peer(node_C.node_identity().node_id().clone()) + .await + .unwrap(); + + let mut node_C_messaging = node_C.messaging_events.subscribe(); + + #[derive(Clone, PartialEq, ::prost::Message)] + struct Person { + #[prost(string, tag = "1")] + name: String, + #[prost(uint32, tag = "2")] + age: u32, + } + + // Just a message to test connectivity between Node A -> Node C, and to get the header from + let out_msg = OutboundDomainMessage::new(123, Person { + name: "John Conway".into(), + age: 82, + }); + + node_A + .dht + .outbound_requester() + .send_message( + SendMessageParams::new() + .direct_node_id(node_B.node_identity().node_id().clone()) + .with_destination(node_C.node_identity().public_key().clone().into()) + .force_origin() + .finish(), + out_msg, + ) + .await + .unwrap(); + + // Get the message that was received by Node B + let mut msg = node_B.next_inbound_message(Duration::from_secs(10)).await.unwrap(); + let bytes = msg.decryption_result.unwrap().to_encoded_bytes(); + + // Clone header without modification + let header_unmodified = msg.dht_header.clone(); + + // Modify the header + msg.dht_header.message_type = DhtMessageType::from_i32(3i32).unwrap(); + + // Forward modified message to Node C - Should get us banned + node_B + .dht + .outbound_requester() + .send_raw( + SendMessageParams::new() + .direct_node_id(node_C.node_identity().node_id().clone()) + .with_dht_header(msg.dht_header) + .finish(), + bytes.clone(), + ) + .await + .unwrap(); + + async_assert_eventually!( + { + let n = node_C + .comms + .peer_manager() + .find_by_node_id(node_B.node_identity().node_id()) + .await + .unwrap(); + n.is_banned() + }, + expect = true, + max_attempts = 10, + interval = Duration::from_secs(3) + ); + + node_A + .dht + .outbound_requester() + .send_raw( + SendMessageParams::new() + .direct_node_id(node_C.node_identity().node_id().clone()) + .with_dht_header(header_unmodified) + .finish(), + bytes, + ) + .await + .unwrap(); + + // Node C receives the correct message from Node A + let msg = node_C + .next_inbound_message(Duration::from_secs(10)) + .await + .expect("Node C expected an inbound message but it never arrived"); + assert!(msg.decryption_succeeded()); + log::info!("Received message {}", msg.tag); + let person = msg + .decryption_result + .unwrap() + .decode_part::(1) + .unwrap() + .unwrap(); + assert_eq!(person.name, "John Conway"); + + let node_A_id = node_A.node_identity().node_id().clone(); + let node_B_id = node_B.node_identity().node_id().clone(); + + node_A.shutdown().await; + node_B.shutdown().await; + node_C.shutdown().await; + + // Check the message flow BEFORE deduping + let received = filter_received(collect_try_recv!(node_C_messaging, timeout = Duration::from_secs(20))); + + let received_from_a = count_messages_received(&received, &[&node_A_id]); + let received_from_b = count_messages_received(&received, &[&node_B_id]); + + assert_eq!(received_from_a, 1); + assert_eq!(received_from_b, 1); +} + #[tokio::test] #[allow(non_snake_case)] async fn dht_repropagate() {