From 7dc48387920510427f63e964e4dbf29b56d3cf87 Mon Sep 17 00:00:00 2001 From: Jose Celano Date: Tue, 7 Feb 2023 10:08:28 +0000 Subject: [PATCH] refactor(http): [#160] extract functions for percent decoding --- src/http/filters.rs | 22 +++--- src/http/mod.rs | 1 + src/http/percent_encoding.rs | 66 ++++++++++++++++ src/protocol/info_hash.rs | 71 +++++++++++++++++ src/tracker/peer.rs | 134 ++++++++++++++++++++++++++++++++ tests/http/bencode.rs | 15 ---- tests/http/mod.rs | 23 +++++- tests/http/requests/announce.rs | 7 +- tests/http/requests/scrape.rs | 5 +- tests/http/responses/scrape.rs | 2 +- tests/http_tracker.rs | 67 +++------------- 11 files changed, 318 insertions(+), 95 deletions(-) create mode 100644 src/http/percent_encoding.rs delete mode 100644 tests/http/bencode.rs diff --git a/src/http/filters.rs b/src/http/filters.rs index 2760c995..e02eac52 100644 --- a/src/http/filters.rs +++ b/src/http/filters.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use warp::{reject, Filter, Rejection}; use super::error::Error; +use super::percent_encoding::{percent_decode_info_hash, percent_decode_peer_id}; use super::{request, WebResult}; use crate::protocol::common::MAX_SCRAPE_TORRENTS; use crate::protocol::info_hash::InfoHash; @@ -78,9 +79,11 @@ fn info_hashes(raw_query: &String) -> WebResult> { for v in split_raw_query { if v.contains("info_hash") { + // get raw percent encoded infohash let raw_info_hash = v.split('=').collect::>()[1]; - let info_hash_bytes = percent_encoding::percent_decode_str(raw_info_hash).collect::>(); - let info_hash = InfoHash::from_str(&hex::encode(info_hash_bytes)); + + let info_hash = percent_decode_info_hash(raw_info_hash); + if let Ok(ih) = info_hash { info_hashes.push(ih); } @@ -112,24 +115,17 @@ fn peer_id(raw_query: &String) -> WebResult { for v in split_raw_query { // look for the peer_id param if v.contains("peer_id") { - // get raw percent_encoded peer_id + // get raw percent encoded peer id let raw_peer_id = v.split('=').collect::>()[1]; - // decode peer_id - let peer_id_bytes = percent_encoding::percent_decode_str(raw_peer_id).collect::>(); - - // peer_id must be 20 bytes - if peer_id_bytes.len() != 20 { + if let Ok(id) = percent_decode_peer_id(raw_peer_id) { + peer_id = Some(id); + } else { return Err(reject::custom(Error::InvalidPeerId { location: Location::caller(), })); } - // clone peer_id_bytes into fixed length array - let mut byte_arr: [u8; 20] = Default::default(); - byte_arr.clone_from_slice(peer_id_bytes.as_slice()); - - peer_id = Some(peer::Id(byte_arr)); break; } } diff --git a/src/http/mod.rs b/src/http/mod.rs index 9cd21aab..15f7abb5 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -15,6 +15,7 @@ pub mod axum; pub mod error; pub mod filters; pub mod handlers; +pub mod percent_encoding; pub mod request; pub mod response; pub mod routes; diff --git a/src/http/percent_encoding.rs b/src/http/percent_encoding.rs new file mode 100644 index 00000000..9b5b79ed --- /dev/null +++ b/src/http/percent_encoding.rs @@ -0,0 +1,66 @@ +use crate::protocol::info_hash::{ConversionError, InfoHash}; +use crate::tracker::peer::{self, IdConversionError}; + +/// # Errors +/// +/// Will return `Err` if if the decoded bytes do not represent a valid `InfoHash`. +pub fn percent_decode_info_hash(raw_info_hash: &str) -> Result { + let bytes = percent_encoding::percent_decode_str(raw_info_hash).collect::>(); + InfoHash::try_from(bytes) +} + +/// # Errors +/// +/// Will return `Err` if if the decoded bytes do not represent a valid `peer::Id`. +pub fn percent_decode_peer_id(raw_peer_id: &str) -> Result { + let bytes = percent_encoding::percent_decode_str(raw_peer_id).collect::>(); + peer::Id::try_from(bytes) +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use crate::http::percent_encoding::{percent_decode_info_hash, percent_decode_peer_id}; + use crate::protocol::info_hash::InfoHash; + use crate::tracker::peer; + + #[test] + fn it_should_decode_a_percent_encoded_info_hash() { + let encoded_infohash = "%3B%24U%04%CF%5F%11%BB%DB%E1%20%1C%EAjk%F4Z%EE%1B%C0"; + + let info_hash = percent_decode_info_hash(encoded_infohash).unwrap(); + + assert_eq!( + info_hash, + InfoHash::from_str("3b245504cf5f11bbdbe1201cea6a6bf45aee1bc0").unwrap() + ); + } + + #[test] + fn it_should_fail_decoding_an_invalid_percent_encoded_info_hash() { + let invalid_encoded_infohash = "invalid percent-encoded infohash"; + + let info_hash = percent_decode_info_hash(invalid_encoded_infohash); + + assert!(info_hash.is_err()); + } + + #[test] + fn it_should_decode_a_percent_encoded_peer_id() { + let encoded_peer_id = "%2DqB00000000000000000"; + + let peer_id = percent_decode_peer_id(encoded_peer_id).unwrap(); + + assert_eq!(peer_id, peer::Id(*b"-qB00000000000000000")); + } + + #[test] + fn it_should_fail_decoding_an_invalid_percent_encoded_peer_id() { + let invalid_encoded_peer_id = "invalid percent-encoded peer id"; + + let peer_id = percent_decode_peer_id(invalid_encoded_peer_id); + + assert!(peer_id.is_err()); + } +} diff --git a/src/protocol/info_hash.rs b/src/protocol/info_hash.rs index 83a595c1..32063672 100644 --- a/src/protocol/info_hash.rs +++ b/src/protocol/info_hash.rs @@ -1,7 +1,24 @@ +use std::panic::Location; + +use thiserror::Error; + #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] pub struct InfoHash(pub [u8; 20]); +const INFO_HASH_BYTES_LEN: usize = 20; + impl InfoHash { + /// # Panics + /// + /// Will panic if byte slice does not contains the exact amount of bytes need for the `InfoHash`. + #[must_use] + pub fn from_bytes(bytes: &[u8]) -> Self { + assert_eq!(bytes.len(), INFO_HASH_BYTES_LEN); + let mut ret = Self([0u8; INFO_HASH_BYTES_LEN]); + ret.0.clone_from_slice(bytes); + ret + } + /// For readability, when accessing the bytes array #[must_use] pub fn bytes(&self) -> [u8; 20] { @@ -57,6 +74,40 @@ impl std::convert::From<[u8; 20]> for InfoHash { } } +#[derive(Error, Debug)] +pub enum ConversionError { + #[error("not enough bytes for infohash: {message} {location}")] + NotEnoughBytes { + location: &'static Location<'static>, + message: String, + }, + #[error("too many bytes for infohash: {message} {location}")] + TooManyBytes { + location: &'static Location<'static>, + message: String, + }, +} + +impl TryFrom> for InfoHash { + type Error = ConversionError; + + fn try_from(bytes: Vec) -> Result { + if bytes.len() < INFO_HASH_BYTES_LEN { + return Err(ConversionError::NotEnoughBytes { + location: Location::caller(), + message: format! {"got {} bytes, expected {}", bytes.len(), INFO_HASH_BYTES_LEN}, + }); + } + if bytes.len() > INFO_HASH_BYTES_LEN { + return Err(ConversionError::TooManyBytes { + location: Location::caller(), + message: format! {"got {} bytes, expected {}", bytes.len(), INFO_HASH_BYTES_LEN}, + }); + } + Ok(Self::from_bytes(&bytes)) + } +} + impl serde::ser::Serialize for InfoHash { fn serialize(&self, serializer: S) -> Result { let mut buffer = [0u8; 40]; @@ -166,6 +217,26 @@ mod tests { ); } + #[test] + fn an_info_hash_can_be_created_from_a_byte_vector() { + let info_hash: InfoHash = [255u8; 20].to_vec().try_into().unwrap(); + + assert_eq!( + info_hash, + InfoHash::from_str("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF").unwrap() + ); + } + + #[test] + fn it_should_fail_trying_to_create_an_info_hash_from_a_byte_vector_with_less_than_20_bytes() { + assert!(InfoHash::try_from([255u8; 19].to_vec()).is_err()); + } + + #[test] + fn it_should_fail_trying_to_create_an_info_hash_from_a_byte_vector_with_more_than_20_bytes() { + assert!(InfoHash::try_from([255u8; 21].to_vec()).is_err()); + } + #[test] fn an_info_hash_can_be_serialized() { let s = ContainingInfoHash { diff --git a/src/tracker/peer.rs b/src/tracker/peer.rs index 3f639f97..16c96e04 100644 --- a/src/tracker/peer.rs +++ b/src/tracker/peer.rs @@ -1,8 +1,10 @@ use std::net::{IpAddr, SocketAddr}; +use std::panic::Location; use aquatic_udp_protocol::{AnnounceEvent, NumberOfBytes}; use serde; use serde::Serialize; +use thiserror::Error; use crate::http::request::Announce; use crate::protocol::clock::{Current, DurationSinceUnixEpoch, Time}; @@ -91,6 +93,69 @@ impl Peer { #[derive(PartialEq, Eq, Hash, Clone, Debug, PartialOrd, Ord, Copy)] pub struct Id(pub [u8; 20]); +const PEER_ID_BYTES_LEN: usize = 20; + +#[derive(Error, Debug)] +pub enum IdConversionError { + #[error("not enough bytes for peer id: {message} {location}")] + NotEnoughBytes { + location: &'static Location<'static>, + message: String, + }, + #[error("too many bytes for peer id: {message} {location}")] + TooManyBytes { + location: &'static Location<'static>, + message: String, + }, +} + +impl Id { + /// # Panics + /// + /// Will panic if byte slice does not contains the exact amount of bytes need for the `Id`. + #[must_use] + pub fn from_bytes(bytes: &[u8]) -> Self { + assert_eq!(bytes.len(), PEER_ID_BYTES_LEN); + let mut ret = Self([0u8; PEER_ID_BYTES_LEN]); + ret.0.clone_from_slice(bytes); + ret + } +} + +impl From<[u8; 20]> for Id { + fn from(bytes: [u8; 20]) -> Self { + Id(bytes) + } +} + +impl TryFrom> for Id { + type Error = IdConversionError; + + fn try_from(bytes: Vec) -> Result { + if bytes.len() < PEER_ID_BYTES_LEN { + return Err(IdConversionError::NotEnoughBytes { + location: Location::caller(), + message: format! {"got {} bytes, expected {}", bytes.len(), PEER_ID_BYTES_LEN}, + }); + } + if bytes.len() > PEER_ID_BYTES_LEN { + return Err(IdConversionError::TooManyBytes { + location: Location::caller(), + message: format! {"got {} bytes, expected {}", bytes.len(), PEER_ID_BYTES_LEN}, + }); + } + Ok(Self::from_bytes(&bytes)) + } +} + +impl std::str::FromStr for Id { + type Err = IdConversionError; + + fn from_str(s: &str) -> Result { + Self::try_from(s.as_bytes().to_vec()) + } +} + impl std::fmt::Display for Id { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.to_hex_string() { @@ -239,6 +304,75 @@ mod test { mod torrent_peer_id { use crate::tracker::peer; + #[test] + fn should_be_instantiated_from_a_byte_slice() { + let id = peer::Id::from_bytes(&[ + 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, + ]); + + let expected_id = peer::Id([ + 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, + ]); + + assert_eq!(id, expected_id); + } + + #[test] + #[should_panic] + fn should_fail_trying_to_instantiate_from_a_byte_slice_with_less_than_20_bytes() { + let less_than_20_bytes = [0; 19]; + let _ = peer::Id::from_bytes(&less_than_20_bytes); + } + + #[test] + #[should_panic] + fn should_fail_trying_to_instantiate_from_a_byte_slice_with_more_than_20_bytes() { + let more_than_20_bytes = [0; 21]; + let _ = peer::Id::from_bytes(&more_than_20_bytes); + } + + #[test] + fn should_be_converted_from_a_20_byte_array() { + let id = peer::Id::from([ + 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, + ]); + + let expected_id = peer::Id([ + 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, + ]); + + assert_eq!(id, expected_id); + } + + #[test] + fn should_be_converted_from_a_byte_vector() { + let id = peer::Id::try_from( + [ + 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, + ] + .to_vec(), + ) + .unwrap(); + + let expected_id = peer::Id([ + 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, 0, 159, 146, 150, + ]); + + assert_eq!(id, expected_id); + } + + #[test] + #[should_panic] + fn should_fail_trying_to_convert_from_a_byte_vector_with_less_than_20_bytes() { + let _ = peer::Id::try_from([0; 19].to_vec()).unwrap(); + } + + #[test] + #[should_panic] + fn should_fail_trying_to_convert_from_a_byte_vector_with_more_than_20_bytes() { + let _ = peer::Id::try_from([0; 21].to_vec()).unwrap(); + } + #[test] fn should_be_converted_to_hex_string() { let id = peer::Id(*b"-qB00000000000000000"); diff --git a/tests/http/bencode.rs b/tests/http/bencode.rs deleted file mode 100644 index d107089c..00000000 --- a/tests/http/bencode.rs +++ /dev/null @@ -1,15 +0,0 @@ -pub type ByteArray20 = [u8; 20]; - -pub struct InfoHash(ByteArray20); - -impl InfoHash { - pub fn new(vec: &[u8]) -> Self { - let mut byte_array_20: ByteArray20 = Default::default(); - byte_array_20.clone_from_slice(vec); - Self(byte_array_20) - } - - pub fn bytes(&self) -> ByteArray20 { - self.0 - } -} diff --git a/tests/http/mod.rs b/tests/http/mod.rs index 87087026..8c1e3c99 100644 --- a/tests/http/mod.rs +++ b/tests/http/mod.rs @@ -1,7 +1,28 @@ pub mod asserts; -pub mod bencode; pub mod client; pub mod connection_info; pub mod requests; pub mod responses; pub mod server; + +use percent_encoding::NON_ALPHANUMERIC; + +pub type ByteArray20 = [u8; 20]; + +pub fn percent_encode_byte_array(bytes: &ByteArray20) -> String { + percent_encoding::percent_encode(bytes, NON_ALPHANUMERIC).to_string() +} + +pub struct InfoHash(ByteArray20); + +impl InfoHash { + pub fn new(vec: &[u8]) -> Self { + let mut byte_array_20: ByteArray20 = Default::default(); + byte_array_20.clone_from_slice(vec); + Self(byte_array_20) + } + + pub fn bytes(&self) -> ByteArray20 { + self.0 + } +} diff --git a/tests/http/requests/announce.rs b/tests/http/requests/announce.rs index a8ebc95f..87aa3425 100644 --- a/tests/http/requests/announce.rs +++ b/tests/http/requests/announce.rs @@ -2,12 +2,11 @@ use std::fmt; use std::net::{IpAddr, Ipv4Addr}; use std::str::FromStr; -use percent_encoding::NON_ALPHANUMERIC; use serde_repr::Serialize_repr; use torrust_tracker::protocol::info_hash::InfoHash; use torrust_tracker::tracker::peer::Id; -use crate::http::bencode::ByteArray20; +use crate::http::{percent_encode_byte_array, ByteArray20}; pub struct Query { pub info_hash: ByteArray20, @@ -211,11 +210,11 @@ impl QueryParams { let compact = announce_query.compact.as_ref().map(std::string::ToString::to_string); Self { - info_hash: Some(percent_encoding::percent_encode(&announce_query.info_hash, NON_ALPHANUMERIC).to_string()), + info_hash: Some(percent_encode_byte_array(&announce_query.info_hash)), peer_addr: Some(announce_query.peer_addr.to_string()), downloaded: Some(announce_query.downloaded.to_string()), uploaded: Some(announce_query.uploaded.to_string()), - peer_id: Some(percent_encoding::percent_encode(&announce_query.peer_id, NON_ALPHANUMERIC).to_string()), + peer_id: Some(percent_encode_byte_array(&announce_query.peer_id)), port: Some(announce_query.port.to_string()), left: Some(announce_query.left.to_string()), event, diff --git a/tests/http/requests/scrape.rs b/tests/http/requests/scrape.rs index 6ab46974..979dad54 100644 --- a/tests/http/requests/scrape.rs +++ b/tests/http/requests/scrape.rs @@ -1,10 +1,9 @@ use std::fmt; use std::str::FromStr; -use percent_encoding::NON_ALPHANUMERIC; use torrust_tracker::protocol::info_hash::InfoHash; -use crate::http::bencode::ByteArray20; +use crate::http::{percent_encode_byte_array, ByteArray20}; pub struct Query { pub info_hash: Vec, @@ -111,7 +110,7 @@ impl QueryParams { let info_hashes = scrape_query .info_hash .iter() - .map(|info_hash_bytes| percent_encoding::percent_encode(info_hash_bytes, NON_ALPHANUMERIC).to_string()) + .map(percent_encode_byte_array) .collect::>(); Self { info_hash: info_hashes } diff --git a/tests/http/responses/scrape.rs b/tests/http/responses/scrape.rs index 5bf938eb..1aea517c 100644 --- a/tests/http/responses/scrape.rs +++ b/tests/http/responses/scrape.rs @@ -4,7 +4,7 @@ use std::str; use serde::{self, Deserialize, Serialize}; use serde_bencode::value::Value; -use crate::http::bencode::{ByteArray20, InfoHash}; +use crate::http::{ByteArray20, InfoHash}; #[derive(Debug, PartialEq, Default)] pub struct Response { diff --git a/tests/http_tracker.rs b/tests/http_tracker.rs index 201f8e70..60219d9f 100644 --- a/tests/http_tracker.rs +++ b/tests/http_tracker.rs @@ -1,6 +1,14 @@ /// Integration tests for HTTP tracker server /// -/// cargo test `http_tracker_server` -- --nocapture +/// Warp version: +/// ```text +/// cargo test `warp_http_tracker_server` -- --nocapture +/// ``` +/// +/// Axum version ()WIP): +/// ```text +/// cargo test `warp_http_tracker_server` -- --nocapture +/// ``` mod common; mod http; @@ -2483,60 +2491,3 @@ mod axum_http_tracker_server { mod receiving_an_scrape_request {} } } - -mod percent_encoding { - // todo: these operations are used in the HTTP tracker but they have not been extracted into independent functions. - // These tests document the operations. This behavior could be move to some functions int he future if they are extracted. - - use std::str::FromStr; - - use percent_encoding::NON_ALPHANUMERIC; - use torrust_tracker::protocol::info_hash::InfoHash; - use torrust_tracker::tracker::peer; - - #[test] - fn how_to_encode_an_info_hash() { - let info_hash = InfoHash::from_str("3b245504cf5f11bbdbe1201cea6a6bf45aee1bc0").unwrap(); - - let encoded_info_hash = percent_encoding::percent_encode(&info_hash.0, NON_ALPHANUMERIC).to_string(); - - assert_eq!(encoded_info_hash, "%3B%24U%04%CF%5F%11%BB%DB%E1%20%1C%EAjk%F4Z%EE%1B%C0"); - } - - #[test] - fn how_to_decode_an_info_hash() { - let encoded_infohash = "%3B%24U%04%CF%5F%11%BB%DB%E1%20%1C%EAjk%F4Z%EE%1B%C0"; - - let info_hash_bytes = percent_encoding::percent_decode_str(encoded_infohash).collect::>(); - let info_hash = InfoHash::from_str(&hex::encode(info_hash_bytes)).unwrap(); - - assert_eq!( - info_hash, - InfoHash::from_str("3b245504cf5f11bbdbe1201cea6a6bf45aee1bc0").unwrap() - ); - } - - #[test] - fn how_to_encode_a_peer_id() { - let peer_id = peer::Id(*b"-qB00000000000000000"); - - let encoded_peer_id = percent_encoding::percent_encode(&peer_id.0, NON_ALPHANUMERIC).to_string(); - - assert_eq!(encoded_peer_id, "%2DqB00000000000000000"); - } - - #[test] - fn how_to_decode_a_peer_id() { - let encoded_peer_id = "%2DqB00000000000000000"; - - let bytes_vec = percent_encoding::percent_decode_str(encoded_peer_id).collect::>(); - - // Clone peer_id_bytes into fixed length array - let mut peer_id_bytes: [u8; 20] = Default::default(); - peer_id_bytes.clone_from_slice(bytes_vec.as_slice()); - - let peer_id = peer::Id(peer_id_bytes); - - assert_eq!(peer_id, peer::Id(*b"-qB00000000000000000")); - } -}