diff --git a/rust/xaynet-client/src/api/in_memory.rs b/rust/xaynet-client/src/api/in_memory.rs index 90b61a1fb..e8148fe69 100644 --- a/rust/xaynet-client/src/api/in_memory.rs +++ b/rust/xaynet-client/src/api/in_memory.rs @@ -7,24 +7,27 @@ use xaynet_core::{ SumParticipantPublicKey, UpdateSeedDict, }; -use xaynet_server::services::{FetchError, Fetcher, PetMessageError, PetMessageHandler}; +use xaynet_server::services::{ + fetchers::{FetchError, Fetcher}, + messages::{PetMessageHandler, ServiceError}, +}; /// A client that communicates with the coordinator's API via /// in-memory channels. pub struct InMemoryApiClient { fetcher: Box, - message_handler: Box, + message_handler: PetMessageHandler, } impl InMemoryApiClient { #[allow(dead_code)] pub fn new( fetcher: impl Fetcher + 'static + Send + Sync, - message_handler: impl PetMessageHandler + 'static + Send + Sync, + message_handler: PetMessageHandler, ) -> Self { Self { fetcher: Box::new(fetcher), - message_handler: Box::new(message_handler), + message_handler: message_handler, } } } @@ -33,7 +36,7 @@ impl InMemoryApiClient { #[derive(Debug, Error)] pub enum InMemoryApiClientError { #[error("a PET message could not be processed by the coordinator: {0}")] - Message(#[from] PetMessageError), + Message(#[from] ServiceError), #[error("failed to fetch data from the coordinator: {0}")] Fetch(#[from] FetchError), diff --git a/rust/xaynet-client/src/lib.rs b/rust/xaynet-client/src/lib.rs index 52c93fb95..690ee6f6d 100644 --- a/rust/xaynet-client/src/lib.rs +++ b/rust/xaynet-client/src/lib.rs @@ -50,7 +50,7 @@ use std::time::Duration; use thiserror::Error; use tokio::time; -use xaynet_core::{crypto::ByteObject, mask::Model, CoordinatorPublicKey, InitError, PetError}; +use xaynet_core::{crypto::ByteObject, mask::Model, CoordinatorPublicKey, InitError}; #[doc(hidden)] pub mod mobile_client; @@ -72,6 +72,14 @@ pub enum CachedModel { I64(Vec), } +#[derive(Debug, Error)] +pub enum PetError { + #[error("Invalid mask")] + InvalidMask, + #[error("Invalid model")] + InvalidModel, +} + #[derive(Debug, Error)] /// Client-side errors pub enum ClientError { diff --git a/rust/xaynet-client/src/mobile_client/client.rs b/rust/xaynet-client/src/mobile_client/client.rs index 89545245a..6721f8507 100644 --- a/rust/xaynet-client/src/mobile_client/client.rs +++ b/rust/xaynet-client/src/mobile_client/client.rs @@ -12,7 +12,9 @@ use crate::{ ClientError, }; use derive_more::From; -use xaynet_core::{common::RoundParameters, crypto::ByteObject, mask::Model, InitError, PetError}; +use xaynet_core::{common::RoundParameters, crypto::ByteObject, mask::Model, InitError}; + +use crate::PetError; #[async_trait] pub trait LocalModel { diff --git a/rust/xaynet-client/src/mobile_client/participant/sum2.rs b/rust/xaynet-client/src/mobile_client/participant/sum2.rs index 641f9a142..1f6036e7c 100644 --- a/rust/xaynet-client/src/mobile_client/participant/sum2.rs +++ b/rust/xaynet-client/src/mobile_client/participant/sum2.rs @@ -5,11 +5,13 @@ use xaynet_core::{ CoordinatorPublicKey, ParticipantPublicKey, ParticipantTaskSignature, - PetError, SumParticipantEphemeralPublicKey, SumParticipantEphemeralSecretKey, UpdateSeedDict, }; + +use crate::PetError; + #[derive(Serialize, Deserialize, Clone)] pub struct Sum2 { ephm_pk: SumParticipantEphemeralPublicKey, @@ -70,7 +72,10 @@ impl Participant { fn get_seeds(&self, seed_dict: &UpdateSeedDict) -> Result, PetError> { seed_dict .values() - .map(|seed| seed.decrypt(&self.inner.ephm_pk, &self.inner.ephm_sk)) + .map(|seed| { + seed.decrypt(&self.inner.ephm_pk, &self.inner.ephm_sk) + .map_err(|_| PetError::InvalidMask) + }) .collect() } diff --git a/rust/xaynet-client/src/participant.rs b/rust/xaynet-client/src/participant.rs index eac35d12d..466ae8b53 100644 --- a/rust/xaynet-client/src/participant.rs +++ b/rust/xaynet-client/src/participant.rs @@ -27,13 +27,14 @@ use xaynet_core::{ ParticipantPublicKey, ParticipantSecretKey, ParticipantTaskSignature, - PetError, SumDict, SumParticipantEphemeralPublicKey, SumParticipantEphemeralSecretKey, UpdateSeedDict, }; +use crate::PetError; + #[derive(Debug, PartialEq, Copy, Clone)] /// Tasks of a participant. pub enum Task { @@ -224,7 +225,10 @@ impl Participant { fn get_seeds(&self, seed_dict: &UpdateSeedDict) -> Result, PetError> { seed_dict .values() - .map(|seed| seed.decrypt(&self.ephm_pk, &self.ephm_sk)) + .map(|seed| { + seed.decrypt(&self.ephm_pk, &self.ephm_sk) + .map_err(|_| PetError::InvalidMask) + }) .collect() } diff --git a/rust/xaynet-core/src/lib.rs b/rust/xaynet-core/src/lib.rs index 887f763c8..e9455a4f0 100644 --- a/rust/xaynet-core/src/lib.rs +++ b/rust/xaynet-core/src/lib.rs @@ -82,7 +82,6 @@ pub mod message; use std::collections::HashMap; -use derive_more::Display; use thiserror::Error; use self::crypto::{ @@ -95,14 +94,6 @@ use self::crypto::{ /// An error related to insufficient system entropy for secrets at program startup. pub struct InitError; -#[derive(Debug, Display, Error)] -/// Errors related to the PET protocol. -pub enum PetError { - InvalidMessage, - InvalidMask, - InvalidModel, -} - /// A public encryption key that identifies a coordinator. pub type CoordinatorPublicKey = PublicEncryptKey; diff --git a/rust/xaynet-core/src/mask/seed.rs b/rust/xaynet-core/src/mask/seed.rs index a233ff1e1..5790a987c 100644 --- a/rust/xaynet-core/src/mask/seed.rs +++ b/rust/xaynet-core/src/mask/seed.rs @@ -10,11 +10,11 @@ use derive_more::{AsMut, AsRef}; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use sodiumoxide::crypto::box_; +use thiserror::Error; use crate::{ crypto::{encrypt::SEALBYTES, prng::generate_integer, ByteObject}, mask::{config::MaskConfig, object::MaskObject}, - PetError, SumParticipantEphemeralPublicKey, SumParticipantEphemeralSecretKey, }; @@ -99,6 +99,14 @@ impl ByteObject for EncryptedMaskSeed { } } +#[derive(Debug, Error)] +pub enum InvalidMaskSeed { + #[error("the encrypted mask seed could not be decrypted")] + DecryptionFailed, + #[error("the mask seed has an invalid length")] + InvalidLength, +} + impl EncryptedMaskSeed { /// Decrypts this seed as a [`MaskSeed`]. /// @@ -108,13 +116,13 @@ impl EncryptedMaskSeed { &self, pk: &SumParticipantEphemeralPublicKey, sk: &SumParticipantEphemeralSecretKey, - ) -> Result { + ) -> Result { MaskSeed::from_slice( sk.decrypt(self.as_slice(), pk) - .or(Err(PetError::InvalidMask))? + .or(Err(InvalidMaskSeed::DecryptionFailed))? .as_slice(), ) - .ok_or(PetError::InvalidMask) + .ok_or(InvalidMaskSeed::InvalidLength) } } diff --git a/rust/xaynet-core/src/message/message.rs b/rust/xaynet-core/src/message/message.rs index f9097aec0..63b5145c0 100644 --- a/rust/xaynet-core/src/message/message.rs +++ b/rust/xaynet-core/src/message/message.rs @@ -207,6 +207,13 @@ pub struct MessageBuffer { } impl> MessageBuffer { + pub fn inner(&self) -> &T { + &self.inner + } + + pub fn as_ref(&self) -> MessageBuffer<&T> { + MessageBuffer::new_unchecked(self.inner()) + } /// Performs bound checks for the various message fields on `bytes` and returns a new /// [`MessageBuffer`]. /// diff --git a/rust/xaynet-server/src/bin/main.rs b/rust/xaynet-server/src/bin/main.rs index dd0a5d9a2..7688dd060 100644 --- a/rust/xaynet-server/src/bin/main.rs +++ b/rust/xaynet-server/src/bin/main.rs @@ -65,8 +65,9 @@ async fn main() { metrics_sender, ) .unwrap(); - let fetcher = services::fetcher(&event_subscriber); - let message_handler = services::message_handler(&event_subscriber, requests_tx); + let fetcher = services::fetchers::fetcher(&event_subscriber); + let message_handler = + services::messages::PetMessageHandler::new(&event_subscriber, requests_tx); tokio::select! { _ = state_machine.run() => { diff --git a/rust/xaynet-server/src/lib.rs b/rust/xaynet-server/src/lib.rs index eec861842..6e007fe64 100644 --- a/rust/xaynet-server/src/lib.rs +++ b/rust/xaynet-server/src/lib.rs @@ -94,8 +94,6 @@ pub mod services; pub mod settings; pub mod state_machine; pub mod storage; -pub mod utils; -pub(crate) mod vendor; #[cfg_attr(docsrs, doc(cfg(feature = "metrics")))] #[cfg(feature = "metrics")] diff --git a/rust/xaynet-server/src/rest.rs b/rust/xaynet-server/src/rest.rs index b91901df8..269c9fe76 100644 --- a/rust/xaynet-server/src/rest.rs +++ b/rust/xaynet-server/src/rest.rs @@ -1,6 +1,6 @@ //! A HTTP API for the PET protocol interactions. -use crate::services::{Fetcher, PetMessageHandler}; +use crate::services::{fetchers::Fetcher, messages::PetMessageHandler}; use bytes::{Buf, Bytes}; use std::{convert::Infallible, net::SocketAddr}; use warp::{ @@ -15,13 +15,12 @@ use xaynet_core::{crypto::ByteObject, ParticipantPublicKey}; /// * `addr`: address of the server. /// * `fetcher`: fetcher for responding to data requests. /// * `pet_message_handler`: handler for responding to PET messages. -pub async fn serve( +pub async fn serve( addr: impl Into + 'static, fetcher: F, - pet_message_handler: MH, + pet_message_handler: PetMessageHandler, ) where F: Fetcher + Sync + Send + 'static + Clone, - MH: PetMessageHandler + Sync + Send + 'static + Clone, { let message = warp::path!("message") .and(warp::post()) @@ -68,9 +67,9 @@ pub async fn serve( } /// Handles and responds to a PET message. -async fn handle_message( +async fn handle_message( body: Bytes, - mut handler: MH, + mut handler: PetMessageHandler, ) -> Result { let _ = handler.handle_message(body.to_vec()).await.map_err(|e| { warn!("failed to handle message: {:?}", e); @@ -191,9 +190,9 @@ async fn handle_params(mut fetcher: F) -> Result( - handler: MH, -) -> impl Filter + Clone { +fn with_message_handler( + handler: PetMessageHandler, +) -> impl Filter + Clone { warp::any().map(move || handler.clone()) } diff --git a/rust/xaynet-server/src/services/fetchers/mask_length.rs b/rust/xaynet-server/src/services/fetchers/mask_length.rs index b8f1343f6..b264df655 100644 --- a/rust/xaynet-server/src/services/fetchers/mask_length.rs +++ b/rust/xaynet-server/src/services/fetchers/mask_length.rs @@ -2,23 +2,14 @@ use std::task::{Context, Poll}; use futures::future::{self, Ready}; use tower::Service; -use tracing::Span; +use tracing_futures::{Instrument, Instrumented}; -use crate::{ - state_machine::events::{EventListener, EventSubscriber, MaskLengthUpdate}, - utils::Traceable, -}; +use crate::state_machine::events::{EventListener, EventSubscriber, MaskLengthUpdate}; /// [`MaskLengthService`]'s request type #[derive(Default, Clone, Eq, PartialEq, Debug)] pub struct MaskLengthRequest; -impl Traceable for MaskLengthRequest { - fn make_span(&self) -> Span { - error_span!("mask_length_fetch_request") - } -} - /// [`MaskLengthService`]'s response type. /// /// The response is `None` when the mask length is not currently @@ -37,7 +28,7 @@ impl MaskLengthService { impl Service for MaskLengthService { type Response = MaskLengthResponse; type Error = ::std::convert::Infallible; - type Future = Ready>; + type Future = Instrumented>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -48,5 +39,6 @@ impl Service for MaskLengthService { MaskLengthUpdate::Invalidate => Ok(None), MaskLengthUpdate::New(mask_length) => Ok(Some(mask_length)), }) + .instrument(error_span!("mask_length_fetch_request")) } } diff --git a/rust/xaynet-server/src/services/fetchers/mod.rs b/rust/xaynet-server/src/services/fetchers/mod.rs index fb704eea6..9f951836e 100644 --- a/rust/xaynet-server/src/services/fetchers/mod.rs +++ b/rust/xaynet-server/src/services/fetchers/mod.rs @@ -19,10 +19,9 @@ pub use self::{ use std::task::{Context, Poll}; use futures::future::poll_fn; -use tower::{layer::Layer, Service}; -use tracing_futures::{Instrument, Instrumented}; +use tower::{layer::Layer, Service, ServiceBuilder}; -use crate::utils::{Request, Traceable}; +use crate::state_machine::events::EventSubscriber; /// A single interface for retrieving data from the coordinator. #[async_trait] @@ -155,20 +154,17 @@ pub(in crate::services) struct FetcherService(S); impl Service for FetcherService where S: Service, - R: Traceable, { type Response = S::Response; type Error = S::Error; - type Future = Instrumented; + type Future = S::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } fn call(&mut self, req: R) -> Self::Future { - let req = Request::new(req); - let span = req.span(); - self.0.call(req.into_inner()).instrument(span) + self.0.call(req) } } @@ -210,3 +206,38 @@ impl } } } + +/// Construct a [`Fetcher`] service +pub fn fetcher(event_subscriber: &EventSubscriber) -> impl Fetcher + Sync + Send + Clone + 'static { + let round_params = ServiceBuilder::new() + .buffer(100) + .concurrency_limit(100) + .layer(FetcherLayer) + .service(RoundParamsService::new(event_subscriber)); + + let mask_length = ServiceBuilder::new() + .buffer(100) + .concurrency_limit(100) + .layer(FetcherLayer) + .service(MaskLengthService::new(event_subscriber)); + + let model = ServiceBuilder::new() + .buffer(100) + .concurrency_limit(100) + .layer(FetcherLayer) + .service(ModelService::new(event_subscriber)); + + let sum_dict = ServiceBuilder::new() + .buffer(100) + .concurrency_limit(100) + .layer(FetcherLayer) + .service(SumDictService::new(event_subscriber)); + + let seed_dict = ServiceBuilder::new() + .buffer(100) + .concurrency_limit(100) + .layer(FetcherLayer) + .service(SeedDictService::new(event_subscriber)); + + Fetchers::new(round_params, sum_dict, seed_dict, mask_length, model) +} diff --git a/rust/xaynet-server/src/services/fetchers/model.rs b/rust/xaynet-server/src/services/fetchers/model.rs index cc6ae2634..efc0bc383 100644 --- a/rust/xaynet-server/src/services/fetchers/model.rs +++ b/rust/xaynet-server/src/services/fetchers/model.rs @@ -5,24 +5,15 @@ use std::{ use futures::future::{self, Ready}; use tower::Service; -use tracing::Span; +use tracing_futures::{Instrument, Instrumented}; use xaynet_core::mask::Model; -use crate::{ - state_machine::events::{EventListener, EventSubscriber, ModelUpdate}, - utils::Traceable, -}; +use crate::state_machine::events::{EventListener, EventSubscriber, ModelUpdate}; /// [`ModelService`]'s request type #[derive(Default, Clone, Eq, PartialEq, Debug)] pub struct ModelRequest; -impl Traceable for ModelRequest { - fn make_span(&self) -> Span { - error_span!("model_fetch_request") - } -} - /// [`ModelService`]'s response type. /// /// The response is `None` when no model is currently available. @@ -40,7 +31,7 @@ impl ModelService { impl Service for ModelService { type Response = ModelResponse; type Error = ::std::convert::Infallible; - type Future = Ready>; + type Future = Instrumented>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -51,5 +42,6 @@ impl Service for ModelService { ModelUpdate::Invalidate => Ok(None), ModelUpdate::New(model) => Ok(Some(model)), }) + .instrument(error_span!("model_fetch_request")) } } diff --git a/rust/xaynet-server/src/services/fetchers/round_parameters.rs b/rust/xaynet-server/src/services/fetchers/round_parameters.rs index 1a6042176..cfb91d15b 100644 --- a/rust/xaynet-server/src/services/fetchers/round_parameters.rs +++ b/rust/xaynet-server/src/services/fetchers/round_parameters.rs @@ -2,24 +2,15 @@ use std::task::{Context, Poll}; use futures::future::{self, Ready}; use tower::Service; -use tracing::Span; +use tracing_futures::{Instrument, Instrumented}; use xaynet_core::common::RoundParameters; -use crate::{ - state_machine::events::{EventListener, EventSubscriber}, - utils::Traceable, -}; +use crate::state_machine::events::{EventListener, EventSubscriber}; /// [`RoundParamsService`]'s request type #[derive(Default, Clone, Eq, PartialEq, Debug)] pub struct RoundParamsRequest; -impl Traceable for RoundParamsRequest { - fn make_span(&self) -> Span { - error_span!("round_params_fetch_request") - } -} - /// [`RoundParamsService`]'s response type pub type RoundParamsResponse = RoundParameters; @@ -35,7 +26,7 @@ impl RoundParamsService { impl Service for RoundParamsService { type Response = RoundParameters; type Error = ::std::convert::Infallible; - type Future = Ready>; + type Future = Instrumented>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -43,5 +34,6 @@ impl Service for RoundParamsService { fn call(&mut self, _req: RoundParamsRequest) -> Self::Future { future::ready(Ok(self.0.get_latest().event)) + .instrument(error_span!("round_params_fetch_request")) } } diff --git a/rust/xaynet-server/src/services/fetchers/seed_dict.rs b/rust/xaynet-server/src/services/fetchers/seed_dict.rs index 0321f1c32..ab4053ba2 100644 --- a/rust/xaynet-server/src/services/fetchers/seed_dict.rs +++ b/rust/xaynet-server/src/services/fetchers/seed_dict.rs @@ -5,13 +5,10 @@ use std::{ use futures::future::{self, Ready}; use tower::Service; -use tracing::Span; +use tracing_futures::{Instrument, Instrumented}; use xaynet_core::SeedDict; -use crate::{ - state_machine::events::{DictionaryUpdate, EventListener, EventSubscriber}, - utils::Traceable, -}; +use crate::state_machine::events::{DictionaryUpdate, EventListener, EventSubscriber}; /// A service that serves the seed dictionary for the current round. pub struct SeedDictService(EventListener>); @@ -26,12 +23,6 @@ impl SeedDictService { #[derive(Default, Clone, Eq, PartialEq, Debug)] pub struct SeedDictRequest; -impl Traceable for SeedDictRequest { - fn make_span(&self) -> Span { - error_span!("seed_dict_fetch_request") - } -} - /// [`SeedDictService`]'s response type. /// /// The response is `None` when no seed dictionary is currently @@ -41,7 +32,7 @@ pub type SeedDictResponse = Option>; impl Service for SeedDictService { type Response = SeedDictResponse; type Error = std::convert::Infallible; - type Future = Ready>; + type Future = Instrumented>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -52,5 +43,6 @@ impl Service for SeedDictService { DictionaryUpdate::Invalidate => Ok(None), DictionaryUpdate::New(dict) => Ok(Some(dict)), }) + .instrument(error_span!("seed_dict_fetch_request")) } } diff --git a/rust/xaynet-server/src/services/fetchers/sum_dict.rs b/rust/xaynet-server/src/services/fetchers/sum_dict.rs index 0ff54de22..85c877751 100644 --- a/rust/xaynet-server/src/services/fetchers/sum_dict.rs +++ b/rust/xaynet-server/src/services/fetchers/sum_dict.rs @@ -5,13 +5,10 @@ use std::{ use futures::future::{self, Ready}; use tower::Service; -use tracing::Span; +use tracing_futures::{Instrument, Instrumented}; use xaynet_core::SumDict; -use crate::{ - state_machine::events::{DictionaryUpdate, EventListener, EventSubscriber}, - utils::Traceable, -}; +use crate::state_machine::events::{DictionaryUpdate, EventListener, EventSubscriber}; /// A service that returns the sum dictionary for the current round. pub struct SumDictService(EventListener>); @@ -20,12 +17,6 @@ pub struct SumDictService(EventListener>); #[derive(Default, Clone, Eq, PartialEq, Debug)] pub struct SumDictRequest; -impl Traceable for SumDictRequest { - fn make_span(&self) -> Span { - error_span!("sum_dict_fetch_request") - } -} - /// [`SumDictService`]'s response type. /// /// The response is `None` when no sum dictionary is currently @@ -41,7 +32,7 @@ impl SumDictService { impl Service for SumDictService { type Response = SumDictResponse; type Error = std::convert::Infallible; - type Future = Ready>; + type Future = Instrumented>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -52,5 +43,6 @@ impl Service for SumDictService { DictionaryUpdate::Invalidate => Ok(None), DictionaryUpdate::New(dict) => Ok(Some(dict)), }) + .instrument(error_span!("sum_dict_fetch_request")) } } diff --git a/rust/xaynet-server/src/services/messages/decryptor.rs b/rust/xaynet-server/src/services/messages/decryptor.rs new file mode 100644 index 000000000..cdd3f19c9 --- /dev/null +++ b/rust/xaynet-server/src/services/messages/decryptor.rs @@ -0,0 +1,151 @@ +use std::{pin::Pin, sync::Arc, task::Poll}; + +use futures::{future::Future, task::Context}; +use rayon::ThreadPool; +use tokio::sync::oneshot; +use tower::{ + limit::concurrency::{future::ResponseFuture, ConcurrencyLimit}, + Service, +}; +use xaynet_core::crypto::EncryptKeyPair; + +use crate::{ + services::messages::{BoxedServiceFuture, ServiceError}, + state_machine::events::{EventListener, EventSubscriber}, +}; + +/// A service for decrypting PET messages. +/// +/// Since this is a CPU-intensive task for large messages, this +/// service offloads the processing to a `rayon` thread-pool to avoid +/// overloading the tokio thread-pool with blocking tasks. +#[derive(Clone)] +struct RawDecryptor { + /// A listener to retrieve the latest coordinator keys. These are + /// necessary for decrypting messages and verifying their + /// signature. + keys_events: EventListener, + + /// Thread-pool the CPU-intensive tasks are offloaded to. + thread_pool: Arc, +} + +impl Service for RawDecryptor +where + T: AsRef<[u8]> + Sync + Send + 'static, +{ + type Response = Vec; + type Error = ServiceError; + #[allow(clippy::type_complexity)] + type Future = + Pin> + 'static + Send + Sync>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, data: T) -> Self::Future { + debug!("retrieving the current keys"); + let keys = self.keys_events.get_latest().event; + let (tx, rx) = oneshot::channel::>(); + + trace!("spawning decryption task on threadpool"); + self.thread_pool.spawn(move || { + info!("decrypting message"); + let res = keys + .secret + .decrypt(&data.as_ref(), &keys.public) + .map_err(|_| ServiceError::Decrypt); + let _ = tx.send(res); + }); + Box::pin(async move { + rx.await.unwrap_or_else(|_| { + Err(ServiceError::InternalError( + "failed to receive response from thread-pool".to_string(), + )) + }) + }) + } +} + +#[derive(Clone)] +pub struct Decryptor(ConcurrencyLimit); + +impl Decryptor { + pub fn new(state_machine_events: &EventSubscriber, thread_pool: Arc) -> Self { + let limit = thread_pool.current_num_threads(); + let keys_events = state_machine_events.keys_listener(); + let service = RawDecryptor { + keys_events, + thread_pool, + }; + Self(ConcurrencyLimit::new(service, limit)) + } +} + +impl Service for Decryptor +where + T: AsRef<[u8]> + Sync + Send + 'static, +{ + type Response = Vec; + type Error = ServiceError; + type Future = ResponseFuture>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + as Service>::poll_ready(&mut self.0, cx) + } + + fn call(&mut self, data: T) -> Self::Future { + self.0.call(data) + } +} + +#[cfg(test)] +mod tests { + use rayon::ThreadPoolBuilder; + use tokio_test::assert_ready; + use tower_test::mock::Spawn; + + use crate::{ + services::tests::utils, + state_machine::events::{EventPublisher, EventSubscriber}, + }; + + use super::*; + + fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { + let (publisher, subscriber) = utils::new_event_channels(); + let thread_pool = Arc::new(ThreadPoolBuilder::new().build().unwrap()); + let task = Spawn::new(Decryptor::new(&subscriber, thread_pool)); + (publisher, subscriber, task) + } + + #[tokio::test] + async fn test_decrypt_fail() { + let (_publisher, _subscriber, mut task) = spawn_svc(); + assert_ready!(task.poll_ready::>()).unwrap(); + + let req = vec![0, 1, 2, 3, 4, 5, 6]; + match task.call(req).await { + Err(ServiceError::Decrypt) => {} + _ => panic!("expected decrypt error"), + } + assert_ready!(task.poll_ready::>()).unwrap(); + } + + #[tokio::test] + async fn test_decrypt_ok() { + let (_publisher, subscriber, mut task) = spawn_svc(); + assert_ready!(task.poll_ready::>()).unwrap(); + + let round_params = subscriber.params_listener().get_latest().event; + let (message, participant_signing_keys) = utils::new_sum_message(&round_params); + let serialized_message = utils::serialize_message(&message, &participant_signing_keys); + let encrypted_message = + utils::encrypt_message(&message, &round_params, &participant_signing_keys); + + // Call the service + let decrypted_message = task.call(encrypted_message).await.unwrap(); + assert_eq!(decrypted_message, serialized_message); + } +} diff --git a/rust/xaynet-server/src/services/messages/error.rs b/rust/xaynet-server/src/services/messages/error.rs new file mode 100644 index 000000000..418e892e8 --- /dev/null +++ b/rust/xaynet-server/src/services/messages/error.rs @@ -0,0 +1,37 @@ +use thiserror::Error; +use xaynet_core::message::DecodeError; + +use crate::state_machine::StateMachineError; + +/// Error type for the message parsing service +#[derive(Debug, Error)] +pub enum ServiceError { + #[error("Failed to decrypt the message with the coordinator secret key")] + Decrypt, + + #[error("Failed to parse the message: {0:?}")] + Parsing(DecodeError), + + #[error("Invalid message signature")] + InvalidMessageSignature, + + #[error("Invalid coordinator public key")] + InvalidCoordinatorPublicKey, + + #[error("The message was not expected in the current phase")] + UnexpectedMessage, + + // FIXME: we need to refine the state machine errors and the + // conversion into a service error + #[error("the state machine failed to process the request: {0:?}")] + StateMachine(StateMachineError), + + #[error("participant is not eligible for sum task")] + NotSumEligible, + + #[error("participant is not eligible for update task")] + NotUpdateEligible, + + #[error("Internal error: {0}")] + InternalError(String), +} diff --git a/rust/xaynet-server/src/services/messages/message_parser.rs b/rust/xaynet-server/src/services/messages/message_parser.rs index 0602a42bc..d03f41afa 100644 --- a/rust/xaynet-server/src/services/messages/message_parser.rs +++ b/rust/xaynet-server/src/services/messages/message_parser.rs @@ -1,219 +1,409 @@ -use std::{convert::TryInto, pin::Pin, sync::Arc, task::Poll}; +use std::{convert::TryInto, sync::Arc, task::Poll}; -use anyhow::Context as _; -use derive_more::From; -use futures::{ - future::{self, Either, Future}, - task::Context, -}; +use futures::{future, task::Context}; use rayon::ThreadPool; -use thiserror::Error; use tokio::sync::oneshot; -use tower::Service; -use tracing::Span; +use tower::{layer::Layer, limit::concurrency::ConcurrencyLimit, Service, ServiceBuilder}; use xaynet_core::{ - crypto::EncryptKeyPair, - message::{DecodeError, Message, MessageBuffer, Tag}, + crypto::{EncryptKeyPair, PublicEncryptKey}, + message::{FromBytes, Message, MessageBuffer, Tag}, }; use crate::{ + services::messages::{BoxedServiceFuture, ServiceError}, state_machine::{ events::{EventListener, EventSubscriber}, phases::PhaseName, }, - utils::{Request, Traceable}, }; -/// A service for decrypting and parsing PET messages. -/// -/// Since this is a CPU-intensive task for large messages, this -/// service offloads the processing to a `rayon` thread-pool to avoid -/// overloading the tokio thread-pool with blocking tasks. -pub struct MessageParserService { - /// A listener to retrieve the latest coordinator keys. These are - /// necessary for decrypting messages and verifying their - /// signature. - keys_events: EventListener, - - /// A listener to retrieve the current coordinator phase. Messages - /// that cannot be handled in the current phase will be - /// rejected. The idea is to perform this filtering as early as - /// possible. - phase_events: EventListener, - - /// Thread-pool the CPU-intensive tasks are offloaded to. - thread_pool: Arc, +/// A type that hold a un-parsed message +struct RawMessage { + /// The buffer that contains the message to parse + buffer: Arc>, } -impl MessageParserService { - pub fn new(subscriber: &EventSubscriber, thread_pool: Arc) -> Self { +impl Clone for RawMessage { + fn clone(&self) -> Self { Self { - keys_events: subscriber.keys_listener(), - phase_events: subscriber.phase_listener(), - thread_pool, + buffer: self.buffer.clone(), + } + } +} + +impl From> for RawMessage { + fn from(buffer: MessageBuffer) -> Self { + RawMessage { + buffer: Arc::new(buffer), } } } -/// A buffer that represents an encrypted message. -#[derive(From, Debug)] -pub struct RawMessage>(T); +/// A service that wraps a buffer `T` representing a message into a +/// [`RawMessage`] +#[derive(Debug, Clone)] +struct BufferWrapper(S); -impl Traceable for RawMessage +impl Service for BufferWrapper where - T: AsRef<[u8]>, + T: AsRef<[u8]> + Send + 'static, + S: Service, Response = Message, Error = ServiceError>, + S::Future: Sync + Send + 'static, { - fn make_span(&self) -> Span { - error_span!("raw_message", payload_len = self.0.as_ref().len()) + type Response = Message; + type Error = ServiceError; + type Future = BoxedServiceFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, req: T) -> Self::Future { + debug!("creating a RawMessage request"); + match MessageBuffer::new(req) { + Ok(buffer) => { + let fut = self.0.call(RawMessage::from(buffer)); + Box::pin(async move { + trace!("calling inner service"); + fut.await + }) + } + Err(e) => Box::pin(future::ready(Err(ServiceError::Parsing(e)))), + } } } -/// Error type for the [`MessageParserService`] -#[derive(Debug, Error)] -pub enum MessageParserError { - #[error("Failed to decrypt the message with the coordinator secret key")] - Decrypt, +struct BufferWrapperLayer; - #[error("Parsing failed: {0:?}")] - Parsing(DecodeError), +impl Layer for BufferWrapperLayer { + type Service = BufferWrapper; - #[error("Invalid message signature")] - InvalidMessageSignature, + fn layer(&self, service: S) -> BufferWrapper { + BufferWrapper(service) + } +} - #[error("The message was rejected because the coordinator did not expect it")] - UnexpectedMessage, +/// A service that discards messages that are not expected in the current phase +#[derive(Debug, Clone)] +struct PhaseFilter { + /// A listener to retrieve the current phase + phase: EventListener, + /// Next service to be called + next_svc: S, +} + +impl Service> for PhaseFilter +where + T: AsRef<[u8]> + Send + 'static, + S: Service, Response = Message, Error = ServiceError>, + S::Future: Sync + Send + 'static, +{ + type Response = Message; + type Error = ServiceError; + type Future = BoxedServiceFuture; - // TODO: we should have a retry layer that automatically retries - // requests that fail with this error. - #[error("The request could not be processed due to a temporary internal error")] - TemporaryInternalError, + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.next_svc.poll_ready(cx) + } - #[error("Internal error: {0}")] - InternalError(String), + fn call(&mut self, req: RawMessage) -> Self::Future { + debug!("retrieving the current phase"); + let phase = self.phase.get_latest().event; + match req.buffer.tag().try_into() { + Ok(tag) => match (phase, tag) { + (PhaseName::Sum, Tag::Sum) + | (PhaseName::Update, Tag::Update) + | (PhaseName::Sum2, Tag::Sum2) => { + let fut = self.next_svc.call(req); + Box::pin(async move { fut.await }) + } + _ => Box::pin(future::ready(Err(ServiceError::UnexpectedMessage))), + }, + Err(e) => Box::pin(future::ready(Err(ServiceError::Parsing(e)))), + } + } +} + +struct PhaseFilterLayer { + phase: EventListener, } -/// Response type for the [`MessageParserService`] -pub type MessageParserResponse = Result; +impl Layer for PhaseFilterLayer { + type Service = PhaseFilter; -/// Request type for the [`MessageParserService`] -pub type MessageParserRequest = Request>; + fn layer(&self, service: S) -> PhaseFilter { + PhaseFilter { + phase: self.phase.clone(), + next_svc: service, + } + } +} -impl Service> for MessageParserService +/// A service for verifying the signature of PET messages +/// +/// Since this is a CPU-intensive task for large messages, this +/// service offloads the processing to a `rayon` thread-pool to avoid +/// overloading the tokio thread-pool with blocking tasks. +#[derive(Debug, Clone)] +struct SignatureVerifier { + /// Thread-pool the CPU-intensive tasks are offloaded to. + thread_pool: Arc, + /// The service to be called after the [`SignatureVerifier`] + next_svc: S, +} + +impl Service> for SignatureVerifier +where + T: AsRef<[u8]> + Sync + Send + 'static, + S: Service, Response = Message, Error = ServiceError> + + Clone + + Sync + + Send + + 'static, + S::Future: Sync + Send + 'static, +{ + type Response = Message; + type Error = ServiceError; + type Future = BoxedServiceFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.next_svc.poll_ready(cx) + } + + fn call(&mut self, req: RawMessage) -> Self::Future { + let (tx, rx) = oneshot::channel::>(); + + let req_clone = req.clone(); + trace!("spawning signature verification task on thread-pool"); + self.thread_pool.spawn(move || { + let res = match req.buffer.as_ref().as_ref().check_signature() { + Ok(()) => { + info!("found a valid message signature"); + Ok(()) + } + Err(e) => { + warn!("invalid message signature: {:?}", e); + Err(ServiceError::InvalidMessageSignature) + } + }; + let _ = tx.send(res); + }); + + let mut next_svc = self.next_svc.clone(); + let fut = async move { + rx.await.map_err(|_| { + ServiceError::InternalError( + "failed to receive response from thread-pool".to_string(), + ) + })??; + next_svc.call(req_clone).await + }; + Box::pin(fut) + } +} + +struct SignatureVerifierLayer { + thread_pool: Arc, +} + +impl Layer for SignatureVerifierLayer { + type Service = ConcurrencyLimit>; + + fn layer(&self, service: S) -> Self::Service { + let limit = self.thread_pool.current_num_threads(); + // FIXME: we actually want to limit the concurrency of just + // the SignatureVerifier middleware. Right now we're limiting + // the whole stack of services. + ConcurrencyLimit::new( + SignatureVerifier { + thread_pool: self.thread_pool.clone(), + next_svc: service, + }, + limit, + ) + } +} + +/// A service that verifies the coordinator public key embedded in PET +/// messsages +#[derive(Debug, Clone)] +struct CoordinatorPublicKeyValidator { + /// A listener to retrieve the latest coordinator keys + keys: EventListener, + /// Next service to be called + next_svc: S, +} + +impl Service> for CoordinatorPublicKeyValidator where T: AsRef<[u8]> + Send + 'static, + S: Service, Response = Message, Error = ServiceError>, + S::Future: Sync + Send + 'static, { - type Response = MessageParserResponse; - type Error = std::convert::Infallible; + type Response = Message; + type Error = ServiceError; + type Future = BoxedServiceFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.next_svc.poll_ready(cx) + } + + fn call(&mut self, req: RawMessage) -> Self::Future { + debug!("retrieving the current keys"); + let coord_pk = self.keys.get_latest().event.public; + match PublicEncryptKey::from_bytes(&req.buffer.as_ref().as_ref().coordinator_pk()) { + Ok(pk) => { + if pk != coord_pk { + warn!("found an invalid coordinator public key"); + Box::pin(future::ready(Err( + ServiceError::InvalidCoordinatorPublicKey, + ))) + } else { + info!("found a valid coordinator public key"); + let fut = self.next_svc.call(req); + Box::pin(async move { fut.await }) + } + } + Err(_) => Box::pin(future::ready(Err( + ServiceError::InvalidCoordinatorPublicKey, + ))), + } + } +} + +struct CoordinatorPublicKeyValidatorLayer { + keys: EventListener, +} + +impl Layer for CoordinatorPublicKeyValidatorLayer { + type Service = CoordinatorPublicKeyValidator; + + fn layer(&self, service: S) -> CoordinatorPublicKeyValidator { + CoordinatorPublicKeyValidator { + keys: self.keys.clone(), + next_svc: service, + } + } +} - #[allow(clippy::type_complexity)] - type Future = Either< - future::Ready>, - Pin> + 'static + Send + Sync>>, - >; +#[derive(Debug, Clone)] +struct Parser; + +impl Service> for Parser +where + T: AsRef<[u8]> + Send + 'static, +{ + type Response = Message; + type Error = ServiceError; + type Future = future::Ready>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, req: MessageParserRequest) -> Self::Future { - debug!("retrieving the current keys and current phase"); - let keys_ev = self.keys_events.get_latest(); - let phase_ev = self.phase_events.get_latest(); + fn call(&mut self, req: RawMessage) -> Self::Future { + let bytes = req.buffer.inner(); + future::ready(Message::from_bytes(&bytes).map_err(ServiceError::Parsing)) + } +} - // This can happen if the coordinator is switching starting a - // new phase. The error should be temporary and we should be - // able to retry the request. - if keys_ev.round_id != phase_ev.round_id { - return Either::Left(future::ready(Ok(Err( - MessageParserError::TemporaryInternalError, - )))); - } +type InnerService = BufferWrapper< + PhaseFilter>>>, +>; - let handler = Handler { - keys: keys_ev.event, - phase: phase_ev.event, - }; +#[derive(Debug, Clone)] +pub struct MessageParser(InnerService); - let (tx, rx) = oneshot::channel::(); +impl Service for MessageParser +where + T: AsRef<[u8]> + Sync + Send + 'static, +{ + type Response = Message; + type Error = ServiceError; + type Future = BoxedServiceFuture; - trace!("spawning pre-processor handler on thread-pool"); - self.thread_pool.spawn(move || { - let span = req.span(); - let _span_guard = span.enter(); - let resp = handler.call(req.into_inner()); - let _ = tx.send(resp); - }); - Either::Right(Box::pin(async move { - Ok(rx.await.unwrap_or_else(|_| { - Err(MessageParserError::InternalError( - "failed to receive response from pre-processor".to_string(), - )) - })) - })) - } -} - -/// Handler created by the [`MessageParserService`] for each request. -struct Handler { - /// Coordinator keys for the current round - keys: EncryptKeyPair, - /// Current phase of the coordinator - phase: PhaseName, -} - -impl Handler { - /// Process the request. `data` is the encrypted PET message to - /// process. - fn call>(self, data: RawMessage) -> MessageParserResponse { - info!("decrypting message"); - let raw = self.decrypt(&data.0.as_ref())?; - - let buf = MessageBuffer::new(&raw).map_err(MessageParserError::Parsing)?; - - info!("filtering message based on the current phase"); - let tag = buf - .tag() - .try_into() - .context("failed to parse message tag field") - .map_err(MessageParserError::Parsing)?; - self.phase_filter(tag)?; - - info!("verifying the message signature"); - buf.check_signature().map_err(|e| { - warn!("invalid message signature: {:?}", e); - MessageParserError::InvalidMessageSignature - })?; - - info!("parsing the message"); - let message = Message::from_bytes(&raw).map_err(MessageParserError::Parsing)?; - - info!("done pre-processing the message"); - Ok(message) - } - - /// Decrypt the given payload with the coordinator secret key - fn decrypt(&self, encrypted_message: &[u8]) -> Result, MessageParserError> { - Ok(self - .keys - .secret - .decrypt(&encrypted_message, &self.keys.public) - .map_err(|_| MessageParserError::Decrypt)?) - } - - /// Reject messages that cannot be handled by the coordinator in - /// the current phase - fn phase_filter(&self, tag: Tag) -> Result<(), MessageParserError> { - match (tag, self.phase) { - (Tag::Sum, PhaseName::Sum) - | (Tag::Update, PhaseName::Update) - | (Tag::Sum2, PhaseName::Sum2) => Ok(()), - (tag, phase) => { - warn!( - "rejecting request: message type is {:?} but phase is {:?}", - tag, phase - ); - Err(MessageParserError::UnexpectedMessage) - } + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + >::poll_ready(&mut self.0, cx) + } + + fn call(&mut self, req: T) -> Self::Future { + let fut = self.0.call(req); + Box::pin(async move { fut.await }) + } +} + +impl MessageParser { + pub fn new(events: &EventSubscriber, thread_pool: Arc) -> Self { + let inner = ServiceBuilder::new() + .layer(BufferWrapperLayer) + .layer(PhaseFilterLayer { + phase: events.phase_listener(), + }) + .layer(SignatureVerifierLayer { thread_pool }) + .layer(CoordinatorPublicKeyValidatorLayer { + keys: events.keys_listener(), + }) + .service(Parser); + Self(inner) + } +} + +#[cfg(test)] +mod tests { + use rayon::ThreadPoolBuilder; + use tokio_test::assert_ready; + use tower_test::mock::Spawn; + + use super::*; + use crate::{ + services::tests::utils, + state_machine::events::{EventPublisher, EventSubscriber}, + }; + + fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { + let (publisher, subscriber) = utils::new_event_channels(); + let thread_pool = Arc::new(ThreadPoolBuilder::new().build().unwrap()); + let task = Spawn::new(MessageParser::new(&subscriber, thread_pool)); + (publisher, subscriber, task) + } + + #[tokio::test] + async fn test_valid_request() { + let (mut publisher, subscriber, mut task) = spawn_svc(); + assert_ready!(task.poll_ready::>()).unwrap(); + + let round_params = subscriber.params_listener().get_latest().event; + let (message, signing_keys) = utils::new_sum_message(&round_params); + let serialized_message = utils::serialize_message(&message, &signing_keys); + + // Simulate the state machine broadcasting the sum phase + // (otherwise the request will be rejected by the phase + // filter) + publisher.broadcast_phase(PhaseName::Sum); + + // Call the service + let mut resp = task.call(serialized_message).await.unwrap(); + // The signature should be set. However in `message` it's not been + // computed, so we just check that it's there, then set it to + // `None` in `resp` + assert!(resp.signature.is_some()); + resp.signature = None; + // Now the comparison should work + assert_eq!(resp, message); + } + + #[tokio::test] + async fn test_unexpected_message() { + let (_publisher, subscriber, mut task) = spawn_svc(); + assert_ready!(task.poll_ready::>()).unwrap(); + + let round_params = subscriber.params_listener().get_latest().event; + let (message, signing_keys) = utils::new_sum_message(&round_params); + let serialized_message = utils::serialize_message(&message, &signing_keys); + let err = task.call(serialized_message).await.unwrap_err(); + match err { + ServiceError::UnexpectedMessage => {} + _ => panic!("expected ServiceError::UnexpectedMessage got {:?}", err), } } } diff --git a/rust/xaynet-server/src/services/messages/mod.rs b/rust/xaynet-server/src/services/messages/mod.rs index 71568659c..3e9f23714 100644 --- a/rust/xaynet-server/src/services/messages/mod.rs +++ b/rust/xaynet-server/src/services/messages/mod.rs @@ -3,180 +3,74 @@ //! //! There are multiple such services and the [`PetMessageHandler`] //! trait provides a single unifying interface for all of these. +mod decryptor; +mod error; mod message_parser; mod state_machine; mod task_validator; -pub use self::{ - message_parser::{ - MessageParserError, - MessageParserRequest, - MessageParserResponse, - MessageParserService, - }, - state_machine::{ - StateMachineError, - StateMachineRequest, - StateMachineResponse, - StateMachineService, - }, - task_validator::{ - TaskValidatorError, - TaskValidatorRequest, - TaskValidatorResponse, - TaskValidatorService, - }, +pub use self::error::ServiceError; +use self::{ + decryptor::Decryptor, + message_parser::MessageParser, + state_machine::StateMachine, + task_validator::TaskValidator, }; -use xaynet_core::message::Message; - -use crate::{ - services::{ - messages::message_parser::RawMessage, - utils::{with_tracing, TracedService}, - }, - utils::Request, -}; +use std::sync::Arc; use futures::future::poll_fn; -use thiserror::Error; +use rayon::ThreadPoolBuilder; use tower::Service; +use xaynet_core::message::Message; -type TracedMessageParser = TracedService>>; -type TracedTaskValidator = TracedService; -type TracedStateMachine = TracedService; - -/// Error returned by the [`PetMessageHandler`] methods. -#[derive(Debug, Error)] -pub enum PetMessageError { - #[error("failed to parse message: {0}")] - Parser(MessageParserError), - - #[error("failed to pre-process message: {0}")] - TaskValidator(TaskValidatorError), - - #[error("state machine failed to handle message: {0}")] - StateMachine(StateMachineError), - - #[error("the service failed to process the request: {0}")] - ServiceError(Box), -} - -/// A single interface for all the PET message processing sub-services -/// ([`MessageParserService`], [`TaskValidatorService`] and -/// [`StateMachineService`]). -#[async_trait] -pub trait PetMessageHandler: Send { - async fn handle_message( - &mut self, - // FIXME: this should take a `Request<_>` instead that should - // be created by the caller (in the rest layer). - req: Vec, - ) -> Result<(), PetMessageError> { - let req = Request::new(RawMessage::from(req)); - let metadata = req.metadata(); - let message = self.call_parser(req).await?; +use crate::state_machine::{events::EventSubscriber, requests::RequestSender}; - let req = Request::from_parts(metadata.clone(), message); - let message = self.call_task_validator(req).await?; +impl PetMessageHandler { + pub fn new(event_subscriber: &EventSubscriber, requests_tx: RequestSender) -> Self { + // TODO: make this configurable. Users should be able to + // choose how many threads they want etc. + // + // TODO: don't unwrap + let thread_pool = Arc::new(ThreadPoolBuilder::new().build().unwrap()); + let decryptor = Decryptor::new(event_subscriber, thread_pool.clone()); + let message_parser = MessageParser::new(event_subscriber, thread_pool); + let task_validator = TaskValidator::new(event_subscriber); + let state_machine = StateMachine::new(requests_tx); - let req = Request::from_parts(metadata, message); - Ok(self.call_state_machine(req).await?) + Self { + decryptor, + message_parser, + task_validator, + state_machine, + } } - - /// Parse an encrypted message - async fn call_parser( - &mut self, - enc_message: MessageParserRequest>, - ) -> Result; - - /// Pre-process a PET message - async fn call_task_validator( - &mut self, - message: TaskValidatorRequest, - ) -> Result; - - /// Have a PET message processed by the state machine - async fn call_state_machine( - &mut self, - message: StateMachineRequest, - ) -> Result<(), PetMessageError>; -} - -#[async_trait] -impl PetMessageHandler for PetMessageService -where - Self: Send + Sync + 'static, - - MP: Service>, Response = MessageParserResponse> + Send + 'static, - >>>::Future: Send + 'static, - >>>::Error: - Into>, - - TV: Service + Send + 'static, - >::Future: Send + 'static, - >::Error: - Into>, - - SM: Service + Send + 'static, - >::Future: Send + 'static, - >::Error: - Into>, -{ - async fn call_parser( - &mut self, - enc_message: MessageParserRequest>, - ) -> Result { - poll_fn(|cx| { - >>>::poll_ready(&mut self.message_parser, cx) - }) - .await - // FIXME: we should actually downcast the error and - // distinguish between the various services errors we can - // have. Currently, this will just turn the error into a - // Box - .map_err(|e| PetMessageError::ServiceError(Into::into(e)))?; - - >>>::call( - &mut self.message_parser, - enc_message.map(Into::into), - ) - .await - .map_err(|e| PetMessageError::ServiceError(Into::into(e)))? - .map_err(PetMessageError::Parser) + async fn decrypt(&mut self, enc_data: Vec) -> Result, ServiceError> { + poll_fn(|cx| >>::poll_ready(&mut self.decryptor, cx)).await?; + self.decryptor.call(enc_data).await } - async fn call_task_validator( - &mut self, - message: TaskValidatorRequest, - ) -> Result { - poll_fn(|cx| { - >::poll_ready(&mut self.task_validator, cx) - }) - .await - .map_err(|e| PetMessageError::ServiceError(Into::into(e)))?; + async fn parse(&mut self, data: Vec) -> Result { + poll_fn(|cx| >>::poll_ready(&mut self.message_parser, cx)) + .await?; + self.message_parser.call(data).await + } - >::call( - &mut self.task_validator, - message.map(Into::into), - ) - .await - .map_err(|e| PetMessageError::ServiceError(Into::into(e)))? - .map_err(PetMessageError::TaskValidator) + async fn validate_task(&mut self, message: Message) -> Result { + poll_fn(|cx| self.task_validator.poll_ready(cx)).await?; + self.task_validator.call(message).await } - async fn call_state_machine( - &mut self, - message: StateMachineRequest, - ) -> Result<(), PetMessageError> { - poll_fn(|cx| >::poll_ready(&mut self.state_machine, cx)) - .await - .map_err(|e| PetMessageError::ServiceError(Into::into(e)))?; + async fn process(&mut self, message: Message) -> Result<(), ServiceError> { + poll_fn(|cx| self.state_machine.poll_ready(cx)).await?; + self.state_machine.call(message).await + } - >::call(&mut self.state_machine, message.map(Into::into)) - .await - .map_err(|e| PetMessageError::ServiceError(Into::into(e)))? - .map_err(PetMessageError::StateMachine) + pub async fn handle_message(&mut self, enc_data: Vec) -> Result<(), ServiceError> { + let raw_message = self.decrypt(enc_data).await?; + let message = self.parse(raw_message).await?; + let message = self.validate_task(message).await?; + self.process(message).await } } @@ -194,46 +88,14 @@ where /// `TaskValidator` may also discard the message /// /// 3. Finally, the message is handled by the `StateMachine` service. -#[derive(Debug, Clone)] -pub struct PetMessageService { +#[derive(Clone)] +pub struct PetMessageHandler { + decryptor: Decryptor, message_parser: MessageParser, task_validator: TaskValidator, state_machine: StateMachine, } -impl - PetMessageService, TracedTaskValidator, TracedStateMachine> -where - MP: Service>, Response = MessageParserResponse>, - TV: Service, - SM: Service, -{ - /// Instantiate a new [`PetMessageService`] with the given sub-services - pub fn new(message_parser: MP, task_validator: TV, state_machine: SM) -> Self { - Self { - message_parser: with_tracing(message_parser), - task_validator: with_tracing(task_validator), - state_machine: with_tracing(state_machine), - } - } -} - -use crate::utils::Traceable; -use tracing::Span; -use xaynet_core::message::Payload; - -impl Traceable for Message { - fn make_span(&self) -> Span { - let message_type = match self.payload { - Payload::Sum(_) => "sum", - Payload::Update(_) => "update", - Payload::Sum2(_) => "sum2", - Payload::Chunk(_) => "chunk", - }; - error_span!( - "Message", - message_type = message_type, - message_length = self.buffer_length() - ) - } -} +pub type BoxedServiceFuture = std::pin::Pin< + Box> + 'static + Send + Sync>, +>; diff --git a/rust/xaynet-server/src/services/messages/state_machine.rs b/rust/xaynet-server/src/services/messages/state_machine.rs index 7e6ad739b..9d5a160d5 100644 --- a/rust/xaynet-server/src/services/messages/state_machine.rs +++ b/rust/xaynet-server/src/services/messages/state_machine.rs @@ -1,26 +1,27 @@ -use std::{pin::Pin, task::Poll}; +use std::task::Poll; -use futures::{future::Future, task::Context}; +use futures::task::Context; use tower::Service; use xaynet_core::message::Message; use crate::{ - state_machine::{requests::RequestSender, StateMachineResult}, - utils::Request, + services::messages::{BoxedServiceFuture, ServiceError}, + state_machine::requests::RequestSender, }; -pub use crate::state_machine::{StateMachineError, StateMachineResult as StateMachineResponse}; +pub use crate::state_machine::StateMachineError; /// A service that hands the requests to the state machine /// ([`StateMachine`]) that runs in the /// background. /// /// [`StateMachine`]: crate::state_machine::StateMachine -pub struct StateMachineService { +#[derive(Debug, Clone)] +pub struct StateMachine { handle: RequestSender, } -impl StateMachineService { +impl StateMachine { /// Create a new service with the given handle for forwarding /// requests to the state machine. The handle should be obtained /// via [`StateMachine::new`]. @@ -31,22 +32,22 @@ impl StateMachineService { } } -/// Request type for [`StateMachineService`] -pub type StateMachineRequest = Request; - -impl Service for StateMachineService { - type Response = StateMachineResult; - type Error = ::std::convert::Infallible; - #[allow(clippy::type_complexity)] - type Future = - Pin> + 'static + Send>>; +impl Service for StateMachine { + type Response = (); + type Error = ServiceError; + type Future = BoxedServiceFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, req: StateMachineRequest) -> Self::Future { + fn call(&mut self, req: Message) -> Self::Future { let handle = self.handle.clone(); - Box::pin(async move { Ok(handle.request(req).await) }) + Box::pin(async move { + handle + .request(req.into(), tracing::Span::none()) + .await + .map_err(ServiceError::StateMachine) + }) } } diff --git a/rust/xaynet-server/src/services/messages/task_validator.rs b/rust/xaynet-server/src/services/messages/task_validator.rs index f6093352f..a75859069 100644 --- a/rust/xaynet-server/src/services/messages/task_validator.rs +++ b/rust/xaynet-server/src/services/messages/task_validator.rs @@ -1,7 +1,6 @@ use std::task::Poll; use futures::{future, task::Context}; -use thiserror::Error; use tower::Service; use xaynet_core::{ common::RoundParameters, @@ -10,17 +9,18 @@ use xaynet_core::{ }; use crate::{ + services::messages::ServiceError, state_machine::events::{EventListener, EventSubscriber}, - utils::request::Request, }; /// A service for performing sanity checks and preparing incoming /// requests to be handled by the state machine. -pub struct TaskValidatorService { +#[derive(Clone, Debug)] +pub struct TaskValidator { params_listener: EventListener, } -impl TaskValidatorService { +impl TaskValidator { pub fn new(subscriber: &EventSubscriber) -> Self { Self { params_listener: subscriber.params_listener(), @@ -28,28 +28,21 @@ impl TaskValidatorService { } } -/// Request type for [`TaskValidatorService`] -pub type TaskValidatorRequest = Request; - -/// Response type for [`TaskValidatorService`] -pub type TaskValidatorResponse = Result; - -impl Service for TaskValidatorService { - type Response = TaskValidatorResponse; - type Error = std::convert::Infallible; +impl Service for TaskValidator { + type Response = Message; + type Error = ServiceError; type Future = future::Ready>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, req: TaskValidatorRequest) -> Self::Future { - let message = req.into_inner(); + fn call(&mut self, message: Message) -> Self::Future { let (sum_signature, update_signature) = match message.payload { Payload::Sum(ref sum) => (sum.sum_signature, None), Payload::Update(ref update) => (update.sum_signature, Some(update.update_signature)), Payload::Sum2(ref sum2) => (sum2.sum_signature, None), - _ => return future::ready(Ok(Err(TaskValidatorError::UnexpectedMessage))), + _ => return future::ready(Err(ServiceError::UnexpectedMessage)), }; let params = self.params_listener.get_latest().event; let seed = params.seed.as_slice(); @@ -77,35 +70,82 @@ impl Service for TaskValidatorService { match message.payload { Payload::Sum(_) | Payload::Sum2(_) => { if is_summer { - future::ready(Ok(Ok(message))) + future::ready(Ok(message)) } else { - future::ready(Ok(Err(TaskValidatorError::NotSumEligible))) + future::ready(Err(ServiceError::NotSumEligible)) } } Payload::Update(_) => { if is_updater { - future::ready(Ok(Ok(message))) + future::ready(Ok(message)) } else { - future::ready(Ok(Err(TaskValidatorError::NotUpdateEligible))) + future::ready(Err(ServiceError::NotUpdateEligible)) } } - _ => future::ready(Ok(Err(TaskValidatorError::UnexpectedMessage))), + _ => future::ready(Err(ServiceError::UnexpectedMessage)), } } } -/// Error type for [`TaskValidatorService`] -#[derive(Error, Debug)] -pub enum TaskValidatorError { - #[error("Not eligible for sum task")] - NotSumEligible, +#[cfg(test)] +mod tests { + use tokio_test::assert_ready; + use tower_test::mock::Spawn; + + use crate::{ + services::tests::utils, + state_machine::{ + events::{EventPublisher, EventSubscriber}, + phases::PhaseName, + }, + }; + + use super::*; + + fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { + let (publisher, subscriber) = utils::new_event_channels(); + let task = Spawn::new(TaskValidator::new(&subscriber)); + (publisher, subscriber, task) + } + + #[tokio::test] + async fn test_sum_ok() { + let (mut publisher, subscriber, mut task) = spawn_svc(); - #[error("Not eligible for update task")] - NotUpdateEligible, + let mut round_params = subscriber.params_listener().get_latest().event; - #[error("The message was rejected because the coordinator did not expect it")] - UnexpectedMessage, + // make sure everyone is eligible + round_params.sum = 1.0; + + publisher.broadcast_params(round_params.clone()); + publisher.broadcast_phase(PhaseName::Sum); + + let (message, _) = utils::new_sum_message(&round_params); + + assert_ready!(task.poll_ready()).unwrap(); + let resp = task.call(message.clone()).await.unwrap(); + assert_eq!(resp, message); + } - #[error("Internal error")] - InternalError, + #[tokio::test] + async fn test_sum_not_eligible() { + let (mut publisher, subscriber, mut task) = spawn_svc(); + + let mut round_params = subscriber.params_listener().get_latest().event; + + // make sure no-one is eligible + round_params.sum = 0.0; + + publisher.broadcast_params(round_params.clone()); + publisher.broadcast_phase(PhaseName::Sum); + + let (message, _) = utils::new_sum_message(&round_params); + + assert_ready!(task.poll_ready()).unwrap(); + let err = task.call(message).await.unwrap_err(); + match err { + ServiceError::NotSumEligible => {} + _ => panic!("expected ServiceError::NotSumEligible got {:?}", err), + } + } } diff --git a/rust/xaynet-server/src/services/mod.rs b/rust/xaynet-server/src/services/mod.rs index dfa5a3a27..d3b96d3e6 100644 --- a/rust/xaynet-server/src/services/mod.rs +++ b/rust/xaynet-server/src/services/mod.rs @@ -21,112 +21,6 @@ //! an interface for the second category of services. pub mod fetchers; pub mod messages; -pub(in crate::services) mod utils; #[cfg(test)] mod tests; - -pub use self::{ - fetchers::{FetchError, Fetcher}, - messages::{PetMessageError, PetMessageHandler}, - utils::TracedService, -}; - -use crate::{ - services::{ - fetchers::{ - FetcherLayer, - Fetchers, - MaskLengthService, - ModelService, - RoundParamsService, - SeedDictService, - SumDictService, - }, - messages::{ - MessageParserService, - PetMessageService, - StateMachineService, - TaskValidatorService, - }, - }, - state_machine::{events::EventSubscriber, requests::RequestSender}, -}; - -use std::sync::Arc; - -use rayon::ThreadPoolBuilder; -use tower::ServiceBuilder; - -/// Construct a [`Fetcher`] service -pub fn fetcher(event_subscriber: &EventSubscriber) -> impl Fetcher + Sync + Send + Clone + 'static { - let round_params = ServiceBuilder::new() - .buffer(100) - .concurrency_limit(100) - .layer(FetcherLayer) - .service(RoundParamsService::new(event_subscriber)); - - let mask_length = ServiceBuilder::new() - .buffer(100) - .concurrency_limit(100) - .layer(FetcherLayer) - .service(MaskLengthService::new(event_subscriber)); - - let model = ServiceBuilder::new() - .buffer(100) - .concurrency_limit(100) - .layer(FetcherLayer) - .service(ModelService::new(event_subscriber)); - - let sum_dict = ServiceBuilder::new() - .buffer(100) - .concurrency_limit(100) - .layer(FetcherLayer) - .service(SumDictService::new(event_subscriber)); - - let seed_dict = ServiceBuilder::new() - .buffer(100) - .concurrency_limit(100) - .layer(FetcherLayer) - .service(SeedDictService::new(event_subscriber)); - - Fetchers::new(round_params, sum_dict, seed_dict, mask_length, model) -} - -/// Construct a [`PetMessageHandler`] service -pub fn message_handler( - event_subscriber: &EventSubscriber, - requests_tx: RequestSender, -) -> impl PetMessageHandler + Sync + Send + 'static + Clone { - // TODO: make this configurable. Users should be able to - // choose how many threads they want etc. - // - // TODO: don't unwrap - let thread_pool = Arc::new(ThreadPoolBuilder::new().build().unwrap()); - - let message_parser = ServiceBuilder::new() - // allow processing 100 request concurrently, and allow up to - // 10 requests to be pending. Once 100 requests are being - // processed and 100 are queued, the service will report that - // it's not ready. - // - // FIXME: what's a good concurrency limit? Should we limit - // the number of concurrent messages being processed to - // the number of threads in the rayon thread-pool? Or is - // rayon smart enough to avoid too many context switches? - .buffer(100) - .concurrency_limit(100) - .service(MessageParserService::new(event_subscriber, thread_pool)); - - let pre_processor = ServiceBuilder::new() - .buffer(100) - .concurrency_limit(100) - .service(TaskValidatorService::new(event_subscriber)); - - let state_machine = ServiceBuilder::new() - .buffer(100) - .concurrency_limit(100) - .service(StateMachineService::new(requests_tx)); - - PetMessageService::new(message_parser, pre_processor, state_machine) -} diff --git a/rust/xaynet-server/src/services/tests/messages/message_parser.rs b/rust/xaynet-server/src/services/tests/messages/message_parser.rs deleted file mode 100644 index 55fff4550..000000000 --- a/rust/xaynet-server/src/services/tests/messages/message_parser.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::sync::Arc; - -use rayon::ThreadPoolBuilder; -use tokio_test::assert_ready; -use tower_test::mock::Spawn; -use xaynet_core::{common::RoundParameters, message::Message}; - -use crate::{ - services::{ - messages::{ - MessageParserError, - MessageParserRequest, - MessageParserResponse, - MessageParserService, - }, - tests::utils, - }, - state_machine::{ - events::{EventPublisher, EventSubscriber}, - phases::PhaseName, - }, - utils::Request, -}; - -fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { - let (publisher, subscriber) = utils::new_event_channels(); - let thread_pool = Arc::new(ThreadPoolBuilder::new().build().unwrap()); - let task = Spawn::new(MessageParserService::new(&subscriber, thread_pool)); - (publisher, subscriber, task) -} - -fn make_req(bytes: Vec) -> MessageParserRequest> { - Request::new(bytes.into()) -} - -fn new_sum_message(round_params: &RoundParameters) -> (Message, Vec) { - let (message, _, participant_signing_keys) = utils::new_sum_message(round_params); - let encrypted_message = - utils::encrypt_message(&message, round_params, &participant_signing_keys); - (message, encrypted_message) -} - -fn assert_ready(task: &mut Spawn) { - assert_ready!(task.poll_ready::>>()).unwrap(); -} - -#[tokio::test] -async fn test_decrypt_fail() { - let (_publisher, _subscriber, mut task) = spawn_svc(); - assert_ready(&mut task); - - let req = make_req(vec![0, 1, 2, 3, 4, 5, 6]); - let resp: Result = task.call(req).await; - // this is a bit weird because MessageParserError doesn't impl Eq - // and PartialEq - match resp { - Ok(Err(MessageParserError::Decrypt)) => {} - _ => panic!("expected decrypt error"), - } - assert_ready(&mut task); -} - -#[tokio::test] -async fn test_valid_request() { - let (mut publisher, subscriber, mut task) = spawn_svc(); - assert_ready(&mut task); - - let round_params = subscriber.params_listener().get_latest().event; - let (message, encrypted_message) = new_sum_message(&round_params); - let req = make_req(encrypted_message); - - // Simulate the state machine broadcasting the sum phase - // (otherwise the request will be rejected) - publisher.broadcast_phase(PhaseName::Sum); - - // Call the service - let mut resp = task.call(req).await.unwrap().unwrap(); - // The signature should be set. However in `message` it's not been - // computed, so we just check that it's there, then set it to - // `None` in `resp` - assert!(resp.signature.is_some()); - resp.signature = None; - // Now the comparison should work - assert_eq!(resp, message); -} - -#[tokio::test] -async fn test_unexpected_message() { - let (_publisher, subscriber, mut task) = spawn_svc(); - assert_ready(&mut task); - - let round_params = subscriber.params_listener().get_latest().event; - let (_, encrypted_message) = new_sum_message(&round_params); - let req = make_req(encrypted_message); - - let err = task.call(req).await.unwrap().unwrap_err(); - match err { - MessageParserError::UnexpectedMessage => {} - _ => panic!( - "expected MessageParserError::UnexpectedMessage got {:?}", - err - ), - } -} diff --git a/rust/xaynet-server/src/services/tests/messages/mod.rs b/rust/xaynet-server/src/services/tests/messages/mod.rs deleted file mode 100644 index 71445c1d1..000000000 --- a/rust/xaynet-server/src/services/tests/messages/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod message_parser; -mod task_validator; diff --git a/rust/xaynet-server/src/services/tests/messages/task_validator.rs b/rust/xaynet-server/src/services/tests/messages/task_validator.rs deleted file mode 100644 index 3bad27795..000000000 --- a/rust/xaynet-server/src/services/tests/messages/task_validator.rs +++ /dev/null @@ -1,68 +0,0 @@ -use tokio_test::assert_ready; -use tower_test::mock::Spawn; -use xaynet_core::message::Message; - -use crate::{ - services::{ - messages::{TaskValidatorError, TaskValidatorRequest, TaskValidatorService}, - tests::utils, - }, - state_machine::{ - events::{EventPublisher, EventSubscriber}, - phases::PhaseName, - }, - utils::Request, -}; - -fn spawn_svc() -> (EventPublisher, EventSubscriber, Spawn) { - let (publisher, subscriber) = utils::new_event_channels(); - let task = Spawn::new(TaskValidatorService::new(&subscriber)); - (publisher, subscriber, task) -} - -fn make_req(message: Message) -> TaskValidatorRequest { - Request::new(message) -} - -#[tokio::test] -async fn test_sum_ok() { - let (mut publisher, subscriber, mut task) = spawn_svc(); - - let mut round_params = subscriber.params_listener().get_latest().event; - - // make sure everyone is eligible - round_params.sum = 1.0; - - publisher.broadcast_params(round_params.clone()); - publisher.broadcast_phase(PhaseName::Sum); - - let (message, _, _) = utils::new_sum_message(&round_params); - let req = make_req(message.clone()); - - assert_ready!(task.poll_ready()).unwrap(); - let resp = task.call(req).await.unwrap().unwrap(); - assert_eq!(resp, message); -} - -#[tokio::test] -async fn test_sum_not_eligible() { - let (mut publisher, subscriber, mut task) = spawn_svc(); - - let mut round_params = subscriber.params_listener().get_latest().event; - - // make sure no-one is eligible - round_params.sum = 0.0; - - publisher.broadcast_params(round_params.clone()); - publisher.broadcast_phase(PhaseName::Sum); - - let (message, _, _) = utils::new_sum_message(&round_params); - let req = make_req(message.clone()); - - assert_ready!(task.poll_ready()).unwrap(); - let err = task.call(req).await.unwrap().unwrap_err(); - match err { - TaskValidatorError::NotSumEligible => {} - _ => panic!("expected TaskValidatorError::NotSumEligible got {:?}", err), - } -} diff --git a/rust/xaynet-server/src/services/tests/mod.rs b/rust/xaynet-server/src/services/tests/mod.rs index 90d8d51b1..b5614dd82 100644 --- a/rust/xaynet-server/src/services/tests/mod.rs +++ b/rust/xaynet-server/src/services/tests/mod.rs @@ -1,3 +1 @@ -mod fetchers; -mod messages; -mod utils; +pub mod utils; diff --git a/rust/xaynet-server/src/services/tests/utils.rs b/rust/xaynet-server/src/services/tests/utils.rs index 17dcd2854..b176055f1 100644 --- a/rust/xaynet-server/src/services/tests/utils.rs +++ b/rust/xaynet-server/src/services/tests/utils.rs @@ -2,7 +2,6 @@ use xaynet_core::{ common::{RoundParameters, RoundSeed}, crypto::{ByteObject, EncryptKeyPair, PublicEncryptKey, SigningKeyPair}, message::{Message, Sum}, - SumParticipantEphemeralPublicKey, }; use crate::state_machine::{ @@ -29,28 +28,22 @@ pub fn new_event_channels() -> (EventPublisher, EventSubscriber) { /// Simulate a participant generating keys and crafting a valid sum /// message for the given round parameters. The keys generated by the /// participants are returned along with the message. -pub fn new_sum_message( - round_params: &RoundParameters, -) -> (Message, SumParticipantEphemeralPublicKey, SigningKeyPair) { - let participant_ephm_pk = PublicEncryptKey::generate(); +pub fn new_sum_message(round_params: &RoundParameters) -> (Message, SigningKeyPair) { let participant_signing_keys = SigningKeyPair::generate(); - let sum_signature = participant_signing_keys .secret .sign_detached(&[round_params.seed.as_slice(), b"sum"].concat()); - let message = Message { signature: None, participant_pk: participant_signing_keys.public.clone(), coordinator_pk: round_params.pk, payload: Sum { sum_signature, - ephm_pk: participant_ephm_pk.clone(), + ephm_pk: PublicEncryptKey::generate(), } .into(), }; - - (message, participant_ephm_pk, participant_signing_keys) + (message, participant_signing_keys) } /// Sign and encrypt the given message using the given round @@ -60,7 +53,12 @@ pub fn encrypt_message( round_params: &RoundParameters, participant_signing_keys: &SigningKeyPair, ) -> Vec { + let serialized = serialize_message(message, participant_signing_keys); + round_params.pk.encrypt(&serialized[..]) +} + +pub fn serialize_message(message: &Message, participant_signing_keys: &SigningKeyPair) -> Vec { let mut buf = vec![0; message.buffer_length()]; message.to_bytes(&mut buf, &participant_signing_keys.secret); - round_params.pk.encrypt(&buf[..]) + buf } diff --git a/rust/xaynet-server/src/services/utils.rs b/rust/xaynet-server/src/services/utils.rs deleted file mode 100644 index dd347b8cc..000000000 --- a/rust/xaynet-server/src/services/utils.rs +++ /dev/null @@ -1,27 +0,0 @@ -use tower::{Service, ServiceBuilder}; - -use crate::{ - utils::{Request, Traceable}, - vendor::tracing_tower, -}; - -/// Return the [`tracing::Span`] associated to the given request. -pub(in crate::services) fn req_span(req: &Request) -> tracing::Span { - req.span() -} - -/// Decorate the given service with a tracing middleware. -pub(in crate::services) fn with_tracing(service: S) -> TracedService -where - S: Service>, - T: Traceable, -{ - ServiceBuilder::new() - .layer(tracing_tower::layer(req_span as for<'r> fn(&'r _) -> _)) - .service(service) -} - -/// A service `S` that handles `Request` requests, decorated with a -/// tracing middleware that automatically enters the request's span. -pub type TracedService = - tracing_tower::Service, fn(&Request) -> tracing::Span>; diff --git a/rust/xaynet-server/src/state_machine/mod.rs b/rust/xaynet-server/src/state_machine/mod.rs index e467cd95c..64995fd16 100644 --- a/rust/xaynet-server/src/state_machine/mod.rs +++ b/rust/xaynet-server/src/state_machine/mod.rs @@ -109,31 +109,30 @@ pub mod events; pub mod phases; pub mod requests; -use crate::{ - settings::{MaskSettings, ModelSettings, PetSettings}, - state_machine::{ - coordinator::CoordinatorState, - events::{EventPublisher, EventSubscriber}, - phases::{ - Idle, - Phase, - PhaseName, - PhaseState, - Shared, - Shutdown, - StateError, - Sum, - Sum2, - Unmask, - Update, - }, - requests::{RequestReceiver, RequestSender}, +use self::{ + coordinator::CoordinatorState, + events::{EventPublisher, EventSubscriber}, + phases::{ + Idle, + Phase, + PhaseName, + PhaseState, + Shared, + Shutdown, + StateError, + Sum, + Sum2, + Unmask, + Update, }, + requests::{RequestReceiver, RequestSender}, }; -use xaynet_core::{mask::UnmaskingError, InitError, PetError}; use derive_more::From; use thiserror::Error; +use xaynet_core::{mask::UnmaskingError, InitError}; + +use crate::settings::{MaskSettings, ModelSettings, PetSettings}; #[cfg(feature = "metrics")] use crate::metrics::MetricsSender; @@ -141,8 +140,15 @@ use crate::metrics::MetricsSender; /// Error returned when the state machine fails to handle a request #[derive(Debug, Error)] pub enum StateMachineError { - #[error("the request failed")] - RequestFailed(#[from] PetError), + #[error("the message was rejected")] + MessageRejected, + + #[error("invalid update: the model or scalar sent by the participant could not be aggregated")] + AggregationFailed, + + #[error("invalid update: the seed dictionary sent by the participant is invalid")] + InvalidLocalSeedDict, + #[error("the request could not be processed due to an internal error")] InternalError, } diff --git a/rust/xaynet-server/src/state_machine/phases/idle.rs b/rust/xaynet-server/src/state_machine/phases/idle.rs index f3a9dcbbc..1f8db5ae2 100644 --- a/rust/xaynet-server/src/state_machine/phases/idle.rs +++ b/rust/xaynet-server/src/state_machine/phases/idle.rs @@ -1,7 +1,6 @@ use xaynet_core::{ common::RoundSeed, crypto::{ByteObject, EncryptKeyPair, SigningKeySeed}, - PetError, }; use crate::state_machine::{ @@ -10,6 +9,7 @@ use crate::state_machine::{ requests::StateMachineRequest, StateError, StateMachine, + StateMachineError, }; #[cfg(feature = "metrics")] @@ -22,9 +22,9 @@ use sodiumoxide::crypto::hash::sha256; pub struct Idle; impl Handler for PhaseState { - /// Reject the request with a [`PetError::InvalidMessage`] - fn handle_request(&mut self, _req: StateMachineRequest) -> Result<(), PetError> { - Err(PetError::InvalidMessage) + /// Reject the request with a [`StateMachineError::MessageRejected`] + fn handle_request(&mut self, _req: StateMachineRequest) -> Result<(), StateMachineError> { + Err(StateMachineError::MessageRejected) } } diff --git a/rust/xaynet-server/src/state_machine/phases/mod.rs b/rust/xaynet-server/src/state_machine/phases/mod.rs index 2e4f6fd11..dce29cb9b 100644 --- a/rust/xaynet-server/src/state_machine/phases/mod.rs +++ b/rust/xaynet-server/src/state_machine/phases/mod.rs @@ -8,8 +8,6 @@ mod sum2; mod unmask; mod update; -use xaynet_core::PetError; - pub use self::{ error::StateError, idle::Idle, @@ -20,20 +18,19 @@ pub use self::{ update::Update, }; -use crate::{ - state_machine::{ - coordinator::CoordinatorState, - events::EventPublisher, - requests::{RequestReceiver, ResponseSender, StateMachineRequest}, - StateMachine, - }, - utils::Request, +use crate::state_machine::{ + coordinator::CoordinatorState, + events::EventPublisher, + requests::{RequestReceiver, ResponseSender, StateMachineRequest}, + StateMachine, + StateMachineError, }; #[cfg(feature = "metrics")] use crate::{metrics, metrics::MetricsSender}; use futures::StreamExt; +use tracing::Span; use tracing_futures::Instrument; /// Name of the current phase @@ -64,7 +61,7 @@ pub trait Phase { /// A trait that must be implemented by a state to handle a request. pub trait Handler { /// Handles a request. - fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), PetError>; + fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), StateMachineError>; } /// I/O interfaces. @@ -158,10 +155,9 @@ where /// Processes the next available request. async fn process_single(&mut self) -> Result<(), StateError> { - let (req, resp_tx) = self.next_request().await?; - let span = req.span(); + let (req, span, resp_tx) = self.next_request().await?; let _span_guard = span.enter(); - let res = self.handle_request(req.into_inner()); + let res = self.handle_request(req); if res.is_err() { metrics!( @@ -227,11 +223,10 @@ where fn purge_outdated_requests(&mut self) -> Result<(), StateError> { loop { match self.try_next_request()? { - Some((req, resp_tx)) => { - let span = req.span(); + Some((_req, span, resp_tx)) => { let _span_guard = span.enter(); info!("rejecting request"); - let _ = resp_tx.send(Err(PetError::InvalidMessage.into())); + let _ = resp_tx.send(Err(StateMachineError::MessageRejected)); metrics!( self.shared.io.metrics_tx, @@ -255,7 +250,7 @@ impl PhaseState { /// Returns [`StateError::ChannelError`] when all sender halves have been dropped. async fn next_request( &mut self, - ) -> Result<(Request, ResponseSender), StateError> { + ) -> Result<(StateMachineRequest, Span, ResponseSender), StateError> { debug!("waiting for the next incoming request"); self.shared.io.request_rx.next().await.ok_or_else(|| { error!("request receiver broken: senders have been dropped"); @@ -265,7 +260,7 @@ impl PhaseState { fn try_next_request( &mut self, - ) -> Result, ResponseSender)>, StateError> { + ) -> Result, StateError> { match self.shared.io.request_rx.try_recv() { Ok(item) => Ok(Some(item)), Err(tokio::sync::mpsc::error::TryRecvError::Empty) => { diff --git a/rust/xaynet-server/src/state_machine/phases/sum.rs b/rust/xaynet-server/src/state_machine/phases/sum.rs index 7c74dec45..80a655301 100644 --- a/rust/xaynet-server/src/state_machine/phases/sum.rs +++ b/rust/xaynet-server/src/state_machine/phases/sum.rs @@ -1,12 +1,13 @@ use std::sync::Arc; -use xaynet_core::{LocalSeedDict, PetError, SeedDict, SumDict}; +use xaynet_core::{LocalSeedDict, SeedDict, SumDict}; use crate::state_machine::{ events::DictionaryUpdate, phases::{Handler, Phase, PhaseName, PhaseState, Shared, StateError, Update}, requests::{StateMachineRequest, SumRequest}, StateMachine, + StateMachineError, }; #[cfg(feature = "metrics")] @@ -35,17 +36,18 @@ impl Handler for PhaseState { /// /// If the request is a [`StateMachineRequest::Update`] or /// [`StateMachineRequest::Sum2`] request, the request sender will receive a - /// [`PetError::InvalidMessage`]. - fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), PetError> { + /// [`StateMachineError::MessageRejected`]. + fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), StateMachineError> { match req { StateMachineRequest::Sum(sum_req) => { metrics!( self.shared.io.metrics_tx, metrics::message::sum::increment(self.shared.state.round_id, Self::NAME) ); - self.handle_sum(sum_req) + self.handle_sum(sum_req); + Ok(()) } - _ => Err(PetError::InvalidMessage), + _ => Err(StateMachineError::MessageRejected), } } } @@ -132,13 +134,12 @@ impl PhaseState { } /// Handles a sum request. - fn handle_sum(&mut self, req: SumRequest) -> Result<(), PetError> { + fn handle_sum(&mut self, req: SumRequest) { let SumRequest { participant_pk, ephm_pk, } = req; self.inner.sum_dict.insert(participant_pk, ephm_pk); - Ok(()) } /// Freezes the sum dictionary. diff --git a/rust/xaynet-server/src/state_machine/phases/sum2.rs b/rust/xaynet-server/src/state_machine/phases/sum2.rs index 71d06a198..59e50f484 100644 --- a/rust/xaynet-server/src/state_machine/phases/sum2.rs +++ b/rust/xaynet-server/src/state_machine/phases/sum2.rs @@ -1,6 +1,5 @@ use xaynet_core::{ mask::{Aggregation, MaskObject}, - PetError, SumDict, SumParticipantPublicKey, }; @@ -10,6 +9,7 @@ use crate::state_machine::{ phases::{Handler, Phase, PhaseName, PhaseState, Shared, StateError, Unmask}, requests::{StateMachineRequest, Sum2Request}, StateMachine, + StateMachineError, }; #[cfg(feature = "metrics")] @@ -125,8 +125,8 @@ impl Handler for PhaseState { /// /// If the request is a [`StateMachineRequest::Sum`] or /// [`StateMachineRequest::Update`] request, the request sender - /// will receive a [`PetError::InvalidMessage`]. - fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), PetError> { + /// will receive a [`StateMachineError::MessageRejected`]. + fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), StateMachineError> { match req { StateMachineRequest::Sum2(sum2_req) => { metrics!( @@ -135,7 +135,7 @@ impl Handler for PhaseState { ); self.handle_sum2(sum2_req) } - _ => Err(PetError::InvalidMessage), + _ => Err(StateMachineError::MessageRejected), } } } @@ -163,7 +163,7 @@ impl PhaseState { /// Handles a sum2 request. /// If the handling of the sum2 message fails, an error is returned to the request sender. - fn handle_sum2(&mut self, req: Sum2Request) -> Result<(), PetError> { + fn handle_sum2(&mut self, req: Sum2Request) -> Result<(), StateMachineError> { let Sum2Request { participant_pk, model_mask, @@ -181,11 +181,11 @@ impl PhaseState { pk: &SumParticipantPublicKey, model_mask: MaskObject, scalar_mask: MaskObject, - ) -> Result<(), PetError> { + ) -> Result<(), StateMachineError> { // We remove the participant key here to make sure a participant // cannot submit a mask multiple times if self.inner.sum_dict.remove(pk).is_none() { - return Err(PetError::InvalidMessage); + return Err(StateMachineError::MessageRejected); } if let Some(count) = self.inner.model_mask_dict.get_mut(&model_mask) { diff --git a/rust/xaynet-server/src/state_machine/phases/update.rs b/rust/xaynet-server/src/state_machine/phases/update.rs index 3c96ee2d6..721e00967 100644 --- a/rust/xaynet-server/src/state_machine/phases/update.rs +++ b/rust/xaynet-server/src/state_machine/phases/update.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use xaynet_core::{ mask::{Aggregation, MaskObject}, LocalSeedDict, - PetError, SeedDict, SumDict, UpdateParticipantPublicKey, @@ -14,6 +13,7 @@ use crate::state_machine::{ phases::{Handler, Phase, PhaseName, PhaseState, Shared, StateError, Sum2}, requests::{StateMachineRequest, UpdateRequest}, StateMachine, + StateMachineError, }; #[cfg(feature = "metrics")] @@ -127,8 +127,8 @@ impl Handler for PhaseState { /// /// If the request is a [`StateMachineRequest::Sum`] or /// [`StateMachineRequest::Sum2`] request, the request sender will - /// receive a [`PetError::InvalidMessage`]. - fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), PetError> { + /// receive a [`StateMachineError::MessageRejected`]. + fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), StateMachineError> { match req { StateMachineRequest::Update(update_req) => { metrics!( @@ -137,7 +137,7 @@ impl Handler for PhaseState { ); self.handle_update(update_req) } - _ => Err(PetError::InvalidMessage), + _ => Err(StateMachineError::MessageRejected), } } } @@ -160,7 +160,7 @@ impl PhaseState { /// Handles an update request. /// If the handling of the update message fails, an error is returned to the request sender. - fn handle_update(&mut self, req: UpdateRequest) -> Result<(), PetError> { + fn handle_update(&mut self, req: UpdateRequest) -> Result<(), StateMachineError> { let UpdateRequest { participant_pk, local_seed_dict, @@ -182,7 +182,7 @@ impl PhaseState { local_seed_dict: &LocalSeedDict, masked_model: MaskObject, masked_scalar: MaskObject, - ) -> Result<(), PetError> { + ) -> Result<(), StateMachineError> { // Check if aggregation can be performed. It is important to // do that _before_ updating the seed dictionary, because we // don't want to add the local seed dict if the corresponding @@ -193,7 +193,7 @@ impl PhaseState { .validate_aggregation(&masked_model) .map_err(|e| { warn!("model aggregation error: {}", e); - PetError::InvalidMessage + StateMachineError::AggregationFailed })?; debug!("checking whether the masked scalar can be aggregated"); @@ -202,7 +202,7 @@ impl PhaseState { .validate_aggregation(&masked_scalar) .map_err(|e| { warn!("scalar aggregation error: {}", e); - PetError::InvalidMessage + StateMachineError::AggregationFailed })?; // Try to update local seed dict first. If this fail, we do @@ -228,7 +228,7 @@ impl PhaseState { &mut self, pk: &UpdateParticipantPublicKey, local_seed_dict: &LocalSeedDict, - ) -> Result<(), PetError> { + ) -> Result<(), StateMachineError> { if local_seed_dict.keys().len() == self.inner.frozen_sum_dict.keys().len() && local_seed_dict .keys() @@ -245,13 +245,16 @@ impl PhaseState { self.inner .seed_dict .get_mut(sum_pk) - .ok_or(PetError::InvalidMessage)? + // FIXME: the error is not very adapted here, it's + // more an internal error. Could we not unwrap + // here per the checks above? + .ok_or(StateMachineError::InvalidLocalSeedDict)? .insert(*pk, seed.clone()); } Ok(()) } else { warn!("invalid seed dictionary"); - Err(PetError::InvalidMessage) + Err(StateMachineError::InvalidLocalSeedDict) } } diff --git a/rust/xaynet-server/src/state_machine/requests.rs b/rust/xaynet-server/src/state_machine/requests.rs index 2278c976b..1016ad80a 100644 --- a/rust/xaynet-server/src/state_machine/requests.rs +++ b/rust/xaynet-server/src/state_machine/requests.rs @@ -27,10 +27,7 @@ use xaynet_core::{ #[error("the RequestSender cannot be used because the state machine shut down")] pub struct StateMachineShutdown; -use crate::{ - state_machine::{StateMachineError, StateMachineResult}, - utils::{Request, Traceable}, -}; +use crate::state_machine::{StateMachineError, StateMachineResult}; /// A sum request. #[derive(Debug)] @@ -75,17 +72,6 @@ pub enum StateMachineRequest { Sum2(Sum2Request), } -impl Traceable for StateMachineRequest { - fn make_span(&self) -> Span { - let request_type = match self { - Self::Sum(_) => "sum", - Self::Update(_) => "update", - Self::Sum2(_) => "sum2", - }; - error_span!("StateMachineRequest", request_type = request_type) - } -} - impl From for StateMachineRequest { fn from(message: Message) -> Self { let participant_pk = message.participant_pk; @@ -122,7 +108,7 @@ impl From for StateMachineRequest { /// /// [`StateMachine`]: crate::state_machine #[derive(Clone, From, Debug)] -pub struct RequestSender(mpsc::UnboundedSender<(Request, ResponseSender)>); +pub struct RequestSender(mpsc::UnboundedSender<(StateMachineRequest, Span, ResponseSender)>); impl RequestSender { /// Sends a request to the [`StateMachine`]. @@ -132,12 +118,9 @@ impl RequestSender { /// closed as a result. /// /// [`StateMachine`]: crate::state_machine - pub async fn request + Traceable>( - &self, - req: Request, - ) -> StateMachineResult { + pub async fn request(&self, req: StateMachineRequest, span: Span) -> StateMachineResult { let (resp_tx, resp_rx) = oneshot::channel::(); - self.0.send((req.map(Into::into), resp_tx)).map_err(|_| { + self.0.send((req, span, resp_tx)).map_err(|_| { warn!("failed to send request to the state machine: state machine is shutting down"); StateMachineError::InternalError })?; @@ -159,10 +142,10 @@ pub(in crate::state_machine) type ResponseSender = oneshot::Sender, ResponseSender)>); +pub struct RequestReceiver(mpsc::UnboundedReceiver<(StateMachineRequest, Span, ResponseSender)>); impl Stream for RequestReceiver { - type Item = (Request, ResponseSender); + type Item = (StateMachineRequest, Span, ResponseSender); fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { trace!("RequestReceiver: polling"); @@ -174,7 +157,7 @@ impl RequestReceiver { /// Creates a new `Request` channel and returns the [`RequestReceiver`] as well as the /// [`RequestSender`] half. pub fn new() -> (Self, RequestSender) { - let (tx, rx) = mpsc::unbounded_channel::<(Request, ResponseSender)>(); + let (tx, rx) = mpsc::unbounded_channel::<(StateMachineRequest, Span, ResponseSender)>(); let receiver = RequestReceiver::from(rx); let handle = RequestSender::from(tx); (receiver, handle) @@ -192,7 +175,7 @@ impl RequestReceiver { /// See [the `tokio` documentation][receive] for more information. /// /// [receive]: https://docs.rs/tokio/0.2.21/tokio/sync/mpsc/struct.UnboundedReceiver.html#method.recv - pub async fn recv(&mut self) -> Option<(Request, ResponseSender)> { + pub async fn recv(&mut self) -> Option<(StateMachineRequest, Span, ResponseSender)> { self.0.recv().await } @@ -202,10 +185,8 @@ impl RequestReceiver { /// [try_receive]: https://docs.rs/tokio/0.2.21/tokio/sync/mpsc/struct.UnboundedReceiver.html#method.try_recv pub fn try_recv( &mut self, - ) -> Result< - (Request, ResponseSender), - tokio::sync::mpsc::error::TryRecvError, - > { + ) -> Result<(StateMachineRequest, Span, ResponseSender), tokio::sync::mpsc::error::TryRecvError> + { self.0.try_recv() } } diff --git a/rust/xaynet-server/src/state_machine/tests/impls.rs b/rust/xaynet-server/src/state_machine/tests/impls.rs index f316c1e8d..2849e54cd 100644 --- a/rust/xaynet-server/src/state_machine/tests/impls.rs +++ b/rust/xaynet-server/src/state_machine/tests/impls.rs @@ -1,19 +1,17 @@ +use tracing::Span; use xaynet_core::message::Message; -use crate::{ - state_machine::{ - events::{DictionaryUpdate, MaskLengthUpdate}, - phases::{self, PhaseState}, - requests::RequestSender, - StateMachine, - StateMachineResult, - }, - utils::Request, +use crate::state_machine::{ + events::{DictionaryUpdate, MaskLengthUpdate}, + phases::{self, PhaseState}, + requests::RequestSender, + StateMachine, + StateMachineResult, }; impl RequestSender { pub async fn msg(&self, msg: &Message) -> StateMachineResult { - self.request(Request::new(msg.clone())).await + self.request(msg.clone().into(), Span::none()).await } } diff --git a/rust/xaynet-server/src/utils/mod.rs b/rust/xaynet-server/src/utils/mod.rs deleted file mode 100644 index 0cab0442e..000000000 --- a/rust/xaynet-server/src/utils/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod request; -pub use self::request::{Request, Traceable}; diff --git a/rust/xaynet-server/src/utils/request.rs b/rust/xaynet-server/src/utils/request.rs deleted file mode 100644 index 70e8a5eea..000000000 --- a/rust/xaynet-server/src/utils/request.rs +++ /dev/null @@ -1,141 +0,0 @@ -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use tracing::Span; -use uuid::Uuid; - -/// A type that can be associated to a span, making it traceable. -pub trait Traceable { - fn make_span(&self) -> Span; -} - -impl<'a, T> Traceable for &'a T -where - T: Traceable, -{ - fn make_span(&self) -> Span { - ::make_span(*self) - } -} - -// NOTE: currently `id` and `timestamp` are immutable. `span` is -// mutable, but when it is changed other copies of the RequestMetadata -// are not affected. In the future, we can have shared mutable fields -// if we want to, by adding an Arc> field. -#[derive(Debug, Clone, PartialEq)] -pub struct RequestMetadata { - /// A random UUID associated to the request - id: Uuid, - /// Time the request was created - timestamp: SystemTime, - /// Current span associated to this request - span: Span, -} - -impl RequestMetadata { - fn new() -> Self { - let id = Uuid::new_v4(); - let timestamp = SystemTime::now(); - let span = error_span!("request", id = %id, timestamp = %timestamp.duration_since(UNIX_EPOCH).unwrap_or_else(|_| Duration::new(0, 0)).as_millis()); - Self { - id, - timestamp, - span, - } - } - - /// Time elapsed since this request was created, in milli seconds - fn elapsed(&self) -> u128 { - SystemTime::now() - .duration_since(self.timestamp) - .unwrap_or_else(|_| Duration::new(0, 0)) - .as_millis() - } - - /// Return the span associated with the metadata - fn span(&self) -> Span { - self.span.clone() - } -} - -#[derive(Debug, Clone, PartialEq)] -/// A request that can be handled by a service -pub struct Request { - /// Content of the request - inner: T, - /// Metadata associated to this request - metadata: RequestMetadata, -} - -impl Request -where - T: Traceable, -{ - /// Create a new request - pub fn new(t: T) -> Self { - Self { - inner: t, - metadata: RequestMetadata::new(), - } - } - - /// Create a [`Request`] with the given metadata and inner - /// request value. - pub fn from_parts(metadata: RequestMetadata, inner: T) -> Self { - Self { metadata, inner } - } - - /// Return the metadata attached to this [`Request`] - pub fn metadata(&self) -> RequestMetadata { - self.metadata.clone() - } - - /// Turn this `Request` into a `Request`. A new span is - /// created with `::make_span` and attached to the - /// request. - pub fn map(self, f: F) -> Request - where - F: ::std::ops::FnOnce(T) -> U, - U: Traceable, - { - let Request { - mut metadata, - inner, - } = self; - let mapped = f(inner); - - // self.span() is the parent of the span associated to the - // inner type - let new_span = metadata.span().in_scope(|| mapped.make_span()); - metadata.span = new_span; - - Request { - metadata, - inner: mapped, - } - } - - /// Span associated with this request - pub fn span(&self) -> Span { - self.metadata.span() - } - - /// Time elapsed since this request was created, in milli seconds - pub fn elapsed(&self) -> u128 { - self.metadata.elapsed() - } - - /// Get a reference to the request's inner value - pub fn inner(&self) -> &T { - &self.inner - } - - /// Get a mutable reference to the request's inner value - pub fn inner_mut(&mut self) -> &mut T { - &mut self.inner - } - - /// Consume this request and return its inner value - pub fn into_inner(self) -> T { - self.inner - } -} diff --git a/rust/xaynet-server/src/vendor/mod.rs b/rust/xaynet-server/src/vendor/mod.rs deleted file mode 100644 index a0769e69b..000000000 --- a/rust/xaynet-server/src/vendor/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod tracing_tower; diff --git a/rust/xaynet-server/src/vendor/tracing_tower/mod.rs b/rust/xaynet-server/src/vendor/tracing_tower/mod.rs deleted file mode 100644 index 015b91601..000000000 --- a/rust/xaynet-server/src/vendor/tracing_tower/mod.rs +++ /dev/null @@ -1,172 +0,0 @@ -//! This module contains the bits of -//! https://github.com/tokio-rs/tracing/blob/master/tracing-tower that -//! we're using. We copied them here because without a release of -//! tracing-tower, we cannot publish to crates.io ourself. - -// Copyright (c) 2019 Tokio Contributors -// -// Permission is hereby granted, free of charge, to any person -// obtaining a copy of this software and associated documentation -// files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, copy, -// modify, merge, publish, distribute, sublicense, and/or sell copies -// of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS -// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN -// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -use std::{ - marker::PhantomData, - task::{Context, Poll}, -}; -use tracing_futures::Instrument; - -pub trait GetSpan: sealed::Sealed { - fn span_for(&self, target: &T) -> tracing::Span; -} - -impl sealed::Sealed for F where F: Fn(&T) -> tracing::Span {} - -impl GetSpan for F -where - F: Fn(&T) -> tracing::Span, -{ - #[inline] - fn span_for(&self, target: &T) -> tracing::Span { - (self)(target) - } -} - -impl sealed::Sealed for tracing::Span {} - -impl GetSpan for tracing::Span { - #[inline] - fn span_for(&self, _: &T) -> tracing::Span { - self.clone() - } -} - -mod sealed { - pub trait Sealed {} -} - -#[derive(Debug)] -pub struct Service tracing::Span> -where - S: tower::Service, - G: GetSpan, -{ - get_span: G, - inner: S, - _p: PhantomData, -} - -pub use self::layer::*; - -mod layer { - use super::*; - - #[derive(Debug)] - pub struct Layer tracing::Span> - where - G: GetSpan + Clone, - { - get_span: G, - _p: PhantomData, - } - - pub fn layer(get_span: G) -> Layer - where - G: GetSpan + Clone, - { - Layer { - get_span, - _p: PhantomData, - } - } - - // === impl Layer === - impl tower::layer::Layer for Layer - where - S: tower::Service, - G: GetSpan + Clone, - { - type Service = Service; - - fn layer(&self, service: S) -> Self::Service { - Service::new(service, self.get_span.clone()) - } - } - - impl Clone for Layer - where - G: GetSpan + Clone, - { - fn clone(&self) -> Self { - Self { - get_span: self.get_span.clone(), - _p: PhantomData, - } - } - } -} - -// === impl Service === - -impl tower::Service for Service -where - S: tower::Service, - G: GetSpan + Clone, -{ - type Response = S::Response; - type Error = S::Error; - type Future = tracing_futures::Instrumented; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, request: R) -> Self::Future { - let span = self.get_span.span_for(&request); - let _enter = span.enter(); - self.inner.call(request).instrument(span.clone()) - } -} - -impl Clone for Service -where - S: tower::Service + Clone, - G: GetSpan + Clone, -{ - fn clone(&self) -> Self { - Service { - get_span: self.get_span.clone(), - inner: self.inner.clone(), - _p: PhantomData, - } - } -} - -impl Service -where - S: tower::Service, - G: GetSpan + Clone, -{ - pub fn new(inner: S, get_span: G) -> Self { - Service { - get_span, - inner, - _p: PhantomData, - } - } -}