diff --git a/comms/dht/src/crypt.rs b/comms/dht/src/crypt.rs index 73c2906a079..9c7a92c911b 100644 --- a/comms/dht/src/crypt.rs +++ b/comms/dht/src/crypt.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{iter, mem::size_of}; +use std::{convert::TryFrom, iter, mem::size_of}; use chacha20::{ cipher::{NewCipher, StreamCipher}, @@ -204,23 +204,27 @@ pub fn encrypt_message(message_key: &CommsMessageKey, plain_text: &mut BytesMut) } /// Encodes a prost Message, efficiently prepending the little-endian 32-bit length to the encoding -fn encode_with_prepended_length(msg: &T, additional_prefix_space: usize) -> BytesMut { +fn encode_with_prepended_length( + msg: &T, + additional_prefix_space: usize, +) -> Result { let len = msg.encoded_len(); let mut buf = BytesMut::with_capacity(size_of::() + additional_prefix_space + len); buf.extend(iter::repeat(0).take(additional_prefix_space)); - buf.put_u32_le(len as u32); - msg.encode(&mut buf).expect( - "prost::Message::encode documentation says it is infallible unless the buffer has insufficient capacity. This \ - buffer's capacity was set with encoded_len", + buf.put_u32_le( + u32::try_from(len).map_err(|_| DhtEncryptError::PaddingError(String::from("Message is too large to pad")))?, ); - buf + msg.encode(&mut buf) + .map_err(|_| DhtEncryptError::PaddingError(String::from("Unable to pad message")))?; + + Ok(buf) } -pub fn prepare_message(is_encrypted: bool, message: &T) -> BytesMut { +pub fn prepare_message(is_encrypted: bool, message: &T) -> Result { if is_encrypted { encode_with_prepended_length(message, size_of::()) } else { - message.encode_into_bytes_mut() + Ok(message.encode_into_bytes_mut()) } } @@ -304,7 +308,7 @@ mod test { fn encrypt_decrypt() { let key = CommsMessageKey::from(SafeArray::default()); let plain_text = "Last enemy position 0830h AJ 9863".to_string(); - let mut msg = prepare_message(true, &plain_text); + let mut msg = prepare_message(true, &plain_text).unwrap(); encrypt_message(&key, &mut msg).unwrap(); decrypt_message(&key, &mut msg).unwrap(); assert_eq!(String::decode(&msg[..]).unwrap(), plain_text); @@ -433,7 +437,7 @@ mod test { assert_eq!(pad, pad_message[message.len()..]); // test for large message - let message = encode_with_prepended_length(&vec![100u8; MESSAGE_BASE_LENGTH * 8 - 100], 0); + let message = encode_with_prepended_length(&vec![100u8; MESSAGE_BASE_LENGTH * 8 - 100], 0).unwrap(); let mut pad_message = message.clone(); pad_message_to_base_length_multiple(&mut pad_message, 0).unwrap(); let pad = iter::repeat(0u8) @@ -448,7 +452,7 @@ mod test { assert_eq!(pad, pad_message[message.len()..]); // test for base message of multiple base length - let message = encode_with_prepended_length(&vec![100u8; MESSAGE_BASE_LENGTH * 9 - 123], 0); + let message = encode_with_prepended_length(&vec![100u8; MESSAGE_BASE_LENGTH * 9 - 123], 0).unwrap(); let pad = std::iter::repeat(0u8) .take((9 * MESSAGE_BASE_LENGTH) - message.len()) .collect::>(); @@ -464,7 +468,7 @@ mod test { assert_eq!(pad, pad_message[message.len()..]); // test for empty message - let message = encode_with_prepended_length(&vec![], 0); + let message = encode_with_prepended_length(&vec![], 0).unwrap(); let mut pad_message = message.clone(); pad_message_to_base_length_multiple(&mut pad_message, 0).unwrap(); let pad = [0u8; MESSAGE_BASE_LENGTH - 4]; @@ -506,7 +510,7 @@ mod test { fn get_original_message_from_padded_text_successful() { // test for short message let message = vec![0u8, 10, 22, 11, 38, 74, 59, 91, 73, 82, 75, 23, 59]; - let mut pad_message = encode_with_prepended_length(&message, 0); + let mut pad_message = encode_with_prepended_length(&message, 0).unwrap(); pad_message_to_base_length_multiple(&mut pad_message, 0).unwrap(); // @@ -516,7 +520,7 @@ mod test { // test for large message let message = vec![100u8; 1024]; - let mut pad_message = encode_with_prepended_length(&message, 0); + let mut pad_message = encode_with_prepended_length(&message, 0).unwrap(); pad_message_to_base_length_multiple(&mut pad_message, 0).unwrap(); let mut output_message = pad_message.clone(); @@ -525,7 +529,7 @@ mod test { // test for base message of base length let message = vec![100u8; 984]; - let mut pad_message = encode_with_prepended_length(&message, 0); + let mut pad_message = encode_with_prepended_length(&message, 0).unwrap(); pad_message_to_base_length_multiple(&mut pad_message, 0).unwrap(); let mut output_message = pad_message.clone(); @@ -534,7 +538,7 @@ mod test { // test for empty message let message: Vec = vec![]; - let mut pad_message = encode_with_prepended_length(&message, 0); + let mut pad_message = encode_with_prepended_length(&message, 0).unwrap(); pad_message_to_base_length_multiple(&mut pad_message, 0).unwrap(); let mut output_message = pad_message.clone(); @@ -545,7 +549,7 @@ mod test { #[test] fn padding_fails_if_pad_message_prepend_length_is_bigger_than_plaintext_length() { let message = "This is my secret message, keep it secret !".as_bytes().to_vec(); - let mut pad_message = encode_with_prepended_length(&message, 0); + let mut pad_message = encode_with_prepended_length(&message, 0).unwrap(); pad_message_to_base_length_multiple(&mut pad_message, 0).unwrap(); let mut pad_message = pad_message.to_vec(); @@ -576,7 +580,7 @@ mod test { // in any way the value of the decrypted content, by applying a cipher stream let key = CommsMessageKey::from(SafeArray::default()); let message = "My secret message, keep it secret !".to_string(); - let mut msg = encode_with_prepended_length(&message, size_of::()); + let mut msg = encode_with_prepended_length(&message, size_of::()).unwrap(); encrypt_message(&key, &mut msg).unwrap(); let n = msg.len(); @@ -590,7 +594,7 @@ mod test { fn decryption_fails_if_message_body_is_modified() { let key = CommsMessageKey::from(SafeArray::default()); let message = "My secret message, keep it secret !".to_string(); - let mut msg = encode_with_prepended_length(&message, size_of::()); + let mut msg = encode_with_prepended_length(&message, size_of::()).unwrap(); encrypt_message(&key, &mut msg).unwrap(); let index = size_of::() + size_of::() + 1; diff --git a/comms/dht/src/outbound/requester.rs b/comms/dht/src/outbound/requester.rs index 0ac3e2e619e..086eec3ff27 100644 --- a/comms/dht/src/outbound/requester.rs +++ b/comms/dht/src/outbound/requester.rs @@ -261,7 +261,8 @@ impl OutboundMessageRequester { message.to_propagation_header() }; let msg = wrap_in_envelope_body!(header, message.into_inner()); - let body = prepare_message(params.encryption.is_encrypt(), &msg); + let body = prepare_message(params.encryption.is_encrypt(), &msg) + .map_err(|_| DhtOutboundError::PaddingError(String::from("Unable to pad message")))?; self.send_raw(params, body).await } @@ -278,7 +279,8 @@ impl OutboundMessageRequester { trace!(target: LOG_TARGET, "Send Message: {} {:?}", params, message); } let msg = wrap_in_envelope_body!(message); - let body = prepare_message(params.encryption.is_encrypt(), &msg); + let body = prepare_message(params.encryption.is_encrypt(), &msg) + .map_err(|_| DhtOutboundError::PaddingError(String::from("Unable to pad message")))?; self.send_raw(params, body).await } @@ -295,7 +297,8 @@ impl OutboundMessageRequester { trace!(target: LOG_TARGET, "Send Message: {} {:?}", params, message); } let msg = wrap_in_envelope_body!(message); - let body = prepare_message(params.encryption.is_encrypt(), &msg); + let body = prepare_message(params.encryption.is_encrypt(), &msg) + .map_err(|_| DhtOutboundError::PaddingError(String::from("Unable to pad message")))?; self.send_raw_no_wait(params, body).await } diff --git a/comms/dht/src/test_utils/makers.rs b/comms/dht/src/test_utils/makers.rs index 9e3428cfe06..2b62155de6f 100644 --- a/comms/dht/src/test_utils/makers.rs +++ b/comms/dht/src/test_utils/makers.rs @@ -202,11 +202,11 @@ pub fn make_dht_envelope( let message = if flags.is_encrypted() { let shared_secret = CommsDHKE::new(&e_secret_key, node_identity.public_key()); let key_message = crypt::generate_key_message(&shared_secret); - let mut message = prepare_message(true, message); + let mut message = prepare_message(true, message).unwrap(); crypt::encrypt_message(&key_message, &mut message).unwrap(); message.freeze() } else { - prepare_message(false, message).freeze() + prepare_message(false, message).unwrap().freeze() }; let header = make_dht_header( node_identity,