diff --git a/comms/core/src/builder/comms_node.rs b/comms/core/src/builder/comms_node.rs index 1e36b3e890..649497c2c7 100644 --- a/comms/core/src/builder/comms_node.rs +++ b/comms/core/src/builder/comms_node.rs @@ -43,7 +43,6 @@ use crate::{ }, connectivity::{ConnectivityEventRx, ConnectivityManager, ConnectivityRequest, ConnectivityRequester}, multiaddr::Multiaddr, - noise::NoiseConfig, peer_manager::{NodeIdentity, PeerManager}, protocol::{ ProtocolExtension, @@ -188,12 +187,9 @@ impl UnspawnedCommsNode { //---------------------------------- Connection Manager --------------------------------------------// - let noise_config = NoiseConfig::new(node_identity.clone()); - let mut connection_manager = ConnectionManager::new( connection_manager_config.clone(), transport.clone(), - noise_config, dial_backoff, connection_manager_request_rx, node_identity.clone(), diff --git a/comms/core/src/connection_manager/dialer.rs b/comms/core/src/connection_manager/dialer.rs index 0276b3931d..823b386a7f 100644 --- a/comms/core/src/connection_manager/dialer.rs +++ b/comms/core/src/connection_manager/dialer.rs @@ -20,11 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{ - collections::HashMap, - sync::Arc, - time::{Duration, Instant}, -}; +use std::{collections::HashMap, sync::Arc, time::Instant}; use futures::{ future, @@ -593,12 +589,9 @@ where .await .map_err(|_| ConnectionManagerError::WireFormatSendFailed)?; - let noise_socket = time::timeout( - Duration::from_secs(40), - noise_config.upgrade_socket(socket, ConnectionDirection::Outbound), - ) - .await - .map_err(|_| ConnectionManagerError::NoiseProtocolTimeout)??; + let noise_socket = noise_config + .upgrade_socket(socket, ConnectionDirection::Outbound) + .await?; let noise_upgrade_time = timer.elapsed(); debug!( diff --git a/comms/core/src/connection_manager/error.rs b/comms/core/src/connection_manager/error.rs index ff0585520c..80cddff827 100644 --- a/comms/core/src/connection_manager/error.rs +++ b/comms/core/src/connection_manager/error.rs @@ -78,8 +78,6 @@ pub enum ConnectionManagerError { InvalidMultiaddr(String), #[error("Failed to send wire format byte")] WireFormatSendFailed, - #[error("Noise protocol handshake timed out")] - NoiseProtocolTimeout, #[error("Listener oneshot cancelled")] ListenerOneshotCancelled, #[error("Peer validation error: {0}")] diff --git a/comms/core/src/connection_manager/listener.rs b/comms/core/src/connection_manager/listener.rs index 8699a11fc7..2266fd3027 100644 --- a/comms/core/src/connection_manager/listener.rs +++ b/comms/core/src/connection_manager/listener.rs @@ -351,12 +351,7 @@ where ); let timer = Instant::now(); - let mut noise_socket = time::timeout( - Duration::from_secs(30), - noise_config.upgrade_socket(socket, CONNECTION_DIRECTION), - ) - .await - .map_err(|_| ConnectionManagerError::NoiseProtocolTimeout)??; + let mut noise_socket = noise_config.upgrade_socket(socket, CONNECTION_DIRECTION).await?; let authenticated_public_key = noise_socket .get_remote_public_key() diff --git a/comms/core/src/connection_manager/manager.rs b/comms/core/src/connection_manager/manager.rs index 3ca4454804..78493205de 100644 --- a/comms/core/src/connection_manager/manager.rs +++ b/comms/core/src/connection_manager/manager.rs @@ -109,8 +109,12 @@ pub struct ConnectionManagerConfig { pub max_simultaneous_inbound_connects: usize, /// Version information for this node pub network_info: NodeNetworkInfo, - /// The maximum time to wait for the first byte before closing the connection. Default: 45s + /// The maximum time to wait for the first byte before closing the connection. Default: 3s pub time_to_first_byte: Duration, + /// The maximum time to wait for a noise protocol handshake message before timing out. For 1.5 RTT XX handshake, + /// the responder will wait 2 x this value (1 per receive) before timing out. + /// Default: 3s + pub noise_handshake_recv_timeout: Duration, /// The number of liveness check sessions to allow. Default: 0 pub liveness_max_sessions: usize, /// CIDR blocks that allowlist liveness checks. Default: Localhost only (127.0.0.1/32) @@ -120,6 +124,7 @@ pub struct ConnectionManagerConfig { /// If set, an additional TCP-only p2p listener will be started. This is useful for local wallet connections. /// Default: None (disabled) pub auxiliary_tcp_listener_address: Option, + /// Peer validation configuration. See [PeerValidatorConfig] pub peer_validation_config: PeerValidatorConfig, } @@ -136,11 +141,12 @@ impl Default for ConnectionManagerConfig { max_simultaneous_inbound_connects: 100, network_info: Default::default(), liveness_max_sessions: 1, - time_to_first_byte: Duration::from_secs(45), + time_to_first_byte: Duration::from_secs(3), liveness_cidr_allowlist: vec![cidr::AnyIpCidr::V4("127.0.0.1/32".parse().unwrap())], liveness_self_check_interval: None, auxiliary_tcp_listener_address: None, peer_validation_config: PeerValidatorConfig::default(), + noise_handshake_recv_timeout: Duration::from_secs(3), } } } @@ -191,7 +197,6 @@ where pub(crate) fn new( mut config: ConnectionManagerConfig, transport: TTransport, - noise_config: NoiseConfig, backoff: TBackoff, request_rx: mpsc::Receiver, node_identity: Arc, @@ -202,6 +207,9 @@ where let (internal_event_tx, internal_event_rx) = mpsc::channel(EVENT_CHANNEL_SIZE); let (dialer_tx, dialer_rx) = mpsc::channel(DIALER_REQUEST_CHANNEL_SIZE); + let noise_config = + NoiseConfig::new(node_identity.clone()).with_recv_timeout(config.noise_handshake_recv_timeout); + let listener = PeerListener::new( config.clone(), config.listener_address.clone(), diff --git a/comms/core/src/connection_manager/tests/manager.rs b/comms/core/src/connection_manager/tests/manager.rs index 292b1f4925..1abde14e21 100644 --- a/comms/core/src/connection_manager/tests/manager.rs +++ b/comms/core/src/connection_manager/tests/manager.rs @@ -34,14 +34,12 @@ use tokio::{ use crate::{ backoff::ConstantBackoff, connection_manager::{ - error::ConnectionManagerError, - manager::ConnectionManagerEvent, ConnectionManager, + ConnectionManagerError, + ConnectionManagerEvent, ConnectionManagerRequester, - PeerConnectionError, }, net_address::{MultiaddressesWithStats, PeerAddressSource}, - noise::NoiseConfig, peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags, PeerManagerError}, protocol::{ProtocolEvent, ProtocolId, Protocols}, test_utils::{ @@ -51,13 +49,13 @@ use crate::{ test_node::{build_connection_manager, TestNodeConfig}, }, transports::{MemoryTransport, TcpTransport}, + PeerConnectionError, }; #[tokio::test] async fn connect_to_nonexistent_peer() { let rt_handle = Handle::current(); let node_identity = build_node_identity(PeerFeatures::empty()); - let noise_config = NoiseConfig::new(node_identity.clone()); let (request_tx, request_rx) = mpsc::channel(1); let (event_tx, _) = broadcast::channel(1); let mut requester = ConnectionManagerRequester::new(request_tx, event_tx.clone()); @@ -68,7 +66,6 @@ async fn connect_to_nonexistent_peer() { let connection_manager = ConnectionManager::new( Default::default(), MemoryTransport, - noise_config, ConstantBackoff::new(Duration::from_secs(1)), request_rx, node_identity, @@ -80,8 +77,7 @@ async fn connect_to_nonexistent_peer() { rt_handle.spawn(connection_manager.run()); let err = requester.dial_peer(NodeId::default()).await.unwrap_err(); - unpack_enum!(ConnectionManagerError::PeerManagerError(err) = err); - unpack_enum!(PeerManagerError::PeerNotFoundError = err); + unpack_enum!(ConnectionManagerError::PeerManagerError(PeerManagerError::PeerNotFoundError) = err); shutdown.trigger(); } diff --git a/comms/core/src/noise/config.rs b/comms/core/src/noise/config.rs index c641bf8f7f..4eec198e3b 100644 --- a/comms/core/src/noise/config.rs +++ b/comms/core/src/noise/config.rs @@ -22,7 +22,7 @@ // This file is heavily influenced by the Libra Noise protocol implementation. -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use log::*; use snow::{self, params::NoiseParams}; @@ -48,6 +48,7 @@ pub(super) const NOISE_PARAMETERS: &str = "Noise_XX_25519_ChaChaPoly_BLAKE2b"; pub struct NoiseConfig { node_identity: Arc, parameters: NoiseParams, + recv_timeout: Duration, } impl NoiseConfig { @@ -57,9 +58,16 @@ impl NoiseConfig { Self { node_identity, parameters, + recv_timeout: Duration::from_secs(1), } } + /// Sets a custom receive timeout when waiting for handshake responses. + pub fn with_recv_timeout(mut self, recv_timeout: Duration) -> Self { + self.recv_timeout = recv_timeout; + self + } + /// Upgrades the given socket to using the noise protocol. The upgraded socket and the peer's static key /// is returned. #[tracing::instrument(name = "noise::upgrade_socket", skip(self, socket))] @@ -90,7 +98,7 @@ impl NoiseConfig { } }; - let handshake = Handshake::new(socket, handshake_state); + let handshake = Handshake::new(socket, handshake_state, self.recv_timeout); let socket = handshake .perform_handshake() .await diff --git a/comms/core/src/noise/socket.rs b/comms/core/src/noise/socket.rs index c59a4a367b..2554cb53d3 100644 --- a/comms/core/src/noise/socket.rs +++ b/comms/core/src/noise/socket.rs @@ -32,13 +32,17 @@ use std::{ io, pin::Pin, task::{Context, Poll}, + time::Duration, }; use futures::ready; use log::*; use snow::{error::StateProblem, HandshakeState, TransportState}; use tari_utilities::ByteArray; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}, + time, +}; use crate::types::CommsPublicKey; @@ -515,12 +519,14 @@ where TSocket: AsyncWrite + Unpin pub struct Handshake { socket: NoiseSocket, + recv_timeout: Duration, } impl Handshake { - pub fn new(socket: TSocket, state: HandshakeState) -> Self { + pub fn new(socket: TSocket, state: HandshakeState, recv_timeout: Duration) -> Self { Self { socket: NoiseSocket::new(socket, state.into()), + recv_timeout, } } } @@ -581,7 +587,9 @@ where TSocket: AsyncRead + AsyncWrite + Unpin } async fn receive(&mut self) -> io::Result { - self.socket.read(&mut []).await + time::timeout(self.recv_timeout, self.socket.read(&mut [])) + .await + .map_err(|_| io::Error::from(io::ErrorKind::TimedOut))? } fn build(self) -> io::Result> { @@ -683,8 +691,14 @@ mod test { ); Ok(( - (dialer_keypair, Handshake { socket: dialer }), - (listener_keypair, Handshake { socket: listener }), + (dialer_keypair, Handshake { + socket: dialer, + recv_timeout: Duration::from_secs(1), + }), + (listener_keypair, Handshake { + socket: listener, + recv_timeout: Duration::from_secs(1), + }), )) } diff --git a/comms/core/src/test_utils/test_node.rs b/comms/core/src/test_utils/test_node.rs index d1b4f5dcf3..9e1fd4b584 100644 --- a/comms/core/src/test_utils/test_node.rs +++ b/comms/core/src/test_utils/test_node.rs @@ -33,7 +33,6 @@ use crate::{ backoff::ConstantBackoff, connection_manager::{ConnectionManager, ConnectionManagerConfig, ConnectionManagerRequester}, multiplexing::Substream, - noise::NoiseConfig, peer_manager::{NodeIdentity, PeerFeatures, PeerManager}, peer_validator::PeerValidatorConfig, protocol::Protocols, @@ -81,7 +80,6 @@ where TTransport: Transport + Unpin + Send + Sync + Clone + 'static, TTransport::Output: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - let noise_config = NoiseConfig::new(config.node_identity.clone()); let (request_tx, request_rx) = mpsc::channel(10); let (event_tx, _) = broadcast::channel(100); @@ -90,7 +88,6 @@ where let mut connection_manager = ConnectionManager::new( config.connection_manager_config, transport, - noise_config, ConstantBackoff::new(config.dial_backoff_duration), request_rx, config.node_identity,