Skip to content

Commit

Permalink
Improve websocket tests
Browse files Browse the repository at this point in the history
Signed-off-by: Lorenzo Donini <lorenzo.donini90@gmail.com>
  • Loading branch information
lorenzodonini committed May 2, 2021
1 parent 2c8eaf1 commit efffd96
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 59 deletions.
2 changes: 1 addition & 1 deletion ws/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ func (server *Server) wsHandler(w http.ResponseWriter, r *http.Request) {
tlsConnectionState: r.TLS,
}
server.connMutex.Lock()
defer server.connMutex.Unlock()
server.connections[ws.id] = &ws
server.connMutex.Unlock()
// Read and write routines are started in separate goroutines and function will return immediately
go server.writePump(&ws)
go server.readPump(&ws)
Expand Down
173 changes: 115 additions & 58 deletions ws/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -79,39 +80,52 @@ func TestWebsocketSetConnected(t *testing.T) {

func TestWebsocketEcho(t *testing.T) {
message := []byte("Hello WebSocket!")
var wsServer *Server
wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) {
triggerC := make(chan bool, 1)
done := make(chan bool, 1)
wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) {
assert.True(t, bytes.Equal(message, data))
// Message received, notifying flow routine
triggerC <- true
return data, nil
})
go wsServer.Start(serverPort, serverPath)
time.Sleep(1 * time.Second)

// Test message
wsServer.SetNewClientHandler(func(ws Channel) {
tlsState := ws.GetTLSConnectionState()
assert.Nil(t, tlsState)
})
wsServer.SetDisconnectedClientHandler(func(ws Channel) {
// Connection closed, completing test
done <- true
})
wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) {
assert.True(t, bytes.Equal(message, data))
// Echo response received, notifying flow routine
triggerC <- true
return nil, nil
})
host := fmt.Sprintf("localhost:%v", serverPort)
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
// Wait for connection to be established, then send a message
go func() {
timer := time.NewTimer(1 * time.Second)
<-timer.C
err := wsClient.Write(message)
assert.Nil(t, err)
}()
done := make(chan bool)
// Wait for messages to be exchanged, then close connection
// Start server
go wsServer.Start(serverPort, serverPath)
// Start flow routine
go func() {
timer := time.NewTimer(3 * time.Second)
<-timer.C
// Wait for messages to be exchanged, then close connection
sig, _ := <-triggerC
assert.True(t, sig)
err := wsServer.Write(testPath, message)
require.Nil(t, err)
sig, _ = <-triggerC
assert.True(t, sig)
wsClient.Stop()
done <- true
}()
time.Sleep(200 * time.Millisecond)

// Test message
host := fmt.Sprintf("localhost:%v", serverPort)
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
err := wsClient.Start(u.String())
assert.Nil(t, err)
assert.True(t, wsClient.IsConnected())
require.NoError(t, err)
require.True(t, wsClient.IsConnected())
err = wsClient.Write(message)
require.NoError(t, err)
// Wait for echo result
result := <-done
assert.True(t, result)
// Cleanup
Expand All @@ -120,12 +134,23 @@ func TestWebsocketEcho(t *testing.T) {

func TestTLSWebsocketEcho(t *testing.T) {
message := []byte("Hello Secure WebSocket!")
var wsServer *Server
triggerC := make(chan bool, 1)
done := make(chan bool, 1)
// Use NewTLSServer() when in different package
wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) {
wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) {
assert.True(t, bytes.Equal(message, data))
// Message received, notifying flow routine
triggerC <- true
return data, nil
})
wsServer.SetNewClientHandler(func(ws Channel) {
tlsState := ws.GetTLSConnectionState()
assert.NotNil(t, tlsState)
})
wsServer.SetDisconnectedClientHandler(func(ws Channel) {
// Connection closed, completing test
done <- true
})
// Create self-signed TLS certificate
certFilename := "/tmp/cert.pem"
keyFilename := "/tmp/key.pem"
Expand All @@ -137,12 +162,11 @@ func TestTLSWebsocketEcho(t *testing.T) {
// Set self-signed TLS certificate
wsServer.tlsCertificatePath = certFilename
wsServer.tlsCertificateKey = keyFilename
go wsServer.Start(serverPort, serverPath)
time.Sleep(1 * time.Second)

// Create TLS client
wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) {
assert.True(t, bytes.Equal(message, data))
// Echo response received, notifying flow routine
triggerC <- true
return nil, nil
})
wsClient.AddOption(func(dialer *websocket.Dialer) {
Expand All @@ -155,37 +179,66 @@ func TestTLSWebsocketEcho(t *testing.T) {
RootCAs: certPool,
}
})
// Test message
host := fmt.Sprintf("localhost:%v", serverPort)
u := url.URL{Scheme: "wss", Host: host, Path: testPath}
// Wait for connection to be established, then send a message to server
go func() {
timer := time.NewTimer(1 * time.Second)
<-timer.C
err := wsClient.Write(message)
assert.Nil(t, err)
}()
done := make(chan bool)
// Wait for messages to be exchanged, then close connection

