Skip to content

Commit

Permalink
Remove type alias impl trait unstable feature requirement from DHT
Browse files Browse the repository at this point in the history
Replaces the `type_alias_impl_trait` feature requirement with boxed
Futures on all services. Removes/simplifies some redundant trait bounds.

SAF handler task had to be implemented to be a less concurrent because
of the additional trait bounds required for boxing.
  • Loading branch information
sdbondi committed Jun 30, 2021
1 parent ef94134 commit 3fb3d20
Show file tree
Hide file tree
Showing 25 changed files with 365 additions and 348 deletions.
27 changes: 15 additions & 12 deletions base_layer/p2p/src/initialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::{
transport::{TorConfig, TransportType},
};
use fs2::FileExt;
use futures::{channel::mpsc, future, Sink};
use futures::{channel::mpsc, future, future::BoxFuture, Sink};
use log::*;
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use std::{
Expand Down Expand Up @@ -342,12 +342,16 @@ where
// Create outbound channel
let (outbound_tx, outbound_rx) = mpsc::channel(config.outbound_buffer_size);

let node_identity = comms.node_identity();
let peer_manager = comms.peer_manager();
let connectivity = comms.connectivity();
let shutdown_signal = comms.shutdown_signal();
let dht = DhtBuilder::new(
comms.node_identity(),
comms.peer_manager(),
node_identity.clone(),
peer_manager,
outbound_tx,
comms.connectivity(),
comms.shutdown_signal(),
connectivity,
shutdown_signal,
)
.with_config(config.dht.clone())
.build()
Expand All @@ -356,10 +360,7 @@ where
let dht_outbound_layer = dht.outbound_middleware_layer();

// DHT RPC service is only available for communication nodes
if comms
.node_identity()
.has_peer_features(PeerFeatures::COMMUNICATION_NODE)
{
if node_identity.has_peer_features(PeerFeatures::COMMUNICATION_NODE) {
comms = comms.add_rpc_server(RpcServer::new().add_service(dht.rpc_service()));
}

Expand Down Expand Up @@ -542,18 +543,20 @@ impl ServiceInitializer for P2pInitializer {
let (comms, dht) = configure_comms_and_dht(builder, &config, connector).await?;

let peers = Self::try_parse_seed_peers(&config.peer_seeds)?;
add_all_peers(&comms.peer_manager(), &comms.node_identity(), peers).await?;
let peer_manager = comms.peer_manager();
let node_identity = comms.node_identity();
add_all_peers(&peer_manager, &node_identity, peers).await?;

let peers = Self::try_resolve_dns_seeds(
config.dns_seeds_name_server,
&config.dns_seeds,
config.dns_seeds_use_dnssec,
)
.await?;
add_all_peers(&comms.peer_manager(), &comms.node_identity(), peers).await?;
add_all_peers(&peer_manager, &node_identity, peers).await?;

context.register_handle(comms.connectivity());
context.register_handle(comms.peer_manager());
context.register_handle(peer_manager);
context.register_handle(comms);
context.register_handle(dht);

Expand Down
4 changes: 2 additions & 2 deletions base_layer/service_framework/src/context/handles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl ServiceInitializerContext {

/// Insert a service handle with the given name
pub fn register_handle<H>(&self, handle: H)
where H: Any + Send + Sync {
where H: Any + Send {
self.inner.register(handle);
}

Expand Down Expand Up @@ -160,7 +160,7 @@ impl ServiceHandles {

/// Register a handle
pub fn register<H>(&self, handle: H)
where H: Any + Send + Sync {
where H: Any + Send {
acquire_lock!(self.handles).insert(TypeId::of::<H>(), Box::new(handle));
}

Expand Down
2 changes: 1 addition & 1 deletion comms/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ tokio-macros = "0.2.3"
tempfile = "3.1.0"

[build-dependencies]
tari_common = { version = "^0.8", path="../common"}
tari_common = { version = "^0.8", path="../common", features = ["build"]}

[features]
avx2 = ["tari_crypto/avx2"]
Expand Down
29 changes: 13 additions & 16 deletions comms/dht/examples/memory_net/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,26 +917,23 @@ async fn setup_comms_dht(
.unwrap();

let dht_outbound_layer = dht.outbound_middleware_layer();
let pipeline = pipeline::Builder::new()
.outbound_buffer_size(10)
.with_outbound_pipeline(outbound_rx, |sink| {
ServiceBuilder::new().layer(dht_outbound_layer).service(sink)
})
.max_concurrent_inbound_tasks(10)
.with_inbound_pipeline(
ServiceBuilder::new()
.layer(dht.inbound_middleware_layer())
.service(SinkService::new(inbound_tx)),
)
.build();

let (messaging_events_tx, _) = broadcast::channel(100);

let comms = comms
.add_rpc_server(RpcServer::new().add_service(dht.rpc_service()))
.add_protocol_extension(MessagingProtocolExtension::new(
messaging_events_tx.clone(),
pipeline::Builder::new()
.outbound_buffer_size(10)
.with_outbound_pipeline(outbound_rx, |sink| {
ServiceBuilder::new().layer(dht_outbound_layer).service(sink)
})
.max_concurrent_inbound_tasks(10)
.with_inbound_pipeline(
ServiceBuilder::new()
.layer(dht.inbound_middleware_layer())
.service(SinkService::new(inbound_tx)),
)
.build(),
))
.add_protocol_extension(MessagingProtocolExtension::new(messaging_events_tx.clone(), pipeline))
.spawn_with_transport(MemoryTransport)
.await
.unwrap();
Expand Down
13 changes: 7 additions & 6 deletions comms/dht/src/dedup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

use crate::{actor::DhtRequester, inbound::DhtInboundMessage};
use digest::Input;
use futures::{task::Context, Future};
use futures::{future::BoxFuture, task::Context};
use log::*;
use std::task::Poll;
use tari_comms::{pipeline::PipelineError, types::Challenge};
Expand Down Expand Up @@ -55,21 +55,22 @@ impl<S> DedupMiddleware<S> {
}

impl<S> Service<DhtInboundMessage> for DedupMiddleware<S>
where S: Service<DhtInboundMessage, Response = (), Error = PipelineError> + Clone
where
S: Service<DhtInboundMessage, Response = (), Error = PipelineError> + Clone + Send + 'static,
S::Future: Send,
{
type Error = PipelineError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
type Response = ();

type Future = impl Future<Output = Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, message: DhtInboundMessage) -> Self::Future {
let next_service = self.next_service.clone();
let mut dht_requester = self.dht_requester.clone();
async move {
Box::pin(async move {
let hash = hash_inbound_message(&message);
trace!(
target: LOG_TARGET,
Expand All @@ -96,7 +97,7 @@ where S: Service<DhtInboundMessage, Response = (), Error = PipelineError> + Clon
message.dht_header.message_tag
);
next_service.oneshot(message).await
}
})
}
}

Expand Down
10 changes: 4 additions & 6 deletions comms/dht/src/dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,8 @@ impl Dht {
InboundMessage,
Response = (),
Error = PipelineError,
Future = impl Future<Output = Result<(), PipelineError>> + Send,
> + Clone
+ Send,
Future = impl Future<Output = Result<(), PipelineError>>,
> + Clone,
>
where
S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError> + Clone + Send + Sync + 'static,
Expand Down Expand Up @@ -341,9 +340,8 @@ impl Dht {
DhtOutboundRequest,
Response = (),
Error = PipelineError,
Future = impl Future<Output = Result<(), PipelineError>> + Send,
> + Clone
+ Send,
Future = impl Future<Output = Result<(), PipelineError>>,
> + Clone,
>
where
S: Service<OutboundMessage, Response = (), Error = PipelineError> + Clone + Send + 'static,
Expand Down
46 changes: 28 additions & 18 deletions comms/dht/src/inbound/decryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::{
proto::envelope::OriginMac,
DhtConfig,
};
use futures::{task::Context, Future};
use futures::{future::BoxFuture, task::Context};
use log::*;
use prost::Message;
use std::{sync::Arc, task::Poll, time::Duration};
Expand Down Expand Up @@ -123,25 +123,26 @@ impl<S> DecryptionService<S> {
}

impl<S> Service<DhtInboundMessage> for DecryptionService<S>
where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError> + Clone
where
S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError> + Clone + Send + 'static,
S::Future: Send,
{
type Error = PipelineError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
type Response = ();

type Future = impl Future<Output = Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, msg: DhtInboundMessage) -> Self::Future {
Self::handle_message(
Box::pin(Self::handle_message(
self.inner.clone(),
Arc::clone(&self.node_identity),
self.connectivity.clone(),
self.config.ban_duration,
msg,
)
))
}
}

Expand Down Expand Up @@ -416,10 +417,13 @@ mod test {

#[test]
fn decrypt_inbound_success() {
let result = Mutex::new(None);
let service = service_fn(|msg: DecryptedDhtMessage| {
*result.lock().unwrap() = Some(msg);
future::ready(Result::<(), PipelineError>::Ok(()))
let result = Arc::new(Mutex::new(None));
let service = service_fn({
let result = result.clone();
move |msg: DecryptedDhtMessage| {
*result.lock().unwrap() = Some(msg);
future::ready(Result::<(), PipelineError>::Ok(()))
}
});
let node_identity = make_node_identity();
let (connectivity, _) = create_connectivity_mock();
Expand All @@ -441,10 +445,13 @@ mod test {

#[test]
fn decrypt_inbound_fail() {
let result = Mutex::new(None);
let service = service_fn(|msg: DecryptedDhtMessage| {
*result.lock().unwrap() = Some(msg);
future::ready(Result::<(), PipelineError>::Ok(()))
let result = Arc::new(Mutex::new(None));
let service = service_fn({
let result = result.clone();
move |msg: DecryptedDhtMessage| {
*result.lock().unwrap() = Some(msg);
future::ready(Result::<(), PipelineError>::Ok(()))
}
});
let node_identity = make_node_identity();
let (connectivity, _) = create_connectivity_mock();
Expand All @@ -466,10 +473,13 @@ mod test {
async fn decrypt_inbound_fail_destination() {
let (connectivity, mock) = create_connectivity_mock();
mock.spawn();
let result = Mutex::new(None);
let service = service_fn(|msg: DecryptedDhtMessage| {
*result.lock().unwrap() = Some(msg);
future::ready(Result::<(), PipelineError>::Ok(()))
let result = Arc::new(Mutex::new(None));
let service = service_fn({
let result = result.clone();
move |msg: DecryptedDhtMessage| {
*result.lock().unwrap() = Some(msg);
future::ready(Result::<(), PipelineError>::Ok(()))
}
});
let node_identity = make_node_identity();
let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service);
Expand Down
16 changes: 10 additions & 6 deletions comms/dht/src/inbound/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use crate::{inbound::DhtInboundMessage, proto::envelope::DhtEnvelope};
use futures::{task::Context, Future};
use futures::{future::BoxFuture, task::Context};
use log::*;
use prost::Message;
use std::{convert::TryInto, sync::Arc, task::Poll};
Expand Down Expand Up @@ -51,21 +51,22 @@ impl<S> DhtDeserializeMiddleware<S> {
}

impl<S> Service<InboundMessage> for DhtDeserializeMiddleware<S>
where S: Service<DhtInboundMessage, Response = (), Error = PipelineError> + Clone + 'static
where
S: Service<DhtInboundMessage, Response = (), Error = PipelineError> + Clone + Send + 'static,
S::Future: Send,
{
type Error = PipelineError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
type Response = ();

type Future = impl Future<Output = Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, message: InboundMessage) -> Self::Future {
let next_service = self.next_service.clone();
let peer_manager = self.peer_manager.clone();
async move {
Box::pin(async move {
trace!(target: LOG_TARGET, "Deserializing InboundMessage {}", message.tag);

let InboundMessage {
Expand All @@ -92,14 +93,15 @@ where S: Service<DhtInboundMessage, Response = (), Error = PipelineError> + Clon
inbound_msg.dht_header.message_tag
);

let next_service = next_service.ready_oneshot().await?;
next_service.oneshot(inbound_msg).await
},
Err(err) => {
error!(target: LOG_TARGET, "DHT deserialization failed: {}", err);
Err(err.into())
},
}
}
})
}
}

Expand Down Expand Up @@ -127,6 +129,7 @@ mod test {
use crate::{
envelope::DhtMessageFlags,
test_utils::{
assert_send_static_service,
build_peer_manager,
make_comms_inbound_message,
make_dht_envelope,
Expand All @@ -144,6 +147,7 @@ mod test {
peer_manager.add_peer(node_identity.to_peer()).await.unwrap();

let mut deserialize = DeserializeLayer::new(peer_manager).layer(spy.to_service::<PipelineError>());
assert_send_static_service(&deserialize);

let dht_envelope = make_dht_envelope(
&node_identity,
Expand Down
Loading

0 comments on commit 3fb3d20

Please sign in to comment.