From df8e50cd33fa76d441c3ce85031d11637dac4f67 Mon Sep 17 00:00:00 2001 From: Corentin Barreau Date: Thu, 26 Sep 2024 12:11:03 +0200 Subject: [PATCH] fix: don't use the result of the DNS archiving for DialContext --- dialer.go | 29 ++++++++--------------------- dns.go | 23 +++++++++++------------ go.mod | 2 +- go.sum | 6 ------ random_local_ip.go | 19 ++++++------------- 5 files changed, 26 insertions(+), 53 deletions(-) diff --git a/dialer.go b/dialer.go index f321128..86b7775 100644 --- a/dialer.go +++ b/dialer.go @@ -28,12 +28,11 @@ type customDialer struct { DNSConfig *dns.ClientConfig DNSClient *dns.Client DNSRecords *sync.Map - // This defines the TTL for DNS records in the cache - DNSRecordsTTL time.Duration net.Dialer - DNSServer string - disableIPv4 bool - disableIPv6 bool + DNSServer string + DNSRecordsTTL time.Duration + disableIPv4 bool + disableIPv6 bool } func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout, DNSRecordsTTL, DNSResolutionTimeout time.Duration, DNSServers []string, disableIPv4, disableIPv6 bool) (d *customDialer, err error) { @@ -120,22 +119,16 @@ func (d *customDialer) CustomDial(network, address string) (conn net.Conn, err e return nil, errors.New("no supported network type available") } - IP, port, err := d.resolveDNS(address) + IP, err := d.archiveDNS(address) if err != nil { return nil, err } - if port != "" { - address = net.JoinHostPort(IP.String(), port) - } else { - address = IP.String() + ":80" - } - if d.proxyDialer != nil { conn, err = d.proxyDialer.Dial(network, address) } else { if d.client.randomLocalIP { - localAddr := getLocalAddr(network, address) + localAddr := getLocalAddr(network, IP.String()) if localAddr != nil { if network == "tcp" || network == "tcp4" || network == "tcp6" { d.LocalAddr = localAddr.(*net.TCPAddr) @@ -162,24 +155,18 @@ func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error) return nil, errors.New("no supported network type available") } - IP, port, err := d.resolveDNS(address) + IP, err := d.archiveDNS(address) if err != nil { return nil, err } - if port != "" { - address = net.JoinHostPort(IP.String(), port) - } else { - address = IP.String() + ":443" - } - var plainConn net.Conn if d.proxyDialer != nil { plainConn, err = d.proxyDialer.Dial(network, address) } else { if d.client.randomLocalIP { - localAddr := getLocalAddr(network, address) + localAddr := getLocalAddr(network, IP.String()) if localAddr != nil { if network == "tcp" || network == "tcp4" || network == "tcp6" { d.LocalAddr = localAddr.(*net.TCPAddr) diff --git a/dns.go b/dns.go index b3c141e..6ffd2ea 100644 --- a/dns.go +++ b/dns.go @@ -9,28 +9,28 @@ import ( ) type cachedIP struct { - ip net.IP expiresAt time.Time + ip net.IP } -func (d *customDialer) resolveDNS(address string) (IP net.IP, port string, err error) { +func (d *customDialer) archiveDNS(address string) (resolvedIP net.IP, err error) { // Get the address without the port if there is one - address, port, err = net.SplitHostPort(address) + address, _, err = net.SplitHostPort(address) if err != nil { - return nil, "", err + return resolvedIP, err } // Check if the address is already an IP - IP = net.ParseIP(address) - if IP != nil { - return IP, port, nil + resolvedIP = net.ParseIP(address) + if resolvedIP != nil { + return resolvedIP, nil } // Check cache first if cached, ok := d.DNSRecords.Load(address); ok { cachedEntry := cached.(cachedIP) if time.Now().Before(cachedEntry.expiresAt) { - return cachedEntry.ip, port, nil + return resolvedIP, nil } // Cache entry expired, remove it d.DNSRecords.Delete(address) @@ -41,7 +41,7 @@ func (d *customDialer) resolveDNS(address string) (IP net.IP, port string, err e r, _, err := d.DNSClient.Exchange(m, net.JoinHostPort(d.DNSConfig.Servers[0], d.DNSConfig.Port)) if err != nil { - return nil, port, err + return resolvedIP, err } // Record the DNS response @@ -58,7 +58,6 @@ func (d *customDialer) resolveDNS(address string) (IP net.IP, port string, err e } } - var resolvedIP net.IP // Prioritize IPv6 if both are available and enabled if ipv6 != nil { resolvedIP = ipv6 @@ -72,8 +71,8 @@ func (d *customDialer) resolveDNS(address string) (IP net.IP, port string, err e ip: resolvedIP, expiresAt: time.Now().Add(d.DNSRecordsTTL), }) - return resolvedIP, port, nil + return resolvedIP, nil } - return nil, port, fmt.Errorf("no suitable IP address found for %s", address) + return resolvedIP, fmt.Errorf("no suitable IP address found for %s", address) } diff --git a/go.mod b/go.mod index 7f50698..7257961 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 github.com/google/uuid v1.6.0 github.com/klauspost/compress v1.17.10 + github.com/miekg/dns v1.1.62 github.com/paulbellamy/ratecounter v0.2.0 github.com/refraction-networking/utls v1.6.7 github.com/remeh/sizedwaitgroup v1.0.0 @@ -20,7 +21,6 @@ require ( github.com/andybalholm/brotli v1.1.0 // indirect github.com/cloudflare/circl v1.4.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/miekg/dns v1.1.62 // indirect github.com/spf13/pflag v1.0.5 // indirect golang.org/x/crypto v0.27.0 // indirect golang.org/x/mod v0.21.0 // indirect diff --git a/go.sum b/go.sum index c763134..1dfe119 100644 --- a/go.sum +++ b/go.sum @@ -11,8 +11,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= -github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/compress v1.17.10 h1:oXAz+Vh0PMUvJczoi+flxpnBEPxoER1IaAnU/NMPtT0= github.com/klauspost/compress v1.17.10/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= @@ -38,8 +36,6 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= -golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= -golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= @@ -48,8 +44,6 @@ golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= -golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/random_local_ip.go b/random_local_ip.go index 0cc7c8a..5aa0318 100644 --- a/random_local_ip.go +++ b/random_local_ip.go @@ -107,30 +107,23 @@ func getNextIP(availableIPs *availableIPs) net.IP { return ip } -func getLocalAddr(network, address string) any { - host, _, err := net.SplitHostPort(address) - if err != nil { - return nil - } - - destAddr := strings.Trim(host, "[]") - - destIP := net.ParseIP(destAddr) +func getLocalAddr(network, IP string) any { + destIP := net.ParseIP(strings.Trim(IP, "[]")) if destIP == nil { return nil } if destIP.To4() != nil { - if network == "tcp" { + if strings.Contains(network, "tcp") { return &net.TCPAddr{IP: getNextIP(IPv4)} - } else if network == "udp" { + } else if strings.Contains(network, "udp") { return &net.UDPAddr{IP: getNextIP(IPv4)} } return nil } else { - if network == "tcp" { + if strings.Contains(network, "tcp") { return &net.TCPAddr{IP: getNextIP(IPv6)} - } else if network == "udp" { + } else if strings.Contains(network, "udp") { return &net.UDPAddr{IP: getNextIP(IPv6)} } return nil