Skip to content

Commit

Permalink
refactor(relay): use oneshot's to track requested streams
Browse files Browse the repository at this point in the history
This is much cleaner as it allows us to construct a single `Future` that expresses the entire outbound protocol from stream opening to finish.

Pull-Request: #4900.
  • Loading branch information
thomaseizinger authored Nov 28, 2023
1 parent ee17df9 commit 4d7a535
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 125 deletions.
220 changes: 97 additions & 123 deletions protocols/relay/src/priv_client/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,29 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use crate::client::Connection;
use crate::priv_client::transport;
use crate::priv_client::transport::ToListenerMsg;
use crate::protocol::{self, inbound_stop, outbound_hop};
use crate::{priv_client, proto, HOP_PROTOCOL_NAME, STOP_PROTOCOL_NAME};
use futures::channel::mpsc::Sender;
use futures::channel::{mpsc, oneshot};
use futures::future::FutureExt;
use futures_timer::Delay;
use libp2p_core::multiaddr::Protocol;
use libp2p_core::upgrade::ReadyUpgrade;
use libp2p_core::Multiaddr;
use libp2p_identity::PeerId;
use libp2p_swarm::handler::{
ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound,
};
use libp2p_swarm::handler::{ConnectionEvent, FullyNegotiatedInbound};
use libp2p_swarm::{
ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, StreamUpgradeError,
ConnectionHandler, ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError,
SubstreamProtocol,
};
use std::collections::VecDeque;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{fmt, io};
use void::Void;

/// The maximum number of circuits being denied concurrently.
///
Expand Down Expand Up @@ -104,8 +106,7 @@ pub struct Handler {
>,
>,

/// We issue a stream upgrade for each pending request.
pending_requests: VecDeque<PendingRequest>,
pending_streams: VecDeque<oneshot::Sender<Result<Stream, StreamUpgradeError<Void>>>>,

inflight_reserve_requests: futures_bounded::FuturesTupleSet<
Result<outbound_hop::Reservation, outbound_hop::ReserveError>,
Expand Down Expand Up @@ -133,7 +134,7 @@ impl Handler {
remote_peer_id,
remote_addr,
queued_events: Default::default(),
pending_requests: Default::default(),
pending_streams: Default::default(),
inflight_reserve_requests: futures_bounded::FuturesTupleSet::new(
STREAM_TIMEOUT,
MAX_CONCURRENT_STREAMS_PER_CONNECTION,
Expand All @@ -154,57 +155,6 @@ impl Handler {
}
}

fn on_dial_upgrade_error(
&mut self,
DialUpgradeError { error, .. }: DialUpgradeError<
<Self as ConnectionHandler>::OutboundOpenInfo,
<Self as ConnectionHandler>::OutboundProtocol,
>,
) {
let pending_request = self
.pending_requests
.pop_front()
.expect("got a stream error without a pending request");

match pending_request {
PendingRequest::Reserve { mut to_listener } => {
let error = match error {
StreamUpgradeError::Timeout => {
outbound_hop::ReserveError::Io(io::ErrorKind::TimedOut.into())
}
StreamUpgradeError::Apply(never) => void::unreachable(never),
StreamUpgradeError::NegotiationFailed => {
outbound_hop::ReserveError::Unsupported
}
StreamUpgradeError::Io(e) => outbound_hop::ReserveError::Io(e),
};

if let Err(e) =
to_listener.try_send(transport::ToListenerMsg::Reservation(Err(error)))
{
tracing::debug!("Unable to send error to listener: {}", e.into_send_error())
}
self.reservation.failed();
}
PendingRequest::Connect {
to_dial: send_back, ..
} => {
let error = match error {
StreamUpgradeError::Timeout => {
outbound_hop::ConnectError::Io(io::ErrorKind::TimedOut.into())
}
StreamUpgradeError::NegotiationFailed => {
outbound_hop::ConnectError::Unsupported
}
StreamUpgradeError::Io(e) => outbound_hop::ConnectError::Io(e),
StreamUpgradeError::Apply(v) => void::unreachable(v),
};

let _ = send_back.send(Err(error));
}
}
}

fn insert_to_deny_futs(&mut self, circuit: inbound_stop::Circuit) {
let src_peer_id = circuit.src_peer_id();

Expand All @@ -219,6 +169,62 @@ impl Handler {
)
}
}

fn make_new_reservation(&mut self, to_listener: Sender<ToListenerMsg>) {
let (sender, receiver) = oneshot::channel();

self.pending_streams.push_back(sender);
self.queued_events
.push_back(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()),
});
let result = self.inflight_reserve_requests.try_push(
async move {
let stream = receiver
.await
.map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?
.map_err(into_reserve_error)?;

let reservation = outbound_hop::make_reservation(stream).await?;

Ok(reservation)
},
to_listener,
);

if result.is_err() {
tracing::warn!("Dropping in-flight reservation request because we are at capacity");
}
}

fn establish_new_circuit(
&mut self,
to_dial: oneshot::Sender<Result<Connection, outbound_hop::ConnectError>>,
dst_peer_id: PeerId,
) {
let (sender, receiver) = oneshot::channel();

self.pending_streams.push_back(sender);
self.queued_events
.push_back(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()),
});
let result = self.inflight_outbound_connect_requests.try_push(
async move {
let stream = receiver
.await
.map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?
.map_err(into_connect_error)?;

outbound_hop::open_circuit(stream, dst_peer_id).await
},
to_dial,
);

