diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 0bffab99a2..194d012758 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -340,12 +340,7 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, } c.streams.m = make(map[*Stream]struct{}) - if len(s.conns.m[p]) == 0 { // first connection - s.emitter.Emit(event.EvtPeerConnectednessChanged{ - Peer: p, - Connectedness: network.Connected, - }) - } + isFirstConnection := len(s.conns.m[p]) == 0 s.conns.m[p] = append(s.conns.m[p], c) // Add two swarm refs: @@ -358,6 +353,15 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, c.notifyLk.Lock() s.conns.Unlock() + // Emit event after releasing `s.conns` lock so that a consumer can still + // use swarm methods that need the `s.conns` lock. + if isFirstConnection { + s.emitter.Emit(event.EvtPeerConnectednessChanged{ + Peer: p, + Connectedness: network.Connected, + }) + } + s.notifyAll(func(f network.Notifiee) { f.Connected(s, c) }) @@ -637,25 +641,32 @@ func (s *Swarm) removeConn(c *Conn) { p := c.RemotePeer() s.conns.Lock() - defer s.conns.Unlock() cs := s.conns.m[p] + + if len(cs) == 1 { + delete(s.conns.m, p) + s.conns.Unlock() + + // Emit event after releasing `s.conns` lock so that a consumer can still + // use swarm methods that need the `s.conns` lock. + s.emitter.Emit(event.EvtPeerConnectednessChanged{ + Peer: p, + Connectedness: network.NotConnected, + }) + return + } + + defer s.conns.Unlock() + for i, ci := range cs { if ci == c { - if len(cs) == 1 { - delete(s.conns.m, p) - s.emitter.Emit(event.EvtPeerConnectednessChanged{ - Peer: p, - Connectedness: network.NotConnected, - }) - } else { - // NOTE: We're intentionally preserving order. - // This way, connections to a peer are always - // sorted oldest to newest. - copy(cs[i:], cs[i+1:]) - cs[len(cs)-1] = nil - s.conns.m[p] = cs[:len(cs)-1] - } + // NOTE: We're intentionally preserving order. + // This way, connections to a peer are always + // sorted oldest to newest. + copy(cs[i:], cs[i+1:]) + cs[len(cs)-1] = nil + s.conns.m[p] = cs[:len(cs)-1] break } } diff --git a/p2p/net/swarm/swarm_event_test.go b/p2p/net/swarm/swarm_event_test.go index 7d4fb6bd5d..86d698d611 100644 --- a/p2p/net/swarm/swarm_event_test.go +++ b/p2p/net/swarm/swarm_event_test.go @@ -64,3 +64,52 @@ func TestConnectednessEventsSingleConn(t *testing.T) { checkEvent(t, sub1, event.EvtPeerConnectednessChanged{Peer: s2.LocalPeer(), Connectedness: network.NotConnected}) checkEvent(t, sub2, event.EvtPeerConnectednessChanged{Peer: s1.LocalPeer(), Connectedness: network.NotConnected}) } + +func TestNoDeadlockWhenConsumingConnectednessEvents(t *testing.T) { + dialerEventBus := eventbus.NewBus() + dialer := swarmt.GenSwarm(t, swarmt.OptDialOnly, swarmt.EventBus(dialerEventBus)) + defer dialer.Close() + + listener := swarmt.GenSwarm(t, swarmt.OptDialOnly) + addrsToListen := []ma.Multiaddr{ + ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), + } + + if err := listener.Listen(addrsToListen...); err != nil { + t.Fatal(err) + } + listenedAddrs := listener.ListenAddresses() + + dialer.Peerstore().AddAddrs(listener.LocalPeer(), listenedAddrs, time.Hour) + + sub, err := dialerEventBus.Subscribe(new(event.EvtPeerConnectednessChanged)) + require.NoError(t, err) + + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // A slow consumer + go func() { + for { + select { + case <-ctx.Done(): + return + case <-sub.Out(): + time.Sleep(100 * time.Millisecond) + // Do something with the swarm that needs the conns lock + _ = dialer.ConnsToPeer(listener.LocalPeer()) + time.Sleep(100 * time.Millisecond) + } + } + }() + + for i := 0; i < 10; i++ { + // Connect and disconnect to trigger a bunch of events + _, err := dialer.DialPeer(context.Background(), listener.LocalPeer()) + require.NoError(t, err) + dialer.ClosePeer(listener.LocalPeer()) + } + + // The test should finish without deadlocking +}