diff --git a/p2p/transport/quicreuse/connmgr_test.go b/p2p/transport/quicreuse/connmgr_test.go index 16677c50fe..b5399e8318 100644 --- a/p2p/transport/quicreuse/connmgr_test.go +++ b/p2p/transport/quicreuse/connmgr_test.go @@ -25,7 +25,7 @@ func checkClosed(t *testing.T, cm *ConnManager) { continue } r.mutex.Lock() - for _, conn := range r.global { + for _, conn := range r.globalListeners { require.Zero(t, conn.GetCount()) } for _, conns := range r.unicast { diff --git a/p2p/transport/quicreuse/reuse.go b/p2p/transport/quicreuse/reuse.go index 684a7e0b10..cc90038efe 100644 --- a/p2p/transport/quicreuse/reuse.go +++ b/p2p/transport/quicreuse/reuse.go @@ -74,16 +74,21 @@ type reuse struct { routes routing.Router unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn - // global contains connections that are listening on 0.0.0.0 / :: - global map[int]*reuseConn + // globalListeners contains connections that are listening on 0.0.0.0 / :: + globalListeners map[int]*reuseConn + // globalDialers contains connections that we've dialed out from. These connections are listening on 0.0.0.0 / :: + // On Dial, connections are reused from this map if no connection is available in the globalListeners + // On Listen, connections are reused from this map if the requested port is 0, and then moved to globalListeners + globalDialers map[int]*reuseConn } func newReuse() *reuse { r := &reuse{ - unicast: make(map[string]map[int]*reuseConn), - global: make(map[int]*reuseConn), - closeChan: make(chan struct{}), - gcStopChan: make(chan struct{}), + unicast: make(map[string]map[int]*reuseConn), + globalListeners: make(map[int]*reuseConn), + globalDialers: make(map[int]*reuseConn), + closeChan: make(chan struct{}), + gcStopChan: make(chan struct{}), } go r.gc() return r @@ -92,7 +97,10 @@ func newReuse() *reuse { func (r *reuse) gc() { defer func() { r.mutex.Lock() - for _, conn := range r.global { + for _, conn := range r.globalListeners { + conn.Close() + } + for _, conn := range r.globalDialers { conn.Close() } for _, conns := range r.unicast { @@ -113,10 +121,16 @@ func (r *reuse) gc() { case <-ticker.C: now := time.Now() r.mutex.Lock() - for key, conn := range r.global { + for key, conn := range r.globalListeners { + if conn.ShouldGarbageCollect(now) { + conn.Close() + delete(r.globalListeners, key) + } + } + for key, conn := range r.globalDialers { if conn.ShouldGarbageCollect(now) { conn.Close() - delete(r.global, key) + delete(r.globalDialers, key) } } for ukey, conns := range r.unicast { @@ -185,7 +199,12 @@ func (r *reuse) dialLocked(network string, source *net.IP) (*reuseConn, error) { // Use a connection listening on 0.0.0.0 (or ::). // Again, we don't care about the port number. - for _, conn := range r.global { + for _, conn := range r.globalListeners { + return conn, nil + } + + // Use a connection we've previously dialed from + for _, conn := range r.globalDialers { return conn, nil } @@ -203,29 +222,59 @@ func (r *reuse) dialLocked(network string, source *net.IP) (*reuseConn, error) { return nil, err } rconn := newReuseConn(conn) - r.global[conn.LocalAddr().(*net.UDPAddr).Port] = rconn + r.globalDialers[conn.LocalAddr().(*net.UDPAddr).Port] = rconn return rconn, nil } func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) { + r.mutex.Lock() + defer r.mutex.Unlock() + + // Check if we can reuse a connection we have already dialed out from. + // We reuse a connection from globalDialers when the requested port is 0 or the requested + // port is already in the globalDialers. + // If we are reusing a connection from globalDialers, we move the globalDialers entry to + // globalListeners + if laddr.IP.IsUnspecified() { + var rconn *reuseConn + var localAddr *net.UDPAddr + + if laddr.Port == 0 { + // the requested port is 0, we can reuse any connection + for _, conn := range r.globalDialers { + rconn = conn + localAddr = rconn.UDPConn.LocalAddr().(*net.UDPAddr) + delete(r.globalDialers, localAddr.Port) + break + } + } else if _, ok := r.globalDialers[laddr.Port]; ok { + rconn = r.globalDialers[laddr.Port] + localAddr = rconn.UDPConn.LocalAddr().(*net.UDPAddr) + delete(r.globalDialers, localAddr.Port) + } + // found a match + if rconn != nil { + rconn.IncreaseCount() + r.globalListeners[localAddr.Port] = rconn + return rconn, nil + } + } + conn, err := net.ListenUDP(network, laddr) if err != nil { return nil, err } localAddr := conn.LocalAddr().(*net.UDPAddr) - rconn := newReuseConn(conn) - rconn.IncreaseCount() - r.mutex.Lock() - defer r.mutex.Unlock() + rconn.IncreaseCount() // Deal with listen on a global address if localAddr.IP.IsUnspecified() { // The kernel already checked that the laddr is not already listen // so we need not check here (when we create ListenUDP). - r.global[localAddr.Port] = rconn - return rconn, err + r.globalListeners[localAddr.Port] = rconn + return rconn, nil } // Deal with listen on a unicast address @@ -239,7 +288,7 @@ func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) { // The kernel already checked that the laddr is not already listen // so we need not check here (when we create ListenUDP). r.unicast[localAddr.IP.String()][localAddr.Port] = rconn - return rconn, err + return rconn, nil } func (r *reuse) Close() error { diff --git a/p2p/transport/quicreuse/reuse_test.go b/p2p/transport/quicreuse/reuse_test.go index 36a109c80d..0cd62f0d51 100644 --- a/p2p/transport/quicreuse/reuse_test.go +++ b/p2p/transport/quicreuse/reuse_test.go @@ -21,7 +21,12 @@ func (c *reuseConn) GetCount() int { func closeAllConns(reuse *reuse) { reuse.mutex.Lock() - for _, conn := range reuse.global { + for _, conn := range reuse.globalListeners { + for conn.GetCount() > 0 { + conn.DecreaseCount() + } + } + for _, conn := range reuse.globalDialers { for conn.GetCount() > 0 { conn.DecreaseCount() } @@ -110,6 +115,52 @@ func TestReuseConnectionWhenDialing(t *testing.T) { require.Equal(t, conn.GetCount(), 2) } +func TestReuseConnectionWhenListening(t *testing.T) { + reuse := newReuse() + cleanup(t, reuse) + + raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") + require.NoError(t, err) + conn, err := reuse.Dial("udp4", raddr) + require.NoError(t, err) + laddr := &net.UDPAddr{IP: net.IPv4zero, Port: conn.UDPConn.LocalAddr().(*net.UDPAddr).Port} + lconn, err := reuse.Listen("udp4", laddr) + require.NoError(t, err) + require.Equal(t, lconn.GetCount(), 2) + require.Equal(t, conn.GetCount(), 2) +} + +func TestReuseConnectionWhenDialBeforeListen(t *testing.T) { + reuse := newReuse() + cleanup(t, reuse) + + // dial any address + raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") + require.NoError(t, err) + rconn, err := reuse.Dial("udp4", raddr) + require.NoError(t, err) + + // open a listener + laddr := &net.UDPAddr{IP: net.IPv4zero, Port: 1234} + lconn, err := reuse.Listen("udp4", laddr) + require.NoError(t, err) + + // new dials should go via the listener connection + raddr, err = net.ResolveUDPAddr("udp4", "1.1.1.1:1235") + require.NoError(t, err) + conn, err := reuse.Dial("udp4", raddr) + require.NoError(t, err) + require.Equal(t, conn, lconn) + require.Equal(t, conn.GetCount(), 2) + + // a listener on an unspecified port should reuse the dialer + laddr2 := &net.UDPAddr{IP: net.IPv4zero, Port: 0} + lconn2, err := reuse.Listen("udp4", laddr2) + require.NoError(t, err) + require.Equal(t, lconn2, rconn) + require.Equal(t, lconn2.GetCount(), 2) +} + func TestReuseListenOnSpecificInterface(t *testing.T) { if platformHasRoutingTables() { t.Skip("this test only works on platforms that support routing tables") @@ -157,10 +208,16 @@ func TestReuseGarbageCollect(t *testing.T) { numGlobals := func() int { reuse.mutex.Lock() defer reuse.mutex.Unlock() - return len(reuse.global) + return len(reuse.globalListeners) + len(reuse.globalDialers) } - addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0") + raddr, err := net.ResolveUDPAddr("udp4", "1.2.3.4:1234") + require.NoError(t, err) + dconn, err := reuse.Dial("udp4", raddr) + require.NoError(t, err) + require.Equal(t, dconn.GetCount(), 1) + + addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:1234") require.NoError(t, err) lconn, err := reuse.Listen("udp4", addr) require.NoError(t, err) @@ -168,13 +225,14 @@ func TestReuseGarbageCollect(t *testing.T) { closeTime := time.Now() lconn.DecreaseCount() + dconn.DecreaseCount() for { num := numGlobals() if closeTime.Add(maxUnusedDuration).Before(time.Now()) { break } - require.Equal(t, num, 1) + require.Equal(t, num, 2) time.Sleep(2 * time.Millisecond) } require.Eventually(t, func() bool { return numGlobals() == 0 }, 4*garbageCollectInterval, 10*time.Millisecond)