diff --git a/p2p/protocol/identify/obsaddr.go b/p2p/protocol/identify/obsaddr.go index 4437c4b011..3b7480a877 100644 --- a/p2p/protocol/identify/obsaddr.go +++ b/p2p/protocol/identify/obsaddr.go @@ -32,10 +32,10 @@ type thinWaist struct { Addr, TW, Rest ma.Multiaddr } -// thinWaistWithCount is a thinWaist along with the count of the connection that have it as the local address +// thinWaistWithConns is a thinWaist along with a set ofconnections that have it as the local address type thinWaistWithCount struct { thinWaist - Count int + conns map[string]struct{} } func thinWaistForm(a ma.Multiaddr) (thinWaist, error) { @@ -84,6 +84,7 @@ func getObserver(a ma.Multiaddr) (string, error) { // connMultiaddrs provides IsClosed along with network.ConnMultiaddrs. It is easier to mock this than network.Conn type connMultiaddrs interface { network.ConnMultiaddrs + ID() string IsClosed() bool } @@ -170,7 +171,8 @@ type ObservedAddrManager struct { // NewObservedAddrManager returns a new address manager using peerstore.OwnObservedAddressTTL as the TTL. func NewObservedAddrManager(listenAddrs, hostAddrs func() []ma.Multiaddr, - interfaceListenAddrs func() ([]ma.Multiaddr, error), normalize func(ma.Multiaddr) ma.Multiaddr) (*ObservedAddrManager, error) { + interfaceListenAddrs func() ([]ma.Multiaddr, error), normalize func(ma.Multiaddr) ma.Multiaddr, +) (*ObservedAddrManager, error) { if normalize == nil { normalize = func(addr ma.Multiaddr) ma.Multiaddr { return addr } } @@ -254,7 +256,6 @@ func (o *ObservedAddrManager) getTopExternalAddrs(localTWStr string) []*observer } else { return 0 } - }) n := len(observerSets) if n > maxExternalThinWaistAddrsPerLocalAddr { @@ -291,7 +292,7 @@ func (o *ObservedAddrManager) worker() { } } -func (o *ObservedAddrManager) shouldRecordObservation(conn connMultiaddrs, observed ma.Multiaddr) (shouldRecord bool, localTW thinWaist, observedTW thinWaist) { +func (o *ObservedAddrManager) shouldRecordObservation(conn connMultiaddrs, observed ma.Multiaddr) (shouldRecord bool, localTW, observedTW thinWaist) { if conn == nil || observed == nil { return false, thinWaist{}, thinWaist{} } @@ -399,7 +400,13 @@ func (o *ObservedAddrManager) recordObservationUnlocked(conn connMultiaddrs, loc } o.localAddrs[string(localTW.Addr.Bytes())] = t } - t.Count++ + if t.conns == nil { + t.conns = map[string]struct{}{ + conn.ID(): {}, + } + } else { + t.conns[conn.ID()] = struct{}{} + } } else { if prevObservedTWAddr.Equal(observedTW.TW) { // we have received the same observation again, nothing to do @@ -462,8 +469,8 @@ func (o *ObservedAddrManager) removeConn(conn connMultiaddrs) { if !ok { return } - t.Count-- - if t.Count <= 0 { + delete(t.conns, conn.ID()) + if len(t.conns) == 0 { delete(o.localAddrs, string(localTW.Addr.Bytes())) } diff --git a/p2p/protocol/identify/obsaddr_glass_test.go b/p2p/protocol/identify/obsaddr_glass_test.go index 31fd4f5726..29bf070b50 100644 --- a/p2p/protocol/identify/obsaddr_glass_test.go +++ b/p2p/protocol/identify/obsaddr_glass_test.go @@ -35,6 +35,10 @@ func (c *mockConn) IsClosed() bool { return c.isClosed.Load() } +func (c *mockConn) ID() string { + return fmt.Sprintf("%s<-->%s", c.local, c.remote) +} + func TestShouldRecordObservationWithWebTransport(t *testing.T) { listenAddr := ma.StringCast("/ip4/0.0.0.0/udp/0/quic-v1/webtransport/certhash/uEgNmb28") ifaceAddr := ma.StringCast("/ip4/10.0.0.2/udp/9999/quic-v1/webtransport/certhash/uEgNmb28") @@ -95,7 +99,6 @@ func TestShouldRecordObservationWithNAT64Addr(t *testing.T) { require.NoError(t, err) for i, tc := range cases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - if shouldRecord, _, _ := o.shouldRecordObservation(c, tc.addr); shouldRecord != tc.want { t.Fatalf("%s %s", tc.addr, tc.failureReason) } @@ -155,5 +158,4 @@ func TestThinWaistForm(t *testing.T) { require.Equal(t, restTW, tw.Rest, "%s %s", restTW, tw.Rest) }) } - } diff --git a/p2p/protocol/identify/obsaddr_test.go b/p2p/protocol/identify/obsaddr_test.go index 9c2d8dee57..94366f882e 100644 --- a/p2p/protocol/identify/obsaddr_test.go +++ b/p2p/protocol/identify/obsaddr_test.go @@ -153,7 +153,7 @@ func TestObservedAddrManager(t *testing.T) { var ob1, ob2 [N]connMultiaddrs for i := 0; i < N; i++ { ob1[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1", i))) - ob2[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1", i))) + ob2[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/2/quic-v1", i))) } for i := 0; i < N-1; i++ { o.Record(ob1[i], observedQuic) @@ -186,6 +186,7 @@ func TestObservedAddrManager(t *testing.T) { return checkAllEntriesRemoved(o) }, 2*time.Second, 100*time.Millisecond) }) + t.Run("SameObserversDifferentAddrs", func(t *testing.T) { o := newObservedAddrMgr() defer o.Close() @@ -197,7 +198,7 @@ func TestObservedAddrManager(t *testing.T) { var ob1, ob2 [N]connMultiaddrs for i := 0; i < N; i++ { ob1[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1", i))) - ob2[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1", i))) + ob2[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/2/quic-v1", i))) } for i := 0; i < N-1; i++ { o.Record(ob1[i], observedQuic1) @@ -238,6 +239,8 @@ func TestObservedAddrManager(t *testing.T) { c2 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.2/udp/1/quic-v1")) c3 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.3/udp/1/quic-v1/webtransport")) c4 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport")) + c5 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.5/udp/1/quic-v1")) + c6 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.6/udp/1/quic-v1")) var observedQuic, observedWebTransport ma.Multiaddr for i := 0; i < 10; i++ { // Change the IP address in each observation @@ -247,6 +250,7 @@ func TestObservedAddrManager(t *testing.T) { o.Record(c2, observedQuic) o.Record(c3, observedWebTransport) o.Record(c4, observedWebTransport) + o.Record(c5, observedQuic) time.Sleep(20 * time.Millisecond) } @@ -258,13 +262,23 @@ func TestObservedAddrManager(t *testing.T) { require.NoError(t, err) require.Less(t, len(o.externalAddrs[string(tw.TW.Bytes())]), 2) - require.Equal(t, o.AddrsFor(webTransport4ListenAddr), []ma.Multiaddr{observedWebTransport}) - require.Equal(t, o.AddrsFor(quic4ListenAddr), []ma.Multiaddr{observedQuic}) + require.Equal(t, []ma.Multiaddr{observedWebTransport}, o.AddrsFor(webTransport4ListenAddr)) + require.Equal(t, []ma.Multiaddr{observedQuic}, o.AddrsFor(quic4ListenAddr)) + require.ElementsMatch(t, []ma.Multiaddr{observedQuic, observedWebTransport}, o.Addrs()) + + for i := 0; i < 3; i++ { + // remove non-recorded connection + o.removeConn(c6) + } + require.Equal(t, []ma.Multiaddr{observedWebTransport}, o.AddrsFor(webTransport4ListenAddr)) + require.Equal(t, []ma.Multiaddr{observedQuic}, o.AddrsFor(quic4ListenAddr)) + require.ElementsMatch(t, []ma.Multiaddr{observedQuic, observedWebTransport}, o.Addrs()) o.removeConn(c1) o.removeConn(c2) o.removeConn(c3) o.removeConn(c4) + o.removeConn(c5) require.Eventually(t, func() bool { return checkAllEntriesRemoved(o) }, 1*time.Second, 100*time.Millisecond) @@ -411,7 +425,7 @@ func TestObservedAddrManager(t *testing.T) { return checkAllEntriesRemoved(o) }, 1*time.Second, 100*time.Millisecond) }) - t.Run("Nill Input", func(t *testing.T) { + t.Run("Nil Input", func(t *testing.T) { o := newObservedAddrMgr() defer o.Close() o.maybeRecordObservation(nil, nil)