diff --git a/p2p/net/swarm/dial_sync.go b/p2p/net/swarm/dial_sync.go index edb6c89821..3179016661 100644 --- a/p2p/net/swarm/dial_sync.go +++ b/p2p/net/swarm/dial_sync.go @@ -5,132 +5,122 @@ import ( "errors" "sync" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" ) // TODO: change this text when we fix the bug var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW") -// DialFunc is the type of function expected by DialSync. -type DialFunc func(context.Context, peer.ID) (*Conn, error) +// DialWorerFunc is used by DialSync to spawn a new dial worker +type dialWorkerFunc func(context.Context, peer.ID, <-chan dialRequest) error -// NewDialSync constructs a new DialSync -func NewDialSync(dfn DialFunc) *DialSync { +// newDialSync constructs a new DialSync +func newDialSync(worker dialWorkerFunc) *DialSync { return &DialSync{ - dials: make(map[peer.ID]*activeDial), - dialFunc: dfn, + dials: make(map[peer.ID]*activeDial), + dialWorker: worker, } } // DialSync is a dial synchronization helper that ensures that at most one dial // to any given peer is active at any given time. type DialSync struct { - dials map[peer.ID]*activeDial - dialsLk sync.Mutex - dialFunc DialFunc + dials map[peer.ID]*activeDial + dialsLk sync.Mutex + dialWorker dialWorkerFunc } type activeDial struct { - id peer.ID - refCnt int - refCntLk sync.Mutex - cancel func() + id peer.ID + refCnt int - err error - conn *Conn - waitch chan struct{} + ctx context.Context + cancel func() + + reqch chan dialRequest ds *DialSync } -func (ad *activeDial) wait(ctx context.Context) (*Conn, error) { - defer ad.decref() - select { - case <-ad.waitch: - return ad.conn, ad.err - case <-ctx.Done(): - return nil, ctx.Err() +func (ad *activeDial) decref() { + ad.ds.dialsLk.Lock() + ad.refCnt-- + if ad.refCnt == 0 { + ad.cancel() + close(ad.reqch) + delete(ad.ds.dials, ad.id) } + ad.ds.dialsLk.Unlock() } -func (ad *activeDial) incref() { - ad.refCntLk.Lock() - defer ad.refCntLk.Unlock() - ad.refCnt++ -} +func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) { + dialCtx := ad.ctx -func (ad *activeDial) decref() { - ad.refCntLk.Lock() - ad.refCnt-- - maybeZero := (ad.refCnt <= 0) - ad.refCntLk.Unlock() - - // make sure to always take locks in correct order. - if maybeZero { - ad.ds.dialsLk.Lock() - ad.refCntLk.Lock() - // check again after lock swap drop to make sure nobody else called incref - // in between locks - if ad.refCnt <= 0 { - ad.cancel() - delete(ad.ds.dials, ad.id) - } - ad.refCntLk.Unlock() - ad.ds.dialsLk.Unlock() + if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect { + dialCtx = network.WithForceDirectDial(dialCtx, reason) + } + if simConnect, reason := network.GetSimultaneousConnect(ctx); simConnect { + dialCtx = network.WithSimultaneousConnect(dialCtx, reason) } -} -func (ad *activeDial) start(ctx context.Context) { - ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id) - - // This isn't the user's context so we should fix the error. - switch ad.err { - case context.Canceled: - // The dial was canceled with `CancelDial`. - ad.err = errDialCanceled - case context.DeadlineExceeded: - // We hit an internal timeout, not a context timeout. - ad.err = ErrDialTimeout + resch := make(chan dialResponse, 1) + select { + case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}: + case <-ctx.Done(): + return nil, ctx.Err() + } + + select { + case res := <-resch: + return res.conn, res.err + case <-ctx.Done(): + return nil, ctx.Err() } - close(ad.waitch) - ad.cancel() } -func (ds *DialSync) getActiveDial(ctx context.Context, p peer.ID) *activeDial { +func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) { ds.dialsLk.Lock() defer ds.dialsLk.Unlock() actd, ok := ds.dials[p] if !ok { - adctx, cancel := context.WithCancel(ctx) + // This code intentionally uses the background context. Otherwise, if the first call + // to Dial is canceled, subsequent dial calls will also be canceled. + // XXX: this also breaks direct connection logic. We will need to pipe the + // information through some other way. + adctx, cancel := context.WithCancel(context.Background()) actd = &activeDial{ id: p, + ctx: adctx, cancel: cancel, - waitch: make(chan struct{}), + reqch: make(chan dialRequest), ds: ds, } - ds.dials[p] = actd - go actd.start(adctx) + err := ds.dialWorker(adctx, p, actd.reqch) + if err != nil { + cancel() + return nil, err + } + + ds.dials[p] = actd } // increase ref count before dropping dialsLk - actd.incref() + actd.refCnt++ - return actd + return actd, nil } // DialLock initiates a dial to the given peer if there are none in progress // then waits for the dial to that peer to complete. func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { - return ds.getActiveDial(ctx, p).wait(ctx) -} - -// CancelDial cancels all in-progress dials to the given peer. -func (ds *DialSync) CancelDial(p peer.ID) { - ds.dialsLk.Lock() - defer ds.dialsLk.Unlock() - if ad, ok := ds.dials[p]; ok { - ad.cancel() + ad, err := ds.getActiveDial(p) + if err != nil { + return nil, err } + + defer ad.decref() + return ad.dial(ctx, p) } diff --git a/p2p/net/swarm/dial_sync_test.go b/p2p/net/swarm/dial_sync_test.go index 485d1a3171..59ace9ae67 100644 --- a/p2p/net/swarm/dial_sync_test.go +++ b/p2p/net/swarm/dial_sync_test.go @@ -1,4 +1,4 @@ -package swarm_test +package swarm import ( "context" @@ -7,24 +7,37 @@ import ( "testing" "time" - . "github.com/libp2p/go-libp2p-swarm" - "github.com/libp2p/go-libp2p-core/peer" ) -func getMockDialFunc() (DialFunc, func(), context.Context, <-chan struct{}) { +func getMockDialFunc() (dialWorkerFunc, func(), context.Context, <-chan struct{}) { dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care dialctx, cancel := context.WithCancel(context.Background()) ch := make(chan struct{}) - f := func(ctx context.Context, p peer.ID) (*Conn, error) { + f := func(ctx context.Context, p peer.ID, reqch <-chan dialRequest) error { dfcalls <- struct{}{} - defer cancel() - select { - case <-ch: - return new(Conn), nil - case <-ctx.Done(): - return nil, ctx.Err() - } + go func() { + defer cancel() + for { + select { + case req, ok := <-reqch: + if !ok { + return + } + + select { + case <-ch: + req.resch <- dialResponse{conn: new(Conn)} + case <-ctx.Done(): + req.resch <- dialResponse{err: ctx.Err()} + return + } + case <-ctx.Done(): + return + } + } + }() + return nil } o := new(sync.Once) @@ -35,7 +48,7 @@ func getMockDialFunc() (DialFunc, func(), context.Context, <-chan struct{}) { func TestBasicDialSync(t *testing.T) { df, done, _, callsch := getMockDialFunc() - dsync := NewDialSync(df) + dsync := newDialSync(df) p := peer.ID("testpeer") @@ -73,7 +86,7 @@ func TestBasicDialSync(t *testing.T) { func TestDialSyncCancel(t *testing.T) { df, done, _, dcall := getMockDialFunc() - dsync := NewDialSync(df) + dsync := newDialSync(df) p := peer.ID("testpeer") @@ -124,7 +137,7 @@ func TestDialSyncCancel(t *testing.T) { func TestDialSyncAllCancel(t *testing.T) { df, done, dctx, _ := getMockDialFunc() - dsync := NewDialSync(df) + dsync := newDialSync(df) p := peer.ID("testpeer") @@ -174,15 +187,31 @@ func TestDialSyncAllCancel(t *testing.T) { func TestFailFirst(t *testing.T) { var count int - f := func(ctx context.Context, p peer.ID) (*Conn, error) { - if count > 0 { - return new(Conn), nil - } - count++ - return nil, fmt.Errorf("gophers ate the modem") + f := func(ctx context.Context, p peer.ID, reqch <-chan dialRequest) error { + go func() { + for { + select { + case req, ok := <-reqch: + if !ok { + return + } + + if count > 0 { + req.resch <- dialResponse{conn: new(Conn)} + } else { + req.resch <- dialResponse{err: fmt.Errorf("gophers ate the modem")} + } + count++ + + case <-ctx.Done(): + return + } + } + }() + return nil } - ds := NewDialSync(f) + ds := newDialSync(f) p := peer.ID("testing") @@ -205,8 +234,22 @@ func TestFailFirst(t *testing.T) { } func TestStressActiveDial(t *testing.T) { - ds := NewDialSync(func(ctx context.Context, p peer.ID) (*Conn, error) { - return nil, nil + ds := newDialSync(func(ctx context.Context, p peer.ID, reqch <-chan dialRequest) error { + go func() { + for { + select { + case req, ok := <-reqch: + if !ok { + return + } + + req.resch <- dialResponse{} + case <-ctx.Done(): + return + } + } + }() + return nil }) wg := sync.WaitGroup{} @@ -227,3 +270,24 @@ func TestStressActiveDial(t *testing.T) { wg.Wait() } + +func TestDialSelf(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + self := peer.ID("ABC") + s := NewSwarm(ctx, self, nil, nil) + defer s.Close() + + // this should fail + _, err := s.dsync.DialLock(ctx, self) + if err != ErrDialToSelf { + t.Fatal("expected error from self dial") + } + + // do it twice to make sure we get a new active dial object that fails again + _, err = s.dsync.DialLock(ctx, self) + if err != ErrDialToSelf { + t.Fatal("expected error from self dial") + } +} diff --git a/p2p/net/swarm/dial_test.go b/p2p/net/swarm/dial_test.go index 9fc5df4189..2a966a4662 100644 --- a/p2p/net/swarm/dial_test.go +++ b/p2p/net/swarm/dial_test.go @@ -9,6 +9,7 @@ import ( addrutil "github.com/libp2p/go-addr-util" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" "github.com/libp2p/go-libp2p-core/transport" @@ -524,3 +525,143 @@ func TestDialPeerFailed(t *testing.T) { t.Errorf("expected %d errors, got %d", expectedErrorsCount, len(dialErr.DialErrors)) } } + +func TestDialExistingConnection(t *testing.T) { + ctx := context.Background() + + swarms := makeSwarms(ctx, t, 2) + defer closeSwarms(swarms) + s1 := swarms[0] + s2 := swarms[1] + + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) + + c1, err := s1.DialPeer(ctx, s2.LocalPeer()) + if err != nil { + t.Fatal(err) + } + + c2, err := s1.DialPeer(ctx, s2.LocalPeer()) + if err != nil { + t.Fatal(err) + } + + if c1 != c2 { + t.Fatal("expecting the same connection from both dials") + } +} + +func newSilentListener(t *testing.T) ([]ma.Multiaddr, net.Listener) { + lst, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + addr, err := manet.FromNetAddr(lst.Addr()) + if err != nil { + t.Fatal(err) + } + addrs := []ma.Multiaddr{addr} + addrs, err = addrutil.ResolveUnspecifiedAddresses(addrs, nil) + if err != nil { + t.Fatal(err) + } + return addrs, lst + +} + +func TestDialSimultaneousJoin(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + swarms := makeSwarms(ctx, t, 2) + s1 := swarms[0] + s2 := swarms[1] + defer s1.Close() + defer s2.Close() + + s2silentAddrs, s2silentListener := newSilentListener(t) + go acceptAndHang(s2silentListener) + + connch := make(chan network.Conn, 512) + + // start a dial to s2 through the silent addr + go func() { + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2silentAddrs, peerstore.PermanentAddrTTL) + + c, err := s1.DialPeer(ctx, s2.LocalPeer()) + if err != nil { + t.Fatal(err) + } + + t.Logf("first dial succedded; conn: %+v", c) + + connch <- c + }() + + // wait a bit for the dial to take hold + time.Sleep(100 * time.Millisecond) + + // start a second dial to s2 that uses the real s2 addrs + go func() { + s2addrs, err := s2.InterfaceListenAddresses() + if err != nil { + t.Fatal(err) + } + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2addrs[:1], peerstore.PermanentAddrTTL) + + c, err := s1.DialPeer(ctx, s2.LocalPeer()) + if err != nil { + t.Fatal(err) + } + + t.Logf("second dial succedded; conn: %+v", c) + + connch <- c + }() + + // wait for the second dial to finish + c2 := <-connch + + // start a third dial to s2, this should get the existing connection from the successful dial + go func() { + c, err := s1.DialPeer(ctx, s2.LocalPeer()) + if err != nil { + t.Fatal(err) + } + + t.Logf("third dial succedded; conn: %+v", c) + + connch <- c + }() + + c3 := <-connch + + if c2 != c3 { + t.Fatal("expected c2 and c3 to be the same") + } + + // next, the first dial to s2, using the silent addr should timeout; at this point the dial + // will error but the last chance check will see the existing connection and return it + select { + case c1 := <-connch: + if c1 != c2 { + t.Fatal("expected c1 and c2 to be the same") + } + case <-time.After(2 * transport.DialTimeout): + t.Fatal("no connection from first dial") + } +} + +func TestDialSelf2(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + swarms := makeSwarms(ctx, t, 2) + s1 := swarms[0] + defer s1.Close() + + _, err := s1.DialPeer(ctx, s1.LocalPeer()) + if err != ErrDialToSelf { + t.Fatal("expected error from self dial") + } +} diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 8dde559e7e..d42020a89d 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -122,8 +122,8 @@ func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc } } - s.dsync = NewDialSync(s.doDial) - s.limiter = newDialLimiter(s.dialAddr, s.IsFdConsumingAddr) + s.dsync = newDialSync(s.startDialWorker) + s.limiter = newDialLimiter(s.dialAddr, isFdConsumingAddr) s.proc = goprocessctx.WithContext(ctx) s.ctx = goprocessctx.OnClosingContext(s.proc) s.backf.init(s.ctx) @@ -281,12 +281,6 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, c.notifyLk.Lock() s.conns.Unlock() - // We have a connection now. Cancel all other in-progress dials. - // This should be fast, no reason to wait till later. - if dir == network.DirOutbound { - s.dsync.CancelDial(p) - } - s.notifyAll(func(f network.Notifiee) { f.Connected(s, c) }) @@ -407,6 +401,38 @@ func (s *Swarm) ConnsToPeer(p peer.ID) []network.Conn { return output } +func isBetterConn(a, b *Conn) bool { + // If one is transient and not the other, prefer the non-transient connection. + aTransient := a.Stat().Transient + bTransient := b.Stat().Transient + if aTransient != bTransient { + return !aTransient + } + + // If one is direct and not the other, prefer the direct connection. + aDirect := isDirectConn(a) + bDirect := isDirectConn(b) + if aDirect != bDirect { + return aDirect + } + + // Otherwise, prefer the connection with more open streams. + a.streams.Lock() + aLen := len(a.streams.m) + a.streams.Unlock() + + b.streams.Lock() + bLen := len(b.streams.m) + b.streams.Unlock() + + if aLen != bLen { + return aLen > bLen + } + + // finally, pick the last connection. + return true +} + // bestConnToPeer returns the best connection to peer. func (s *Swarm) bestConnToPeer(p peer.ID) *Conn { @@ -417,31 +443,29 @@ func (s *Swarm) bestConnToPeer(p peer.ID) *Conn { defer s.conns.RUnlock() var best *Conn - bestLen := 0 for _, c := range s.conns.m[p] { if c.conn.IsClosed() { // We *will* garbage collect this soon anyways. continue } - c.streams.Lock() - cLen := len(c.streams.m) - c.streams.Unlock() - - // We will never prefer a Relayed connection over a direct connection. - if isDirectConn(best) && !isDirectConn(c) { - continue - } - - // 1. Always prefer a direct connection over a relayed connection. - // 2. If both conns are direct or relayed, pick the one with as many or more streams. - if (!isDirectConn(best) && isDirectConn(c)) || (cLen >= bestLen) { + if best == nil || isBetterConn(c, best) { best = c - bestLen = cLen } } return best } +func (s *Swarm) bestAcceptableConnToPeer(ctx context.Context, p peer.ID) *Conn { + conn := s.bestConnToPeer(p) + if conn != nil { + forceDirect, _ := network.GetForceDirectDial(ctx) + if !forceDirect || isDirectConn(conn) { + return conn + } + } + return nil +} + func isDirectConn(c *Conn) bool { return c != nil && !c.conn.Transport().Proxy() } diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 052da67fe5..14129257be 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -14,7 +14,6 @@ import ( addrutil "github.com/libp2p/go-addr-util" lgbl "github.com/libp2p/go-libp2p-loggables" - logging "github.com/ipfs/go-log" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" ) @@ -251,14 +250,9 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { defer log.EventBegin(ctx, "swarmDialAttemptSync", p).Done() - conn := s.bestConnToPeer(p) - forceDirect, _ := network.GetForceDirectDial(ctx) - if forceDirect { - if isDirectConn(conn) { - return conn, nil - } - } else if conn != nil { - // check if we already have an open connection first + // check if we already have an open (usable) connection first + conn := s.bestAcceptableConnToPeer(ctx, p) + if conn != nil { return conn, nil } @@ -286,172 +280,375 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { return nil, err } -// doDial is an ugly shim method to retain all the logging and backoff logic -// of the old dialsync code -func (s *Swarm) doDial(ctx context.Context, p peer.ID) (*Conn, error) { - // Short circuit. - // By the time we take the dial lock, we may already *have* a connection - // to the peer. - forceDirect, _ := network.GetForceDirectDial(ctx) - c := s.bestConnToPeer(p) - if forceDirect { - if isDirectConn(c) { - return c, nil - } - } else if c != nil { - return c, nil +/////////////////////////////////////////////////////////////////////////////////// +// lo and behold, The Dialer +// TODO explain how all this works +////////////////////////////////////////////////////////////////////////////////// + +type dialRequest struct { + ctx context.Context + resch chan dialResponse +} + +type dialResponse struct { + conn *Conn + err error +} + +// startDialWorker starts an active dial goroutine that synchronizes and executes concurrent dials +func (s *Swarm) startDialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequest) error { + if p == s.local { + return ErrDialToSelf } - logdial := lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil) + go s.dialWorkerLoop(ctx, p, reqch) + return nil +} - // ok, we have been charged to dial! let's do it. - // if it succeeds, dial will add the conn to the swarm itself. - defer log.EventBegin(ctx, "swarmDialAttemptStart", logdial).Done() +func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan dialRequest) { + defer s.limiter.clearAllPeerDials(p) - conn, err := s.dial(ctx, p) - if err != nil { - conn = s.bestConnToPeer(p) - if forceDirect { - if isDirectConn(conn) { - log.Debugf("ignoring dial error because we already have a direct connection: %s", err) - return conn, nil + type pendRequest struct { + req dialRequest // the original request + err *DialError // dial error accumulator + addrs map[ma.Multiaddr]struct{} // pending addr dials + } + + type addrDial struct { + addr ma.Multiaddr + ctx context.Context + conn *Conn + err error + requests []int + dialed bool + } + + reqno := 0 + requests := make(map[int]*pendRequest) + pending := make(map[ma.Multiaddr]*addrDial) + + dispatchError := func(ad *addrDial, err error) { + ad.err = err + for _, reqno := range ad.requests { + pr, ok := requests[reqno] + if !ok { + // has already been dispatched + continue + } + + // accumulate the error + pr.err.recordErr(ad.addr, err) + + delete(pr.addrs, ad.addr) + if len(pr.addrs) == 0 { + // all addrs have erred, dispatch dial error + // but first do a last one check in case an acceptable connection has landed from + // a simultaneous dial that started later and added new acceptable addrs + c := s.bestAcceptableConnToPeer(pr.req.ctx, p) + if c != nil { + pr.req.resch <- dialResponse{conn: c} + } else { + pr.req.resch <- dialResponse{err: pr.err} + } + delete(requests, reqno) } - } else if conn != nil { - // Hm? What error? - // Could have canceled the dial because we received a - // connection or some other random reason. - // Just ignore the error and return the connection. - log.Debugf("ignoring dial error because we already have a connection: %s", err) - return conn, nil } - // ok, we failed. - return nil, err + ad.requests = nil + + // if it was a backoff, clear the address dial so that it doesn't inhibit new dial requests. + // this is necessary to support active listen scenarios, where a new dial comes in while + // another dial is in progress, and needs to do a direct connection without inhibitions from + // dial backoff. + // it is also necessary to preserve consisent behaviour with the old dialer -- TestDialBackoff + // regresses without this. + if err == ErrDialBackoff { + delete(pending, ad.addr) + } } - return conn, nil -} -func (s *Swarm) canDial(addr ma.Multiaddr) bool { - t := s.TransportForDialing(addr) - return t != nil && t.CanDial(addr) -} + var triggerDial <-chan struct{} + triggerNow := make(chan struct{}) + close(triggerNow) -func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { - t := s.TransportForDialing(addr) - return !t.Proxy() -} + var nextDial []ma.Multiaddr + active := 0 + done := false // true when the request channel has been closed + connected := false // true when a connection has been successfully established -// ranks addresses in descending order of preference for dialing -// Private UDP > Public UDP > Private TCP > Public TCP > UDP Relay server > TCP Relay server -func (s *Swarm) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - var localUdpAddrs []ma.Multiaddr // private udp - var relayUdpAddrs []ma.Multiaddr // relay udp - var othersUdp []ma.Multiaddr // public udp + resch := make(chan dialResult) - var localFdAddrs []ma.Multiaddr // private fd consuming - var relayFdAddrs []ma.Multiaddr // relay fd consuming - var othersFd []ma.Multiaddr // public fd consuming +loop: + for { + select { + case req, ok := <-reqch: + if !ok { + // request channel has been closed, wait for pending dials to complete + if active > 0 { + done = true + reqch = nil + triggerDial = nil + continue loop + } - for _, a := range addrs { - if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil { - if s.IsFdConsumingAddr(a) { - relayFdAddrs = append(relayFdAddrs, a) - continue + // no active dials, we are done + return } - relayUdpAddrs = append(relayUdpAddrs, a) - } else if manet.IsPrivateAddr(a) { - if s.IsFdConsumingAddr(a) { - localFdAddrs = append(localFdAddrs, a) - continue + + c := s.bestAcceptableConnToPeer(req.ctx, p) + if c != nil { + req.resch <- dialResponse{conn: c} + continue loop } - localUdpAddrs = append(localUdpAddrs, a) - } else { - if s.IsFdConsumingAddr(a) { - othersFd = append(othersFd, a) - continue + + addrs, err := s.addrsForDial(req.ctx, p) + if err != nil { + req.resch <- dialResponse{err: err} + continue loop } - othersUdp = append(othersUdp, a) - } - } - relays := append(relayUdpAddrs, relayFdAddrs...) - fds := append(localFdAddrs, othersFd...) + // at this point, len(addrs) > 0 or else it would be error from addrsForDial + // ranke them to process in order + addrs = s.rankAddrs(addrs) - return append(append(append(localUdpAddrs, othersUdp...), fds...), relays...) -} + // create the pending request object + pr := &pendRequest{ + req: req, + err: &DialError{Peer: p}, + addrs: make(map[ma.Multiaddr]struct{}), + } + for _, a := range addrs { + pr.addrs[a] = struct{}{} + } -// dial is the actual swarm's dial logic, gated by Dial. -func (s *Swarm) dial(ctx context.Context, p peer.ID) (*Conn, error) { - forceDirect, _ := network.GetForceDirectDial(ctx) - var logdial = lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil) - if p == s.local { - log.Event(ctx, "swarmDialDoDialSelf", logdial) - return nil, ErrDialToSelf - } - defer log.EventBegin(ctx, "swarmDialDo", logdial).Done() - logdial["dial"] = "failure" // start off with failure. set to "success" at the end. + // check if any of the addrs has been successfully dialed and accumulate + // errors from complete dials while collecting new addrs to dial/join + var todial []ma.Multiaddr + var tojoin []*addrDial + + for _, a := range addrs { + ad, ok := pending[a] + if !ok { + todial = append(todial, a) + continue + } + + if ad.conn != nil { + // dial to this addr was successful, complete the request + req.resch <- dialResponse{conn: ad.conn} + continue loop + } + + if ad.err != nil { + // dial to this addr errored, accumulate the error + pr.err.recordErr(a, ad.err) + delete(pr.addrs, a) + continue + } + + // dial is still pending, add to the join list + tojoin = append(tojoin, ad) + } + + if len(todial) == 0 && len(tojoin) == 0 { + // all request applicable addrs have been dialed, we must have errored + req.resch <- dialResponse{err: pr.err} + continue loop + } - sk := s.peers.PrivKey(s.local) - logdial["encrypted"] = sk != nil // log whether this will be an encrypted dial or not. - if sk == nil { - // fine for sk to be nil, just log. - log.Debug("Dial not given PrivateKey, so WILL NOT SECURE conn.") + // the request has some pending or new dials, track it and schedule new dials + reqno++ + requests[reqno] = pr + + for _, ad := range tojoin { + if !ad.dialed { + ad.ctx = s.mergeDialContexts(ad.ctx, req.ctx) + } + ad.requests = append(ad.requests, reqno) + } + + if len(todial) > 0 { + for _, a := range todial { + pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{reqno}} + } + + nextDial = append(nextDial, todial...) + nextDial = s.rankAddrs(nextDial) + + // trigger a new dial now to account for the new addrs we added + triggerDial = triggerNow + } + + case <-triggerDial: + for _, addr := range nextDial { + // spawn the dial + ad := pending[addr] + err := s.dialNextAddr(ad.ctx, p, addr, resch) + if err != nil { + dispatchError(ad, err) + } + } + + nextDial = nil + triggerDial = nil + + case res := <-resch: + active-- + + if res.Conn != nil { + connected = true + } + + if done && active == 0 { + if res.Conn != nil { + // we got an actual connection, but the dial has been cancelled + // Should we close it? I think not, we should just add it to the swarm + _, err := s.addConn(res.Conn, network.DirOutbound) + if err != nil { + // well duh, now we have to close it + res.Conn.Close() + } + } + return + } + + ad := pending[res.Addr] + + if res.Conn != nil { + // we got a connection, add it to the swarm + conn, err := s.addConn(res.Conn, network.DirOutbound) + if err != nil { + // oops no, we failed to add it to the swarm + res.Conn.Close() + dispatchError(ad, err) + if active == 0 && len(nextDial) > 0 { + triggerDial = triggerNow + } + continue loop + } + + // dispatch to still pending requests + for _, reqno := range ad.requests { + pr, ok := requests[reqno] + if !ok { + // it has already dispatched a connection + continue + } + + pr.req.resch <- dialResponse{conn: conn} + delete(requests, reqno) + } + + ad.conn = conn + ad.requests = nil + + continue loop + } + + // it must be an error -- add backoff if applicable and dispatch + if res.Err != context.Canceled && !connected { + // we only add backoff if there has not been a successful connection + // for consistency with the old dialer behavior. + s.backf.AddBackoff(p, res.Addr) + } + + dispatchError(ad, res.Err) + if active == 0 && len(nextDial) > 0 { + triggerDial = triggerNow + } + } } +} - ////// +func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) { peerAddrs := s.peers.Addrs(p) if len(peerAddrs) == 0 { - return nil, &DialError{Peer: p, Cause: ErrNoAddresses} + return nil, ErrNoAddresses } + goodAddrs := s.filterKnownUndialables(p, peerAddrs) - if forceDirect { + if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect { goodAddrs = addrutil.FilterAddrs(goodAddrs, s.nonProxyAddr) } + if len(goodAddrs) == 0 { - return nil, &DialError{Peer: p, Cause: ErrNoGoodAddresses} + return nil, ErrNoGoodAddresses } - if !forceDirect { - /////// Check backoff andnRank addresses - var nonBackoff bool - for _, a := range goodAddrs { - // skip addresses in back-off - if !s.backf.Backoff(p, a) { - nonBackoff = true - } + return goodAddrs, nil +} + +func (s *Swarm) mergeDialContexts(a, b context.Context) context.Context { + dialCtx := a + + if simConnect, reason := network.GetSimultaneousConnect(b); simConnect { + if simConnect, _ := network.GetSimultaneousConnect(a); !simConnect { + dialCtx = network.WithSimultaneousConnect(dialCtx, reason) } - if !nonBackoff { - return nil, ErrDialBackoff + } + + return dialCtx +} + +func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, resch chan dialResult) error { + // check the dial backoff + if forceDirect, _ := network.GetForceDirectDial(ctx); !forceDirect { + if s.backf.Backoff(p, addr) { + return ErrDialBackoff } } - connC, dialErr := s.dialAddrs(ctx, p, s.rankAddrs(goodAddrs)) - if dialErr != nil { - logdial["error"] = dialErr.Cause.Error() - switch dialErr.Cause { - case context.Canceled, context.DeadlineExceeded: - // Always prefer the context errors as we rely on being - // able to check them. - // - // Removing this will BREAK backoff (causing us to - // backoff when canceling dials). - return nil, dialErr.Cause + // start the dial + s.limitedDial(ctx, p, addr, resch) + + return nil +} + +func (s *Swarm) canDial(addr ma.Multiaddr) bool { + t := s.TransportForDialing(addr) + return t != nil && t.CanDial(addr) +} + +func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { + t := s.TransportForDialing(addr) + return !t.Proxy() +} + +// ranks addresses in descending order of preference for dialing, with the following rules: +// NonRelay > Relay +// NonWS > WS +// Private > Public +// UDP > TCP +func (s *Swarm) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { + addrTier := func(a ma.Multiaddr) (tier int) { + if isRelayAddr(a) { + tier |= 0b1000 + } + if isExpensiveAddr(a) { + tier |= 0b0100 } - return nil, dialErr + if !manet.IsPrivateAddr(a) { + tier |= 0b0010 + } + if isFdConsumingAddr(a) { + tier |= 0b0001 + } + + return tier } - logdial["conn"] = logging.Metadata{ - "localAddr": connC.LocalMultiaddr(), - "remoteAddr": connC.RemoteMultiaddr(), + + tiers := make([][]ma.Multiaddr, 16) + for _, a := range addrs { + tier := addrTier(a) + tiers[tier] = append(tiers[tier], a) } - swarmC, err := s.addConn(connC, network.DirOutbound) - if err != nil { - logdial["error"] = err.Error() - connC.Close() // close the connection. didn't work out :( - return nil, &DialError{Peer: p, Cause: err} + + result := make([]ma.Multiaddr, 0, len(addrs)) + for _, tier := range tiers { + result = append(result, tier...) } - logdial["dial"] = "success" - return swarmC, nil + return result } // filterKnownUndialables takes a list of multiaddrs, and removes those @@ -481,98 +678,6 @@ func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Mul ) } -func (s *Swarm) dialAddrs(ctx context.Context, p peer.ID, remoteAddrs []ma.Multiaddr) (transport.CapableConn, *DialError) { - /* - This slice-to-chan code is temporary, the peerstore can currently provide - a channel as an interface for receiving addresses, but more thought - needs to be put into the execution. For now, this allows us to use - the improved rate limiter, while maintaining the outward behaviour - that we previously had (halting a dial when we run out of addrs) - */ - var remoteAddrChan chan ma.Multiaddr - if len(remoteAddrs) > 0 { - remoteAddrChan = make(chan ma.Multiaddr, len(remoteAddrs)) - for i := range remoteAddrs { - remoteAddrChan <- remoteAddrs[i] - } - close(remoteAddrChan) - } - - log.Debugf("%s swarm dialing %s", s.local, p) - - ctx, cancel := context.WithCancel(ctx) - defer cancel() // cancel work when we exit func - - // use a single response type instead of errs and conns, reduces complexity *a ton* - respch := make(chan dialResult) - err := &DialError{Peer: p} - - defer s.limiter.clearAllPeerDials(p) - - var active int -dialLoop: - for remoteAddrChan != nil || active > 0 { - // Check for context cancellations and/or responses first. - select { - case <-ctx.Done(): - break dialLoop - case resp := <-respch: - active-- - if resp.Err != nil { - // Errors are normal, lots of dials will fail - if resp.Err != context.Canceled { - s.backf.AddBackoff(p, resp.Addr) - } - - log.Infof("got error on dial: %s", resp.Err) - err.recordErr(resp.Addr, resp.Err) - } else if resp.Conn != nil { - return resp.Conn, nil - } - - // We got a result, try again from the top. - continue - default: - } - - // Now, attempt to dial. - select { - case addr, ok := <-remoteAddrChan: - if !ok { - remoteAddrChan = nil - continue - } - - s.limitedDial(ctx, p, addr, respch) - active++ - case <-ctx.Done(): - break dialLoop - case resp := <-respch: - active-- - if resp.Err != nil { - // Errors are normal, lots of dials will fail - if resp.Err != context.Canceled { - s.backf.AddBackoff(p, resp.Addr) - } - - log.Infof("got error on dial: %s", resp.Err) - err.recordErr(resp.Addr, resp.Err) - } else if resp.Conn != nil { - return resp.Conn, nil - } - } - } - - if ctxErr := ctx.Err(); ctxErr != nil { - err.Cause = ctxErr - } else if len(err.DialErrors) == 0 { - err.Cause = network.ErrNoRemoteAddrs - } else { - err.Cause = ErrAllDialsFailed - } - return nil, err -} - // limitedDial will start a dial to the given peer when // it is able, respecting the various different types of rate // limiting that occur without using extra goroutines per addr @@ -585,6 +690,7 @@ func (s *Swarm) limitedDial(ctx context.Context, p peer.ID, a ma.Multiaddr, resp }) } +// dialAddr is the actual dial for an addr, indirectly invoked through the limiter func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr) (transport.CapableConn, error) { // Just to double check. Costs nothing. if s.local == p { @@ -620,7 +726,7 @@ func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr) (tra // A Non-circuit address which has the TCP/UNIX protocol is deemed FD consuming. // For a circuit-relay address, we look at the address of the relay server/proxy // and use the same logic as above to decide. -func (s *Swarm) IsFdConsumingAddr(addr ma.Multiaddr) bool { +func isFdConsumingAddr(addr ma.Multiaddr) bool { first, _ := ma.SplitFunc(addr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CIRCUIT }) @@ -634,3 +740,14 @@ func (s *Swarm) IsFdConsumingAddr(addr ma.Multiaddr) bool { _, err2 := first.ValueForProtocol(ma.P_UNIX) return err1 == nil || err2 == nil } + +func isExpensiveAddr(addr ma.Multiaddr) bool { + _, err1 := addr.ValueForProtocol(ma.P_WS) + _, err2 := addr.ValueForProtocol(ma.P_WSS) + return err1 == nil || err2 == nil +} + +func isRelayAddr(addr ma.Multiaddr) bool { + _, err := addr.ValueForProtocol(ma.P_CIRCUIT) + return err == nil +} diff --git a/p2p/net/swarm/swarm_net_test.go b/p2p/net/swarm/swarm_net_test.go index 2ba64edb96..64121bb1b4 100644 --- a/p2p/net/swarm/swarm_net_test.go +++ b/p2p/net/swarm/swarm_net_test.go @@ -57,8 +57,8 @@ func TestConnectednessCorrect(t *testing.T) { t.Fatal("expected net 0 to have two peers") } - if len(nets[2].Conns()) != 2 { - t.Fatal("expected net 2 to have two conns") + if len(nets[2].Peers()) != 2 { + t.Fatal("expected net 2 to have two peers") } if len(nets[1].ConnsToPeer(nets[3].LocalPeer())) != 0 { diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go index 9b1e9c42d7..4e9801ad08 100644 --- a/p2p/net/swarm/swarm_test.go +++ b/p2p/net/swarm/swarm_test.go @@ -15,7 +15,6 @@ import ( "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" - "github.com/libp2p/go-libp2p-core/test" . "github.com/libp2p/go-libp2p-swarm" . "github.com/libp2p/go-libp2p-swarm/testing" @@ -387,53 +386,6 @@ func TestConnectionGating(t *testing.T) { } } -func TestIsFdConsuming(t *testing.T) { - tcs := map[string]struct { - addr string - isFdConsuming bool - }{ - "tcp": { - addr: "/ip4/127.0.0.1/tcp/20", - isFdConsuming: true, - }, - "quic": { - addr: "/ip4/127.0.0.1/udp/0/quic", - isFdConsuming: false, - }, - "addr-without-registered-transport": { - addr: "/ip4/127.0.0.1/tcp/20/ws", - isFdConsuming: true, - }, - "relay-tcp": { - addr: fmt.Sprintf("/ip4/127.0.0.1/tcp/20/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), - isFdConsuming: true, - }, - "relay-quic": { - addr: fmt.Sprintf("/ip4/127.0.0.1/udp/20/quic/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), - isFdConsuming: false, - }, - "relay-without-serveraddr": { - addr: fmt.Sprintf("/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), - isFdConsuming: true, - }, - "relay-without-registered-transport-server": { - addr: fmt.Sprintf("/ip4/127.0.0.1/tcp/20/ws/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), - isFdConsuming: true, - }, - } - - ctx := context.Background() - sw := GenSwarm(t, ctx) - sk := sw.Peerstore().PrivKey(sw.LocalPeer()) - require.NotNil(t, sk) - - for name := range tcs { - maddr, err := ma.NewMultiaddr(tcs[name].addr) - require.NoError(t, err, name) - require.Equal(t, tcs[name].isFdConsuming, sw.IsFdConsumingAddr(maddr), name) - } -} - func TestNoDial(t *testing.T) { ctx := context.Background() swarms := makeSwarms(ctx, t, 2) diff --git a/p2p/net/swarm/util_test.go b/p2p/net/swarm/util_test.go new file mode 100644 index 0000000000..11124adb27 --- /dev/null +++ b/p2p/net/swarm/util_test.go @@ -0,0 +1,53 @@ +package swarm + +import ( + "fmt" + "testing" + + "github.com/libp2p/go-libp2p-core/test" + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" +) + +func TestIsFdConsuming(t *testing.T) { + tcs := map[string]struct { + addr string + isFdConsuming bool + }{ + "tcp": { + addr: "/ip4/127.0.0.1/tcp/20", + isFdConsuming: true, + }, + "quic": { + addr: "/ip4/127.0.0.1/udp/0/quic", + isFdConsuming: false, + }, + "addr-without-registered-transport": { + addr: "/ip4/127.0.0.1/tcp/20/ws", + isFdConsuming: true, + }, + "relay-tcp": { + addr: fmt.Sprintf("/ip4/127.0.0.1/tcp/20/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), + isFdConsuming: true, + }, + "relay-quic": { + addr: fmt.Sprintf("/ip4/127.0.0.1/udp/20/quic/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), + isFdConsuming: false, + }, + "relay-without-serveraddr": { + addr: fmt.Sprintf("/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), + isFdConsuming: true, + }, + "relay-without-registered-transport-server": { + addr: fmt.Sprintf("/ip4/127.0.0.1/tcp/20/ws/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), + isFdConsuming: true, + }, + } + + for name := range tcs { + maddr, err := ma.NewMultiaddr(tcs[name].addr) + require.NoError(t, err, name) + require.Equal(t, tcs[name].isFdConsuming, isFdConsumingAddr(maddr), name) + } +}