From 81f01d228cc425fb859fd07e08dfb34f03e1bd22 Mon Sep 17 00:00:00 2001 From: Stan Bondi Date: Tue, 19 Oct 2021 00:50:21 +0400 Subject: [PATCH] fix: validate dht header before dedup cache (#3468) Description --- - reorders the DHT messaging layers to validate the message before entering the dedup store. - adds the origin_mac to the dedup hash This PR was written by @Impala123, I finished off a rust integration test Motivation and Context --- From original PR: #3450 > With the current order of layers, a malicious node could tamper with a message which would then be discarded by the > validation layer. However the dedup cache currently stores this before it is discarded by validate. Thus any un-tampered > version of the same message would no longer be processed. A valid origin mac means the message comes from the possessor of the private key and has not been altered. The valid origin mac bytes are included in the dedup hash preimage so that the origin of the message (if any) is tied to the dedup entry. Previously, an attacker could craft a message `A'` that had no/different valid origin MAC but the same body and cause a subsequent message `A` to not to be discarded as a duplicate. How Has This Been Tested? --- Rust integration test memorynet --- .../protocols/transaction_send_protocol.rs | 2 +- .../tasks/send_finalized_transaction.rs | 2 +- .../tasks/send_transaction_cancelled.rs | 2 +- .../tasks/send_transaction_reply.rs | 2 +- comms/dht/examples/memory_net/utilities.rs | 4 +- comms/dht/src/dedup/mod.rs | 43 ++--- comms/dht/src/dht.rs | 14 +- comms/dht/src/inbound/message.rs | 14 +- comms/dht/src/outbound/message.rs | 4 + comms/dht/src/outbound/message_params.rs | 2 +- comms/dht/tests/dht.rs | 179 +++++++++++++++++- 11 files changed, 224 insertions(+), 44 deletions(-) diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs index 87036b199c..e967a51044 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs @@ -688,7 +688,7 @@ where .outbound_message_service .closest_broadcast( NodeId::from_public_key(&self.dest_pubkey), - OutboundEncryption::EncryptFor(Box::new(self.dest_pubkey.clone())), + OutboundEncryption::encrypt_for(self.dest_pubkey.clone()), vec![], OutboundDomainMessage::new(TariMessageType::SenderPartialTransaction, proto_message), ) diff --git a/base_layer/wallet/src/transaction_service/tasks/send_finalized_transaction.rs b/base_layer/wallet/src/transaction_service/tasks/send_finalized_transaction.rs index cf723bbfad..6603457751 100644 --- a/base_layer/wallet/src/transaction_service/tasks/send_finalized_transaction.rs +++ b/base_layer/wallet/src/transaction_service/tasks/send_finalized_transaction.rs @@ -215,7 +215,7 @@ async fn send_transaction_finalized_message_store_and_forward( match outbound_message_service .closest_broadcast( NodeId::from_public_key(&destination_pubkey), - OutboundEncryption::EncryptFor(Box::new(destination_pubkey.clone())), + OutboundEncryption::encrypt_for(destination_pubkey.clone()), vec![], OutboundDomainMessage::new(TariMessageType::TransactionFinalized, msg.clone()), ) diff --git a/base_layer/wallet/src/transaction_service/tasks/send_transaction_cancelled.rs b/base_layer/wallet/src/transaction_service/tasks/send_transaction_cancelled.rs index f7049db69a..8f9138a5b0 100644 --- a/base_layer/wallet/src/transaction_service/tasks/send_transaction_cancelled.rs +++ b/base_layer/wallet/src/transaction_service/tasks/send_transaction_cancelled.rs @@ -48,7 +48,7 @@ pub async fn send_transaction_cancelled_message( let _ = outbound_message_service .closest_broadcast( NodeId::from_public_key(&destination_public_key), - OutboundEncryption::EncryptFor(Box::new(destination_public_key)), + OutboundEncryption::encrypt_for(destination_public_key), vec![], OutboundDomainMessage::new(TariMessageType::SenderPartialTransaction, proto_message), ) diff --git a/base_layer/wallet/src/transaction_service/tasks/send_transaction_reply.rs b/base_layer/wallet/src/transaction_service/tasks/send_transaction_reply.rs index adc26c4ab2..e8e09ad8c4 100644 --- a/base_layer/wallet/src/transaction_service/tasks/send_transaction_reply.rs +++ b/base_layer/wallet/src/transaction_service/tasks/send_transaction_reply.rs @@ -196,7 +196,7 @@ async fn send_transaction_reply_store_and_forward( match outbound_message_service .closest_broadcast( NodeId::from_public_key(&destination_pubkey), - OutboundEncryption::EncryptFor(Box::new(destination_pubkey.clone())), + OutboundEncryption::encrypt_for(destination_pubkey.clone()), vec![], OutboundDomainMessage::new(TariMessageType::ReceiverPartialTransactionReply, msg), ) diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index bb6cf8f55a..4b875675b8 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -436,7 +436,7 @@ pub async fn do_store_and_forward_message_propagation( .outbound_requester() .closest_broadcast( node_identity.node_id().clone(), - OutboundEncryption::EncryptFor(Box::new(node_identity.public_key().clone())), + OutboundEncryption::encrypt_for(node_identity.public_key().clone()), vec![], OutboundDomainMessage::new(123i32, secret_message.clone()), ) @@ -716,7 +716,7 @@ impl TestNode { loop { match conn_man_event_sub.recv().await { Ok(event) => { - events_tx.send(logger(event)).await.unwrap(); + let _ = events_tx.send(logger(event)).await; }, Err(broadcast::error::RecvError::Closed) => break, Err(err) => log::error!("{}", err), diff --git a/comms/dht/src/dedup/mod.rs b/comms/dht/src/dedup/mod.rs index 9c7ec3684d..f178b804dc 100644 --- a/comms/dht/src/dedup/mod.rs +++ b/comms/dht/src/dedup/mod.rs @@ -24,24 +24,19 @@ mod dedup_cache; pub use dedup_cache::DedupCacheDatabase; -use crate::{actor::DhtRequester, inbound::DhtInboundMessage}; -use digest::Digest; +use crate::{actor::DhtRequester, inbound::DecryptedDhtMessage}; use futures::{future::BoxFuture, task::Context}; use log::*; use std::task::Poll; -use tari_comms::{pipeline::PipelineError, types::Challenge}; +use tari_comms::pipeline::PipelineError; use tari_utilities::hex::Hex; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::dedup"; -fn hash_inbound_message(message: &DhtInboundMessage) -> Vec { - Challenge::new().chain(&message.body).finalize().to_vec() -} - /// # DHT Deduplication middleware /// -/// Takes in a `DhtInboundMessage` and checks the message signature cache for duplicates. +/// Takes in a `DecryptedDhtMessage` and checks the message signature cache for duplicates. /// If a duplicate message is detected, it is discarded. #[derive(Clone)] pub struct DedupMiddleware { @@ -60,9 +55,9 @@ impl DedupMiddleware { } } -impl Service for DedupMiddleware +impl Service for DedupMiddleware where - S: Service + Clone + Send + 'static, + S: Service + Clone + Send + 'static, S::Future: Send, { type Error = PipelineError; @@ -73,22 +68,21 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, mut message: DhtInboundMessage) -> Self::Future { + fn call(&mut self, mut message: DecryptedDhtMessage) -> Self::Future { let next_service = self.next_service.clone(); let mut dht_requester = self.dht_requester.clone(); let allowed_message_occurrences = self.allowed_message_occurrences; Box::pin(async move { - let hash = hash_inbound_message(&message); trace!( target: LOG_TARGET, "Inserting message hash {} for message {} (Trace: {})", - hash.to_hex(), + message.hash.to_hex(), message.tag, message.dht_header.message_tag ); message.dedup_hit_count = dht_requester - .add_message_to_dedup_cache(hash, message.source_peer.public_key.clone()) + .add_message_to_dedup_cache(message.hash.clone(), message.source_peer.public_key.clone()) .await?; if message.dedup_hit_count as usize > allowed_message_occurrences { @@ -144,6 +138,7 @@ mod test { envelope::DhtMessageFlags, test_utils::{create_dht_actor_mock, make_dht_inbound_message, make_node_identity, service_spy}, }; + use tari_comms::wrap_in_envelope_body; use tari_test_utils::panic_context; use tokio::runtime::Runtime; @@ -163,13 +158,14 @@ mod test { assert!(dedup.poll_ready(&mut cx).is_ready()); let node_identity = make_node_identity(); - let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false, false); + let inbound_message = make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty(), false, false); + let decrypted_msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, inbound_message); - rt.block_on(dedup.call(msg.clone())).unwrap(); + rt.block_on(dedup.call(decrypted_msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); mock_state.set_number_of_message_hits(4); - rt.block_on(dedup.call(msg)).unwrap(); + rt.block_on(dedup.call(decrypted_msg)).unwrap(); assert_eq!(spy.call_count(), 1); // Drop dedup so that the DhtMock will stop running drop(dedup); @@ -179,28 +175,29 @@ mod test { fn deterministic_hash() { const TEST_MSG: &[u8] = b"test123"; const EXPECTED_HASH: &str = "90cccd774db0ac8c6ea2deff0e26fc52768a827c91c737a2e050668d8c39c224"; + let node_identity = make_node_identity(); - let msg = make_dht_inbound_message( + let dht_message = make_dht_inbound_message( &node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false, false, ); - let hash1 = hash_inbound_message(&msg); + let decrypted1 = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, dht_message); let node_identity = make_node_identity(); - let msg = make_dht_inbound_message( + let dht_message = make_dht_inbound_message( &node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false, false, ); - let hash2 = hash_inbound_message(&msg); + let decrypted2 = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, dht_message); - assert_eq!(hash1, hash2); - let subjects = &[hash1, hash2]; + assert_eq!(decrypted1.hash, decrypted2.hash); + let subjects = &[decrypted1.hash, decrypted2.hash]; assert!(subjects.iter().all(|h| h.to_hex() == EXPECTED_HASH)); } } diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 1b762e02bc..cae6bbbfb3 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -295,21 +295,21 @@ impl Dht { ServiceBuilder::new() .layer(MetricsLayer::new(self.metrics_collector.clone())) .layer(inbound::DeserializeLayer::new(self.peer_manager.clone())) + .layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter())) + .layer(inbound::DecryptionLayer::new( + self.config.clone(), + self.node_identity.clone(), + self.connectivity.clone(), + )) .layer(DedupLayer::new( self.dht_requester(), self.config.dedup_allowed_message_occurrences, )) - .layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter())) + .layer(filter::FilterLayer::new(filter_messages_to_rebroadcast)) .layer(MessageLoggingLayer::new(format!( "Inbound [{}]", self.node_identity.node_id().short_str() ))) - .layer(inbound::DecryptionLayer::new( - self.config.clone(), - self.node_identity.clone(), - self.connectivity.clone(), - )) - .layer(filter::FilterLayer::new(filter_messages_to_rebroadcast)) .layer(store_forward::StoreLayer::new( self.config.clone(), Arc::clone(&self.peer_manager), diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index a2f75b755f..048b9c65e7 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::envelope::{DhtMessageFlags, DhtMessageHeader}; +use digest::Digest; use std::{ fmt, fmt::{Display, Formatter}, @@ -29,9 +30,17 @@ use std::{ use tari_comms::{ message::{EnvelopeBody, MessageTag}, peer_manager::Peer, - types::CommsPublicKey, + types::{Challenge, CommsPublicKey}, }; +fn hash_inbound_message(message: &DhtInboundMessage) -> Vec { + Challenge::new() + .chain(&message.dht_header.origin_mac) + .chain(&message.body) + .finalize() + .to_vec() +} + #[derive(Debug, Clone)] pub struct DhtInboundMessage { pub tag: MessageTag, @@ -84,6 +93,7 @@ pub struct DecryptedDhtMessage { pub is_already_forwarded: bool, pub decryption_result: Result>, pub dedup_hit_count: u32, + pub hash: Vec, } impl DecryptedDhtMessage { @@ -104,6 +114,7 @@ impl DecryptedDhtMessage { message: DhtInboundMessage, ) -> Self { Self { + hash: hash_inbound_message(&message), tag: message.tag, source_peer: message.source_peer, authenticated_origin, @@ -118,6 +129,7 @@ impl DecryptedDhtMessage { pub fn failed(message: DhtInboundMessage) -> Self { Self { + hash: hash_inbound_message(&message), tag: message.tag, source_peer: message.source_peer, authenticated_origin: None, diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index 74e8b8c720..dc8bd21e85 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -46,6 +46,10 @@ pub enum OutboundEncryption { } impl OutboundEncryption { + pub fn encrypt_for(public_key: CommsPublicKey) -> Self { + OutboundEncryption::EncryptFor(Box::new(public_key)) + } + /// Return the correct DHT flags for the encryption setting pub fn flags(&self) -> DhtMessageFlags { match self { diff --git a/comms/dht/src/outbound/message_params.rs b/comms/dht/src/outbound/message_params.rs index 3b38272c38..81d92ab19e 100644 --- a/comms/dht/src/outbound/message_params.rs +++ b/comms/dht/src/outbound/message_params.rs @@ -40,7 +40,7 @@ use tari_comms::{message::MessageTag, peer_manager::NodeId, types::CommsPublicKe /// let dest_public_key = CommsPublicKey::default(); /// let params = SendMessageParams::new() /// .random(5) -/// .with_encryption(OutboundEncryption::EncryptFor(Box::new(dest_public_key))) +/// .with_encryption(OutboundEncryption::encrypt_for(dest_public_key)) /// .finish(); /// ``` #[derive(Debug, Clone)] diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index cebedac101..9ae42cb296 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -403,13 +403,12 @@ async fn dht_store_forward() { .await .unwrap(); - let dest_public_key = Box::new(node_C_node_identity.public_key().clone()); let params = SendMessageParams::new() .broadcast(vec![]) - .with_encryption(OutboundEncryption::EncryptFor(dest_public_key)) - .with_destination(NodeDestination::NodeId(Box::new( - node_C_node_identity.node_id().clone(), - ))) + .with_encryption(OutboundEncryption::encrypt_for( + node_C_node_identity.public_key().clone(), + )) + .with_destination(node_C_node_identity.node_id().clone().into()) .finish(); let secret_msg1 = b"NCZW VUSX PNYM INHZ XMQX SFWX WLKJ AHSH"; @@ -570,7 +569,7 @@ async fn dht_propagate_dedup() { .outbound_requester() .propagate( NodeDestination::Unknown, - OutboundEncryption::EncryptFor(Box::new(node_D.node_identity().public_key().clone())), + OutboundEncryption::encrypt_for(node_D.node_identity().public_key().clone()), vec![], out_msg, ) @@ -623,6 +622,174 @@ async fn dht_propagate_dedup() { assert_eq!(count_messages_received(&received, &[&node_C_id]), 1); } +#[tokio::test] +#[allow(non_snake_case)] +async fn dht_do_not_store_invalid_message_in_dedup() { + let mut config = dht_config(); + config.dedup_allowed_message_occurrences = 1; + + // Node C receives messages from A and B + let mut node_C = make_node("node_B", PeerFeatures::COMMUNICATION_NODE, config.clone(), None).await; + + // Node B forwards a message from A but modifies it + let mut node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_C.to_peer()), + ) + .await; + + // Node A creates a message sends it to B, B modifies it, sends it to C; Node A sends message to C + let node_A = make_node("node_A", PeerFeatures::COMMUNICATION_NODE, config.clone(), [ + node_B.to_peer(), + node_C.to_peer(), + ]) + .await; + + log::info!( + "NodeA = {}, NodeB = {}, NodeC = {}", + node_A.node_identity().node_id().short_str(), + node_B.node_identity().node_id().short_str(), + node_C.node_identity().node_id().short_str(), + ); + + // Connect the peers that should be connected + node_A + .comms + .connectivity() + .dial_peer(node_B.node_identity().node_id().clone()) + .await + .unwrap(); + + node_A + .comms + .connectivity() + .dial_peer(node_C.node_identity().node_id().clone()) + .await + .unwrap(); + + node_B + .comms + .connectivity() + .dial_peer(node_C.node_identity().node_id().clone()) + .await + .unwrap(); + + let mut node_C_messaging = node_C.messaging_events.subscribe(); + + #[derive(Clone, PartialEq, ::prost::Message)] + struct Person { + #[prost(string, tag = "1")] + name: String, + #[prost(uint32, tag = "2")] + age: u32, + } + + // Just a message to test connectivity between Node A -> Node C, and to get the header from + let out_msg = OutboundDomainMessage::new(123, Person { + name: "John Conway".into(), + age: 82, + }); + + node_A + .dht + .outbound_requester() + .send_message( + SendMessageParams::new() + .direct_node_id(node_B.node_identity().node_id().clone()) + .with_destination(node_C.node_identity().public_key().clone().into()) + .force_origin() + .finish(), + out_msg, + ) + .await + .unwrap(); + + // Get the message that was received by Node B + let mut msg = node_B.next_inbound_message(Duration::from_secs(10)).await.unwrap(); + let bytes = msg.decryption_result.unwrap().to_encoded_bytes(); + + // Clone header without modification + let header_unmodified = msg.dht_header.clone(); + + // Modify the header + msg.dht_header.message_type = DhtMessageType::from_i32(3i32).unwrap(); + + // Forward modified message to Node C - Should get us banned + node_B + .dht + .outbound_requester() + .send_raw( + SendMessageParams::new() + .direct_node_id(node_C.node_identity().node_id().clone()) + .with_dht_header(msg.dht_header) + .finish(), + bytes.clone(), + ) + .await + .unwrap(); + + async_assert_eventually!( + { + let n = node_C + .comms + .peer_manager() + .find_by_node_id(node_B.node_identity().node_id()) + .await + .unwrap(); + n.is_banned() + }, + expect = true, + max_attempts = 10, + interval = Duration::from_secs(3) + ); + + node_A + .dht + .outbound_requester() + .send_raw( + SendMessageParams::new() + .direct_node_id(node_C.node_identity().node_id().clone()) + .with_dht_header(header_unmodified) + .finish(), + bytes, + ) + .await + .unwrap(); + + // Node C receives the correct message from Node A + let msg = node_C + .next_inbound_message(Duration::from_secs(10)) + .await + .expect("Node C expected an inbound message but it never arrived"); + assert!(msg.decryption_succeeded()); + log::info!("Received message {}", msg.tag); + let person = msg + .decryption_result + .unwrap() + .decode_part::(1) + .unwrap() + .unwrap(); + assert_eq!(person.name, "John Conway"); + + let node_A_id = node_A.node_identity().node_id().clone(); + let node_B_id = node_B.node_identity().node_id().clone(); + + node_A.shutdown().await; + node_B.shutdown().await; + node_C.shutdown().await; + + // Check the message flow BEFORE deduping + let received = filter_received(collect_try_recv!(node_C_messaging, timeout = Duration::from_secs(20))); + + let received_from_a = count_messages_received(&received, &[&node_A_id]); + let received_from_b = count_messages_received(&received, &[&node_B_id]); + + assert_eq!(received_from_a, 1); + assert_eq!(received_from_b, 1); +} + #[tokio::test] #[allow(non_snake_case)] async fn dht_repropagate() {