// Start server
go wsServer.Start(serverPort, serverPath)
// Start flow routine
go func() {
timer := time.NewTimer(3 * time.Second)
<-timer.C
// Wait for messages to be exchanged, then close connection
sig, _ := <-triggerC
assert.True(t, sig)
err := wsServer.Write(testPath, message)
require.NoError(t, err)
sig, _ = <-triggerC
assert.True(t, sig)
wsClient.Stop()
done <- true
}()
time.Sleep(200 * time.Millisecond)

// Test message
host := fmt.Sprintf("localhost:%v", serverPort)
u := url.URL{Scheme: "wss", Host: host, Path: testPath}
err = wsClient.Start(u.String())
assert.Nil(t, err)
require.NoError(t, err)
require.True(t, wsClient.IsConnected())
err = wsClient.Write(message)
require.NoError(t, err)
// Wait for echo result
result := <-done
assert.True(t, result)
// Cleanup
wsServer.Stop()
}

func TestServerStartErrors(t *testing.T) {
triggerC := make(chan bool, 1)
wsServer := newWebsocketServer(t, nil)
wsServer.SetNewClientHandler(func(ws Channel) {
triggerC <- true
})
// Make sure http server is initialized on start
wsServer.httpServer = nil
// Listen for errors
go func() {
err, ok := <-wsServer.Errors()
assert.True(t, ok)
assert.Error(t, err)
triggerC <- true
}()
time.Sleep(100 * time.Millisecond)
go wsServer.Start(serverPort, serverPath)
time.Sleep(100 * time.Millisecond)
// Starting server again throws error
wsServer.Start(serverPort, serverPath)
r, _ := <-triggerC
require.True(t, r)
wsServer.Stop()
}

