diff --git a/client/client.go b/client/client.go index 8d9c162..e67f79b 100644 --- a/client/client.go +++ b/client/client.go @@ -4,6 +4,7 @@ import ( "errors" "net" "strings" + "sync" "time" "github.com/miekg/dns" @@ -18,7 +19,7 @@ type Client struct { // Response stores a DNS response. type Response struct { - Server string + Server Server Addr string Msg *dns.Msg RTT time.Duration @@ -29,12 +30,16 @@ type Responses []Response // Fastest returns the fastest success response or nil. func (rs Responses) Fastest() *Response { + var fr Response for _, r := range rs { - if r.Err == nil { - return &r + if r.Err != nil { + continue + } + if fr.Msg == nil || ((r.RTT + r.Server.LookupRTT) < (fr.RTT + fr.Server.LookupRTT)) { + fr = r } } - return nil + return &fr } type Tracer struct { @@ -58,14 +63,14 @@ func (c *Client) ParallelQuery(m *dns.Msg, servers []Server) Responses { for _, s := range servers { for _, addr := range s.Addrs { cnt++ - go func(name, addr string) { + go func(s Server, addr string) { r := Response{ - Server: name, + Server: s, Addr: addr, } r.Msg, r.RTT, r.Err = c.Exchange(m, net.JoinHostPort(addr, "53")) rc <- r - }(s.Name, addr) + }(s, addr) } } rs := make([]Response, 0, cnt) @@ -103,7 +108,7 @@ func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time } return nil, rtt, errors.New("no response") } - rtt += fr.RTT + rtt += fr.Server.LookupRTT + fr.RTT var done bool var deleg bool @@ -129,8 +134,7 @@ func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time } if deleg { - lrttc := make(chan time.Duration) - lc := 0 + wg := &sync.WaitGroup{} for _, ns := range r.Ns { ns, ok := ns.(*dns.NS) if !ok { @@ -155,7 +159,7 @@ func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time Addrs: addrs, } if !s.HasGlue { - lc++ + wg.Add(1) go func() { var err error lm := m.Copy() @@ -165,7 +169,7 @@ func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time s.LookupErr = err } c.DCache.Add(name, s) - lrttc <- s.LookupRTT + wg.Done() }() continue } @@ -176,14 +180,7 @@ func (c *Client) RecursiveQuery(m *dns.Msg, tracer Tracer) (r *dns.Msg, rtt time break } } - var lrtt time.Duration - for ; lc > 0; lc-- { - d := <-lrttc - if lrtt == 0 || lrtt > d { - lrtt = d - } - } - rtt += lrtt + wg.Wait() } if tracer.GotDelegateResponses != nil { diff --git a/main.go b/main.go index 4a33b1e..bebec8a 100644 --- a/main.go +++ b/main.go @@ -103,7 +103,8 @@ func main() { ln = pr.Msg.Len() } rtt := float64(pr.RTT) / float64(time.Millisecond) - fmt.Printf(col(" - %d bytes in %6.2fms on %s(%s)", cDarkGray), ln, rtt, pr.Addr, pr.Server) + lrtt := float64(pr.Server.LookupRTT) / float64(time.Millisecond) + fmt.Printf(col(" - %3d bytes in %6.2fms + %6.2fms on %s(%s)", cDarkGray), ln, rtt, lrtt, pr.Addr, pr.Server.Name) if pr.Err != nil { err := pr.Err if oerr, ok := err.(*net.OpError); ok { @@ -146,13 +147,13 @@ func main() { }, } r, rtt, err := c.RecursiveQuery(m, t) - if err != nil { fmt.Printf(col("*** error: %v\n", cRed), err) os.Exit(1) } + fmt.Println() - fmt.Printf(col(";; Cold best path RTT: %s\n\n", cGray), rtt) + fmt.Printf(col(";; Cold best path time: %s\n\n", cGray), rtt) for _, rr := range r.Answer { fmt.Println(rr) }