Skip to content

Commit

Permalink
fix: don't use the result of the DNS archiving for DialContext
Browse files Browse the repository at this point in the history
  • Loading branch information
CorentinB committed Sep 26, 2024
1 parent eb3b2b1 commit df8e50c
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 53 deletions.
29 changes: 8 additions & 21 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ type customDialer struct {
DNSConfig *dns.ClientConfig
DNSClient *dns.Client
DNSRecords *sync.Map
// This defines the TTL for DNS records in the cache
DNSRecordsTTL time.Duration
net.Dialer
DNSServer string
disableIPv4 bool
disableIPv6 bool
DNSServer string
DNSRecordsTTL time.Duration
disableIPv4 bool
disableIPv6 bool
}

func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout, DNSRecordsTTL, DNSResolutionTimeout time.Duration, DNSServers []string, disableIPv4, disableIPv6 bool) (d *customDialer, err error) {
Expand Down Expand Up @@ -120,22 +119,16 @@ func (d *customDialer) CustomDial(network, address string) (conn net.Conn, err e
return nil, errors.New("no supported network type available")
}

IP, port, err := d.resolveDNS(address)
IP, err := d.archiveDNS(address)
if err != nil {
return nil, err
}

if port != "" {
address = net.JoinHostPort(IP.String(), port)
} else {
address = IP.String() + ":80"
}

if d.proxyDialer != nil {
conn, err = d.proxyDialer.Dial(network, address)
} else {
if d.client.randomLocalIP {
localAddr := getLocalAddr(network, address)
localAddr := getLocalAddr(network, IP.String())
if localAddr != nil {
if network == "tcp" || network == "tcp4" || network == "tcp6" {
d.LocalAddr = localAddr.(*net.TCPAddr)
Expand All @@ -162,24 +155,18 @@ func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error)
return nil, errors.New("no supported network type available")
}

IP, port, err := d.resolveDNS(address)
IP, err := d.archiveDNS(address)
if err != nil {
return nil, err
}

if port != "" {
address = net.JoinHostPort(IP.String(), port)
} else {
address = IP.String() + ":443"
}

var plainConn net.Conn

if d.proxyDialer != nil {
plainConn, err = d.proxyDialer.Dial(network, address)
} else {
if d.client.randomLocalIP {
localAddr := getLocalAddr(network, address)
localAddr := getLocalAddr(network, IP.String())
if localAddr != nil {
if network == "tcp" || network == "tcp4" || network == "tcp6" {
d.LocalAddr = localAddr.(*net.TCPAddr)
Expand Down
23 changes: 11 additions & 12 deletions dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,28 @@ import (
)

type cachedIP struct {
ip net.IP
expiresAt time.Time
ip net.IP
}

func (d *customDialer) resolveDNS(address string) (IP net.IP, port string, err error) {
func (d *customDialer) archiveDNS(address string) (resolvedIP net.IP, err error) {
// Get the address without the port if there is one
address, port, err = net.SplitHostPort(address)
address, _, err = net.SplitHostPort(address)
if err != nil {
return nil, "", err
return resolvedIP, err
}

// Check if the address is already an IP
IP = net.ParseIP(address)
if IP != nil {
return IP, port, nil
resolvedIP = net.ParseIP(address)
if resolvedIP != nil {
return resolvedIP, nil
}

// Check cache first
if cached, ok := d.DNSRecords.Load(address); ok {
cachedEntry := cached.(cachedIP)
if time.Now().Before(cachedEntry.expiresAt) {
return cachedEntry.ip, port, nil
return resolvedIP, nil
}
// Cache entry expired, remove it
d.DNSRecords.Delete(address)
Expand All @@ -41,7 +41,7 @@ func (d *customDialer) resolveDNS(address string) (IP net.IP, port string, err e

r, _, err := d.DNSClient.Exchange(m, net.JoinHostPort(d.DNSConfig.Servers[0], d.DNSConfig.Port))
if err != nil {
return nil, port, err
return resolvedIP, err
}

// Record the DNS response
Expand All @@ -58,7 +58,6 @@ func (d *customDialer) resolveDNS(address string) (IP net.IP, port string, err e
}
}

var resolvedIP net.IP
// Prioritize IPv6 if both are available and enabled
if ipv6 != nil {
resolvedIP = ipv6
Expand All @@ -72,8 +71,8 @@ func (d *customDialer) resolveDNS(address string) (IP net.IP, port string, err e
ip: resolvedIP,
expiresAt: time.Now().Add(d.DNSRecordsTTL),
})
return resolvedIP, port, nil
return resolvedIP, nil
}

return nil, port, fmt.Errorf("no suitable IP address found for %s", address)
return resolvedIP, fmt.Errorf("no suitable IP address found for %s", address)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
github.com/google/uuid v1.6.0
github.com/klauspost/compress v1.17.10
github.com/miekg/dns v1.1.62
github.com/paulbellamy/ratecounter v0.2.0
github.com/refraction-networking/utls v1.6.7
github.com/remeh/sizedwaitgroup v1.0.0
Expand All @@ -20,7 +21,6 @@ require (
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/cloudflare/circl v1.4.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/miekg/dns v1.1.62 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/mod v0.21.0 // indirect
Expand Down
6 changes: 0 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/klauspost/compress v1.17.10 h1:oXAz+Vh0PMUvJczoi+flxpnBEPxoER1IaAnU/NMPtT0=
github.com/klauspost/compress v1.17.10/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
Expand All @@ -38,8 +36,6 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
Expand All @@ -48,8 +44,6 @@ golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE=
golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
19 changes: 6 additions & 13 deletions random_local_ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,30 +107,23 @@ func getNextIP(availableIPs *availableIPs) net.IP {
return ip
}

func getLocalAddr(network, address string) any {
host, _, err := net.SplitHostPort(address)
if err != nil {
return nil
}

destAddr := strings.Trim(host, "[]")

destIP := net.ParseIP(destAddr)
func getLocalAddr(network, IP string) any {
destIP := net.ParseIP(strings.Trim(IP, "[]"))
if destIP == nil {
return nil
}

if destIP.To4() != nil {
if network == "tcp" {
if strings.Contains(network, "tcp") {
return &net.TCPAddr{IP: getNextIP(IPv4)}
} else if network == "udp" {
} else if strings.Contains(network, "udp") {
return &net.UDPAddr{IP: getNextIP(IPv4)}
}
return nil
} else {
if network == "tcp" {
if strings.Contains(network, "tcp") {
return &net.TCPAddr{IP: getNextIP(IPv6)}
} else if network == "udp" {
} else if strings.Contains(network, "udp") {
return &net.UDPAddr{IP: getNextIP(IPv6)}
}
return nil
Expand Down

0 comments on commit df8e50c

Please sign in to comment.