Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

Commit

Permalink
update to the new message API and refactor PreProcessor
Browse files Browse the repository at this point in the history
As part of updating the code to match the new message API, I realized
the `PreProcessor` was un-necessarily complex and contained a lot of
duplicated code (once for each message type). The service was
initially intended to perform a bunch of checks on the incoming
messages before passing them to the state machine, but it turns out
the only this it actually does is checking the eligibility of the
participant for the task it wants to take part to. As such we:

- removed the duplicated code
- update the service from `PreProcessor` to `TaskValidator`
  • Loading branch information
little-dude committed Aug 31, 2020
1 parent a8c7160 commit 9287168
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 556 deletions.
80 changes: 18 additions & 62 deletions rust/xaynet-server/src/services/messages/message_parser.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{pin::Pin, sync::Arc, task::Poll};
use std::{convert::TryInto, pin::Pin, sync::Arc, task::Poll};

use anyhow::Context as _;
use derive_more::From;
use futures::{
future::{self, Either, Future},
Expand All @@ -11,8 +12,8 @@ use tokio::sync::oneshot;
use tower::Service;
use tracing::Span;
use xaynet_core::{
crypto::{ByteObject, EncryptKeyPair, Signature},
message::{DecodeError, FromBytes, Header, Message, Payload, Sum, Sum2, Tag, ToBytes, Update},
crypto::EncryptKeyPair,
message::{DecodeError, Message, MessageBuffer, Tag},
};

use crate::{
Expand Down Expand Up @@ -167,20 +168,27 @@ impl Handler {
info!("decrypting message");
let raw = self.decrypt(&data.0.as_ref())?;

info!("parsing message header");
let header = self.parse_header(raw.as_slice())?;
let buf = MessageBuffer::new(&raw).map_err(MessageParserError::Parsing)?;

info!("filtering message based on the current phase");
self.phase_filter(header.tag)?;
let tag = buf
.tag()
.try_into()
.context("failed to parse message tag field")
.map_err(MessageParserError::Parsing)?;
self.phase_filter(tag)?;

info!("verifying the message signature");
self.verify_signature(raw.as_slice(), &header)?;
buf.check_signature().map_err(|e| {
warn!("invalid message signature: {:?}", e);
MessageParserError::InvalidMessageSignature
})?;

info!("parsing the message payload");
let payload = self.parse_payload(raw.as_slice(), &header)?;
info!("parsing the message");
let message = Message::from_bytes(&raw).map_err(MessageParserError::Parsing)?;

info!("done pre-processing the message");
Ok(Message { header, payload })
Ok(message)
}

/// Decrypt the given payload with the coordinator secret key
Expand All @@ -192,12 +200,6 @@ impl Handler {
.map_err(|_| MessageParserError::Decrypt)?)
}

/// Attempt to parse the message header from the raw message
fn parse_header(&self, raw_message: &[u8]) -> Result<Header, MessageParserError> {
Ok(Header::from_bytes(&&raw_message[Signature::LENGTH..])
.map_err(MessageParserError::Parsing)?)
}

/// Reject messages that cannot be handled by the coordinator in
/// the current phase
fn phase_filter(&self, tag: Tag) -> Result<(), MessageParserError> {
Expand All @@ -214,50 +216,4 @@ impl Handler {
}
}
}

/// Verify the integrity of the given message by checking the
/// signature embedded in the header.
fn verify_signature(
&self,
raw_message: &[u8],
header: &Header,
) -> Result<(), MessageParserError> {
// UNWRAP_SAFE: We already parsed the header, so we now the
// message is at least as big as: signature length + header
// length
let signature = Signature::from_slice(&raw_message[..Signature::LENGTH]).unwrap();
let bytes = &raw_message[Signature::LENGTH..];
if header.participant_pk.verify_detached(&signature, bytes) {
Ok(())
} else {
Err(MessageParserError::InvalidMessageSignature)
}
}