func TestWebsocketClientConnectionBreak(t *testing.T) {
newClient := make(chan bool)
disconnected := make(chan bool)
var wsServer *Server
wsServer = newWebsocketServer(t, nil)
wsServer := newWebsocketServer(t, nil)
wsServer.SetNewClientHandler(func(ws Channel) {
newClient <- true
})
Expand Down Expand Up @@ -217,9 +270,8 @@ func TestWebsocketClientConnectionBreak(t *testing.T) {
}

func TestWebsocketServerConnectionBreak(t *testing.T) {
var wsServer *Server
disconnected := make(chan bool)
wsServer = newWebsocketServer(t, nil)
wsServer := newWebsocketServer(t, nil)
wsServer.SetNewClientHandler(func(ws Channel) {
assert.NotNil(t, ws)
conn := wsServer.connections[ws.GetID()]
Expand Down Expand Up @@ -249,7 +301,6 @@ func TestWebsocketServerConnectionBreak(t *testing.T) {
func TestValidBasicAuth(t *testing.T) {
authUsername := "testUsername"
authPassword := "testPassword"
var wsServer *Server
// Create self-signed TLS certificate
certFilename := "/tmp/cert.pem"
keyFilename := "/tmp/key.pem"
Expand All @@ -259,7 +310,7 @@ func TestValidBasicAuth(t *testing.T) {
defer os.Remove(keyFilename)

// Create TLS server with self-signed certificate
wsServer = NewTLSServer(certFilename, keyFilename, nil)
wsServer := NewTLSServer(certFilename, keyFilename, nil)
// Add basic auth handler
wsServer.SetBasicAuthHandler(func(username string, password string) bool {
require.Equal(t, authUsername, username)
Expand Down Expand Up @@ -300,7 +351,6 @@ func TestValidBasicAuth(t *testing.T) {
func TestInvalidBasicAuth(t *testing.T) {
authUsername := "testUsername"
authPassword := "testPassword"
var wsServer *Server
// Create self-signed TLS certificate
certFilename := "/tmp/cert.pem"
keyFilename := "/tmp/key.pem"
Expand All @@ -310,7 +360,7 @@ func TestInvalidBasicAuth(t *testing.T) {
defer os.Remove(keyFilename)

// Create TLS server with self-signed certificate
wsServer = NewTLSServer(certFilename, keyFilename, nil)
wsServer := NewTLSServer(certFilename, keyFilename, nil)
// Add basic auth handler
wsServer.SetBasicAuthHandler(func(username string, password string) bool {
validCredentials := authUsername == username && authPassword == password
Expand Down Expand Up @@ -338,7 +388,14 @@ func TestInvalidBasicAuth(t *testing.T) {
host := fmt.Sprintf("localhost:%v", serverPort)
u := url.URL{Scheme: "wss", Host: host, Path: testPath}
err = wsClient.Start(u.String())
// Assert HTTP error
assert.Error(t, err)
httpErr, ok := err.(HttpConnectionError)
require.True(t, ok)
assert.Equal(t, http.StatusUnauthorized, httpErr.HttpCode)
assert.Equal(t, "401 Unauthorized", httpErr.HttpStatus)
assert.Equal(t, "websocket: bad handshake", httpErr.Message)
assert.True(t, strings.Contains(err.Error(), "http status:"))
// Add basic auth
wsClient.SetBasicAuth(authUsername, "invalidPassword")
// Test connection
Expand All @@ -353,8 +410,7 @@ func TestInvalidBasicAuth(t *testing.T) {
}

func TestInvalidOriginHeader(t *testing.T) {
var wsServer *Server
wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) {
wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) {
assert.Fail(t, "no message should be received from client!")
return nil, nil
})
Expand Down Expand Up @@ -386,10 +442,9 @@ func TestInvalidOriginHeader(t *testing.T) {
}

func TestCustomOriginHeaderHandler(t *testing.T) {
var wsServer *Server
origin := "example.org"
connected := make(chan bool)
wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) {
wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) {
assert.Fail(t, "no message should be received from client!")
return nil, nil
})
Expand Down Expand Up @@ -519,7 +574,7 @@ func TestInvalidClientTLSCertificate(t *testing.T) {
})
// Run server
go wsServer.Start(serverPort, serverPath)
time.Sleep(1 * time.Second)
time.Sleep(200 * time.Millisecond)

// Create TLS client
certPool = x509.NewCertPool()
Expand All @@ -546,9 +601,8 @@ func TestInvalidClientTLSCertificate(t *testing.T) {
}

func TestUnsupportedSubprotocol(t *testing.T) {
var wsServer *Server
disconnected := make(chan bool)
wsServer = newWebsocketServer(t, nil)
wsServer := newWebsocketServer(t, nil)
wsServer.SetNewClientHandler(func(ws Channel) {
assert.Fail(t, "invalid subprotocol expected, but hit client handler instead")
t.Fail()
Expand Down Expand Up @@ -710,6 +764,9 @@ func TestServerErrors(t *testing.T) {
require.NoError(t, err)
r, _ = <-triggerC
assert.True(t, r)
// Send message to non-existing client
err = wsServer.Write("fakeId", []byte("dummy response"))
require.Error(t, err)
// Send unexpected close message and wait for error to be thrown
err = wsClient.webSocket.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, ""))
assert.NoError(t, err)
Expand Down

0 comments on commit efffd96

Please sign in to comment.