if result.is_err() {
tracing::warn!("Dropping in-flight connect request because we are at capacity")
}
}
}

impl ConnectionHandler for Handler {
Expand All @@ -236,25 +242,13 @@ impl ConnectionHandler for Handler {
fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
match event {
In::Reserve { to_listener } => {
self.pending_requests
.push_back(PendingRequest::Reserve { to_listener });
self.queued_events
.push_back(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()),
});
self.make_new_reservation(to_listener);
}
In::EstablishCircuit {
to_dial: send_back,
to_dial,
dst_peer_id,
} => {
self.pending_requests.push_back(PendingRequest::Connect {
dst_peer_id,
to_dial: send_back,
});
self.queued_events
.push_back(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()),
});
self.establish_new_circuit(to_dial, dst_peer_id);
}
}
}
Expand Down Expand Up @@ -402,12 +396,8 @@ impl ConnectionHandler for Handler {
}

if let Poll::Ready(Some(to_listener)) = self.reservation.poll(cx) {
self.pending_requests
.push_back(PendingRequest::Reserve { to_listener });

return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()),
});
self.make_new_reservation(to_listener);
continue;
}

// Deny incoming circuit requests.
Expand Down Expand Up @@ -450,42 +440,16 @@ impl ConnectionHandler for Handler {
tracing::warn!("Dropping inbound stream because we are at capacity")
}
}
ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
protocol: stream,
..
}) => {
let pending_request = self.pending_requests.pop_front().expect(
"opened a stream without a pending connection command or a reserve listener",
);
match pending_request {
PendingRequest::Reserve { to_listener } => {
if self
.inflight_reserve_requests
.try_push(outbound_hop::make_reservation(stream), to_listener)
.is_err()
{
tracing::warn!("Dropping outbound stream because we are at capacity");
}
}
PendingRequest::Connect {
dst_peer_id,
to_dial: send_back,
} => {
if self
.inflight_outbound_connect_requests
.try_push(outbound_hop::open_circuit(stream, dst_peer_id), send_back)
.is_err()
{
tracing::warn!("Dropping outbound stream because we are at capacity");
}
}
ConnectionEvent::FullyNegotiatedOutbound(ev) => {
if let Some(next) = self.pending_streams.pop_front() {
let _ = next.send(Ok(ev.protocol));
}
}
ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
void::unreachable(listen_upgrade_error.error)
}
ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
self.on_dial_upgrade_error(dial_upgrade_error)
ConnectionEvent::ListenUpgradeError(ev) => void::unreachable(ev.error),
ConnectionEvent::DialUpgradeError(ev) => {
if let Some(next) = self.pending_streams.pop_front() {
let _ = next.send(Err(ev.error));
}
}
_ => {}
}
Expand Down Expand Up @@ -614,14 +578,24 @@ impl Reservation {
}
}

pub(crate) enum PendingRequest {
Reserve {
/// A channel into the [`Transport`](priv_client::Transport).
to_listener: mpsc::Sender<transport::ToListenerMsg>,
},
Connect {
dst_peer_id: PeerId,
/// A channel into the future returned by [`Transport::dial`](libp2p_core::Transport::dial).
to_dial: oneshot::Sender<Result<priv_client::Connection, outbound_hop::ConnectError>>,
},
fn into_reserve_error(e: StreamUpgradeError<Void>) -> outbound_hop::ReserveError {
match e {
StreamUpgradeError::Timeout => {
outbound_hop::ReserveError::Io(io::ErrorKind::TimedOut.into())
}
StreamUpgradeError::Apply(never) => void::unreachable(never),
StreamUpgradeError::NegotiationFailed => outbound_hop::ReserveError::Unsupported,
StreamUpgradeError::Io(e) => outbound_hop::ReserveError::Io(e),
}
}

fn into_connect_error(e: StreamUpgradeError<Void>) -> outbound_hop::ConnectError {
match e {
StreamUpgradeError::Timeout => {
outbound_hop::ConnectError::Io(io::ErrorKind::TimedOut.into())
}
StreamUpgradeError::Apply(never) => void::unreachable(never),
StreamUpgradeError::NegotiationFailed => outbound_hop::ConnectError::Unsupported,
StreamUpgradeError::Io(e) => outbound_hop::ConnectError::Io(e),
}
}
4 changes: 2 additions & 2 deletions protocols/relay/src/protocol/outbound_hop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub enum ConnectError {
#[error("Remote does not support the `{HOP_PROTOCOL_NAME}` protocol")]
Unsupported,
#[error("IO error")]
Io(#[source] io::Error),
Io(#[from] io::Error),
#[error("Protocol error")]
Protocol(#[from] ProtocolViolation),
}
Expand All @@ -61,7 +61,7 @@ pub enum ReserveError {
#[error("Remote does not support the `{HOP_PROTOCOL_NAME}` protocol")]
Unsupported,
#[error("IO error")]
Io(#[source] io::Error),
Io(#[from] io::Error),
#[error("Protocol error")]
Protocol(#[from] ProtocolViolation),
}
Expand Down

0 comments on commit 4d7a535

Please sign in to comment.