diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 3eeebee588..e3dd183c60 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -268,8 +268,6 @@ func (s *Swarm) Close() error { func (s *Swarm) close() { s.ctxCancel() - s.emitter.Close() - // Prevents new connections and/or listeners from being added to the swarm. s.listeners.Lock() listeners := s.listeners.m @@ -279,7 +277,6 @@ func (s *Swarm) close() { s.conns.Lock() conns := s.conns.m s.conns.m = nil - s.conns.connectedness = nil s.conns.Unlock() // Lots of goroutines but we might as well do this in parallel. We want to shut down as fast as @@ -305,6 +302,13 @@ func (s *Swarm) close() { // Wait for everything to finish. s.refs.Wait() + s.emitter.Close() + + // Remove the connectedness map only after we have closed the connection and sent all the disconnection + // events + s.conns.Lock() + s.conns.connectedness = nil + s.conns.Unlock() // Now close out any transports (if necessary). Do this after closing // all connections/listeners. @@ -793,14 +797,13 @@ func (s *Swarm) removeConn(c *Conn) { } newState := s.connectednessUnlocked(p) - if s.conns.connectedness != nil { // swarm is not closing + if s.conns.connectedness != nil { // This shoud always be non nil but a check doesn't hurt if newState == network.NotConnected { delete(s.conns.connectedness, p) } else { s.conns.connectedness[p] = newState } } - s.conns.Unlock() if oldState != newState { diff --git a/p2p/net/swarm/swarm_event_test.go b/p2p/net/swarm/swarm_event_test.go index 86d698d611..8d2b2d79ce 100644 --- a/p2p/net/swarm/swarm_event_test.go +++ b/p2p/net/swarm/swarm_event_test.go @@ -113,3 +113,91 @@ func TestNoDeadlockWhenConsumingConnectednessEvents(t *testing.T) { // The test should finish without deadlocking } + +func TestConnectednessEvents(t *testing.T) { + s1, sub1 := newSwarmWithSubscription(t) + const N = 100 + peers := make([]*Swarm, N) + for i := 0; i < N; i++ { + peers[i] = swarmt.GenSwarm(t) + } + + // First check all connected events + done := make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < N; i++ { + e := <-sub1.Out() + evt, ok := e.(event.EvtPeerConnectednessChanged) + if !ok { + t.Error("invalid event received", e) + return + } + if evt.Connectedness != network.Connected { + t.Errorf("invalid event received: expected: Connected, got: %s", evt) + return + } + } + }() + for i := 0; i < N; i++ { + s1.Peerstore().AddAddrs(peers[i].LocalPeer(), []ma.Multiaddr{peers[i].ListenAddresses()[0]}, time.Hour) + _, err := s1.DialPeer(context.Background(), peers[i].LocalPeer()) + require.NoError(t, err) + } + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("expected all connectedness events to be completed") + } + + // Disconnect some peers + done = make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < N/2; i++ { + e := <-sub1.Out() + evt, ok := e.(event.EvtPeerConnectednessChanged) + if !ok { + t.Error("invalid event received", e) + return + } + if evt.Connectedness != network.NotConnected { + t.Errorf("invalid event received: expected: NotConnected, got: %s", evt) + return + } + } + }() + for i := 0; i < N/2; i++ { + err := s1.ClosePeer(peers[i].LocalPeer()) + require.NoError(t, err) + } + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("expected all disconnected events to be completed") + } + + // Check for disconnected events on swarm close + done = make(chan struct{}) + go func() { + defer close(done) + for i := N / 2; i < N; i++ { + e := <-sub1.Out() + evt, ok := e.(event.EvtPeerConnectednessChanged) + if !ok { + t.Error("invalid event received", e) + return + } + if evt.Connectedness != network.NotConnected { + t.Errorf("invalid event received: expected: NotConnected, got: %s", evt) + return + } + } + }() + s1.Close() + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("expected all disconnected events after swarm close to be completed") + } +}