diff --git a/rust/xaynet-server/src/services/messages/message_parser.rs b/rust/xaynet-server/src/services/messages/message_parser.rs index 1367d48d5..0602a42bc 100644 --- a/rust/xaynet-server/src/services/messages/message_parser.rs +++ b/rust/xaynet-server/src/services/messages/message_parser.rs @@ -1,5 +1,6 @@ -use std::{pin::Pin, sync::Arc, task::Poll}; +use std::{convert::TryInto, pin::Pin, sync::Arc, task::Poll}; +use anyhow::Context as _; use derive_more::From; use futures::{ future::{self, Either, Future}, @@ -11,8 +12,8 @@ use tokio::sync::oneshot; use tower::Service; use tracing::Span; use xaynet_core::{ - crypto::{ByteObject, EncryptKeyPair, Signature}, - message::{DecodeError, FromBytes, Header, Message, Payload, Sum, Sum2, Tag, ToBytes, Update}, + crypto::EncryptKeyPair, + message::{DecodeError, Message, MessageBuffer, Tag}, }; use crate::{ @@ -167,20 +168,27 @@ impl Handler { info!("decrypting message"); let raw = self.decrypt(&data.0.as_ref())?; - info!("parsing message header"); - let header = self.parse_header(raw.as_slice())?; + let buf = MessageBuffer::new(&raw).map_err(MessageParserError::Parsing)?; info!("filtering message based on the current phase"); - self.phase_filter(header.tag)?; + let tag = buf + .tag() + .try_into() + .context("failed to parse message tag field") + .map_err(MessageParserError::Parsing)?; + self.phase_filter(tag)?; info!("verifying the message signature"); - self.verify_signature(raw.as_slice(), &header)?; + buf.check_signature().map_err(|e| { + warn!("invalid message signature: {:?}", e); + MessageParserError::InvalidMessageSignature + })?; - info!("parsing the message payload"); - let payload = self.parse_payload(raw.as_slice(), &header)?; + info!("parsing the message"); + let message = Message::from_bytes(&raw).map_err(MessageParserError::Parsing)?; info!("done pre-processing the message"); - Ok(Message { header, payload }) + Ok(message) } /// Decrypt the given payload with the coordinator secret key @@ -192,12 +200,6 @@ impl Handler { .map_err(|_| MessageParserError::Decrypt)?) } - /// Attempt to parse the message header from the raw message - fn parse_header(&self, raw_message: &[u8]) -> Result { - Ok(Header::from_bytes(&&raw_message[Signature::LENGTH..]) - .map_err(MessageParserError::Parsing)?) - } - /// Reject messages that cannot be handled by the coordinator in /// the current phase fn phase_filter(&self, tag: Tag) -> Result<(), MessageParserError> { @@ -214,50 +216,4 @@ impl Handler { } } } - - /// Verify the integrity of the given message by checking the - /// signature embedded in the header. - fn verify_signature( - &self, - raw_message: &[u8], - header: &Header, - ) -> Result<(), MessageParserError> { - // UNWRAP_SAFE: We already parsed the header, so we now the - // message is at least as big as: signature length + header - // length - let signature = Signature::from_slice(&raw_message[..Signature::LENGTH]).unwrap(); - let bytes = &raw_message[Signature::LENGTH..]; - if header.participant_pk.verify_detached(&signature, bytes) { - Ok(()) - } else { - Err(MessageParserError::InvalidMessageSignature) - } - } - - /// Parse the payload of the given message - fn parse_payload( - &self, - raw_message: &[u8], - header: &Header, - ) -> Result { - let bytes = &raw_message[header.buffer_length() + Signature::LENGTH..]; - match header.tag { - Tag::Sum => { - let parsed = Sum::from_bytes(&bytes) - .map_err(|e| MessageParserError::Parsing(e.context("invalid sum payload")))?; - Ok(Payload::Sum(parsed)) - } - Tag::Update => { - let parsed = Update::from_bytes(&bytes).map_err(|e| { - MessageParserError::Parsing(e.context("invalid update payload")) - })?; - Ok(Payload::Update(parsed)) - } - Tag::Sum2 => { - let parsed = Sum2::from_bytes(&bytes) - .map_err(|e| MessageParserError::Parsing(e.context("invalid sum2 payload")))?; - Ok(Payload::Sum2(parsed)) - } - } - } } diff --git a/rust/xaynet-server/src/services/messages/mod.rs b/rust/xaynet-server/src/services/messages/mod.rs index 5cc05852d..71568659c 100644 --- a/rust/xaynet-server/src/services/messages/mod.rs +++ b/rust/xaynet-server/src/services/messages/mod.rs @@ -4,8 +4,8 @@ //! There are multiple such services and the [`PetMessageHandler`] //! trait provides a single unifying interface for all of these. mod message_parser; -mod pre_processor; mod state_machine; +mod task_validator; pub use self::{ message_parser::{ @@ -14,18 +14,18 @@ pub use self::{ MessageParserResponse, MessageParserService, }, - pre_processor::{ - PreProcessorError, - PreProcessorRequest, - PreProcessorResponse, - PreProcessorService, - }, state_machine::{ StateMachineError, StateMachineRequest, StateMachineResponse, StateMachineService, }, + task_validator::{ + TaskValidatorError, + TaskValidatorRequest, + TaskValidatorResponse, + TaskValidatorService, + }, }; use xaynet_core::message::Message; @@ -43,7 +43,7 @@ use thiserror::Error; use tower::Service; type TracedMessageParser = TracedService>>; -type TracedPreProcessor = TracedService; +type TracedTaskValidator = TracedService; type TracedStateMachine = TracedService; /// Error returned by the [`PetMessageHandler`] methods. @@ -53,7 +53,7 @@ pub enum PetMessageError { Parser(MessageParserError), #[error("failed to pre-process message: {0}")] - PreProcessor(PreProcessorError), + TaskValidator(TaskValidatorError), #[error("state machine failed to handle message: {0}")] StateMachine(StateMachineError), @@ -63,7 +63,7 @@ pub enum PetMessageError { } /// A single interface for all the PET message processing sub-services -/// ([`MessageParserService`], [`PreProcessorService`] and +/// ([`MessageParserService`], [`TaskValidatorService`] and /// [`StateMachineService`]). #[async_trait] pub trait PetMessageHandler: Send { @@ -78,7 +78,7 @@ pub trait PetMessageHandler: Send { let message = self.call_parser(req).await?; let req = Request::from_parts(metadata.clone(), message); - let message = self.call_pre_processor(req).await?; + let message = self.call_task_validator(req).await?; let req = Request::from_parts(metadata, message); Ok(self.call_state_machine(req).await?) @@ -91,9 +91,9 @@ pub trait PetMessageHandler: Send { ) -> Result; /// Pre-process a PET message - async fn call_pre_processor( + async fn call_task_validator( &mut self, - message: PreProcessorRequest, + message: TaskValidatorRequest, ) -> Result; /// Have a PET message processed by the state machine @@ -104,7 +104,7 @@ pub trait PetMessageHandler: Send { } #[async_trait] -impl PetMessageHandler for PetMessageService +impl PetMessageHandler for PetMessageService where Self: Send + Sync + 'static, @@ -113,9 +113,9 @@ where >>>::Error: Into>, - PP: Service + Send + 'static, - >::Future: Send + 'static, - >::Error: + TV: Service + Send + 'static, + >::Future: Send + 'static, + >::Error: Into>, SM: Service + Send + 'static, @@ -146,18 +146,23 @@ where .map_err(PetMessageError::Parser) } - async fn call_pre_processor( + async fn call_task_validator( &mut self, - message: PreProcessorRequest, + message: TaskValidatorRequest, ) -> Result { - poll_fn(|cx| >::poll_ready(&mut self.pre_processor, cx)) - .await - .map_err(|e| PetMessageError::ServiceError(Into::into(e)))?; + poll_fn(|cx| { + >::poll_ready(&mut self.task_validator, cx) + }) + .await + .map_err(|e| PetMessageError::ServiceError(Into::into(e)))?; - >::call(&mut self.pre_processor, message.map(Into::into)) - .await - .map_err(|e| PetMessageError::ServiceError(Into::into(e)))? - .map_err(PetMessageError::PreProcessor) + >::call( + &mut self.task_validator, + message.map(Into::into), + ) + .await + .map_err(|e| PetMessageError::ServiceError(Into::into(e)))? + .map_err(PetMessageError::TaskValidator) } async fn call_state_machine( @@ -184,30 +189,30 @@ where /// encrypted message) goes through the `MessageParser` service, /// which decrypt the message, validates it, and parses it /// -/// 2. The message is passed to the `PreProcessor`, which depending on +/// 2. The message is passed to the `TaskValidator`, which depending on /// the message type performs some additional checks. The -/// `PreProcessor` may also discard the message +/// `TaskValidator` may also discard the message /// /// 3. Finally, the message is handled by the `StateMachine` service. #[derive(Debug, Clone)] -pub struct PetMessageService { +pub struct PetMessageService { message_parser: MessageParser, - pre_processor: PreProcessor, + task_validator: TaskValidator, state_machine: StateMachine, } -impl - PetMessageService, TracedPreProcessor, TracedStateMachine> +impl + PetMessageService, TracedTaskValidator, TracedStateMachine> where MP: Service>, Response = MessageParserResponse>, - PP: Service, + TV: Service, SM: Service, { /// Instantiate a new [`PetMessageService`] with the given sub-services - pub fn new(message_parser: MP, pre_processor: PP, state_machine: SM) -> Self { + pub fn new(message_parser: MP, task_validator: TV, state_machine: SM) -> Self { Self { message_parser: with_tracing(message_parser), - pre_processor: with_tracing(pre_processor), + task_validator: with_tracing(task_validator), state_machine: with_tracing(state_machine), } } @@ -215,7 +220,7 @@ where use crate::utils::Traceable; use tracing::Span; -use xaynet_core::message::{Payload, ToBytes}; +use xaynet_core::message::Payload; impl Traceable for Message { fn make_span(&self) -> Span { @@ -223,6 +228,7 @@ impl Traceable for Message { Payload::Sum(_) => "sum", Payload::Update(_) => "update", Payload::Sum2(_) => "sum2", + Payload::Chunk(_) => "chunk", }; error_span!( "Message", diff --git a/rust/xaynet-server/src/services/messages/pre_processor/mod.rs b/rust/xaynet-server/src/services/messages/pre_processor/mod.rs deleted file mode 100644 index de9c7efde..000000000 --- a/rust/xaynet-server/src/services/messages/pre_processor/mod.rs +++ /dev/null @@ -1,127 +0,0 @@ -mod sum; -pub use sum::SumPreProcessorService; - -mod update; -pub use update::UpdatePreProcessorService; - -mod sum2; -pub use sum2::Sum2PreProcessorService; - -use std::{pin::Pin, task::Poll}; - -use futures::{ - future::{self, Future}, - task::Context, -}; -use thiserror::Error; -use tower::Service; -use xaynet_core::{ - common::RoundParameters, - message::{Message, Payload}, -}; - -use crate::{ - state_machine::{ - events::{Event, EventListener, EventSubscriber}, - phases::PhaseName, - }, - utils::request::Request, -}; - -/// A service for performing sanity checks and preparing incoming -/// requests to be handled by the state machine. -pub struct PreProcessorService { - params_listener: EventListener, - /// A stream that receives phase updates - phase_listener: EventListener, - /// Latest phase event the service has received - latest_phase_event: Event, - /// Inner service to handle sum messages - sum: SumPreProcessorService, - /// Inner service to handle update messages - update: UpdatePreProcessorService, - /// Inner service to handle sum2 messages - sum2: Sum2PreProcessorService, -} - -impl PreProcessorService { - pub fn new(subscriber: &EventSubscriber) -> Self { - Self { - params_listener: subscriber.params_listener(), - phase_listener: subscriber.phase_listener(), - latest_phase_event: subscriber.phase_listener().get_latest(), - sum: SumPreProcessorService, - update: UpdatePreProcessorService, - sum2: Sum2PreProcessorService, - } - } -} - -/// Request type for [`PreProcessorService`] -pub type PreProcessorRequest = Request; - -/// Response type for [`PreProcessorService`] -pub type PreProcessorResponse = Result; - -impl Service for PreProcessorService { - type Response = PreProcessorResponse; - type Error = std::convert::Infallible; - - #[allow(clippy::type_complexity)] - type Future = - Pin> + 'static + Send + Sync>>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.latest_phase_event = self.phase_listener.get_latest(); - match self.latest_phase_event.event { - PhaseName::Sum => self.sum.poll_ready(cx), - PhaseName::Update => self.update.poll_ready(cx), - PhaseName::Sum2 => self.sum2.poll_ready(cx), - _ => Poll::Ready(Ok(())), - } - } - - fn call(&mut self, req: PreProcessorRequest) -> Self::Future { - let Message { header, payload } = req.into_inner(); - match (self.latest_phase_event.event, payload) { - (PhaseName::Sum, Payload::Sum(sum)) => { - let req = (header, sum, self.params_listener.get_latest().event); - let fut = self.sum.call(req); - Box::pin(fut) - } - (PhaseName::Update, Payload::Update(update)) => { - let req = (header, update, self.params_listener.get_latest().event); - let fut = self.update.call(req); - Box::pin(fut) - } - (PhaseName::Sum2, Payload::Sum2(sum2)) => { - let req = (header, sum2, self.params_listener.get_latest().event); - let fut = self.sum2.call(req); - Box::pin(fut) - } - _ => Box::pin(future::ready(Ok(Err(PreProcessorError::UnexpectedMessage)))), - } - } -} - -/// Error type for [`PreProcessorService`] -#[derive(Error, Debug)] -pub enum PreProcessorError { - #[error("Invalid sum signature")] - InvalidSumSignature, - - #[error("Invalid update signature")] - InvalidUpdateSignature, - - #[error("Not eligible for sum task")] - NotSumEligible, - - #[error("Not eligible for update task")] - NotUpdateEligible, - - #[error("The message was rejected because the coordinator did not expect it")] - UnexpectedMessage, - - #[error("Internal error")] - InternalError, -} diff --git a/rust/xaynet-server/src/services/messages/pre_processor/sum.rs b/rust/xaynet-server/src/services/messages/pre_processor/sum.rs deleted file mode 100644 index 8c7806610..000000000 --- a/rust/xaynet-server/src/services/messages/pre_processor/sum.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::task::Poll; - -use futures::{ - future::{ready, Ready}, - task::Context, -}; -use tower::Service; -use xaynet_core::{ - common::RoundParameters, - crypto::ByteObject, - message::{Header, Message, Payload, Sum}, -}; - -use crate::services::messages::pre_processor::{PreProcessorError, PreProcessorResponse}; - -/// Request type for [`SumPreProcessorService`] -pub type SumRequest = (Header, Sum, RoundParameters); - -/// A service for performing sanity checks and preparing a sum request -/// to be handled by the state machine. At the moment, this is limited -/// to verifying the participant's eligibility for the sum task. -#[derive(Debug, Clone)] -pub struct SumPreProcessorService; - -impl Service for SumPreProcessorService { - type Response = PreProcessorResponse; - type Error = ::std::convert::Infallible; - type Future = Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, (header, message, params): SumRequest) -> Self::Future { - let pre_processor = SumPreProcessor { - header, - message, - params, - }; - ready(Ok(pre_processor.call())) - } -} - -struct SumPreProcessor { - header: Header, - message: Sum, - params: RoundParameters, -} - -impl SumPreProcessor { - fn call(self) -> PreProcessorResponse { - if !self.has_valid_sum_signature() { - return Err(PreProcessorError::InvalidSumSignature); - } - if !self.is_eligible_for_sum_task() { - return Err(PreProcessorError::NotSumEligible); - } - - let Self { - header, message, .. - } = self; - Ok(Message { - header, - payload: Payload::Sum(message), - }) - } - /// Check whether this request contains a valid sum signature - fn has_valid_sum_signature(&self) -> bool { - let seed = &self.params.seed; - let signature = &self.message.sum_signature; - let pk = &self.header.participant_pk; - pk.verify_detached(&signature, &[seed.as_slice(), b"sum"].concat()) - } - - /// Check whether this request comes from a participant that is eligible for the sum task. - fn is_eligible_for_sum_task(&self) -> bool { - self.message.sum_signature.is_eligible(self.params.sum) - } -} diff --git a/rust/xaynet-server/src/services/messages/pre_processor/sum2.rs b/rust/xaynet-server/src/services/messages/pre_processor/sum2.rs deleted file mode 100644 index e603719c7..000000000 --- a/rust/xaynet-server/src/services/messages/pre_processor/sum2.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::task::Poll; - -use futures::{ - future::{ready, Ready}, - task::Context, -}; -use tower::Service; -use xaynet_core::{ - common::RoundParameters, - crypto::ByteObject, - message::{Header, Message, Payload, Sum2}, -}; - -use crate::services::messages::pre_processor::{PreProcessorError, PreProcessorResponse}; - -/// Request type for [`SumPreProcessorService`] -pub type Sum2Request = (Header, Sum2, RoundParameters); - -/// A service for performing sanity checks and preparing a sum2 -/// request to be handled by the state machine. At the moment, this is -/// limited to verifying the participant's eligibility for the sum -/// task. -#[derive(Clone, Debug)] -pub struct Sum2PreProcessorService; - -impl Service for Sum2PreProcessorService { - type Response = PreProcessorResponse; - type Error = std::convert::Infallible; - type Future = Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, (header, message, params): Sum2Request) -> Self::Future { - let pre_processor = Sum2PreProcessor { - header, - message, - params, - }; - ready(Ok(pre_processor.call())) - } -} - -struct Sum2PreProcessor { - header: Header, - message: Sum2, - params: RoundParameters, -} - -impl Sum2PreProcessor { - fn call(self) -> PreProcessorResponse { - if !self.has_valid_sum_signature() { - return Err(PreProcessorError::InvalidSumSignature); - } - if !self.is_eligible_for_sum_task() { - return Err(PreProcessorError::NotSumEligible); - } - - let Self { - header, message, .. - } = self; - Ok(Message { - header, - payload: Payload::Sum2(message), - }) - } - /// Check whether this request contains a valid sum signature - fn has_valid_sum_signature(&self) -> bool { - let seed = &self.params.seed; - let signature = &self.message.sum_signature; - let pk = &self.header.participant_pk; - pk.verify_detached(&signature, &[seed.as_slice(), b"sum"].concat()) - } - - /// Check whether this request comes from a participant that is eligible for the sum task. - fn is_eligible_for_sum_task(&self) -> bool { - self.message.sum_signature.is_eligible(self.params.sum) - } -} diff --git a/rust/xaynet-server/src/services/messages/pre_processor/update.rs b/rust/xaynet-server/src/services/messages/pre_processor/update.rs deleted file mode 100644 index d8a0529a6..000000000 --- a/rust/xaynet-server/src/services/messages/pre_processor/update.rs +++ /dev/null @@ -1,114 +0,0 @@ -use std::task::Poll; - -use futures::{ - future::{ready, Ready}, - task::Context, -}; -use tower::Service; -use xaynet_core::{ - common::RoundParameters, - crypto::ByteObject, - message::{Header, Message, Payload, Update}, -}; - -use crate::services::messages::pre_processor::{PreProcessorError, PreProcessorResponse}; - -/// Request type for [`UpdatePreProcessorService`] -pub type UpdateRequest = (Header, Update, RoundParameters); - -/// A service for performing sanity checks and preparing an update -/// request to be handled by the state machine. At the moment, this is -/// limited to verifying the participant's eligibility for the update -/// task. -#[derive(Clone, Debug)] -pub struct UpdatePreProcessorService; - -impl Service for UpdatePreProcessorService { - type Response = PreProcessorResponse; - type Error = std::convert::Infallible; - type Future = Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, (header, message, params): UpdateRequest) -> Self::Future { - let pre_processor = UpdatePreProcessor { - header, - message, - params, - }; - ready(Ok(pre_processor.call())) - } -} - -struct UpdatePreProcessor { - header: Header, - message: Update, - params: RoundParameters, -} - -impl UpdatePreProcessor { - fn call(self) -> Result { - debug!("checking sum signature"); - if !self.has_valid_sum_signature() { - debug!("invalid sum signature"); - return Err(PreProcessorError::InvalidSumSignature); - } - - debug!("checking sum task eligibility"); - if self.is_eligible_for_sum_task() { - debug!("participant is eligible for the sum task, so is not eligible for update task"); - return Err(PreProcessorError::NotUpdateEligible); - } - - debug!("checking update signature"); - if !self.has_valid_update_signature() { - debug!("invalid update signature"); - return Err(PreProcessorError::InvalidUpdateSignature); - } - - debug!("checking update task eligibility"); - if !self.is_eligible_for_update_task() { - debug!("not eligible for update task"); - return Err(PreProcessorError::NotUpdateEligible); - } - - let Self { - header, message, .. - } = self; - Ok(Message { - header, - payload: Payload::Update(message), - }) - } - - /// Check whether this request contains a valid sum signature - fn has_valid_sum_signature(&self) -> bool { - let seed = &self.params.seed; - let signature = &self.message.sum_signature; - let pk = &self.header.participant_pk; - pk.verify_detached(&signature, &[seed.as_slice(), b"sum"].concat()) - } - - /// Check whether this request comes from a participant that is eligible for the sum task. - fn is_eligible_for_sum_task(&self) -> bool { - self.message.sum_signature.is_eligible(self.params.sum) - } - - /// Check whether this request contains a valid update signature - fn has_valid_update_signature(&self) -> bool { - let seed = &self.params.seed; - let signature = &self.message.update_signature; - let pk = &self.header.participant_pk; - pk.verify_detached(&signature, &[seed.as_slice(), b"update"].concat()) - } - - /// Check whether this request comes from a participant that is - /// eligible for the update task. - fn is_eligible_for_update_task(&self) -> bool { - self.message - .update_signature - .is_eligible(self.params.update) - } -} diff --git a/rust/xaynet-server/src/services/messages/task_validator.rs b/rust/xaynet-server/src/services/messages/task_validator.rs new file mode 100644 index 000000000..be6c70e29 --- /dev/null +++ b/rust/xaynet-server/src/services/messages/task_validator.rs @@ -0,0 +1,114 @@ +use std::task::Poll; + +use futures::{future, task::Context}; +use thiserror::Error; +use tower::Service; +use xaynet_core::{ + common::RoundParameters, + crypto::ByteObject, + message::{Message, Payload}, +}; + +use crate::{ + state_machine::events::{EventListener, EventSubscriber}, + utils::request::Request, +}; + +/// A service for performing sanity checks and preparing incoming +/// requests to be handled by the state machine. +pub struct TaskValidatorService { + params_listener: EventListener, +} + +impl TaskValidatorService { + pub fn new(subscriber: &EventSubscriber) -> Self { + Self { + params_listener: subscriber.params_listener(), + } + } +} + +/// Request type for [`TaskValidatorService`] +pub type TaskValidatorRequest = Request; + +/// Response type for [`TaskValidatorService`] +pub type TaskValidatorResponse = Result; + +impl Service for TaskValidatorService { + type Response = TaskValidatorResponse; + type Error = std::convert::Infallible; + type Future = future::Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: TaskValidatorRequest) -> Self::Future { + let message = req.into_inner(); + let (sum_signature, update_signature) = match message.payload { + Payload::Sum(ref sum) => (sum.sum_signature, None), + Payload::Update(ref update) => ( + update.sum_signature, + Some(update.update_signature), + ), + Payload::Sum2(ref sum2) => (sum2.sum_signature, None), + _ => return future::ready(Ok(Err(TaskValidatorError::UnexpectedMessage))), + }; + let params = self.params_listener.get_latest().event; + let seed = params.seed.as_slice(); + + // Check whether the participant is eligible for the sum task + let has_valid_sum_signature = message + .participant_pk + .verify_detached(&sum_signature, &[seed, b"sum"].concat()); + let is_summer = has_valid_sum_signature && sum_signature.is_eligible(params.sum); + + // Check whether the participant is eligible for the update task + let has_valid_update_signature = update_signature + .map(|sig| { + message + .participant_pk + .verify_detached(&sig, &[seed, b"update"].concat()) + }) + .unwrap_or(false); + let is_updater = !is_summer + && has_valid_update_signature + && update_signature + .map(|sig| sig.is_eligible(params.update)) + .unwrap_or(false); + + match message.payload { + Payload::Sum(_) | Payload::Sum2(_) => { + if is_summer { + future::ready(Ok(Ok(message))) + } else { + future::ready(Ok(Err(TaskValidatorError::NotSumEligible))) + } + } + Payload::Update(_) => { + if is_updater { + future::ready(Ok(Ok(message))) + } else { + future::ready(Ok(Err(TaskValidatorError::NotUpdateEligible))) + } + } + _ => future::ready(Ok(Err(TaskValidatorError::UnexpectedMessage))), + } + } +} + +/// Error type for [`TaskValidatorService`] +#[derive(Error, Debug)] +pub enum TaskValidatorError { + #[error("Not eligible for sum task")] + NotSumEligible, + + #[error("Not eligible for update task")] + NotUpdateEligible, + + #[error("The message was rejected because the coordinator did not expect it")] + UnexpectedMessage, + + #[error("Internal error")] + InternalError, +} diff --git a/rust/xaynet-server/src/services/mod.rs b/rust/xaynet-server/src/services/mod.rs index 34275710a..dfa5a3a27 100644 --- a/rust/xaynet-server/src/services/mod.rs +++ b/rust/xaynet-server/src/services/mod.rs @@ -10,7 +10,7 @@ //! - [`SumDictService`]: for fetching the sum dictionary //! - the services for handling PET messages from the participant: //! - [`MessageParserService`]: decrypt and parses incoming message -//! - [`PreProcessorService`]: performs sanity checks on the messages +//! - [`TaskValidator`]: performs sanity checks on the messages //! (verify the task signatures, etc.) //! - [`StateMachineService`]: pass the messages down to the state machine //! for actual processing @@ -46,8 +46,8 @@ use crate::{ messages::{ MessageParserService, PetMessageService, - PreProcessorService, StateMachineService, + TaskValidatorService, }, }, state_machine::{events::EventSubscriber, requests::RequestSender}, @@ -121,7 +121,7 @@ pub fn message_handler( let pre_processor = ServiceBuilder::new() .buffer(100) .concurrency_limit(100) - .service(PreProcessorService::new(event_subscriber)); + .service(TaskValidatorService::new(event_subscriber)); let state_machine = ServiceBuilder::new() .buffer(100) diff --git a/rust/xaynet-server/src/services/tests/messages/message_parser.rs b/rust/xaynet-server/src/services/tests/messages/message_parser.rs index 97534365b..55fff4550 100644 --- a/rust/xaynet-server/src/services/tests/messages/message_parser.rs +++ b/rust/xaynet-server/src/services/tests/messages/message_parser.rs @@ -74,7 +74,13 @@ async fn test_valid_request() { publisher.broadcast_phase(PhaseName::Sum); // Call the service - let resp = task.call(req).await.unwrap().unwrap(); + let mut resp = task.call(req).await.unwrap().unwrap(); + // The signature should be set. However in `message` it's not been + // computed, so we just check that it's there, then set it to + // `None` in `resp` + assert!(resp.signature.is_some()); + resp.signature = None; + // Now the comparison should work assert_eq!(resp, message); } diff --git a/rust/xaynet-server/src/services/tests/messages/mod.rs b/rust/xaynet-server/src/services/tests/messages/mod.rs index 9b3b4c42b..71445c1d1 100644 --- a/rust/xaynet-server/src/services/tests/messages/mod.rs +++ b/rust/xaynet-server/src/services/tests/messages/mod.rs @@ -1,2 +1,2 @@ mod message_parser; -mod pre_processor; +mod task_validator; diff --git a/rust/xaynet-server/src/services/tests/messages/pre_processor.rs b/rust/xaynet-server/src/services/tests/messages/task_validator.rs similarity index 50% rename from rust/xaynet-server/src/services/tests/messages/pre_processor.rs rename to rust/xaynet-server/src/services/tests/messages/task_validator.rs index 94f04cab0..3bad27795 100644 --- a/rust/xaynet-server/src/services/tests/messages/pre_processor.rs +++ b/rust/xaynet-server/src/services/tests/messages/task_validator.rs @@ -4,7 +4,7 @@ use xaynet_core::message::Message; use crate::{ services::{ - messages::{PreProcessorError, PreProcessorRequest, PreProcessorService}, + messages::{TaskValidatorError, TaskValidatorRequest, TaskValidatorService}, tests::utils, }, state_machine::{ @@ -14,13 +14,13 @@ use crate::{ utils::Request, }; -fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { +fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { let (publisher, subscriber) = utils::new_event_channels(); - let task = Spawn::new(PreProcessorService::new(&subscriber)); + let task = Spawn::new(TaskValidatorService::new(&subscriber)); (publisher, subscriber, task) } -fn make_req(message: Message) -> PreProcessorRequest { +fn make_req(message: Message) -> TaskValidatorRequest { Request::new(message) } @@ -62,36 +62,7 @@ async fn test_sum_not_eligible() { assert_ready!(task.poll_ready()).unwrap(); let err = task.call(req).await.unwrap().unwrap_err(); match err { - PreProcessorError::NotSumEligible => {} - _ => panic!("expected PreProcessorError::NotSumEligible got {:?}", err), - } -} - -// This is a corner case which should almost never happen but is worth -// testing: in `poll_ready`, the service checks the current phase, and -// calls `poll_ready` on the appropriate service based on that. Then -// the request is processed by `call` but if the phase has changed in -// the meantime, we want to reject the request, because the service -// that should process it is not the one on which we called -// `poll_ready` previously. -#[tokio::test] -async fn test_phase_change_between_poll_ready_and_call() { - let (mut publisher, subscriber, mut task) = spawn_svc(); - // call poll_ready here - assert_ready!(task.poll_ready()).unwrap(); - - let round_params = subscriber.params_listener().get_latest().event; - let (message, _, _) = utils::new_sum_message(&round_params); - let req = make_req(message.clone()); - - publisher.broadcast_phase(PhaseName::Sum); - - let err = task.call(req).await.unwrap().unwrap_err(); - match err { - PreProcessorError::UnexpectedMessage => {} - _ => panic!( - "expected PreProcessorError::UnexpectedMessage got {:?}", - err - ), + TaskValidatorError::NotSumEligible => {} + _ => panic!("expected TaskValidatorError::NotSumEligible got {:?}", err), } } diff --git a/rust/xaynet-server/src/services/tests/utils.rs b/rust/xaynet-server/src/services/tests/utils.rs index ae8d5dde4..4568ff79a 100644 --- a/rust/xaynet-server/src/services/tests/utils.rs +++ b/rust/xaynet-server/src/services/tests/utils.rs @@ -1,7 +1,7 @@ use xaynet_core::{ common::{RoundParameters, RoundSeed}, crypto::{ByteObject, EncryptKeyPair, PublicEncryptKey, SigningKeyPair}, - message::{Message, MessageSeal, Sum}, + message::{Message, Sum}, SumParticipantEphemeralPublicKey, }; @@ -38,12 +38,16 @@ pub fn new_sum_message( let sum_signature = participant_signing_keys .secret .sign_detached(&[round_params.seed.as_slice(), b"sum"].concat()); - let payload = Sum { - sum_signature, - ephm_pk: participant_ephm_pk, - }; - let message = Message::new_sum(participant_signing_keys.public.clone(), payload); + let message = Message { + signature: None, + participant_pk: participant_signing_keys.public.clone(), + payload: Sum { + sum_signature, + ephm_pk: participant_ephm_pk.clone(), + } + .into(), + }; (message, participant_ephm_pk, participant_signing_keys) } @@ -55,11 +59,7 @@ pub fn encrypt_message( round_params: &RoundParameters, participant_signing_keys: &SigningKeyPair, ) -> Vec { - let seal = MessageSeal { - recipient_pk: &round_params.pk, - sender_sk: &participant_signing_keys.secret, - }; - let encrypted_message = seal.seal(&message); - - encrypted_message + let mut buf = vec![0; message.buffer_length()]; + message.to_bytes(&mut buf, &participant_signing_keys.secret); + round_params.pk.encrypt(&buf[..]) } diff --git a/rust/xaynet-server/src/state_machine/requests.rs b/rust/xaynet-server/src/state_machine/requests.rs index 24a5228ac..2278c976b 100644 --- a/rust/xaynet-server/src/state_machine/requests.rs +++ b/rust/xaynet-server/src/state_machine/requests.rs @@ -88,10 +88,10 @@ impl Traceable for StateMachineRequest { impl From for StateMachineRequest { fn from(message: Message) -> Self { - let Message { header, payload } = message; - match payload { + let participant_pk = message.participant_pk; + match message.payload { Payload::Sum(sum) => StateMachineRequest::Sum(SumRequest { - participant_pk: header.participant_pk, + participant_pk, ephm_pk: sum.ephm_pk, }), Payload::Update(update) => { @@ -102,17 +102,18 @@ impl From for StateMachineRequest { .. } = update; StateMachineRequest::Update(UpdateRequest { - participant_pk: header.participant_pk, + participant_pk, local_seed_dict, masked_model, masked_scalar, }) } Payload::Sum2(sum2) => StateMachineRequest::Sum2(Sum2Request { - participant_pk: header.participant_pk, + participant_pk, model_mask: sum2.model_mask, scalar_mask: sum2.scalar_mask, }), + Payload::Chunk(_) => unimplemented!(), } } }