Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DisableIPv4 & DisableIPv6 #47

Merged
merged 3 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type HTTPClientSettings struct {
FullOnDisk bool
VerifyCerts bool
RandomLocalIP bool
DisableIPv4 bool
DisableIPv6 bool
}

type CustomHTTPClient struct {
Expand Down Expand Up @@ -147,7 +149,7 @@ func NewWARCWritingHTTPClient(HTTPClientSettings HTTPClientSettings) (httpClient
httpClient.TLSHandshakeTimeout = HTTPClientSettings.TLSHandshakeTimeout

// Configure custom dialer / transport
customDialer, err := newCustomDialer(httpClient, HTTPClientSettings.Proxy, HTTPClientSettings.DialTimeout)
customDialer, err := newCustomDialer(httpClient, HTTPClientSettings.Proxy, HTTPClientSettings.DialTimeout, HTTPClientSettings.DisableIPv4, HTTPClientSettings.DisableIPv6)
if err != nil {
return nil, err
}
Expand Down
143 changes: 143 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package warc

import (
"context"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -1267,6 +1268,148 @@ func TestHTTPClientWithZStandardDictionary(t *testing.T) {
}
}

func setupIPv4Server(t *testing.T) (string, func()) {
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to set up IPv4 server: %v", err)
}

server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("IPv4 Server"))
}),
}

go server.Serve(listener)

return "http://" + listener.Addr().String(), func() {
server.Shutdown(context.Background())
}
}

func setupIPv6Server(t *testing.T) (string, func()) {
listener, err := net.Listen("tcp6", "[::1]:0")
if err != nil {
t.Fatalf("Failed to set up IPv6 server: %v", err)
}

server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("IPv6 Server"))
}),
}

go server.Serve(listener)

return "http://" + listener.Addr().String(), func() {
server.Shutdown(context.Background())
}
}

func TestHTTPClientWithIPv4Disabled(t *testing.T) {
defer goleak.VerifyNone(t)

ipv4URL, closeIPv4 := setupIPv4Server(t)
defer closeIPv4()

ipv6URL, closeIPv6 := setupIPv6Server(t)
defer closeIPv6()

rotatorSettings := NewRotatorSettings()
rotatorSettings.OutputDirectory, _ = os.MkdirTemp("", "warc-tests-")
defer os.RemoveAll(rotatorSettings.OutputDirectory)
rotatorSettings.Prefix = "TESTIPv6Only"

httpClient, err := NewWARCWritingHTTPClient(HTTPClientSettings{
RotatorSettings: rotatorSettings,
DisableIPv4: true,
})
if err != nil {
t.Fatalf("Unable to init WARC writing HTTP client: %s", err)
}

// Try IPv4 - should fail
_, err = httpClient.Get(ipv4URL)
if err == nil {
t.Fatalf("Expected error when connecting to IPv4 server, but got none")
}

// Try IPv6 - should succeed
resp, err := httpClient.Get(ipv6URL)
if err != nil {
t.Fatalf("Failed to connect to IPv6 server: %v", err)
}
defer resp.Body.Close()

body, _ := io.ReadAll(resp.Body)
if string(body) != "IPv6 Server" {
t.Fatalf("Unexpected response from IPv6 server: %s", string(body))
}

httpClient.Close()

files, err := filepath.Glob(rotatorSettings.OutputDirectory + "/*")
if err != nil {
t.Fatal(err)
}

for _, path := range files {
testFileSingleHashCheck(t, path, "sha1:RTK62UJNR5UCIPX2J64LMV7J4JJ6EXCJ", []string{"147"}, 1)
}
}

func TestHTTPClientWithIPv6Disabled(t *testing.T) {
defer goleak.VerifyNone(t)

ipv4URL, closeIPv4 := setupIPv4Server(t)
defer closeIPv4()

ipv6URL, closeIPv6 := setupIPv6Server(t)
defer closeIPv6()

rotatorSettings := NewRotatorSettings()
rotatorSettings.OutputDirectory, _ = os.MkdirTemp("", "warc-tests-")
defer os.RemoveAll(rotatorSettings.OutputDirectory)
rotatorSettings.Prefix = "TESTIPv4Only"

httpClient, err := NewWARCWritingHTTPClient(HTTPClientSettings{
RotatorSettings: rotatorSettings,
DisableIPv6: true,
})
if err != nil {
t.Fatalf("Unable to init WARC writing HTTP client: %s", err)
}

// Try IPv6 - should fail
_, err = httpClient.Get(ipv6URL)
if err == nil {
t.Fatalf("Expected error when connecting to IPv6 server, but got none")
}

// Try IPv4 - should succeed
resp, err := httpClient.Get(ipv4URL)
if err != nil {
t.Fatalf("Failed to connect to IPv4 server: %v", err)
}
defer resp.Body.Close()

body, _ := io.ReadAll(resp.Body)
if string(body) != "IPv4 Server" {
t.Fatalf("Unexpected response from IPv4 server: %s", string(body))
}

httpClient.Close()

files, err := filepath.Glob(rotatorSettings.OutputDirectory + "/*")
if err != nil {
t.Fatal(err)
}

for _, path := range files {
testFileSingleHashCheck(t, path, "sha1:JZIRQ2YRCQ55F6SSNPTXHKMDSKJV6QFM", []string{"147"}, 1)
}
}

