diff --git a/ws/websocket.go b/ws/websocket.go index 54b216e7..d94acda5 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -407,8 +407,8 @@ func (server *Server) wsHandler(w http.ResponseWriter, r *http.Request) { tlsConnectionState: r.TLS, } server.connMutex.Lock() - defer server.connMutex.Unlock() server.connections[ws.id] = &ws + server.connMutex.Unlock() // Read and write routines are started in separate goroutines and function will return immediately go server.writePump(&ws) go server.readPump(&ws) diff --git a/ws/websocket_test.go b/ws/websocket_test.go index 00b2f37e..571a7ba0 100644 --- a/ws/websocket_test.go +++ b/ws/websocket_test.go @@ -16,6 +16,7 @@ import ( "net/http" "net/url" "os" + "strings" "testing" "time" @@ -79,39 +80,52 @@ func TestWebsocketSetConnected(t *testing.T) { func TestWebsocketEcho(t *testing.T) { message := []byte("Hello WebSocket!") - var wsServer *Server - wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) { + triggerC := make(chan bool, 1) + done := make(chan bool, 1) + wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) { assert.True(t, bytes.Equal(message, data)) + // Message received, notifying flow routine + triggerC <- true return data, nil }) - go wsServer.Start(serverPort, serverPath) - time.Sleep(1 * time.Second) - - // Test message + wsServer.SetNewClientHandler(func(ws Channel) { + tlsState := ws.GetTLSConnectionState() + assert.Nil(t, tlsState) + }) + wsServer.SetDisconnectedClientHandler(func(ws Channel) { + // Connection closed, completing test + done <- true + }) wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { assert.True(t, bytes.Equal(message, data)) + // Echo response received, notifying flow routine + triggerC <- true return nil, nil }) - host := fmt.Sprintf("localhost:%v", serverPort) - u := url.URL{Scheme: "ws", Host: host, Path: testPath} - // Wait for connection to be established, then send a message - go func() { - timer := time.NewTimer(1 * time.Second) - <-timer.C - err := wsClient.Write(message) - assert.Nil(t, err) - }() - done := make(chan bool) - // Wait for messages to be exchanged, then close connection + // Start server + go wsServer.Start(serverPort, serverPath) + // Start flow routine go func() { - timer := time.NewTimer(3 * time.Second) - <-timer.C + // Wait for messages to be exchanged, then close connection + sig, _ := <-triggerC + assert.True(t, sig) + err := wsServer.Write(testPath, message) + require.Nil(t, err) + sig, _ = <-triggerC + assert.True(t, sig) wsClient.Stop() - done <- true }() + time.Sleep(200 * time.Millisecond) + + // Test message + host := fmt.Sprintf("localhost:%v", serverPort) + u := url.URL{Scheme: "ws", Host: host, Path: testPath} err := wsClient.Start(u.String()) - assert.Nil(t, err) - assert.True(t, wsClient.IsConnected()) + require.NoError(t, err) + require.True(t, wsClient.IsConnected()) + err = wsClient.Write(message) + require.NoError(t, err) + // Wait for echo result result := <-done assert.True(t, result) // Cleanup @@ -120,12 +134,23 @@ func TestWebsocketEcho(t *testing.T) { func TestTLSWebsocketEcho(t *testing.T) { message := []byte("Hello Secure WebSocket!") - var wsServer *Server + triggerC := make(chan bool, 1) + done := make(chan bool, 1) // Use NewTLSServer() when in different package - wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) { + wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) { assert.True(t, bytes.Equal(message, data)) + // Message received, notifying flow routine + triggerC <- true return data, nil }) + wsServer.SetNewClientHandler(func(ws Channel) { + tlsState := ws.GetTLSConnectionState() + assert.NotNil(t, tlsState) + }) + wsServer.SetDisconnectedClientHandler(func(ws Channel) { + // Connection closed, completing test + done <- true + }) // Create self-signed TLS certificate certFilename := "/tmp/cert.pem" keyFilename := "/tmp/key.pem" @@ -137,12 +162,11 @@ func TestTLSWebsocketEcho(t *testing.T) { // Set self-signed TLS certificate wsServer.tlsCertificatePath = certFilename wsServer.tlsCertificateKey = keyFilename - go wsServer.Start(serverPort, serverPath) - time.Sleep(1 * time.Second) - // Create TLS client wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { assert.True(t, bytes.Equal(message, data)) + // Echo response received, notifying flow routine + triggerC <- true return nil, nil }) wsClient.AddOption(func(dialer *websocket.Dialer) { @@ -155,37 +179,66 @@ func TestTLSWebsocketEcho(t *testing.T) { RootCAs: certPool, } }) - // Test message - host := fmt.Sprintf("localhost:%v", serverPort) - u := url.URL{Scheme: "wss", Host: host, Path: testPath} - // Wait for connection to be established, then send a message to server - go func() { - timer := time.NewTimer(1 * time.Second) - <-timer.C - err := wsClient.Write(message) - assert.Nil(t, err) - }() - done := make(chan bool) - // Wait for messages to be exchanged, then close connection + + // Start server + go wsServer.Start(serverPort, serverPath) + // Start flow routine go func() { - timer := time.NewTimer(3 * time.Second) - <-timer.C + // Wait for messages to be exchanged, then close connection + sig, _ := <-triggerC + assert.True(t, sig) + err := wsServer.Write(testPath, message) + require.NoError(t, err) + sig, _ = <-triggerC + assert.True(t, sig) wsClient.Stop() - done <- true }() + time.Sleep(200 * time.Millisecond) + + // Test message + host := fmt.Sprintf("localhost:%v", serverPort) + u := url.URL{Scheme: "wss", Host: host, Path: testPath} err = wsClient.Start(u.String()) - assert.Nil(t, err) + require.NoError(t, err) + require.True(t, wsClient.IsConnected()) + err = wsClient.Write(message) + require.NoError(t, err) + // Wait for echo result result := <-done assert.True(t, result) // Cleanup wsServer.Stop() } +func TestServerStartErrors(t *testing.T) { + triggerC := make(chan bool, 1) + wsServer := newWebsocketServer(t, nil) + wsServer.SetNewClientHandler(func(ws Channel) { + triggerC <- true + }) + // Make sure http server is initialized on start + wsServer.httpServer = nil + // Listen for errors + go func() { + err, ok := <-wsServer.Errors() + assert.True(t, ok) + assert.Error(t, err) + triggerC <- true + }() + time.Sleep(100 * time.Millisecond) + go wsServer.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) + // Starting server again throws error + wsServer.Start(serverPort, serverPath) + r, _ := <-triggerC + require.True(t, r) + wsServer.Stop() +} + func TestWebsocketClientConnectionBreak(t *testing.T) { newClient := make(chan bool) disconnected := make(chan bool) - var wsServer *Server - wsServer = newWebsocketServer(t, nil) + wsServer := newWebsocketServer(t, nil) wsServer.SetNewClientHandler(func(ws Channel) { newClient <- true }) @@ -217,9 +270,8 @@ func TestWebsocketClientConnectionBreak(t *testing.T) { } func TestWebsocketServerConnectionBreak(t *testing.T) { - var wsServer *Server disconnected := make(chan bool) - wsServer = newWebsocketServer(t, nil) + wsServer := newWebsocketServer(t, nil) wsServer.SetNewClientHandler(func(ws Channel) { assert.NotNil(t, ws) conn := wsServer.connections[ws.GetID()] @@ -249,7 +301,6 @@ func TestWebsocketServerConnectionBreak(t *testing.T) { func TestValidBasicAuth(t *testing.T) { authUsername := "testUsername" authPassword := "testPassword" - var wsServer *Server // Create self-signed TLS certificate certFilename := "/tmp/cert.pem" keyFilename := "/tmp/key.pem" @@ -259,7 +310,7 @@ func TestValidBasicAuth(t *testing.T) { defer os.Remove(keyFilename) // Create TLS server with self-signed certificate - wsServer = NewTLSServer(certFilename, keyFilename, nil) + wsServer := NewTLSServer(certFilename, keyFilename, nil) // Add basic auth handler wsServer.SetBasicAuthHandler(func(username string, password string) bool { require.Equal(t, authUsername, username) @@ -300,7 +351,6 @@ func TestValidBasicAuth(t *testing.T) { func TestInvalidBasicAuth(t *testing.T) { authUsername := "testUsername" authPassword := "testPassword" - var wsServer *Server // Create self-signed TLS certificate certFilename := "/tmp/cert.pem" keyFilename := "/tmp/key.pem" @@ -310,7 +360,7 @@ func TestInvalidBasicAuth(t *testing.T) { defer os.Remove(keyFilename) // Create TLS server with self-signed certificate - wsServer = NewTLSServer(certFilename, keyFilename, nil) + wsServer := NewTLSServer(certFilename, keyFilename, nil) // Add basic auth handler wsServer.SetBasicAuthHandler(func(username string, password string) bool { validCredentials := authUsername == username && authPassword == password @@ -338,7 +388,14 @@ func TestInvalidBasicAuth(t *testing.T) { host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "wss", Host: host, Path: testPath} err = wsClient.Start(u.String()) + // Assert HTTP error assert.Error(t, err) + httpErr, ok := err.(HttpConnectionError) + require.True(t, ok) + assert.Equal(t, http.StatusUnauthorized, httpErr.HttpCode) + assert.Equal(t, "401 Unauthorized", httpErr.HttpStatus) + assert.Equal(t, "websocket: bad handshake", httpErr.Message) + assert.True(t, strings.Contains(err.Error(), "http status:")) // Add basic auth wsClient.SetBasicAuth(authUsername, "invalidPassword") // Test connection @@ -353,8 +410,7 @@ func TestInvalidBasicAuth(t *testing.T) { } func TestInvalidOriginHeader(t *testing.T) { - var wsServer *Server - wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) { + wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) { assert.Fail(t, "no message should be received from client!") return nil, nil }) @@ -386,10 +442,9 @@ func TestInvalidOriginHeader(t *testing.T) { } func TestCustomOriginHeaderHandler(t *testing.T) { - var wsServer *Server origin := "example.org" connected := make(chan bool) - wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) { + wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) { assert.Fail(t, "no message should be received from client!") return nil, nil }) @@ -519,7 +574,7 @@ func TestInvalidClientTLSCertificate(t *testing.T) { }) // Run server go wsServer.Start(serverPort, serverPath) - time.Sleep(1 * time.Second) + time.Sleep(200 * time.Millisecond) // Create TLS client certPool = x509.NewCertPool() @@ -546,9 +601,8 @@ func TestInvalidClientTLSCertificate(t *testing.T) { } func TestUnsupportedSubprotocol(t *testing.T) { - var wsServer *Server disconnected := make(chan bool) - wsServer = newWebsocketServer(t, nil) + wsServer := newWebsocketServer(t, nil) wsServer.SetNewClientHandler(func(ws Channel) { assert.Fail(t, "invalid subprotocol expected, but hit client handler instead") t.Fail() @@ -710,6 +764,9 @@ func TestServerErrors(t *testing.T) { require.NoError(t, err) r, _ = <-triggerC assert.True(t, r) + // Send message to non-existing client + err = wsServer.Write("fakeId", []byte("dummy response")) + require.Error(t, err) // Send unexpected close message and wait for error to be thrown err = wsClient.webSocket.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, "")) assert.NoError(t, err)