Skip to content

Commit

Permalink
refactor(http): [#184] move extractor to extractor mod
Browse files Browse the repository at this point in the history
  • Loading branch information
josecelano committed Feb 16, 2023
1 parent 99dbbe4 commit 30918da
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 77 deletions.
45 changes: 45 additions & 0 deletions src/http/axum_implementation/extractors/announce_request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use std::panic::Location;

use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::response::{IntoResponse, Response};

use crate::http::axum_implementation::query::Query;
use crate::http::axum_implementation::requests::announce::{Announce, ParseAnnounceQueryError};
use crate::http::axum_implementation::responses;

pub struct ExtractRequest(pub Announce);

#[async_trait]
impl<S> FromRequestParts<S> for ExtractRequest
where
S: Send + Sync,
{
type Rejection = Response;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let raw_query = parts.uri.query();

if raw_query.is_none() {
return Err(responses::error::Error::from(ParseAnnounceQueryError::MissingParams {
location: Location::caller(),
})
.into_response());
}

let query = raw_query.unwrap().parse::<Query>();

if let Err(error) = query {
return Err(responses::error::Error::from(error).into_response());
}

let announce_request = Announce::try_from(query.unwrap());

if let Err(error) = announce_request {
return Err(responses::error::Error::from(error).into_response());
}

Ok(ExtractRequest(announce_request.unwrap()))
}
}
1 change: 1 addition & 0 deletions src/http/axum_implementation/extractors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod announce_request;
pub mod peer_ip;
pub mod remote_client_ip;
5 changes: 3 additions & 2 deletions src/http/axum_implementation/handlers/announce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use axum::extract::State;
use axum::response::{IntoResponse, Response};
use log::debug;

use crate::http::axum_implementation::extractors::announce_request::ExtractRequest;
use crate::http::axum_implementation::extractors::peer_ip::assign_ip_address_to_peer;
use crate::http::axum_implementation::extractors::remote_client_ip::RemoteClientIp;
use crate::http::axum_implementation::requests::announce::{Announce, Event, ExtractAnnounceRequest};
use crate::http::axum_implementation::requests::announce::{Announce, Event};
use crate::http::axum_implementation::{responses, services};
use crate::protocol::clock::{Current, Time};
use crate::tracker::peer::Peer;
Expand All @@ -17,7 +18,7 @@ use crate::tracker::Tracker;
#[allow(clippy::unused_async)]
pub async fn handle(
State(tracker): State<Arc<Tracker>>,
ExtractAnnounceRequest(announce_request): ExtractAnnounceRequest,
ExtractRequest(announce_request): ExtractRequest,
remote_client_ip: RemoteClientIp,
) -> Response {
debug!("http announce request: {:#?}", announce_request);
Expand Down
111 changes: 36 additions & 75 deletions src/http/axum_implementation/requests/announce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@ use std::fmt;
use std::panic::Location;
use std::str::FromStr;

use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::response::{IntoResponse, Response};
use thiserror::Error;

use crate::http::axum_implementation::query::{ParseQueryError, Query};
Expand All @@ -17,9 +13,7 @@ use crate::tracker::peer::{self, IdConversionError};

pub type NumberOfBytes = i64;

pub struct ExtractAnnounceRequest(pub Announce);

// Param names in the URL query
// Query param names
const INFO_HASH: &str = "info_hash";
const PEER_ID: &str = "peer_id";
const PORT: &str = "port";
Expand All @@ -43,6 +37,41 @@ pub struct Announce {
pub compact: Option<Compact>,
}

#[derive(Error, Debug)]
pub enum ParseAnnounceQueryError {
#[error("missing query params for announce request in {location}")]
MissingParams { location: &'static Location<'static> },
#[error("missing param {param_name} in {location}")]
MissingParam {
location: &'static Location<'static>,
param_name: String,
},
#[error("invalid param value {param_value} for {param_name} in {location}")]
InvalidParam {
param_name: String,
param_value: String,
location: &'static Location<'static>,
},
#[error("param value overflow {param_value} for {param_name} in {location}")]
NumberOfBytesOverflow {
param_name: String,
param_value: String,
location: &'static Location<'static>,
},
#[error("invalid param value {param_value} for {param_name} in {source}")]
InvalidInfoHashParam {
param_name: String,
param_value: String,
source: LocatedError<'static, ConversionError>,
},
#[error("invalid param value {param_value} for {param_name} in {source}")]
InvalidPeerIdParam {
param_name: String,
param_value: String,
source: LocatedError<'static, IdConversionError>,
},
}

#[derive(PartialEq, Debug)]
pub enum Event {
Started,
Expand Down Expand Up @@ -108,41 +137,6 @@ impl FromStr for Compact {
}
}

#[derive(Error, Debug)]
pub enum ParseAnnounceQueryError {
#[error("missing query params for announce request in {location}")]
MissingParams { location: &'static Location<'static> },
#[error("missing param {param_name} in {location}")]
MissingParam {
location: &'static Location<'static>,
param_name: String,
},
#[error("invalid param value {param_value} for {param_name} in {location}")]
InvalidParam {
param_name: String,
param_value: String,
location: &'static Location<'static>,
},
#[error("param value overflow {param_value} for {param_name} in {location}")]
NumberOfBytesOverflow {
param_name: String,
param_value: String,
location: &'static Location<'static>,
},
#[error("invalid param value {param_value} for {param_name} in {source}")]
InvalidInfoHashParam {
param_name: String,
param_value: String,
source: LocatedError<'static, ConversionError>,
},
#[error("invalid param value {param_value} for {param_name} in {source}")]
InvalidPeerIdParam {
param_name: String,
param_value: String,
source: LocatedError<'static, IdConversionError>,
},
}

impl From<ParseQueryError> for responses::error::Error {
fn from(err: ParseQueryError) -> Self {
responses::error::Error {
Expand Down Expand Up @@ -281,39 +275,6 @@ fn extract_compact(query: &Query) -> Result<Option<Compact>, ParseAnnounceQueryE
}
}

#[async_trait]
impl<S> FromRequestParts<S> for ExtractAnnounceRequest
where
S: Send + Sync,
{
type Rejection = Response;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let raw_query = parts.uri.query();

if raw_query.is_none() {
return Err(responses::error::Error::from(ParseAnnounceQueryError::MissingParams {
location: Location::caller(),
})
.into_response());
}

let query = raw_query.unwrap().parse::<Query>();

if let Err(error) = query {
return Err(responses::error::Error::from(error).into_response());
}

let announce_request = Announce::try_from(query.unwrap());

if let Err(error) = announce_request {
return Err(responses::error::Error::from(error).into_response());
}

Ok(ExtractAnnounceRequest(announce_request.unwrap()))
}
}

#[cfg(test)]
mod tests {

Expand Down
2 changes: 2 additions & 0 deletions src/tracker/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub struct Peer {
pub downloaded: NumberOfBytes,
#[serde(with = "NumberOfBytesDef")]
pub left: NumberOfBytes, // The number of bytes this peer still has to download
// code-review: aquatic_udp_protocol::request::AnnounceEvent is used also for the HTTP tracker.
// Maybe we should use our own enum and use the¡is one only for the UDP tracker.
#[serde(with = "AnnounceEventDef")]
pub event: AnnounceEvent,
}
Expand Down

0 comments on commit 30918da

Please sign in to comment.