diff --git a/client.go b/client.go index 4e722bc..7b66d09 100644 --- a/client.go +++ b/client.go @@ -28,6 +28,8 @@ type HTTPClientSettings struct { FullOnDisk bool VerifyCerts bool RandomLocalIP bool + DisableIPv4 bool + DisableIPv6 bool } type CustomHTTPClient struct { @@ -147,7 +149,7 @@ func NewWARCWritingHTTPClient(HTTPClientSettings HTTPClientSettings) (httpClient httpClient.TLSHandshakeTimeout = HTTPClientSettings.TLSHandshakeTimeout // Configure custom dialer / transport - customDialer, err := newCustomDialer(httpClient, HTTPClientSettings.Proxy, HTTPClientSettings.DialTimeout) + customDialer, err := newCustomDialer(httpClient, HTTPClientSettings.Proxy, HTTPClientSettings.DialTimeout, HTTPClientSettings.DisableIPv4, HTTPClientSettings.DisableIPv6) if err != nil { return nil, err } diff --git a/client_test.go b/client_test.go index f4e6ce4..447e181 100644 --- a/client_test.go +++ b/client_test.go @@ -1,6 +1,7 @@ package warc import ( + "context" "io" "net" "net/http" @@ -1267,6 +1268,148 @@ func TestHTTPClientWithZStandardDictionary(t *testing.T) { } } +func setupIPv4Server(t *testing.T) (string, func()) { + listener, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to set up IPv4 server: %v", err) + } + + server := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("IPv4 Server")) + }), + } + + go server.Serve(listener) + + return "http://" + listener.Addr().String(), func() { + server.Shutdown(context.Background()) + } +} + +func setupIPv6Server(t *testing.T) (string, func()) { + listener, err := net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Fatalf("Failed to set up IPv6 server: %v", err) + } + + server := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("IPv6 Server")) + }), + } + + go server.Serve(listener) + + return "http://" + listener.Addr().String(), func() { + server.Shutdown(context.Background()) + } +} + +func TestHTTPClientWithIPv4Disabled(t *testing.T) { + defer goleak.VerifyNone(t) + + ipv4URL, closeIPv4 := setupIPv4Server(t) + defer closeIPv4() + + ipv6URL, closeIPv6 := setupIPv6Server(t) + defer closeIPv6() + + rotatorSettings := NewRotatorSettings() + rotatorSettings.OutputDirectory, _ = os.MkdirTemp("", "warc-tests-") + defer os.RemoveAll(rotatorSettings.OutputDirectory) + rotatorSettings.Prefix = "TESTIPv6Only" + + httpClient, err := NewWARCWritingHTTPClient(HTTPClientSettings{ + RotatorSettings: rotatorSettings, + DisableIPv4: true, + }) + if err != nil { + t.Fatalf("Unable to init WARC writing HTTP client: %s", err) + } + + // Try IPv4 - should fail + _, err = httpClient.Get(ipv4URL) + if err == nil { + t.Fatalf("Expected error when connecting to IPv4 server, but got none") + } + + // Try IPv6 - should succeed + resp, err := httpClient.Get(ipv6URL) + if err != nil { + t.Fatalf("Failed to connect to IPv6 server: %v", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if string(body) != "IPv6 Server" { + t.Fatalf("Unexpected response from IPv6 server: %s", string(body)) + } + + httpClient.Close() + + files, err := filepath.Glob(rotatorSettings.OutputDirectory + "/*") + if err != nil { + t.Fatal(err) + } + + for _, path := range files { + testFileSingleHashCheck(t, path, "sha1:RTK62UJNR5UCIPX2J64LMV7J4JJ6EXCJ", []string{"147"}, 1) + } +} + +func TestHTTPClientWithIPv6Disabled(t *testing.T) { + defer goleak.VerifyNone(t) + + ipv4URL, closeIPv4 := setupIPv4Server(t) + defer closeIPv4() + + ipv6URL, closeIPv6 := setupIPv6Server(t) + defer closeIPv6() + + rotatorSettings := NewRotatorSettings() + rotatorSettings.OutputDirectory, _ = os.MkdirTemp("", "warc-tests-") + defer os.RemoveAll(rotatorSettings.OutputDirectory) + rotatorSettings.Prefix = "TESTIPv4Only" + + httpClient, err := NewWARCWritingHTTPClient(HTTPClientSettings{ + RotatorSettings: rotatorSettings, + DisableIPv6: true, + }) + if err != nil { + t.Fatalf("Unable to init WARC writing HTTP client: %s", err) + } + + // Try IPv6 - should fail + _, err = httpClient.Get(ipv6URL) + if err == nil { + t.Fatalf("Expected error when connecting to IPv6 server, but got none") + } + + // Try IPv4 - should succeed + resp, err := httpClient.Get(ipv4URL) + if err != nil { + t.Fatalf("Failed to connect to IPv4 server: %v", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if string(body) != "IPv4 Server" { + t.Fatalf("Unexpected response from IPv4 server: %s", string(body)) + } + + httpClient.Close() + + files, err := filepath.Glob(rotatorSettings.OutputDirectory + "/*") + if err != nil { + t.Fatal(err) + } + + for _, path := range files { + testFileSingleHashCheck(t, path, "sha1:JZIRQ2YRCQ55F6SSNPTXHKMDSKJV6QFM", []string{"147"}, 1) + } +} + func BenchmarkConcurrentUnder2MB(b *testing.B) { var ( rotatorSettings = NewRotatorSettings() diff --git a/dialer.go b/dialer.go index 12d318a..85f508a 100644 --- a/dialer.go +++ b/dialer.go @@ -24,14 +24,18 @@ import ( type customDialer struct { proxyDialer proxy.Dialer client *CustomHTTPClient + disableIPv4 bool + disableIPv6 bool net.Dialer } -func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout time.Duration) (d *customDialer, err error) { +func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout time.Duration, disableIPv4, disableIPv6 bool) (d *customDialer, err error) { d = new(customDialer) d.Timeout = DialTimeout d.client = httpClient + d.disableIPv4 = disableIPv4 + d.disableIPv6 = disableIPv6 if proxyURL != "" { u, err := url.Parse(proxyURL) @@ -87,59 +91,65 @@ func (d *customDialer) wrapConnection(c net.Conn, scheme string) net.Conn { } func (d *customDialer) CustomDial(network, address string) (conn net.Conn, err error) { + // Determine the network based on IPv4/IPv6 settings + network = d.getNetworkType(network) + if network == "" { + return nil, errors.New("no supported network type available") + } + if d.proxyDialer != nil { conn, err = d.proxyDialer.Dial(network, address) - if err != nil { - return nil, err - } } else { if d.client.randomLocalIP { localAddr := getLocalAddr(network, address) if localAddr != nil { - if network == "tcp" { + if network == "tcp" || network == "tcp4" || network == "tcp6" { d.LocalAddr = localAddr.(*net.TCPAddr) - } else if network == "udp" { + } else if network == "udp" || network == "udp4" || network == "udp6" { d.LocalAddr = localAddr.(*net.UDPAddr) } } } conn, err = d.Dial(network, address) - if err != nil { - return nil, err - } + } + + if err != nil { + return nil, err } return d.wrapConnection(conn, "http"), nil } func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error) { - var ( - plainConn net.Conn - err error - ) + // Determine the network based on IPv4/IPv6 settings + network = d.getNetworkType(network) + if network == "" { + return nil, errors.New("no supported network type available") + } + + var plainConn net.Conn + var err error if d.proxyDialer != nil { plainConn, err = d.proxyDialer.Dial(network, address) - if err != nil { - return nil, err - } } else { if d.client.randomLocalIP { localAddr := getLocalAddr(network, address) if localAddr != nil { - if network == "tcp" { + if network == "tcp" || network == "tcp4" || network == "tcp6" { d.LocalAddr = localAddr.(*net.TCPAddr) - } else if network == "udp" { + } else if network == "udp" || network == "udp4" || network == "udp6" { d.LocalAddr = localAddr.(*net.UDPAddr) } } } plainConn, err = d.Dial(network, address) - if err != nil { - return nil, err - } + } + + if err != nil { + return nil, err } cfg := new(tls.Config) @@ -171,6 +181,31 @@ func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error) return d.wrapConnection(tlsConn, "https"), nil } +func (d *customDialer) getNetworkType(network string) string { + switch network { + case "tcp", "udp": + if d.disableIPv4 && !d.disableIPv6 { + return network + "6" + } + if !d.disableIPv4 && d.disableIPv6 { + return network + "4" + } + return network // Both enabled or both disabled, use default + case "tcp4", "udp4": + if d.disableIPv4 { + return "" + } + return network + case "tcp6", "udp6": + if d.disableIPv6 { + return "" + } + return network + default: + return "" // Unsupported network type + } +} + func (d *customDialer) writeWARCFromConnection(reqPipe, respPipe *io.PipeReader, scheme string, conn net.Conn) { defer d.client.WaitGroup.Done()