From a1c76aa2769709924880e69e6738027dbab40743 Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 11 Jul 2024 22:52:41 +0530 Subject: [PATCH] basichost: reset new stream if rcmgr blocks (#2869) --- p2p/host/basic/basic_host.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 8fc808e6b6..6d6db91aef 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -640,7 +640,7 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) { // header with given protocol.ID. If there is no connection to p, attempts // to create one. If ProtocolID is "", writes no header. // (Thread-safe) -func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { +func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (str network.Stream, strErr error) { // If the caller wants to prevent the host from dialing, it should use the NoDial option. if nodial, _ := network.GetNoDial(ctx); !nodial { err := h.Connect(ctx, peer.AddrInfo{ID: p}) @@ -658,6 +658,11 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I } return nil, fmt.Errorf("failed to open stream: %w", err) } + defer func() { + if strErr != nil && s != nil { + s.Reset() + } + }() // Wait for any in-progress identifies on the connection to finish. This // is faster than negotiating. @@ -667,13 +672,11 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I select { case <-h.ids.IdentifyWait(s.Conn()): case <-ctx.Done(): - _ = s.Reset() return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err()) } pref, err := h.preferredProtocol(p, pids) if err != nil { - _ = s.Reset() return nil, err } @@ -698,7 +701,6 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I select { case err = <-errCh: if err != nil { - s.Reset() return nil, fmt.Errorf("failed to negotiate protocol: %w", err) } case <-ctx.Done():