func BenchmarkConcurrentUnder2MB(b *testing.B) {
var (
rotatorSettings = NewRotatorSettings()
Expand Down
77 changes: 56 additions & 21 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ import (
type customDialer struct {
proxyDialer proxy.Dialer
client *CustomHTTPClient
disableIPv4 bool
disableIPv6 bool
net.Dialer
}

func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout time.Duration) (d *customDialer, err error) {
func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout time.Duration, disableIPv4, disableIPv6 bool) (d *customDialer, err error) {
d = new(customDialer)

d.Timeout = DialTimeout
d.client = httpClient
d.disableIPv4 = disableIPv4
d.disableIPv6 = disableIPv6

if proxyURL != "" {
u, err := url.Parse(proxyURL)
Expand Down Expand Up @@ -87,59 +91,65 @@ func (d *customDialer) wrapConnection(c net.Conn, scheme string) net.Conn {
}

func (d *customDialer) CustomDial(network, address string) (conn net.Conn, err error) {
// Determine the network based on IPv4/IPv6 settings
network = d.getNetworkType(network)
if network == "" {
return nil, errors.New("no supported network type available")
}

if d.proxyDialer != nil {
conn, err = d.proxyDialer.Dial(network, address)
if err != nil {
return nil, err
}
} else {
if d.client.randomLocalIP {
localAddr := getLocalAddr(network, address)
if localAddr != nil {
if network == "tcp" {
if network == "tcp" || network == "tcp4" || network == "tcp6" {
d.LocalAddr = localAddr.(*net.TCPAddr)
} else if network == "udp" {
} else if network == "udp" || network == "udp4" || network == "udp6" {
d.LocalAddr = localAddr.(*net.UDPAddr)
}
}
}

conn, err = d.Dial(network, address)
if err != nil {
return nil, err
}
}

if err != nil {
return nil, err
}

return d.wrapConnection(conn, "http"), nil
}

func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error) {
var (
plainConn net.Conn
err error
)
// Determine the network based on IPv4/IPv6 settings
network = d.getNetworkType(network)
if network == "" {
return nil, errors.New("no supported network type available")
}

var plainConn net.Conn
var err error

if d.proxyDialer != nil {
plainConn, err = d.proxyDialer.Dial(network, address)
if err != nil {
return nil, err
}
} else {
if d.client.randomLocalIP {
localAddr := getLocalAddr(network, address)
if localAddr != nil {
if network == "tcp" {
if network == "tcp" || network == "tcp4" || network == "tcp6" {
d.LocalAddr = localAddr.(*net.TCPAddr)
} else if network == "udp" {
} else if network == "udp" || network == "udp4" || network == "udp6" {
d.LocalAddr = localAddr.(*net.UDPAddr)
}
}
}

plainConn, err = d.Dial(network, address)
if err != nil {
return nil, err
}
}

if err != nil {
return nil, err
}

cfg := new(tls.Config)
Expand Down Expand Up @@ -171,6 +181,31 @@ func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error)
return d.wrapConnection(tlsConn, "https"), nil
}

func (d *customDialer) getNetworkType(network string) string {
switch network {
case "tcp", "udp":
if d.disableIPv4 && !d.disableIPv6 {
return network + "6"
}
if !d.disableIPv4 && d.disableIPv6 {
return network + "4"
}
return network // Both enabled or both disabled, use default
case "tcp4", "udp4":
if d.disableIPv4 {
return ""
}
return network
case "tcp6", "udp6":
if d.disableIPv6 {
return ""
}
return network
default:
return "" // Unsupported network type
}
}

func (d *customDialer) writeWARCFromConnection(reqPipe, respPipe *io.PipeReader, scheme string, conn net.Conn) {
defer d.client.WaitGroup.Done()

Expand Down
Loading