Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dht): saf storage uses constructs correct msg hash #4003

Merged
8 changes: 4 additions & 4 deletions comms/dht/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ impl DhtActor {
} => {
let msg_hash_cache = self.msg_hash_dedup_cache.clone();
Box::pin(async move {
match msg_hash_cache.add_body_hash(message_hash, &received_from) {
match msg_hash_cache.add_msg_hash(&message_hash, &received_from) {
Ok(hit_count) => {
let _ = reply_tx.send(hit_count);
},
Expand All @@ -366,7 +366,7 @@ impl DhtActor {
GetMsgHashHitCount(hash, reply_tx) => {
let msg_hash_cache = self.msg_hash_dedup_cache.clone();
Box::pin(async move {
let hit_count = msg_hash_cache.get_hit_count(hash)?;
let hit_count = msg_hash_cache.get_hit_count(&hash)?;
let _ = reply_tx.send(hit_count);
Ok(())
})
Expand Down Expand Up @@ -1043,15 +1043,15 @@ mod test {
for key in &signatures {
let num_hits = actor
.msg_hash_dedup_cache
.add_body_hash(key.clone(), &CommsPublicKey::default())
.add_msg_hash(key, &CommsPublicKey::default())
.unwrap();
assert_eq!(num_hits, 1);
}
// Try to re-insert all; all hashes should have incremented their hit count
for key in &signatures {
let num_hits = actor
.msg_hash_dedup_cache
.add_body_hash(key.clone(), &CommsPublicKey::default())
.add_msg_hash(key, &CommsPublicKey::default())
.unwrap();
assert_eq!(num_hits, 2);
}
Expand Down
13 changes: 6 additions & 7 deletions comms/dht/src/dedup/dedup_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use chrono::{NaiveDateTime, Utc};
use diesel::{dsl, result::DatabaseErrorKind, sql_types, ExpressionMethods, OptionalExtension, QueryDsl, RunQueryDsl};
use log::*;
use tari_comms::types::CommsPublicKey;
use tari_crypto::tari_utilities::hex::Hex;
use tari_crypto::tari_utilities::hex::to_hex;
use tari_utilities::hex::Hex;

use crate::{
schema::dedup_cache,
Expand Down Expand Up @@ -59,9 +60,8 @@ impl DedupCacheDatabase {

/// Adds the body hash to the cache, returning the number of hits (inclusive) that have been recorded for this body
/// hash
#[allow(clippy::needless_pass_by_value)]
pub fn add_body_hash(&self, body_hash: Vec<u8>, public_key: &CommsPublicKey) -> Result<u32, StorageError> {
let hit_count = self.insert_body_hash_or_update_stats(&body_hash.to_hex(), &public_key.to_hex())?;
pub fn add_msg_hash(&self, msg_hash: &[u8], public_key: &CommsPublicKey) -> Result<u32, StorageError> {
let hit_count = self.insert_body_hash_or_update_stats(&to_hex(msg_hash), &public_key.to_hex())?;

if hit_count == 0 {
warn!(
Expand All @@ -72,12 +72,11 @@ impl DedupCacheDatabase {
Ok(hit_count)
}

#[allow(clippy::needless_pass_by_value)]
pub fn get_hit_count(&self, body_hash: Vec<u8>) -> Result<u32, StorageError> {
pub fn get_hit_count(&self, body_hash: &[u8]) -> Result<u32, StorageError> {
let conn = self.connection.get_pooled_connection()?;
let hit_count = dedup_cache::table
.select(dedup_cache::number_of_hits)
.filter(dedup_cache::body_hash.eq(&body_hash.to_hex()))
.filter(dedup_cache::body_hash.eq(&to_hex(body_hash)))
.get_result::<i32>(&conn)
.optional()?;

Expand Down
16 changes: 14 additions & 2 deletions comms/dht/src/dedup/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,28 @@ mod dedup_cache;
use std::task::Poll;

pub use dedup_cache::DedupCacheDatabase;
use digest::Digest;
use futures::{future::BoxFuture, task::Context};
use log::*;
use tari_comms::pipeline::PipelineError;
use tari_comms::{pipeline::PipelineError, types::Challenge};
use tari_utilities::hex::Hex;
use tower::{layer::Layer, Service, ServiceExt};

use crate::{actor::DhtRequester, inbound::DecryptedDhtMessage};
use crate::{
actor::DhtRequester,
inbound::{DecryptedDhtMessage, DhtInboundMessage},
};

const LOG_TARGET: &str = "comms::dht::dedup";

pub fn hash_inbound_message(msg: &DhtInboundMessage) -> [u8; 32] {
create_message_hash(&msg.dht_header.origin_mac, &msg.body)
}

pub fn create_message_hash(origin_mac: &[u8], body: &[u8]) -> [u8; 32] {
Challenge::new().chain(origin_mac).chain(&body).finalize().into()
}

/// # DHT Deduplication middleware
///
/// Takes in a `DecryptedDhtMessage` and checks the message signature cache for duplicates.
Expand Down
21 changes: 8 additions & 13 deletions comms/dht/src/inbound/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,17 @@ use std::{
sync::Arc,
};

use digest::Digest;
use tari_comms::{
message::{EnvelopeBody, MessageTag},
peer_manager::Peer,
types::{Challenge, CommsPublicKey},
types::CommsPublicKey,
};
use tari_utilities::ByteArray;

use crate::envelope::{DhtMessageFlags, DhtMessageHeader};

fn hash_inbound_message(message: &DhtInboundMessage) -> Vec<u8> {
Challenge::new()
.chain(&message.dht_header.origin_mac)
.chain(&message.body)
.finalize()
.to_vec()
}
use crate::{
dedup,
envelope::{DhtMessageFlags, DhtMessageHeader},
};

#[derive(Debug, Clone)]
pub struct DhtInboundMessage {
Expand Down Expand Up @@ -116,7 +111,7 @@ impl DecryptedDhtMessage {
message: DhtInboundMessage,
) -> Self {
Self {
dedup_hash: hash_inbound_message(&message),
dedup_hash: dedup::hash_inbound_message(&message).to_vec(),
tag: message.tag,
source_peer: message.source_peer,
authenticated_origin,
Expand All @@ -131,7 +126,7 @@ impl DecryptedDhtMessage {

pub fn failed(message: DhtInboundMessage) -> Self {
Self {
dedup_hash: hash_inbound_message(&message),
dedup_hash: dedup::hash_inbound_message(&message).to_vec(),
tag: message.tag,
source_peer: message.source_peer,
authenticated_origin: None,
Expand Down
13 changes: 6 additions & 7 deletions comms/dht/src/outbound/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use std::{sync::Arc, task::Poll};

use bytes::Bytes;
use chrono::{DateTime, Utc};
use digest::Digest;
use futures::{
future,
future::BoxFuture,
Expand Down Expand Up @@ -53,6 +52,7 @@ use crate::{
actor::DhtRequester,
broadcast_strategy::BroadcastStrategy,
crypt,
dedup,
discovery::DhtDiscoveryRequester,
envelope::{datetime_to_epochtime, datetime_to_timestamp, DhtMessageFlags, DhtMessageHeader, NodeDestination},
outbound::{
Expand Down Expand Up @@ -429,8 +429,8 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
)?;

if is_broadcast {
self.add_to_dedup_cache(&body, self.node_identity.public_key().clone())
.await?;
let hash = dedup::create_message_hash(origin_mac.as_deref().unwrap_or(&[]), &body);
self.add_to_dedup_cache(hash).await?;
}

// Construct a DhtOutboundMessage for each recipient
Expand Down Expand Up @@ -461,8 +461,7 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
Ok(messages.unzip())
}

async fn add_to_dedup_cache(&mut self, body: &[u8], public_key: CommsPublicKey) -> Result<(), DhtOutboundError> {
let hash = Challenge::new().chain(&body).finalize().to_vec();
async fn add_to_dedup_cache(&mut self, hash: [u8; 32]) -> Result<(), DhtOutboundError> {
trace!(
target: LOG_TARGET,
"Dedup added message hash {} to cache for message",
Expand All @@ -472,12 +471,12 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
// Do not count messages we've broadcast towards the total hit count
let hit_count = self
.dht_requester
.get_message_cache_hit_count(hash.clone())
.get_message_cache_hit_count(hash.to_vec())
.await
.map_err(|err| DhtOutboundError::FailedToInsertMessageHash(err.to_string()))?;
if hit_count == 0 {
self.dht_requester
.add_message_to_dedup_cache(hash, public_key)
.add_message_to_dedup_cache(hash.to_vec(), self.node_identity.public_key().clone())
.await
.map_err(|err| DhtOutboundError::FailedToInsertMessageHash(err.to_string()))?;
}
Expand Down
7 changes: 4 additions & 3 deletions comms/dht/src/store_forward/database/stored_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
use std::convert::TryInto;

use chrono::NaiveDateTime;
use digest::Digest;
use tari_comms::{message::MessageExt, types::Challenge};
use tari_comms::message::MessageExt;
use tari_utilities::{hex, hex::Hex};

use crate::{
dedup,
inbound::DecryptedDhtMessage,
proto::envelope::DhtHeader,
schema::stored_messages,
Expand Down Expand Up @@ -62,6 +62,7 @@ impl NewStoredMessage {
Ok(envelope_body) => envelope_body.to_encoded_bytes(),
Err(encrypted_body) => encrypted_body,
};
let body_hash = hex::to_hex(&dedup::create_message_hash(&dht_header.origin_mac, &body));

Some(Self {
version: dht_header.version.as_major().try_into().ok()?,
Expand All @@ -75,7 +76,7 @@ impl NewStoredMessage {
let dht_header: DhtHeader = dht_header.into();
dht_header.to_encoded_bytes()
},
body_hash: hex::to_hex(&Challenge::new().chain(body.clone()).finalize()),
body_hash,
body,
})
}
Expand Down
28 changes: 21 additions & 7 deletions comms/dht/src/store_forward/saf_handler/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
use std::{convert::TryInto, sync::Arc};

use chrono::{DateTime, NaiveDateTime, Utc};
use digest::Digest;
use futures::{future, stream, StreamExt};
use log::*;
use prost::Message;
Expand All @@ -41,6 +40,7 @@ use tower::{Service, ServiceExt};
use crate::{
actor::DhtRequester,
crypt,
dedup,
envelope::{timestamp_to_datetime, DhtMessageFlags, DhtMessageHeader, NodeDestination},
inbound::{DecryptedDhtMessage, DhtInboundMessage},
outbound::{OutboundMessageRequester, SendMessageParams},
Expand Down Expand Up @@ -445,6 +445,15 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
return Err(StoreAndForwardError::StoredAtWasInFuture);
}

let msg_hash = dedup::create_message_hash(
message
.dht_header
.as_ref()
.map(|h| h.origin_mac.as_slice())
.unwrap_or(&[]),
&message.body,
);

let dht_header: DhtMessageHeader = message
.dht_header
.expect("previously checked")
Expand Down Expand Up @@ -478,13 +487,19 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>

// Check that the destination is either undisclosed, for us or for our network region
Self::check_destination(config, peer_manager, node_identity, &dht_header).await?;
// Check that the message has not already been received.
Self::check_duplicate(&mut self.dht_requester, &message.body, source_peer.public_key.clone()).await?;

// Attempt to decrypt the message (if applicable), and deserialize it
let (authenticated_pk, decrypted_body) =
Self::authenticate_and_decrypt_if_required(node_identity, &dht_header, &message.body)?;

// Check that the message has not already been received.
Self::check_duplicate(
&mut self.dht_requester,
msg_hash.to_vec(),
source_peer.public_key.clone(),
)
.await?;

let mut inbound_msg =
DhtInboundMessage::new(MessageTag::new(), dht_header, Arc::clone(&source_peer), message.body);
inbound_msg.is_saf_message = true;
Expand All @@ -497,10 +512,9 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>

async fn check_duplicate(
dht_requester: &mut DhtRequester,
body: &[u8],
msg_hash: Vec<u8>,
public_key: CommsPublicKey,
) -> Result<(), StoreAndForwardError> {
let msg_hash = Challenge::new().chain(body).finalize().to_vec();
let hit_count = dht_requester.add_message_to_dedup_cache(msg_hash, public_key).await?;
if hit_count > 1 {
Err(StoreAndForwardError::DuplicateMessage)
Expand Down Expand Up @@ -642,8 +656,8 @@ mod test {
dht_header: DhtMessageHeader,
stored_at: NaiveDateTime,
) -> StoredMessage {
let msg_hash = hex::to_hex(&dedup::create_message_hash(&dht_header.origin_mac, message.as_bytes()));
let body = message.into_bytes();
let body_hash = hex::to_hex(&Challenge::new().chain(&body).finalize());
StoredMessage {
id: 1,
version: 0,
Expand All @@ -656,7 +670,7 @@ mod test {
is_encrypted: false,
priority: StoredMessagePriority::High as i32,
stored_at,
body_hash,
body_hash: msg_hash,
}
}

Expand Down
5 changes: 1 addition & 4 deletions comms/dht/src/test_utils/store_and_forward_mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@ use std::{
};

use chrono::Utc;
use digest::Digest;
use log::*;
use rand::{rngs::OsRng, RngCore};
use tari_comms::types::Challenge;
use tari_utilities::hex;
use tokio::{
runtime,
sync::{mpsc, RwLock},
Expand Down Expand Up @@ -150,7 +147,7 @@ impl StoreAndForwardMock {
is_encrypted: msg.is_encrypted,
priority: msg.priority,
stored_at: Utc::now().naive_utc(),
body_hash: hex::to_hex(&Challenge::new().chain(msg.body).finalize()),
body_hash: msg.body_hash,
});
reply_tx.send(Ok(false)).unwrap();
},
Expand Down