Skip to content

Commit

Permalink
websocket: set the HTTP host header in WSS
Browse files Browse the repository at this point in the history
* Send host header

Co-authored-by: Thibault Meunier <thibault@cloudflare.com>

* Add comment and use splithostport

* Return error

* Defer the close

Co-authored-by: Thibault Meunier <thibault@cloudflare.com>
  • Loading branch information
2 people authored and marten-seemann committed Nov 2, 2022
1 parent 90b2d5d commit 771a814
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
2 changes: 1 addition & 1 deletion p2p/transport/websocket/addrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func parseMultiaddr(maddr ma.Multiaddr) (*url.URL, error) {

type parsedWebsocketMultiaddr struct {
isWSS bool
// sni is the SNI value for the TLS handshake
// sni is the SNI value for the TLS handshake, and for setting HTTP Host header
sni *ma.Component
// the rest of the multiaddr before the /tls/sni/example.com/ws or /ws or /wss
restMultiaddr ma.Multiaddr
Expand Down
12 changes: 12 additions & 0 deletions p2p/transport/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package websocket
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"

Expand Down Expand Up @@ -186,6 +187,17 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
copytlsClientConf := t.tlsClientConf.Clone()
copytlsClientConf.ServerName = sni
dialer.TLSClientConfig = copytlsClientConf
ipAddr := wsurl.Host
// Setting the NetDial because we already have the resolved IP address, so we don't want to do another resolution.
// We set the `.Host` to the sni field so that the host header gets properly set.
dialer.NetDial = func(network, address string) (net.Conn, error) {
tcpAddr, err := net.ResolveTCPAddr(network, ipAddr)
if err != nil {
return nil, err
}
return net.DialTCP("tcp", nil, tcpAddr)
}
wsurl.Host = sni + ":" + wsurl.Port()
} else {
dialer.TLSClientConfig = t.tlsClientConf
}
Expand Down
41 changes: 41 additions & 0 deletions p2p/transport/websocket/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"io"
"math/big"
"net"
"net/http"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -218,6 +221,44 @@ func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
}
}

func TestHostHeaderWss(t *testing.T) {
server := &http.Server{}
l, err := net.Listen("tcp", ":0")
require.NoError(t, err)
defer server.Close()

errChan := make(chan error, 1)
go func() {
server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(errChan)
if !strings.Contains(r.Host, "example.com") {
errChan <- errors.New("Didn't see host header")
}
w.WriteHeader(http.StatusNotFound)
})
server.TLSConfig = getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(time.Hour))
server.ServeTLS(l, "", "")
}()

_, port, err := net.SplitHostPort(l.Addr().String())
require.NoError(t, err)
serverMA := ma.StringCast("/ip4/127.0.0.1/tcp/" + port + "/tls/sni/example.com/ws")

tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
_, u := newSecureUpgrader(t)
tpt, err := New(u, network.NullResourceManager, WithTLSClientConfig(tlsConfig))
require.NoError(t, err)

masToDial, err := tpt.Resolve(context.Background(), serverMA)
require.NoError(t, err)

_, err = tpt.Dial(context.Background(), masToDial[0], test.RandPeerIDFatal(t))
require.Error(t, err)

err = <-errChan
require.NoError(t, err)
}

func TestDialWss(t *testing.T) {
serverMA, rid, errChan := testWSSServer(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/sni/example.com/ws"))
require.Contains(t, serverMA.String(), "tls")
Expand Down

0 comments on commit 771a814

Please sign in to comment.