diff --git a/p2p/test/websocket/websocket_test.go b/p2p/test/websocket/websocket_test.go new file mode 100644 index 0000000000..ccadb6fffc --- /dev/null +++ b/p2p/test/websocket/websocket_test.go @@ -0,0 +1,66 @@ +package websocket_test + +import ( + "context" + "io" + "testing" + + "github.com/libp2p/go-libp2p" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/require" +) + +func TestReadLimit(t *testing.T) { + h1, err := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws")) + require.NoError(t, err) + defer h1.Close() + + ctx := context.Background() + h2, err := libp2p.New(libp2p.NoListenAddrs) + require.NoError(t, err) + defer h2.Close() + + err = h2.Connect(ctx, peer.AddrInfo{ID: h1.ID(), Addrs: h1.Addrs()}) + require.NoError(t, err) + + buf := make([]byte, 256<<10) + buf2 := make([]byte, 256<<10) + copyBuf := make([]byte, 8<<10) + + errCh := make(chan error, 1) + // TODO perf would be perfect here, but not yet merged. + h1.SetStreamHandler("/big-blocks", func(s network.Stream) { + defer s.Close() + _, err := io.CopyBuffer(io.Discard, s, copyBuf) + if err != nil { + errCh <- err + return + } + _, err = s.Write(buf) + if err != nil { + errCh <- err + return + } + errCh <- nil + }) + + allocs := testing.AllocsPerRun(100, func() { + s, err := h2.NewStream(ctx, h1.ID(), "/big-blocks") + require.NoError(t, err) + defer s.Close() + _, err = s.Write(buf2) + require.NoError(t, err) + require.NoError(t, s.CloseWrite()) + + _, err = io.ReadFull(s, buf2) + require.NoError(t, err) + + _, err = s.Read([]byte{0}) + require.ErrorIs(t, err, io.EOF) + require.NoError(t, <-errCh) + }) + + // Make sure we aren't doing some crazy allocs when transferring big blocks + require.Less(t, allocs, 8*1024.0) +} diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 70d4abbdcd..c3918c3811 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -31,7 +31,6 @@ func (c conn) Read(b []byte) (int, error) { if err == nil && n == 0 && c.readAttempts < maxReadAttempts { c.readAttempts++ // Nothing happened, let's read again. We reached the end of the frame - // we have // (https://github.com/nhooyr/websocket/blob/master/netconn.go#L118). // The next read will block until we get // the next frame. We limit here to avoid looping in case of a bunch of diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 70dbce0fea..8f2f02c31c 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + "math" "net" "net/http" "net/url" @@ -109,6 +110,10 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + // Set an arbitrarily large read limit since we don't actually want to limit the message size here. + // See https://github.com/nhooyr/websocket/issues/382 for details. + c.SetReadLimit(math.MaxInt64 - 1) // -1 because the library adds a byte for the fin frame + select { case l.incoming <- conn{ Conn: ws.NetConn(context.Background(), c, ws.MessageBinary), diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index f3a3c8709b..a78add9782 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "fmt" + "math" "net" "net/http" "net/url" @@ -250,6 +251,8 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma return nil, fmt.Errorf("failed to get local address") } + // Set an arbitrarily large read limit since we don't actually want to limit the message size here. + wscon.SetReadLimit(math.MaxInt64 - 1) // -1 because the library adds a byte for the fin frame mnc, err := manet.WrapNetConn( conn{ Conn: ws.NetConn(context.Background(), wscon, ws.MessageBinary),