/// Parse the payload of the given message
fn parse_payload(
&self,
raw_message: &[u8],
header: &Header,
) -> Result<Payload, MessageParserError> {
let bytes = &raw_message[header.buffer_length() + Signature::LENGTH..];
match header.tag {
Tag::Sum => {
let parsed = Sum::from_bytes(&bytes)
.map_err(|e| MessageParserError::Parsing(e.context("invalid sum payload")))?;
Ok(Payload::Sum(parsed))
}
Tag::Update => {
let parsed = Update::from_bytes(&bytes).map_err(|e| {
MessageParserError::Parsing(e.context("invalid update payload"))
})?;
Ok(Payload::Update(parsed))
}
Tag::Sum2 => {
let parsed = Sum2::from_bytes(&bytes)
.map_err(|e| MessageParserError::Parsing(e.context("invalid sum2 payload")))?;
Ok(Payload::Sum2(parsed))
}
}
}
}
78 changes: 42 additions & 36 deletions rust/xaynet-server/src/services/messages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
//! There are multiple such services and the [`PetMessageHandler`]
//! trait provides a single unifying interface for all of these.
mod message_parser;
mod pre_processor;
mod state_machine;
mod task_validator;

pub use self::{
message_parser::{
Expand All @@ -14,18 +14,18 @@ pub use self::{
MessageParserResponse,
MessageParserService,
},
pre_processor::{
PreProcessorError,
PreProcessorRequest,
PreProcessorResponse,
PreProcessorService,
},
state_machine::{
StateMachineError,
StateMachineRequest,
StateMachineResponse,
StateMachineService,
},
task_validator::{
TaskValidatorError,
TaskValidatorRequest,
TaskValidatorResponse,
TaskValidatorService,
},
};

use xaynet_core::message::Message;
Expand All @@ -43,7 +43,7 @@ use thiserror::Error;
use tower::Service;

type TracedMessageParser<S> = TracedService<S, RawMessage<Vec<u8>>>;
type TracedPreProcessor<S> = TracedService<S, Message>;
type TracedTaskValidator<S> = TracedService<S, Message>;
type TracedStateMachine<S> = TracedService<S, Message>;

/// Error returned by the [`PetMessageHandler`] methods.
Expand All @@ -53,7 +53,7 @@ pub enum PetMessageError {
Parser(MessageParserError),

#[error("failed to pre-process message: {0}")]
PreProcessor(PreProcessorError),
TaskValidator(TaskValidatorError),

#[error("state machine failed to handle message: {0}")]
StateMachine(StateMachineError),
Expand All @@ -63,7 +63,7 @@ pub enum PetMessageError {
}

/// A single interface for all the PET message processing sub-services
/// ([`MessageParserService`], [`PreProcessorService`] and
/// ([`MessageParserService`], [`TaskValidatorService`] and
/// [`StateMachineService`]).
#[async_trait]
pub trait PetMessageHandler: Send {
Expand All @@ -78,7 +78,7 @@ pub trait PetMessageHandler: Send {
let message = self.call_parser(req).await?;

let req = Request::from_parts(metadata.clone(), message);
let message = self.call_pre_processor(req).await?;
let message = self.call_task_validator(req).await?;

let req = Request::from_parts(metadata, message);
Ok(self.call_state_machine(req).await?)
Expand All @@ -91,9 +91,9 @@ pub trait PetMessageHandler: Send {
) -> Result<Message, PetMessageError>;

/// Pre-process a PET message
async fn call_pre_processor(
async fn call_task_validator(
&mut self,
message: PreProcessorRequest,
message: TaskValidatorRequest,
) -> Result<Message, PetMessageError>;

/// Have a PET message processed by the state machine
Expand All @@ -104,7 +104,7 @@ pub trait PetMessageHandler: Send {
}

#[async_trait]
impl<MP, PP, SM> PetMessageHandler for PetMessageService<MP, PP, SM>
impl<MP, TV, SM> PetMessageHandler for PetMessageService<MP, TV, SM>
where
Self: Send + Sync + 'static,

Expand All @@ -113,9 +113,9 @@ where
<MP as Service<MessageParserRequest<Vec<u8>>>>::Error:
Into<Box<dyn ::std::error::Error + Send + Sync + 'static>>,

PP: Service<PreProcessorRequest, Response = PreProcessorResponse> + Send + 'static,
<PP as Service<PreProcessorRequest>>::Future: Send + 'static,
<PP as Service<PreProcessorRequest>>::Error:
TV: Service<TaskValidatorRequest, Response = TaskValidatorResponse> + Send + 'static,
<TV as Service<TaskValidatorRequest>>::Future: Send + 'static,
<TV as Service<TaskValidatorRequest>>::Error:
Into<Box<dyn ::std::error::Error + Send + Sync + 'static>>,

SM: Service<StateMachineRequest, Response = StateMachineResponse> + Send + 'static,
Expand Down Expand Up @@ -146,18 +146,23 @@ where
.map_err(PetMessageError::Parser)
}

async fn call_pre_processor(
async fn call_task_validator(
&mut self,
message: PreProcessorRequest,
message: TaskValidatorRequest,
) -> Result<Message, PetMessageError> {
poll_fn(|cx| <PP as Service<PreProcessorRequest>>::poll_ready(&mut self.pre_processor, cx))
.await
.map_err(|e| PetMessageError::ServiceError(Into::into(e)))?;
poll_fn(|cx| {
<TV as Service<TaskValidatorRequest>>::poll_ready(&mut self.task_validator, cx)
})
.await
.map_err(|e| PetMessageError::ServiceError(Into::into(e)))?;

<PP as Service<PreProcessorRequest>>::call(&mut self.pre_processor, message.map(Into::into))
.await
.map_err(|e| PetMessageError::ServiceError(Into::into(e)))?
.map_err(PetMessageError::PreProcessor)
<TV as Service<TaskValidatorRequest>>::call(
&mut self.task_validator,
message.map(Into::into),
)
.await
.map_err(|e| PetMessageError::ServiceError(Into::into(e)))?
.map_err(PetMessageError::TaskValidator)
}

async fn call_state_machine(
Expand All @@ -184,45 +189,46 @@ where
/// encrypted message) goes through the `MessageParser` service,
/// which decrypt the message, validates it, and parses it
///
/// 2. The message is passed to the `PreProcessor`, which depending on
/// 2. The message is passed to the `TaskValidator`, which depending on
/// the message type performs some additional checks. The
/// `PreProcessor` may also discard the message
/// `TaskValidator` may also discard the message
///
/// 3. Finally, the message is handled by the `StateMachine` service.
#[derive(Debug, Clone)]
pub struct PetMessageService<MessageParser, PreProcessor, StateMachine> {
pub struct PetMessageService<MessageParser, TaskValidator, StateMachine> {
message_parser: MessageParser,
pre_processor: PreProcessor,
task_validator: TaskValidator,
state_machine: StateMachine,
}

impl<MP, PP, SM>
PetMessageService<TracedMessageParser<MP>, TracedPreProcessor<PP>, TracedStateMachine<SM>>
impl<MP, TV, SM>
PetMessageService<TracedMessageParser<MP>, TracedTaskValidator<TV>, TracedStateMachine<SM>>
where
MP: Service<MessageParserRequest<Vec<u8>>, Response = MessageParserResponse>,
PP: Service<PreProcessorRequest, Response = PreProcessorResponse>,
TV: Service<TaskValidatorRequest, Response = TaskValidatorResponse>,
SM: Service<StateMachineRequest, Response = StateMachineResponse>,
{
/// Instantiate a new [`PetMessageService`] with the given sub-services
pub fn new(message_parser: MP, pre_processor: PP, state_machine: SM) -> Self {
pub fn new(message_parser: MP, task_validator: TV, state_machine: SM) -> Self {
Self {
message_parser: with_tracing(message_parser),
pre_processor: with_tracing(pre_processor),
task_validator: with_tracing(task_validator),
state_machine: with_tracing(state_machine),
}
}
}

use crate::utils::Traceable;
use tracing::Span;
use xaynet_core::message::{Payload, ToBytes};
use xaynet_core::message::Payload;

impl Traceable for Message {
fn make_span(&self) -> Span {
let message_type = match self.payload {
Payload::Sum(_) => "sum",
Payload::Update(_) => "update",
Payload::Sum2(_) => "sum2",
Payload::Chunk(_) => "chunk",
};
error_span!(
"Message",
Expand Down
Loading

0 comments on commit 9287168

Please sign in to comment.