diff --git a/src/udp/connection/connection_id_data.rs b/src/udp/connection/connection_id_data.rs index 3acb80e3..86425bec 100644 --- a/src/udp/connection/connection_id_data.rs +++ b/src/udp/connection/connection_id_data.rs @@ -8,35 +8,13 @@ pub struct ConnectionIdData { } impl ConnectionIdData { - pub fn from_bytes(bytes: &[u8; 8]) -> Self { - let client_id = Self::extract_client_id(bytes); - let expiration_timestamp = Self::extract_timestamp(bytes); - Self { - client_id, - expiration_timestamp - } + pub fn client_id(&self) -> &ClientId { + &self.client_id } - pub fn to_bytes(&self) -> [u8; 8] { - let connection_id: Vec = [ - self.client_id.to_bytes().as_slice(), - self.expiration_timestamp.to_le_bytes().as_slice(), - ].concat(); - - let connection_as_array: [u8; 8] = connection_id.try_into().unwrap(); - - connection_as_array + pub fn expiration_timestamp(&self) -> &Timestamp32 { + &self.expiration_timestamp } - - fn extract_timestamp(decrypted_connection_id: &[u8; 8]) -> Timestamp32 { - let timestamp_bytes = &decrypted_connection_id[4..]; - let timestamp = Timestamp32::from_le_bytes(timestamp_bytes); - timestamp - } - - fn extract_client_id(decrypted_connection_id: &[u8; 8]) -> ClientId { - ClientId::from_bytes(&decrypted_connection_id[..4]) - } } #[cfg(test)] @@ -65,28 +43,4 @@ mod tests { assert_eq!(connection_id.expiration_timestamp, 0u32.into()); } - - #[test] - fn it_should_be_converted_to_a_byte_array() { - - let connection_id = ConnectionIdData { - client_id: ClientId::from_bytes(&[0u8; 4]), - expiration_timestamp: (u32::MAX).into(), - }; - - assert_eq!(connection_id.to_bytes(), [0, 0, 0, 0, 255, 255, 255, 255]); - } - - #[test] - fn it_should_be_instantiated_from_a_byte_array() { - - let connection_id = ConnectionIdData::from_bytes(&[0, 0, 0, 0, 255, 255, 255, 255]); - - let expected_connection_id = ConnectionIdData { - client_id: ClientId::from_bytes(&[0, 0, 0, 0]), - expiration_timestamp: (u32::MAX).into(), - }; - - assert_eq!(connection_id, expected_connection_id); - } } diff --git a/src/udp/connection/connection_id_issuer.rs b/src/udp/connection/connection_id_issuer.rs index a52173fe..5b6f4c4c 100644 --- a/src/udp/connection/connection_id_issuer.rs +++ b/src/udp/connection/connection_id_issuer.rs @@ -2,11 +2,9 @@ use std::{net::SocketAddr, collections::hash_map::DefaultHasher}; use aquatic_udp_protocol::ConnectionId; -use super::{cypher::{BlowfishCypher, Cypher}, secret::Secret, timestamp_64::Timestamp64, client_id::Make, timestamp_32::Timestamp32, connection_id_data::ConnectionIdData, encrypted_connection_id_data::EncryptedConnectionIdData}; +use super::{cypher::{BlowfishCypher, Cypher}, secret::Secret, timestamp_64::Timestamp64, client_id::Make, timestamp_32::Timestamp32, connection_id_data::{ConnectionIdData}, encrypted_connection_id_data::EncryptedConnectionIdData, encoded_connection_id_data::EncodedConnectionIdData}; pub trait ConnectionIdIssuer { - type Error; - fn new_connection_id(&self, remote_address: &SocketAddr, current_timestamp: Timestamp64) -> ConnectionId; fn is_connection_id_valid(&self, connection_id: &ConnectionId, remote_address: &SocketAddr, current_timestamp: Timestamp64) -> bool; @@ -18,13 +16,13 @@ pub struct EncryptedConnectionIdIssuer { } impl ConnectionIdIssuer for EncryptedConnectionIdIssuer { - type Error = &'static str; - fn new_connection_id(&self, remote_address: &SocketAddr, current_timestamp: Timestamp64) -> ConnectionId { let connection_id_data = self.generate_connection_id_data(&remote_address, current_timestamp); - let encrypted_connection_id_data = self.encrypt_connection_id_data(&connection_id_data); + let encoded_connection_id_data: EncodedConnectionIdData = connection_id_data.into(); + + let encrypted_connection_id_data = self.encrypt_connection_id_data(&encoded_connection_id_data); self.pack_connection_id(encrypted_connection_id_data) } @@ -81,15 +79,15 @@ impl EncryptedConnectionIdIssuer { fn decrypt_connection_id_data(&self, encrypted_connection_id_data: &EncryptedConnectionIdData) -> ConnectionIdData { let decrypted_raw_data = self.cypher.decrypt(&encrypted_connection_id_data.bytes); - let connection_id_data = ConnectionIdData::from_bytes(&decrypted_raw_data); + let encoded_connection_id_data = EncodedConnectionIdData::from_bytes(&decrypted_raw_data); + + let connection_id_data: ConnectionIdData = encoded_connection_id_data.into(); connection_id_data } - fn encrypt_connection_id_data(&self, connection_id_data: &ConnectionIdData) -> EncryptedConnectionIdData { - let decrypted_raw_data = connection_id_data.to_bytes(); - - let encrypted_raw_data = self.cypher.encrypt(&decrypted_raw_data); + fn encrypt_connection_id_data(&self, encoded_connection_id_data: &EncodedConnectionIdData) -> EncryptedConnectionIdData { + let encrypted_raw_data = self.cypher.encrypt(&encoded_connection_id_data.as_bytes()); let encrypted_connection_id_data = EncryptedConnectionIdData::from_encrypted_bytes(&encrypted_raw_data); diff --git a/src/udp/connection/encoded_connection_id_data.rs b/src/udp/connection/encoded_connection_id_data.rs new file mode 100644 index 00000000..185de13d --- /dev/null +++ b/src/udp/connection/encoded_connection_id_data.rs @@ -0,0 +1,46 @@ +use super::{client_id::ClientId, timestamp_32::Timestamp32, connection_id_data::ConnectionIdData}; + +/// The encoded version of ConnectionIdData to be use in the UPD tracker package field "connection_id" +pub struct EncodedConnectionIdData([u8; 8]); + +impl EncodedConnectionIdData { + pub fn from_bytes(bytes: &[u8; 8]) -> Self { + let mut sized_bytes_arr = [0u8; 8]; + sized_bytes_arr.copy_from_slice(&bytes[..8]); + Self(sized_bytes_arr) + } + + pub fn as_bytes(&self) -> &[u8; 8] { + &self.0 + } + + fn extract_client_id(&self) -> ClientId { + ClientId::from_bytes(&self.0[..4]) + } + + fn extract_expiration_timestamp(&self) -> Timestamp32 { + let timestamp_bytes = &self.0[4..]; + let timestamp = Timestamp32::from_le_bytes(timestamp_bytes); + timestamp + } +} + +impl From for ConnectionIdData { + fn from(encoded_connection_id_data: EncodedConnectionIdData) -> Self { + Self { + client_id: encoded_connection_id_data.extract_client_id(), + expiration_timestamp: encoded_connection_id_data.extract_expiration_timestamp() + } + } +} + +impl From for EncodedConnectionIdData { + fn from(connection_id_data: ConnectionIdData) -> Self { + let byte_vec: Vec = [ + connection_id_data.client_id.to_bytes().as_slice(), + connection_id_data.expiration_timestamp.to_le_bytes().as_slice(), + ].concat(); + let bytes: [u8; 8] = byte_vec.try_into().unwrap(); + EncodedConnectionIdData::from_bytes(&bytes) + } +} diff --git a/src/udp/connection/mod.rs b/src/udp/connection/mod.rs index d27a8268..294db3a8 100644 --- a/src/udp/connection/mod.rs +++ b/src/udp/connection/mod.rs @@ -89,3 +89,4 @@ pub mod cypher; pub mod connection_id_issuer; pub mod connection_id_data; pub mod encrypted_connection_id_data; +pub mod encoded_connection_id_data;