From 0fe12851b276d4cf2bcce75bce290bdd6e9c0b8d Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 1 Dec 2023 21:37:20 +0530 Subject: [PATCH] webrtc: wait for fin_ack for closing datachannel --- core/network/mux.go | 7 + p2p/net/swarm/swarm_stream.go | 6 + p2p/net/swarm/swarm_stream_test.go | 45 +++++ p2p/test/swarm/swarm_test.go | 83 +++++++++ p2p/test/transport/transport_test.go | 12 +- p2p/transport/webrtc/connection.go | 110 ++++++------ p2p/transport/webrtc/pb/message.pb.go | 29 ++-- p2p/transport/webrtc/pb/message.proto | 4 + p2p/transport/webrtc/stream.go | 227 +++++++++++++++++-------- p2p/transport/webrtc/stream_read.go | 38 ++--- p2p/transport/webrtc/stream_test.go | 178 +++++++++++++++++-- p2p/transport/webrtc/stream_write.go | 69 ++------ p2p/transport/webrtc/transport_test.go | 49 ++++++ 13 files changed, 625 insertions(+), 232 deletions(-) create mode 100644 p2p/net/swarm/swarm_stream_test.go diff --git a/core/network/mux.go b/core/network/mux.go index d12e2ea34b..fdda55365a 100644 --- a/core/network/mux.go +++ b/core/network/mux.go @@ -61,6 +61,13 @@ type MuxedStream interface { SetWriteDeadline(time.Time) error } +// AsyncCloser is implemented by streams that need to do expensive operations on close before +// releasing the resources. Closing the stream async avoids blocking the calling goroutine. +type AsyncCloser interface { + // AsyncClose closes the stream and executes onDone after the stream is closed + AsyncClose(onDone func()) error +} + // MuxedConn represents a connection to a remote peer that has been // extended to support stream multiplexing. // diff --git a/p2p/net/swarm/swarm_stream.go b/p2p/net/swarm/swarm_stream.go index b7846adec2..1339709db2 100644 --- a/p2p/net/swarm/swarm_stream.go +++ b/p2p/net/swarm/swarm_stream.go @@ -78,6 +78,12 @@ func (s *Stream) Write(p []byte) (int, error) { // Close closes the stream, closing both ends and freeing all associated // resources. func (s *Stream) Close() error { + if as, ok := s.stream.(network.AsyncCloser); ok { + err := as.AsyncClose(func() { + s.closeAndRemoveStream() + }) + return err + } err := s.stream.Close() s.closeAndRemoveStream() return err diff --git a/p2p/net/swarm/swarm_stream_test.go b/p2p/net/swarm/swarm_stream_test.go new file mode 100644 index 0000000000..653489fe8f --- /dev/null +++ b/p2p/net/swarm/swarm_stream_test.go @@ -0,0 +1,45 @@ +package swarm + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/stretchr/testify/require" +) + +type asyncStreamWrapper struct { + network.MuxedStream + beforeClose func() +} + +func (s *asyncStreamWrapper) AsyncClose(onDone func()) error { + s.beforeClose() + err := s.Close() + onDone() + return err +} + +func TestStreamAsyncCloser(t *testing.T) { + s1 := makeSwarm(t) + s2 := makeSwarm(t) + + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL) + s, err := s1.NewStream(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + ss, ok := s.(*Stream) + require.True(t, ok) + + var called atomic.Bool + as := &asyncStreamWrapper{ + MuxedStream: ss.stream, + beforeClose: func() { + called.Store(true) + }, + } + ss.stream = as + ss.Close() + require.True(t, called.Load()) +} diff --git a/p2p/test/swarm/swarm_test.go b/p2p/test/swarm/swarm_test.go index 9874431441..7f2e731f25 100644 --- a/p2p/test/swarm/swarm_test.go +++ b/p2p/test/swarm/swarm_test.go @@ -2,6 +2,7 @@ package swarm_test import ( "context" + "fmt" "io" "sync" "testing" @@ -14,6 +15,7 @@ import ( rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -243,3 +245,84 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) { return false }, 5*time.Second, 100*time.Millisecond) } + +func TestLimitStreamsWhenHangingHandlersWebRTC(t *testing.T) { + var partial rcmgr.PartialLimitConfig + const streamLimit = 10 + partial.System.Streams = streamLimit + mgr, err := rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(partial.Build(rcmgr.InfiniteLimits))) + require.NoError(t, err) + + maddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/webrtc-direct") + require.NoError(t, err) + + receiver, err := libp2p.New( + libp2p.ResourceManager(mgr), + libp2p.ListenAddrs(maddr), + libp2p.Transport(libp2pwebrtc.New), + ) + require.NoError(t, err) + t.Cleanup(func() { receiver.Close() }) + + var wg sync.WaitGroup + wg.Add(1) + + const pid = "/test" + receiver.SetStreamHandler(pid, func(s network.Stream) { + defer s.Close() + s.Write([]byte{42}) + wg.Wait() + }) + + // Open streamLimit streams + success := 0 + // we make a lot of tries because identify and identify push take up a few streams + for i := 0; i < 1000 && success < streamLimit; i++ { + mgr, err = rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(rcmgr.InfiniteLimits)) + require.NoError(t, err) + + sender, err := libp2p.New(libp2p.ResourceManager(mgr), libp2p.Transport(libp2pwebrtc.New)) + require.NoError(t, err) + t.Cleanup(func() { sender.Close() }) + + sender.Peerstore().AddAddrs(receiver.ID(), receiver.Addrs(), peerstore.PermanentAddrTTL) + + s, err := sender.NewStream(context.Background(), receiver.ID(), pid) + if err != nil { + continue + } + + var b [1]byte + _, err = io.ReadFull(s, b[:]) + if err == nil { + success++ + } + sender.Close() + } + require.Equal(t, streamLimit, success) + // We have the maximum number of streams open. Next call should fail. + mgr, err = rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(rcmgr.InfiniteLimits)) + require.NoError(t, err) + + sender, err := libp2p.New(libp2p.ResourceManager(mgr), libp2p.Transport(libp2pwebrtc.New)) + require.NoError(t, err) + t.Cleanup(func() { sender.Close() }) + + sender.Peerstore().AddAddrs(receiver.ID(), receiver.Addrs(), peerstore.PermanentAddrTTL) + + _, err = sender.NewStream(context.Background(), receiver.ID(), pid) + require.Error(t, err) + // Close the open streams + wg.Done() + + // Next call should succeed + require.Eventually(t, func() bool { + s, err := sender.NewStream(context.Background(), receiver.ID(), pid) + if err == nil { + s.Close() + return true + } + fmt.Println(err) + return false + }, 5*time.Second, 1*time.Second) +} diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index a7e98a0d85..b078cf7b04 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -382,9 +382,6 @@ func TestMoreStreamsThanOurLimits(t *testing.T) { const streamCount = 1024 for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { - if strings.Contains(tc.Name, "WebRTC") { - t.Skip("This test potentially exhausts the uint16 WebRTC stream ID space.") - } listenerLimits := rcmgr.PartialLimitConfig{ PeerDefault: rcmgr.ResourceLimits{ Streams: 32, @@ -428,7 +425,9 @@ func TestMoreStreamsThanOurLimits(t *testing.T) { workerCount := 4 var startWorker func(workerIdx int) + var wCount atomic.Int32 startWorker = func(workerIdx int) { + fmt.Println("worker count", wCount.Add(1)) wg.Add(1) defer wg.Done() for { @@ -440,7 +439,10 @@ func TestMoreStreamsThanOurLimits(t *testing.T) { // Inline function so we can use defer func() { var didErr bool - defer completedStreams.Add(1) + defer func() { + x := completedStreams.Add(1) + fmt.Println("completed streams", x) + }() defer func() { // Only the first worker adds more workers if workerIdx == 0 && !didErr && !sawFirstErr.Load() { @@ -483,7 +485,6 @@ func TestMoreStreamsThanOurLimits(t *testing.T) { return } err = func(s network.Stream) error { - defer s.Close() err = s.SetDeadline(time.Now().Add(100 * time.Millisecond)) if err != nil { return err @@ -511,6 +512,7 @@ func TestMoreStreamsThanOurLimits(t *testing.T) { return nil }(s) if err != nil && shouldRetry(err) { + fmt.Println("failed to write stream!", err) time.Sleep(50 * time.Millisecond) continue } diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index fd31f8351a..3241ce46cd 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -4,10 +4,8 @@ import ( "context" "errors" "fmt" - "math" "net" "sync" - "sync/atomic" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" @@ -25,9 +23,7 @@ import ( var _ tpt.CapableConn = &connection{} -const maxAcceptQueueLen = 10 - -const maxDataChannelID = 1 << 10 +const maxAcceptQueueLen = 256 type errConnectionTimeout struct{} @@ -47,7 +43,8 @@ type connection struct { transport *WebRTCTransport scope network.ConnManagementScope - closeErr error + closeOnce sync.Once + closeErr error localPeer peer.ID localMultiaddr ma.Multiaddr @@ -56,9 +53,8 @@ type connection struct { remoteKey ic.PubKey remoteMultiaddr ma.Multiaddr - m sync.Mutex - streams map[uint16]*stream - nextStreamID atomic.Int32 + m sync.Mutex + streams map[uint16]*stream acceptQueue chan dataChannel @@ -97,25 +93,12 @@ func newConnection( acceptQueue: make(chan dataChannel, maxAcceptQueueLen), } - switch direction { - case network.DirInbound: - c.nextStreamID.Store(1) - case network.DirOutbound: - // stream ID 0 is used for the Noise handshake stream - c.nextStreamID.Store(2) - } pc.OnConnectionStateChange(c.onConnectionStateChange) pc.OnDataChannel(func(dc *webrtc.DataChannel) { if c.IsClosed() { return } - // Limit the number of streams, since we're not able to actually properly close them. - // See https://github.com/libp2p/specs/issues/575 for details. - if *dc.ID() > maxDataChannelID { - c.Close() - return - } dc.OnOpen(func() { rwc, err := dc.Detach() if err != nil { @@ -133,7 +116,6 @@ func newConnection( } }) }) - return c, nil } @@ -144,16 +126,41 @@ func (c *connection) ConnState() network.ConnectionState { // Close closes the underlying peerconnection. func (c *connection) Close() error { - if c.IsClosed() { - return nil - } + c.closeOnce.Do(func() { + c.closeErr = errors.New("connection closed") + // cancel must be called after closeErr is set. This ensures interested goroutines waiting on + // ctx.Done can read closeErr without holding the conn lock. + c.cancel() + c.m.Lock() + streams := c.streams + c.streams = nil + c.m.Unlock() + for _, str := range streams { + str.Reset() + } + c.pc.Close() + c.scope.Done() + }) + return nil +} - c.m.Lock() - defer c.m.Unlock() - c.scope.Done() - c.closeErr = errors.New("connection closed") - c.cancel() - return c.pc.Close() +func (c *connection) closeTimedOut() error { + c.closeOnce.Do(func() { + c.closeErr = errConnectionTimeout{} + // cancel must be called after closeErr is set. This ensures interested goroutines waiting on + // ctx.Done can read closeErr without holding the conn lock. + c.cancel() + c.m.Lock() + streams := c.streams + c.streams = nil + c.m.Unlock() + for _, str := range streams { + str.closeWithError(errConnectionTimeout{}) + } + c.pc.Close() + c.scope.Done() + }) + return nil } func (c *connection) IsClosed() bool { @@ -170,19 +177,7 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error return nil, c.closeErr } - id := c.nextStreamID.Add(2) - 2 - if id > math.MaxUint16 { - return nil, errors.New("exhausted stream ID space") - } - // Limit the number of streams, since we're not able to actually properly close them. - // See https://github.com/libp2p/specs/issues/575 for details. - if id > maxDataChannelID { - c.Close() - return c.OpenStream(ctx) - } - - streamID := uint16(id) - dc, err := c.pc.CreateDataChannel("", &webrtc.DataChannelInit{ID: &streamID}) + dc, err := c.pc.CreateDataChannel("", nil) if err != nil { return nil, err } @@ -190,9 +185,10 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error if err != nil { return nil, fmt.Errorf("open stream: %w", err) } - str := newStream(dc, rwc, func() { c.removeStream(streamID) }) + fmt.Println("opened dc with ID: ", *dc.ID()) + str := newStream(dc, rwc, func() { c.removeStream(*dc.ID()) }) if err := c.addStream(str); err != nil { - str.Close() + str.Reset() return nil, err } return str, nil @@ -205,7 +201,7 @@ func (c *connection) AcceptStream() (network.MuxedStream, error) { case dc := <-c.acceptQueue: str := newStream(dc.channel, dc.stream, func() { c.removeStream(*dc.channel.ID()) }) if err := c.addStream(str); err != nil { - str.Close() + str.Reset() return nil, err } return str, nil @@ -223,6 +219,9 @@ func (c *connection) Transport() tpt.Transport { return c.transport } func (c *connection) addStream(str *stream) error { c.m.Lock() defer c.m.Unlock() + if c.IsClosed() { + return fmt.Errorf("connection closed: %w", c.closeErr) + } if _, ok := c.streams[str.id]; ok { return errors.New("stream ID already exists") } @@ -238,20 +237,7 @@ func (c *connection) removeStream(id uint16) { func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { - // reset any streams - if c.IsClosed() { - return - } - c.m.Lock() - defer c.m.Unlock() - c.closeErr = errConnectionTimeout{} - for k, str := range c.streams { - str.setCloseError(c.closeErr) - delete(c.streams, k) - } - c.cancel() - c.scope.Done() - c.pc.Close() + c.closeTimedOut() } } diff --git a/p2p/transport/webrtc/pb/message.pb.go b/p2p/transport/webrtc/pb/message.pb.go index fffc025f7f..384bddd289 100644 --- a/p2p/transport/webrtc/pb/message.pb.go +++ b/p2p/transport/webrtc/pb/message.pb.go @@ -31,6 +31,10 @@ const ( // The sender abruptly terminates the sending part of the stream. The // receiver can discard any data that it already received on that stream. Message_RESET Message_Flag = 2 + // Sending the FIN_ACK flag acknowledges the previous receipt of a message + // with the FIN flag set. Receiving a FIN_ACK flag gives the recipient + // confidence that the remote has received all sent messages. + Message_FIN_ACK Message_Flag = 3 ) // Enum value maps for Message_Flag. @@ -39,11 +43,13 @@ var ( 0: "FIN", 1: "STOP_SENDING", 2: "RESET", + 3: "FIN_ACK", } Message_Flag_value = map[string]int32{ "FIN": 0, "STOP_SENDING": 1, "RESET": 2, + "FIN_ACK": 3, } ) @@ -143,17 +149,18 @@ var File_message_proto protoreflect.FileDescriptor var file_message_proto_rawDesc = []byte{ 0x0a, 0x0d, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, - 0x74, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x66, 0x6c, - 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0d, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x12, 0x18, 0x0a, - 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, - 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x2c, 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, - 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, - 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, - 0x53, 0x45, 0x54, 0x10, 0x02, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, - 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, - 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, + 0x81, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x66, + 0x6c, 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0d, 0x2e, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x12, 0x18, + 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x39, 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, + 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, + 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, + 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, + 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, + 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, + 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, } var ( diff --git a/p2p/transport/webrtc/pb/message.proto b/p2p/transport/webrtc/pb/message.proto index d6b1957beb..aab885b0da 100644 --- a/p2p/transport/webrtc/pb/message.proto +++ b/p2p/transport/webrtc/pb/message.proto @@ -12,6 +12,10 @@ message Message { // The sender abruptly terminates the sending part of the stream. The // receiver can discard any data that it already received on that stream. RESET = 2; + // Sending the FIN_ACK flag acknowledges the previous receipt of a message + // with the FIN flag set. Receiving a FIN_ACK flag gives the recipient + // confidence that the remote has received all sent messages. + FIN_ACK = 3; } optional Flag flag=1; diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 0358dce56c..db7109e4bf 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -1,6 +1,8 @@ package libp2pwebrtc import ( + "errors" + "os" "sync" "time" @@ -52,6 +54,7 @@ type sendState uint8 const ( sendStateSending sendState = iota sendStateDataSent + sendStateDataReceived sendStateReset ) @@ -59,27 +62,38 @@ const ( // and then a network.MuxedStream type stream struct { mx sync.Mutex - // pbio.Reader is not thread safe, - // and while our Read is not promised to be thread safe, - // we ourselves internally read from multiple routines... - reader pbio.Reader + + // readerOnce ensures that only a single goroutine reads from the reader. Read is not threadsafe + // But we may need to read from reader for control messages from a different goroutine. + readerOnce chan struct{} + reader pbio.Reader + // this buffer is limited up to a single message. Reason we need it // is because a reader might read a message midway, and so we need a // wait to buffer that for as long as the remaining part is not (yet) read nextMessage *pb.Message receiveState receiveState - // The public Write API is not promised to be thread safe, - // but we need to be able to write control messages. + // writerMx ensures that only a single goroutine is calling WriteMsg on writer. writer is a + // pbio.uvarintWriter which is not thread safe. The public Write API is not promised to be + // thread safe, but we need to be able to write control messages concurrently + writerMx sync.Mutex writer pbio.Writer sendStateChanged chan struct{} sendState sendState - controlMsgQueue []*pb.Message writeDeadline time.Time writeDeadlineUpdated chan struct{} writeAvailable chan struct{} - readLoopOnce sync.Once + controlMessageReaderOnce sync.Once + // controlMessageReaderEndTime is the end time for reading FIN_ACK from the control + // message reader. We cannot rely on SetReadDeadline to do this since that is prone to + // race condition where a previous deadline timer fires after the latest call to + // SetReadDeadline + // See: https://github.com/pion/sctp/pull/290 + controlMessageReaderEndTime time.Time + controlMessageReaderStarted chan struct{} + controlMessageReaderDone chan struct{} onDone func() id uint16 // for logging purposes @@ -95,13 +109,17 @@ func newStream( onDone func(), ) *stream { s := &stream{ - reader: pbio.NewDelimitedReader(rwc, maxMessageSize), - writer: pbio.NewDelimitedWriter(rwc), + readerOnce: make(chan struct{}, 1), + reader: pbio.NewDelimitedReader(rwc, maxMessageSize), + writer: pbio.NewDelimitedWriter(rwc), sendStateChanged: make(chan struct{}, 1), writeDeadlineUpdated: make(chan struct{}, 1), writeAvailable: make(chan struct{}, 1), + controlMessageReaderStarted: make(chan struct{}), + controlMessageReaderDone: make(chan struct{}), + id: *channel.ID(), dataChannel: rwc.(*datachannel.DataChannel), onDone: onDone, @@ -111,35 +129,6 @@ func newStream( channel.OnBufferedAmountLow(func() { s.mx.Lock() defer s.mx.Unlock() - // first send out queued control messages - for len(s.controlMsgQueue) > 0 { - msg := s.controlMsgQueue[0] - available := s.availableSendSpace() - if controlMsgSize < available { - s.writer.WriteMsg(msg) // TODO: handle error - s.controlMsgQueue = s.controlMsgQueue[1:] - } else { - return - } - } - - if s.isDone() { - // onDone removes the stream from the connection and requires the connection lock. - // This callback(onBufferedAmountLow) is executing in the sctp readLoop goroutine. - // If Connection.Close is called concurrently, the closing goroutine will acquire - // the connection lock and wait for sctp readLoop to exit, the sctp readLoop will - // wait for the connection lock before exiting, causing a deadlock. - // Run this in a different goroutine to avoid the deadlock. - go func() { - s.mx.Lock() - defer s.mx.Unlock() - // TODO: we should be closing the underlying datachannel, but this resets the stream - // See https://github.com/libp2p/specs/issues/575 for details. - // _ = s.dataChannel.Close() - // TODO: write for the spawned reader to return - s.onDone() - }() - } select { case s.writeAvailable <- struct{}{}: @@ -150,15 +139,50 @@ func newStream( } func (s *stream) Close() error { + defer s.cleanup() + closeWriteErr := s.CloseWrite() closeReadErr := s.CloseRead() - if closeWriteErr != nil { - return closeWriteErr + if closeWriteErr != nil || closeReadErr != nil { + s.Reset() + return errors.Join(closeWriteErr, closeReadErr) } - return closeReadErr + + s.mx.Lock() + s.controlMessageReaderEndTime = time.Now().Add(10 * time.Second) + s.mx.Unlock() + s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) + <-s.controlMessageReaderDone + return nil +} + +func (s *stream) AsyncClose(onDone func()) error { + closeWriteErr := s.CloseWrite() + closeReadErr := s.CloseRead() + if closeWriteErr != nil || closeReadErr != nil { + s.Reset() + if onDone != nil { + onDone() + } + return errors.Join(closeWriteErr, closeReadErr) + } + s.mx.Lock() + s.controlMessageReaderEndTime = time.Now().Add(10 * time.Second) + s.mx.Unlock() + s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) + go func() { + <-s.controlMessageReaderDone + s.cleanup() + if onDone != nil { + onDone() + } + }() + return nil } func (s *stream) Reset() error { + defer s.cleanup() + cancelWriteErr := s.cancelWrite() closeReadErr := s.CloseRead() if cancelWriteErr != nil { @@ -167,14 +191,20 @@ func (s *stream) Reset() error { return closeReadErr } +func (s *stream) closeWithError(e error) { + defer s.cleanup() + + s.mx.Lock() + defer s.mx.Unlock() + s.closeErr = e +} + func (s *stream) SetDeadline(t time.Time) error { _ = s.SetReadDeadline(t) return s.SetWriteDeadline(t) } // processIncomingFlag process the flag on an incoming message -// It needs to be called with msg.Flag, not msg.GetFlag(), -// otherwise we'd misinterpret the default value. // It needs to be called while the mutex is locked. func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { if flag == nil { @@ -182,50 +212,101 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { } switch *flag { - case pb.Message_FIN: - if s.receiveState == receiveStateReceiving { - s.receiveState = receiveStateDataRead - } case pb.Message_STOP_SENDING: - if s.sendState == sendStateSending { + // We must process STOP_SENDING after sending a FIN(sendStateDataSent). Remote peer + // may not send a FIN_ACK once it has sent a STOP_SENDING + if s.sendState == sendStateSending || s.sendState == sendStateDataSent { s.sendState = sendStateReset } select { case s.sendStateChanged <- struct{}{}: default: } + case pb.Message_FIN_ACK: + s.sendState = sendStateDataReceived + select { + case s.sendStateChanged <- struct{}{}: + default: + } + case pb.Message_FIN: + if s.receiveState == receiveStateReceiving { + s.receiveState = receiveStateDataRead + } + if err := s.writeMsgOnWriter(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()}); err != nil { + log.Debugf("failed to send FIN_ACK: %s", err) + // Remote has finished writing all the data It'll stop waiting for the + // FIN_ACK eventually or will be notified when we close the datachannel + } + s.controlMessageReaderOnce.Do(s.spawnControlMessageReader) case pb.Message_RESET: if s.receiveState == receiveStateReceiving { s.receiveState = receiveStateReset } + s.controlMessageReaderOnce.Do(s.spawnControlMessageReader) } - s.maybeDeclareStreamDone() } -// maybeDeclareStreamDone is used to force reset a stream. It should be called with -// the stream lock acquired. It calls stream.onDone which requires the connection lock. -func (s *stream) maybeDeclareStreamDone() { - if s.isDone() { - _ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times - // TODO: we should be closing the underlying datachannel, but this resets the stream - // See https://github.com/libp2p/specs/issues/575 for details. - // _ = s.dataChannel.Close() - // TODO: write for the spawned reader to return - s.onDone() - } -} +// spawnControlMessageReader is used for processing control messages after the reader is closed. +func (s *stream) spawnControlMessageReader() { -// isDone indicates whether the stream is completed and all the control messages have also been -// flushed. It must be called with the stream lock acquired. -func (s *stream) isDone() bool { - return (s.sendState == sendStateReset || s.sendState == sendStateDataSent) && - (s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) && - len(s.controlMsgQueue) == 0 -} + // Spawn a goroutine to ensure that we're not holding any locks + go func() { + defer close(s.controlMessageReaderDone) + // cleanup the sctp deadline timer goroutine + defer s.SetReadDeadline(time.Time{}) -func (s *stream) setCloseError(e error) { - s.mx.Lock() - defer s.mx.Unlock() + isSendCompleted := func() bool { + s.mx.Lock() + defer s.mx.Unlock() + return s.sendState == sendStateDataReceived || s.sendState == sendStateReset + } - s.closeErr = e + setDeadline := func() bool { + s.mx.Lock() + if s.controlMessageReaderEndTime.IsZero() || time.Now().Before(s.controlMessageReaderEndTime) { + s.SetReadDeadline(s.controlMessageReaderEndTime) + s.mx.Unlock() + return true + } + s.mx.Unlock() + return false + } + + // Unblock any Read call waiting on reader.ReadMsg + s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) + // We have the lock, any waiting reader has exited. + s.readerOnce <- struct{}{} + <-s.readerOnce + // From this point onwards only this goroutine can do reader.ReadMsg + + s.mx.Lock() + if s.nextMessage != nil { + s.processIncomingFlag(s.nextMessage.Flag) + s.nextMessage = nil + } + s.mx.Unlock() + + for !isSendCompleted() { + var msg pb.Message + if !setDeadline() { + return + } + if err := s.reader.ReadMsg(&msg); err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + continue + } + return + } + s.mx.Lock() + s.processIncomingFlag(msg.Flag) + s.mx.Unlock() + } + }() +} + +func (s *stream) cleanup() { + s.dataChannel.Close() + if s.onDone != nil { + s.onDone() + } } diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index e064c8558b..209af0a634 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -1,7 +1,6 @@ package libp2pwebrtc import ( - "errors" "io" "time" @@ -10,9 +9,8 @@ import ( ) func (s *stream) Read(b []byte) (int, error) { - if len(b) == 0 { - return 0, nil - } + s.readerOnce <- struct{}{} + defer func() { <-s.readerOnce }() s.mx.Lock() defer s.mx.Unlock() @@ -27,6 +25,10 @@ func (s *stream) Read(b []byte) (int, error) { return 0, network.ErrReset } + if len(b) == 0 { + return 0, nil + } + var read int for { if s.nextMessage == nil { @@ -40,13 +42,19 @@ func (s *stream) Read(b []byte) (int, error) { if s.receiveState == receiveStateDataRead { return 0, io.EOF } - // This case occurs when the remote node closes the stream without writing a FIN message - // There's little we can do here - return 0, errors.New("didn't receive final state for stream") + // This case occurs when remote closes the datachannel without writing a FIN + // message. Some implementations discard the buffered data on closing the + // datachannel. For these implementations a stream reset will be observed as an + // abrupt closing of the datachannel. + s.receiveState = receiveStateReset + return 0, network.ErrReset } if s.receiveState == receiveStateReset { return 0, network.ErrReset } + if s.receiveState == receiveStateDataRead { + return 0, io.EOF + } return 0, err } s.mx.Lock() @@ -70,7 +78,6 @@ func (s *stream) Read(b []byte) (int, error) { case receiveStateDataRead: return read, io.EOF case receiveStateReset: - s.dataChannel.SetReadDeadline(time.Time{}) return read, network.ErrReset } } @@ -81,20 +88,11 @@ func (s *stream) SetReadDeadline(t time.Time) error { return s.dataChannel.SetRe func (s *stream) CloseRead() error { s.mx.Lock() defer s.mx.Unlock() - - if s.nextMessage != nil { - s.processIncomingFlag(s.nextMessage.Flag) - s.nextMessage = nil - } var err error if s.receiveState == receiveStateReceiving && s.closeErr == nil { - err = s.sendControlMessage(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()}) + err = s.writeMsgOnWriter(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()}) + s.receiveState = receiveStateReset } - s.receiveState = receiveStateReset - s.maybeDeclareStreamDone() - - // make any calls to Read blocking on ReadMsg return immediately - s.dataChannel.SetReadDeadline(time.Now()) - + s.controlMessageReaderOnce.Do(s.spawnControlMessageReader) return err } diff --git a/p2p/transport/webrtc/stream_test.go b/p2p/transport/webrtc/stream_test.go index f1442b9bfd..851171c2cc 100644 --- a/p2p/transport/webrtc/stream_test.go +++ b/p2p/transport/webrtc/stream_test.go @@ -5,6 +5,7 @@ import ( "errors" "io" "os" + "sync/atomic" "testing" "time" @@ -14,6 +15,7 @@ import ( "github.com/pion/datachannel" "github.com/pion/webrtc/v3" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -24,6 +26,7 @@ type detachedChan struct { func getDetachedDataChannels(t *testing.T) (detachedChan, detachedChan) { s := webrtc.SettingEngine{} + s.SetIncludeLoopbackCandidate(true) s.DetachDataChannels() api := webrtc.NewAPI(webrtc.WithSettingEngine(s)) @@ -97,9 +100,9 @@ func getDetachedDataChannels(t *testing.T) (detachedChan, detachedChan) { func TestStreamSimpleReadWriteClose(t *testing.T) { client, server := getDetachedDataChannels(t) - var clientDone, serverDone bool - clientStr := newStream(client.dc, client.rwc, func() { clientDone = true }) - serverStr := newStream(server.dc, server.rwc, func() { serverDone = true }) + var clientDone, serverDone atomic.Bool + clientStr := newStream(client.dc, client.rwc, func() { clientDone.Store(true) }) + serverStr := newStream(server.dc, server.rwc, func() { serverDone.Store(true) }) // send a foobar from the client n, err := clientStr.Write([]byte("foobar")) @@ -109,7 +112,7 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { // writing after closing should error _, err = clientStr.Write([]byte("foobar")) require.Error(t, err) - require.False(t, clientDone) + require.False(t, clientDone.Load()) // now read all the data on the server side b, err := io.ReadAll(serverStr) @@ -119,19 +122,26 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { n, err = serverStr.Read(make([]byte, 10)) require.Zero(t, n) require.ErrorIs(t, err, io.EOF) - require.False(t, serverDone) + require.False(t, serverDone.Load()) // send something back _, err = serverStr.Write([]byte("lorem ipsum")) require.NoError(t, err) require.NoError(t, serverStr.CloseWrite()) - require.True(t, serverDone) + // and read it at the client - require.False(t, clientDone) + require.False(t, clientDone.Load()) b, err = io.ReadAll(clientStr) require.NoError(t, err) require.Equal(t, []byte("lorem ipsum"), b) - require.True(t, clientDone) + + // stream is only cleaned up on calling Close or AsyncClose or Reset + clientStr.AsyncClose(nil) + serverStr.AsyncClose(nil) + require.Eventually(t, func() bool { return clientDone.Load() }, 10*time.Second, 100*time.Millisecond) + // Need to call Close for cleanup. Otherwise the FIN_ACK is never read + require.NoError(t, serverStr.Close()) + require.Eventually(t, func() bool { return serverDone.Load() }, 10*time.Second, 100*time.Millisecond) } func TestStreamPartialReads(t *testing.T) { @@ -201,14 +211,17 @@ func TestStreamReadReturnsOnClose(t *testing.T) { _, err := clientStr.Read([]byte{0}) errChan <- err }() - time.Sleep(50 * time.Millisecond) // give the Read call some time to hit the loop - require.NoError(t, clientStr.Close()) + time.Sleep(100 * time.Millisecond) // give the Read call some time to hit the loop + require.NoError(t, clientStr.AsyncClose(nil)) select { case err := <-errChan: require.ErrorIs(t, err, network.ErrReset) case <-time.After(500 * time.Millisecond): t.Fatal("timeout") } + + _, err := clientStr.Read([]byte{0}) + require.ErrorIs(t, err, network.ErrReset) } func TestStreamResets(t *testing.T) { @@ -242,6 +255,7 @@ func TestStreamResets(t *testing.T) { _, err := serverStr.Write([]byte("foobar")) return errors.Is(err, network.ErrReset) }, time.Second, 50*time.Millisecond) + serverStr.Close() require.True(t, serverDone) } @@ -305,3 +319,147 @@ func TestStreamWriteDeadlineAsync(t *testing.T) { require.GreaterOrEqual(t, took, timeout) require.LessOrEqual(t, took, timeout*3/2) } + +func TestStreamReadAfterClose(t *testing.T) { + client, server := getDetachedDataChannels(t) + + clientStr := newStream(client.dc, client.rwc, func() {}) + serverStr := newStream(server.dc, server.rwc, func() {}) + + serverStr.AsyncClose(nil) + b := make([]byte, 1) + _, err := clientStr.Read(b) + require.Equal(t, io.EOF, err) + _, err = clientStr.Read(nil) + require.Equal(t, io.EOF, err) + + client, server = getDetachedDataChannels(t) + + clientStr = newStream(client.dc, client.rwc, func() {}) + serverStr = newStream(server.dc, server.rwc, func() {}) + + serverStr.Reset() + b = make([]byte, 1) + _, err = clientStr.Read(b) + require.ErrorIs(t, err, network.ErrReset) + _, err = clientStr.Read(nil) + require.ErrorIs(t, err, network.ErrReset) +} + +func TestStreamCloseAfterFINACK(t *testing.T) { + client, server := getDetachedDataChannels(t) + + done := make(chan bool, 1) + clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + serverStr := newStream(server.dc, server.rwc, func() {}) + + go func() { + done <- true + err := clientStr.Close() + assert.NoError(t, err) + }() + <-done + + select { + case <-done: + t.Fatalf("Close should not have completed without processing FIN_ACK") + case <-time.After(2 * time.Second): + } + + b := make([]byte, 1) + _, err := serverStr.Read(b) + require.Error(t, err) + require.ErrorIs(t, err, io.EOF) + select { + case <-done: + case <-time.After(3 * time.Second): + t.Errorf("Close should have completed") + } +} + +func TestStreamFinAckAfterStopSending(t *testing.T) { + client, server := getDetachedDataChannels(t) + + done := make(chan bool, 1) + clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + serverStr := newStream(server.dc, server.rwc, func() {}) + + go func() { + clientStr.CloseRead() + clientStr.Write([]byte("hello world")) + done <- true + err := clientStr.Close() + assert.NoError(t, err) + }() + <-done + + select { + case <-done: + t.Errorf("Close should not have completed without processing FIN_ACK") + case <-time.After(500 * time.Millisecond): + } + + // serverStr has write half of the stream closed but the read half should + // respond correctly + b := make([]byte, 24) + _, err := serverStr.Read(b) + require.NoError(t, err) + serverStr.Close() // Sends stop_sending, fin + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatalf("Close should have completed") + } +} + +func TestStreamConcurrentClose(t *testing.T) { + client, server := getDetachedDataChannels(t) + + start := make(chan bool, 1) + done := make(chan bool, 2) + clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + serverStr := newStream(server.dc, server.rwc, func() { done <- true }) + + go func() { + start <- true + clientStr.Close() + }() + go func() { + start <- true + serverStr.Close() + }() + <-start + <-start + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("concurrent close should succeed quickly") + } + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("concurrent close should succeed quickly") + } +} + +func TestStreamResetAfterAsyncClose(t *testing.T) { + client, _ := getDetachedDataChannels(t) + + done := make(chan bool, 1) + clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + clientStr.AsyncClose(nil) + + select { + case <-done: + t.Fatalf("AsyncClose shouldn't run cleanup immediately") + case <-time.After(500 * time.Millisecond): + } + + clientStr.Reset() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("Reset should run callback immediately") + } +} diff --git a/p2p/transport/webrtc/stream_write.go b/p2p/transport/webrtc/stream_write.go index 698af9c4d6..7a99957288 100644 --- a/p2p/transport/webrtc/stream_write.go +++ b/p2p/transport/webrtc/stream_write.go @@ -7,6 +7,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb" + "google.golang.org/protobuf/proto" ) var errWriteAfterClose = errors.New("write after close") @@ -25,16 +26,10 @@ func (s *stream) Write(b []byte) (int, error) { switch s.sendState { case sendStateReset: return 0, network.ErrReset - case sendStateDataSent: + case sendStateDataSent, sendStateDataReceived: return 0, errWriteAfterClose } - // Check if there is any message on the wire. This is used for control - // messages only when the read side of the stream is closed - if s.receiveState != receiveStateReceiving { - s.readLoopOnce.Do(s.spawnControlMessageReader) - } - if !s.writeDeadline.IsZero() && time.Now().After(s.writeDeadline) { return 0, os.ErrDeadlineExceeded } @@ -54,7 +49,7 @@ func (s *stream) Write(b []byte) (int, error) { switch s.sendState { case sendStateReset: return n, network.ErrReset - case sendStateDataSent: + case sendStateDataSent, sendStateDataReceived: return n, errWriteAfterClose } @@ -100,7 +95,7 @@ func (s *stream) Write(b []byte) (int, error) { end = len(b) } msg := &pb.Message{Message: b[:end]} - if err := s.writer.WriteMsg(msg); err != nil { + if err := s.writeMsgOnWriter(msg); err != nil { return n, err } n += end @@ -109,30 +104,6 @@ func (s *stream) Write(b []byte) (int, error) { return n, nil } -// used for reading control messages while writing, in case the reader is closed, -// as to ensure we do still get control messages. This is important as according to the spec -// our data and control channels are intermixed on the same conn. -func (s *stream) spawnControlMessageReader() { - if s.nextMessage != nil { - s.processIncomingFlag(s.nextMessage.Flag) - s.nextMessage = nil - } - - go func() { - // no deadline needed, Read will return once there's a new message, or an error occurred - _ = s.dataChannel.SetReadDeadline(time.Time{}) - for { - var msg pb.Message - if err := s.reader.ReadMsg(&msg); err != nil { - return - } - s.mx.Lock() - s.processIncomingFlag(msg.Flag) - s.mx.Unlock() - } - }() -} - func (s *stream) SetWriteDeadline(t time.Time) error { s.mx.Lock() defer s.mx.Unlock() @@ -153,24 +124,12 @@ func (s *stream) availableSendSpace() int { return availableSpace } -// There's no way to determine the size of a Protobuf message in the pbio package. -// Setting the size to 100 works as long as the control messages (incl. the varint prefix) are smaller than that value. -const controlMsgSize = 100 - -func (s *stream) sendControlMessage(msg *pb.Message) error { - available := s.availableSendSpace() - if controlMsgSize < available { - return s.writer.WriteMsg(msg) - } - s.controlMsgQueue = append(s.controlMsgQueue, msg) - return nil -} - func (s *stream) cancelWrite() error { s.mx.Lock() defer s.mx.Unlock() - if s.sendState != sendStateSending { + // Don't wait for FIN_ACK on reset + if s.sendState != sendStateSending && s.sendState != sendStateDataSent { return nil } s.sendState = sendStateReset @@ -178,10 +137,9 @@ func (s *stream) cancelWrite() error { case s.sendStateChanged <- struct{}{}: default: } - if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil { + if err := s.writeMsgOnWriter(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil { return err } - s.maybeDeclareStreamDone() return nil } @@ -193,9 +151,18 @@ func (s *stream) CloseWrite() error { return nil } s.sendState = sendStateDataSent - if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil { + select { + case s.sendStateChanged <- struct{}{}: + default: + } + if err := s.writeMsgOnWriter(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil { return err } - s.maybeDeclareStreamDone() return nil } + +func (s *stream) writeMsgOnWriter(msg proto.Message) error { + s.writerMx.Lock() + defer s.writerMx.Unlock() + return s.writer.WriteMsg(msg) +} diff --git a/p2p/transport/webrtc/transport_test.go b/p2p/transport/webrtc/transport_test.go index 7f4df94fc1..c596851007 100644 --- a/p2p/transport/webrtc/transport_test.go +++ b/p2p/transport/webrtc/transport_test.go @@ -481,6 +481,7 @@ func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) { return } err = stream.Close() + fmt.Println("closed!") if err != nil { done <- err return @@ -496,7 +497,55 @@ func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) { require.NoError(t, err) // require write and close to complete require.NoError(t, <-done) + stream.SetReadDeadline(time.Now().Add(5 * time.Second)) + + buf := make([]byte, 10) + n, err := stream.Read(buf) + require.NoError(t, err) + require.Equal(t, n, 4) +} + +func TestTransportWebRTC_RemoteReadsAfterAsyncClose(t *testing.T) { + tr, listeningPeer := getTransport(t) + listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") + listener, err := tr.Listen(listenMultiaddr) + require.NoError(t, err) + + tr1, _ := getTransport(t) + + done := make(chan error) + go func() { + lconn, err := listener.Accept() + if err != nil { + done <- err + return + } + s, err := lconn.AcceptStream() + if err != nil { + done <- err + return + } + _, err = s.Write([]byte{1, 2, 3, 4}) + if err != nil { + done <- err + return + } + err = s.(*stream).AsyncClose(nil) + if err != nil { + done <- err + return + } + close(done) + }() + conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) + require.NoError(t, err) + // create a stream + stream, err := conn.OpenStream(context.Background()) + + require.NoError(t, err) + // require write and close to complete + require.NoError(t, <-done) stream.SetReadDeadline(time.Now().Add(5 * time.Second)) buf := make([]byte, 10)