Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix races in websocket and ocppj #233

Merged
merged 6 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions ocppj/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ type DefaultClientDispatcher struct {
timeout time.Duration
}

const defaultTimeoutTick = 24 * time.Hour
const defaultMessageTimeout = 30 * time.Second
const (
defaultTimeoutTick = 24 * time.Hour
defaultMessageTimeout = 30 * time.Second
)

// NewDefaultClientDispatcher creates a new DefaultClientDispatcher struct.
func NewDefaultClientDispatcher(queue RequestQueue) *DefaultClientDispatcher {
Expand All @@ -121,20 +123,22 @@ func (d *DefaultClientDispatcher) SetTimeout(timeout time.Duration) {
}

func (d *DefaultClientDispatcher) Start() {
d.mutex.Lock()
defer d.mutex.Unlock()
d.requestChannel = make(chan bool, 1)
d.timer = time.NewTimer(defaultTimeoutTick) // Default to 24 hours tick
go d.messagePump()
}

func (d *DefaultClientDispatcher) IsRunning() bool {
d.mutex.Lock()
defer d.mutex.Unlock()
d.mutex.RLock()
defer d.mutex.RUnlock()
return d.requestChannel != nil
}

func (d *DefaultClientDispatcher) IsPaused() bool {
d.mutex.Lock()
defer d.mutex.Unlock()
d.mutex.RLock()
defer d.mutex.RUnlock()
return d.paused
}

Expand All @@ -160,19 +164,30 @@ func (d *DefaultClientDispatcher) SendRequest(req RequestBundle) error {
if err := d.requestQueue.Push(req); err != nil {
return err
}
d.mutex.RLock()
d.requestChannel <- true
d.mutex.RUnlock()
return nil
}

func (d *DefaultClientDispatcher) messagePump() {
rdy := true // Ready to transmit at the beginning

reqChan := func() chan bool {
d.mutex.RLock()
defer d.mutex.RUnlock()
return d.requestChannel
}

for {
select {
case _, ok := <-d.requestChannel:
case _, ok := <-reqChan():
// New request was posted
if !ok {
d.requestQueue.Init()
d.mutex.Lock()
d.requestChannel = nil
d.mutex.Unlock()
return
}
case _, ok := <-d.timer.C:
Expand All @@ -195,14 +210,13 @@ func (d *DefaultClientDispatcher) messagePump() {
case rdy = <-d.readyForDispatch:
// Ready flag set, keep going
}

// Check if dispatcher is paused
d.mutex.Lock()
paused := d.paused
d.mutex.Unlock()
if paused {
if d.IsPaused() {
// Ignore dispatch events as long as dispatcher is paused
continue
}

// Only dispatch request if able to send and request queue isn't empty
if rdy && !d.requestQueue.IsEmpty() {
d.dispatchNextRequest()
Expand All @@ -225,7 +239,7 @@ func (d *DefaultClientDispatcher) dispatchNextRequest() {
// Attempt to send over network
err := d.network.Write(jsonMessage)
if err != nil {
//TODO: handle retransmission instead of skipping request altogether
// TODO: handle retransmission instead of skipping request altogether
d.CompleteRequest(bundle.Call.GetUniqueId())
if d.onRequestCancel != nil {
d.onRequestCancel(bundle.Call.UniqueId, bundle.Call.Payload,
Expand Down Expand Up @@ -422,7 +436,9 @@ func (d *DefaultServerDispatcher) CreateClient(clientID string) {
func (d *DefaultServerDispatcher) DeleteClient(clientID string) {
d.queueMap.Remove(clientID)
if d.IsRunning() {
andig marked this conversation as resolved.
Show resolved Hide resolved
d.mutex.RLock()
d.requestChannel <- clientID
d.mutex.RUnlock()
}
}

Expand All @@ -449,7 +465,9 @@ func (d *DefaultServerDispatcher) SendRequest(clientID string, req RequestBundle
if err := q.Push(req); err != nil {
return err
}
d.mutex.RLock()
d.requestChannel <- clientID
d.mutex.RUnlock()
return nil
}

Expand All @@ -462,6 +480,13 @@ func (d *DefaultServerDispatcher) messagePump() {
var clientCtx clientTimeoutContext
var clientQueue RequestQueue
clientContextMap := map[string]clientTimeoutContext{} // Empty at the beginning

reqChan := func() chan string {
d.mutex.RLock()
defer d.mutex.RUnlock()
return d.requestChannel
}

// Dispatcher Loop
for {
select {
Expand All @@ -470,7 +495,7 @@ func (d *DefaultServerDispatcher) messagePump() {
d.queueMap.Init()
log.Info("stopped processing requests")
return
case clientID = <-d.requestChannel:
case clientID = <-reqChan():
// Check whether there is a request queue for the specified client
clientQueue, ok = d.queueMap.Get(clientID)
if !ok {
Expand Down Expand Up @@ -530,6 +555,7 @@ func (d *DefaultServerDispatcher) messagePump() {
}
log.Debugf("%v ready to transmit again", clientID)
}

// Only dispatch request if able to send and request queue isn't empty
if rdy && clientQueue != nil && !clientQueue.IsEmpty() {
// Send request & set new context
Expand Down Expand Up @@ -559,7 +585,7 @@ func (d *DefaultServerDispatcher) dispatchNextRequest(clientID string) (clientCt
err := d.network.Write(clientID, jsonMessage)
if err != nil {
log.Errorf("error while sending message: %v", err)
//TODO: handle retransmission instead of removing pending request
// TODO: handle retransmission instead of removing pending request
d.CompleteRequest(clientID, callID)
if d.onRequestCancel != nil {
d.onRequestCancel(clientID, bundle.Call.UniqueId, bundle.Call.Payload,
Expand Down
14 changes: 8 additions & 6 deletions ws/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,10 @@ func (server *Server) AddHttpHandler(listenPath string, handler func(w http.Resp
}

func (server *Server) Start(port int, listenPath string) {

server.connMutex.Lock()
server.connections = make(map[string]*WebSocket)
server.connMutex.Unlock()

if server.httpServer == nil {
server.httpServer = &http.Server{}
}
Expand Down Expand Up @@ -440,16 +442,16 @@ func (server *Server) StopConnection(id string, closeError websocket.CloseError)
}

func (server *Server) stopConnections() {
server.connMutex.Lock()
defer server.connMutex.Unlock()
server.connMutex.RLock()
defer server.connMutex.RUnlock()
for _, conn := range server.connections {
conn.closeC <- websocket.CloseError{Code: websocket.CloseNormalClosure, Text: ""}
}
}

func (server *Server) Write(webSocketId string, data []byte) error {
server.connMutex.Lock()
defer server.connMutex.Unlock()
server.connMutex.RLock()
defer server.connMutex.RUnlock()
ws, ok := server.connections[webSocketId]
if !ok {
return fmt.Errorf("couldn't write to websocket. No socket with id %v is open", webSocketId)
Expand Down Expand Up @@ -1082,7 +1084,7 @@ func (client *Client) Start(urlStr string) error {
log.Infof("connected to server as %s", id)
client.reconnectC = make(chan struct{})
client.setConnected(true)
//Start reader and write routine
// Start reader and write routine
go client.writePump()
go client.readPump()
return nil
Expand Down