Skip to content

Commit

Permalink
webrtc: wait for fin_ack for closing datachannel
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 23, 2023
1 parent b7c04f8 commit 416c934
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 37 deletions.
7 changes: 7 additions & 0 deletions core/network/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ type MuxedStream interface {
SetWriteDeadline(time.Time) error
}

// AsyncCloser is implemented by streams that need to do expensive operations on close. Closing the stream async avoids
// blocking the calling goroutine.
type AsyncCloser interface {
// AsyncClose closes the stream and executes onDone when stream is closed
AsyncClose(onDone func()) error
}

// MuxedConn represents a connection to a remote peer that has been
// extended to support stream multiplexing.
//
Expand Down
6 changes: 6 additions & 0 deletions p2p/net/swarm/swarm_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ func (s *Stream) Write(p []byte) (int, error) {
// Close closes the stream, closing both ends and freeing all associated
// resources.
func (s *Stream) Close() error {
if as, ok := s.stream.(network.AsyncCloser); ok {
err := as.AsyncClose(func() {
s.closeAndRemoveStream()
})
return err
}
err := s.stream.Close()
s.closeAndRemoveStream()
return err
Expand Down
45 changes: 45 additions & 0 deletions p2p/net/swarm/swarm_stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package swarm

import (
"context"
"sync/atomic"
"testing"

"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/stretchr/testify/require"
)

type asyncStreamWrapper struct {
network.MuxedStream
before func()
}

func (s *asyncStreamWrapper) AsyncClose(onDone func()) error {
s.before()
err := s.Close()
onDone()
return err
}

func TestStreamAsyncCloser(t *testing.T) {
s1 := makeSwarm(t)
s2 := makeSwarm(t)

s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL)
s, err := s1.NewStream(context.Background(), s2.LocalPeer())
require.NoError(t, err)
ss, ok := s.(*Stream)
require.True(t, ok)

var called atomic.Bool
as := &asyncStreamWrapper{
MuxedStream: ss.stream,
before: func() {
called.Store(true)
},
}
ss.stream = as
ss.Close()
require.True(t, called.Load())
}
14 changes: 0 additions & 14 deletions p2p/transport/webrtc/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ var _ tpt.CapableConn = &connection{}
// maxAcceptQueueLen is the number of waiting streams.
const maxAcceptQueueLen = 256

const maxDataChannelID = 1 << 10

type errConnectionTimeout struct{}

var _ net.Error = &errConnectionTimeout{}
Expand Down Expand Up @@ -195,12 +193,6 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error
if id > math.MaxUint16 {
return nil, errors.New("exhausted stream ID space")
}
// Limit the number of streams, since we're not able to actually properly close them.
// See https://github.com/libp2p/specs/issues/575 for details.
if id > maxDataChannelID {
c.Close()
return c.OpenStream(ctx)
}

streamID := uint16(id)
dc, err := c.pc.CreateDataChannel("", &webrtc.DataChannelInit{ID: &streamID})
Expand Down Expand Up @@ -329,12 +321,6 @@ func (c *connection) setRemotePublicKey(key ic.PubKey) {
func SetupDataChannelQueue(pc *webrtc.PeerConnection, queueLen int) chan DetachedDataChannel {
queue := make(chan DetachedDataChannel, queueLen)
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
// Limit the number of streams, since we're not able to actually properly close them.
// See https://github.com/libp2p/specs/issues/575 for details.
if *dc.ID() > maxDataChannelID {
dc.Close()
return
}
dc.OnOpen(func() {
rwc, err := dc.Detach()
if err != nil {
Expand Down
29 changes: 18 additions & 11 deletions p2p/transport/webrtc/pb/message.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions p2p/transport/webrtc/pb/message.proto
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ message Message {
// The sender abruptly terminates the sending part of the stream. The
// receiver can discard any data that it already received on that stream.
RESET = 2;
// Sending the FIN_ACK flag acknowledges the previous receipt of a message
// with the FIN flag set. Receiving a FIN_ACK flag gives the recipient
// confidence that the remote has received all sent messages.
FIN_ACK = 3;
}

optional Flag flag=1;
Expand Down
96 changes: 85 additions & 11 deletions p2p/transport/webrtc/stream.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package libp2pwebrtc

import (
"errors"
"os"
"sync"
"time"

Expand Down Expand Up @@ -53,12 +55,17 @@ const (
sendStateSending sendState = iota
sendStateDataSent
sendStateReset
sendStateDataReceived
)

// Package pion detached data channel into a net.Conn
// and then a network.MuxedStream
type stream struct {
mx sync.Mutex

// readerMx ensures there's only a single goroutine reading from reader as the underlying SCTP reader
// doesn't support multiple readers
readerMx sync.Mutex
// pbio.Reader is not thread safe,
// and while our Read is not promised to be thread safe,
// we ourselves internally read from multiple routines...
Expand Down Expand Up @@ -132,28 +139,92 @@ func newStream(
}

func (s *stream) Close() error {
// Close read before write to ensure that the STOP_SENDING message is delivered before
// we close the data channel
closeReadErr := s.CloseRead()
closeWriteErr := s.CloseWrite()
if closeWriteErr != nil {
// writing FIN failed, reset the stream
s.Reset()
return closeWriteErr
}
s.waitForFINACK()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
return errors.Join(closeWriteErr, closeReadErr)
}

func (s *stream) AsyncClose(onDone func()) error {
// Close read before write to ensure that the STOP_SENDING message is delivered before
// we close the data channel
closeReadErr := s.CloseRead()
closeWriteErr := s.CloseWrite()
if closeWriteErr != nil {
// writing FIN failed, reset the stream
s.Reset()
onDone()
return closeWriteErr
}
return closeReadErr
go func() {
s.waitForFINACK()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
onDone()
}()
return errors.Join(closeWriteErr, closeReadErr)
}

func (s *stream) Reset() error {
cancelWriteErr := s.cancelWrite()
closeReadErr := s.CloseRead()
if cancelWriteErr != nil {
return cancelWriteErr
}
return closeReadErr
dcCloseErr := s.dataChannel.Close()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
return errors.Join(cancelWriteErr, closeReadErr, dcCloseErr)
}

func (s *stream) SetDeadline(t time.Time) error {
_ = s.SetReadDeadline(t)
return s.SetWriteDeadline(t)
}

func (s *stream) waitForFINACK() {
s.mx.Lock()
defer s.mx.Unlock()
// Only wait for FIN_ACK if we are waiting for FIN_ACK and we have stopped reading from the stream
if s.sendState != sendStateDataSent || s.receiveState == receiveStateReceiving {
return
}
// First wait for any existing readers to exit
s.SetReadDeadline(time.Now().Add(-1 * time.Minute))
s.readerMx.Lock()
s.SetReadDeadline(time.Now().Add(10 * time.Second))
var msg pb.Message
for {
s.mx.Unlock()
if err := s.reader.ReadMsg(&msg); err != nil {
s.readerMx.Unlock()
s.mx.Lock()
// 10 seconds is enough time for the message to be delivered. The peer just hasn't responded
// with FIN_ACK
if errors.Is(err, os.ErrDeadlineExceeded) {
s.sendState = sendStateDataReceived
}
break
}
s.readerMx.Unlock()
s.mx.Lock()
s.processIncomingFlag(msg.Flag)
if s.sendState != sendStateDataSent {
break
}
s.readerMx.Lock()
}
}

// processIncomingFlag process the flag on an incoming message
// It needs to be called with msg.Flag, not msg.GetFlag(),
// otherwise we'd misinterpret the default value.
Expand All @@ -168,6 +239,9 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
if s.receiveState == receiveStateReceiving {
s.receiveState = receiveStateDataRead
}
if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()}); err != nil {
log.Debugf("failed to send FIN_ACK:", err)
}
case pb.Message_STOP_SENDING:
if s.sendState == sendStateSending {
s.sendState = sendStateReset
Expand All @@ -180,24 +254,24 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
if s.receiveState == receiveStateReceiving {
s.receiveState = receiveStateReset
}
case pb.Message_FIN_ACK:
if s.sendState == sendStateDataSent {
s.sendState = sendStateDataReceived
}
}
s.maybeDeclareStreamDone()
}

// maybeDeclareStreamDone is used to force reset a stream. It must be called with mx acquired
func (s *stream) maybeDeclareStreamDone() {
if (s.sendState == sendStateReset || s.sendState == sendStateDataSent) &&
if (s.sendState == sendStateReset || s.sendState == sendStateDataReceived) &&
(s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) &&
len(s.controlMsgQueue) == 0 {

s.mx.Unlock()
defer s.mx.Lock()
_ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times
// TODO: we should be closing the underlying datachannel, but this resets the stream
// See https://github.com/libp2p/specs/issues/575 for details.
// _ = s.dataChannel.Close()
// TODO: write for the spawned reader to return

s.dataChannel.Close()
s.onDone()
}
}
Expand Down
3 changes: 3 additions & 0 deletions p2p/transport/webrtc/stream_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ func (s *stream) Read(b []byte) (int, error) {
// load the next message
s.mx.Unlock()
var msg pb.Message
s.readerMx.Lock()
if err := s.reader.ReadMsg(&msg); err != nil {
s.readerMx.Unlock()
s.mx.Lock()
if err == io.EOF {
// if the channel was properly closed, return EOF
Expand All @@ -48,6 +50,7 @@ func (s *stream) Read(b []byte) (int, error) {
}
return 0, err
}
s.readerMx.Unlock()
s.mx.Lock()
s.nextMessage = &msg
}
Expand Down
Loading

0 comments on commit 416c934

Please sign in to comment.