From 035ffcef01a4367067c6509377bbea473df95dfc Mon Sep 17 00:00:00 2001 From: Corentin Barreau Date: Thu, 26 Sep 2024 12:43:18 +0200 Subject: [PATCH] add: AAAA lookup in addition to the A record --- dns.go | 74 ++++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 20 deletions(-) diff --git a/dns.go b/dns.go index 6ffd2ea..fbc300c 100644 --- a/dns.go +++ b/dns.go @@ -3,6 +3,7 @@ package warc import ( "fmt" "net" + "sync" "time" "github.com/miekg/dns" @@ -30,38 +31,38 @@ func (d *customDialer) archiveDNS(address string) (resolvedIP net.IP, err error) if cached, ok := d.DNSRecords.Load(address); ok { cachedEntry := cached.(cachedIP) if time.Now().Before(cachedEntry.expiresAt) { - return resolvedIP, nil + return cachedEntry.ip, nil } // Cache entry expired, remove it d.DNSRecords.Delete(address) } - m := new(dns.Msg) - m.SetQuestion(dns.Fqdn(address), dns.TypeA) + var wg sync.WaitGroup + var ipv4, ipv6 net.IP + var errA, errAAAA error - r, _, err := d.DNSClient.Exchange(m, net.JoinHostPort(d.DNSConfig.Servers[0], d.DNSConfig.Port)) - if err != nil { - return resolvedIP, err - } + wg.Add(2) - // Record the DNS response - d.client.WriteRecord("dns:"+address, "resource", "text/dns", r.String()) + go func() { + defer wg.Done() + ipv4, errA = d.lookupIP(address, dns.TypeA) + }() - var ipv4, ipv6 net.IP + go func() { + defer wg.Done() + ipv6, errAAAA = d.lookupIP(address, dns.TypeAAAA) + }() - for _, answer := range r.Answer { - if a, ok := answer.(*dns.A); ok && !d.disableIPv4 { - ipv4 = a.A - } else if aaaa, ok := answer.(*dns.AAAA); ok && !d.disableIPv6 { - ipv6 = aaaa.AAAA - break // Prioritize IPv6 if available - } + wg.Wait() + + if errA != nil && errAAAA != nil { + return nil, fmt.Errorf("failed to resolve DNS: A error: %v, AAAA error: %v", errA, errAAAA) } // Prioritize IPv6 if both are available and enabled - if ipv6 != nil { + if ipv6 != nil && !d.disableIPv6 { resolvedIP = ipv6 - } else if ipv4 != nil { + } else if ipv4 != nil && !d.disableIPv4 { resolvedIP = ipv4 } @@ -74,5 +75,38 @@ func (d *customDialer) archiveDNS(address string) (resolvedIP net.IP, err error) return resolvedIP, nil } - return resolvedIP, fmt.Errorf("no suitable IP address found for %s", address) + return nil, fmt.Errorf("no suitable IP address found for %s", address) +} + +func (d *customDialer) lookupIP(address string, recordType uint16) (net.IP, error) { + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(address), recordType) + + r, _, err := d.DNSClient.Exchange(m, net.JoinHostPort(d.DNSConfig.Servers[0], d.DNSConfig.Port)) + if err != nil { + return nil, err + } + + // Record the DNS response + recordTypeStr := "A" + if recordType == dns.TypeAAAA { + recordTypeStr = "AAAA" + } + + d.client.WriteRecord(fmt.Sprintf("dns:%s:%s", address, recordTypeStr), "resource", "text/dns", r.String()) + + for _, answer := range r.Answer { + switch recordType { + case dns.TypeA: + if a, ok := answer.(*dns.A); ok { + return a.A, nil + } + case dns.TypeAAAA: + if aaaa, ok := answer.(*dns.AAAA); ok { + return aaaa.AAAA, nil + } + } + } + + return nil, fmt.Errorf("no %s record found", recordTypeStr) }