Skip to content

Commit

Permalink
backend: Fix panic of websockets
Browse files Browse the repository at this point in the history
Now websocket has clear type that is needs to sends. This also fixes
panic of websocket in various edge cases.

Signed-off-by: Kautilya Tripathi <[email protected]>
  • Loading branch information
knrt10 committed Nov 29, 2024
1 parent 0cf0c99 commit d72ca38
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 106 deletions.
278 changes: 178 additions & 100 deletions backend/cmd/multiplexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ type Connection struct {
Done chan struct{}
// mu is a mutex to synchronize access to the connection.
mu sync.RWMutex
// writeMu is a mutex to synchronize access to the write operations.
writeMu sync.Mutex
// closed is a flag to indicate if the connection is closed.
closed bool
}

// Message represents a WebSocket message structure.
Expand All @@ -81,7 +85,9 @@ type Message struct {
// UserID is the ID of the user.
UserID string `json:"userId"`
// Data contains the message payload.
Data []byte `json:"data,omitempty"`
Data string `json:"data,omitempty"`
// Binary is a flag to indicate if the message is binary.
Binary bool `json:"binary,omitempty"`
// Type is the type of the message.
Type string `json:"type"`
}
Expand Down Expand Up @@ -116,41 +122,58 @@ func (c *Connection) updateStatus(state ConnectionState, err error) {
c.mu.Lock()
defer c.mu.Unlock()

if c.closed {
return
}

c.Status.State = state
c.Status.LastMsg = time.Now()
c.Status.Error = ""

if err != nil {
c.Status.Error = err.Error()
} else {
c.Status.Error = ""
}

if c.Client != nil {
statusData := struct {
State string `json:"state"`
Error string `json:"error"`
}{
State: string(state),
Error: c.Status.Error,
}
if c.Client == nil {
return
}

jsonData, jsonErr := json.Marshal(statusData)
if jsonErr != nil {
logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message")
c.writeMu.Lock()
defer c.writeMu.Unlock()

return
}
// Check if connection is closed before writing
if c.closed {
return
}

statusMsg := Message{
ClusterID: c.ClusterID,
Path: c.Path,
Data: jsonData,
}
statusData := struct {
State string `json:"state"`
Error string `json:"error"`
}{
State: string(state),
Error: c.Status.Error,
}

err := c.Client.WriteJSON(statusMsg)
if err != nil {
jsonData, jsonErr := json.Marshal(statusData)
if jsonErr != nil {
logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message")

return
}

statusMsg := Message{
ClusterID: c.ClusterID,
Path: c.Path,
Data: string(jsonData),
Type: "STATUS",
}

if err := c.Client.WriteJSON(statusMsg); err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, err, "writing status message to client")
}

c.closed = true
}
}

Expand Down Expand Up @@ -190,7 +213,8 @@ func (m *Multiplexer) establishClusterConnection(
connection.updateStatus(StateConnected, nil)

m.mutex.Lock()
m.connections[clusterID+path] = connection
connKey := fmt.Sprintf("%s:%s:%s", clusterID, path, userID)
m.connections[connKey] = connection
m.mutex.Unlock()

go m.monitorConnection(connection)
Expand Down Expand Up @@ -334,7 +358,7 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque
}

// Check if it's a close message
if msg.Data != nil && len(msg.Data) > 0 && string(msg.Data) == "close" {
if msg.Type == "CLOSE" {
err := m.CloseConnection(msg.ClusterID, msg.Path, msg.UserID)
if err != nil {
logger.Log(
Expand All @@ -355,8 +379,8 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque
continue
}

if len(msg.Data) > 0 && conn.Status.State == StateConnected {
err = m.writeMessageToCluster(conn, msg.Data)
if msg.Type == "REQUEST" && conn.Status.State == StateConnected {
err = m.writeMessageToCluster(conn, []byte(msg.Data))
if err != nil {
continue
}
Expand Down Expand Up @@ -458,100 +482,149 @@ func (m *Multiplexer) writeMessageToCluster(conn *Connection, data []byte) error

// handleClusterMessages handles messages from a cluster connection.
func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *websocket.Conn) {
defer func() {
conn.updateStatus(StateClosed, nil)
conn.WSConn.Close()
}()
defer m.cleanupConnection(conn)

var lastResourceVersion string

for {
select {
case <-conn.Done:
return
default:
if err := m.processClusterMessage(conn, clientConn); err != nil {
if err := m.processClusterMessage(conn, clientConn, &lastResourceVersion); err != nil {
return
}
}
}
}

// processClusterMessage processes a message from a cluster connection.
func (m *Multiplexer) processClusterMessage(conn *Connection, clientConn *websocket.Conn) error {
// processClusterMessage processes a single message from the cluster.
func (m *Multiplexer) processClusterMessage(
conn *Connection,
clientConn *websocket.Conn,
lastResourceVersion *string,
) error {
messageType, message, err := conn.WSConn.ReadMessage()
if err != nil {
m.handleReadError(conn, err)
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
logger.Log(logger.LevelError,
map[string]string{
"clusterID": conn.ClusterID,
"userID": conn.UserID,
},
err,
"reading cluster message",
)
}

return err
}

wrapperMsg := m.createWrapperMessage(conn, messageType, message)
if err := m.checkResourceVersion(message, conn, clientConn, lastResourceVersion); err != nil {
return err
}

if err := clientConn.WriteJSON(wrapperMsg); err != nil {
m.handleWriteError(conn, err)
return m.sendDataMessage(conn, clientConn, messageType, message)
}

return err
// checkResourceVersion checks and handles resource version changes.
func (m *Multiplexer) checkResourceVersion(
message []byte,
conn *Connection,
clientConn *websocket.Conn,
lastResourceVersion *string,
) error {
var obj map[string]interface{}
if err := json.Unmarshal(message, &obj); err != nil {
return nil // Ignore unmarshalling errors for resource version check
}

conn.mu.Lock()
conn.Status.LastMsg = time.Now()
conn.mu.Unlock()
if metadata, ok := obj["metadata"].(map[string]interface{}); ok {
if rv, ok := metadata["resourceVersion"].(string); ok {
if *lastResourceVersion != "" && rv != *lastResourceVersion {
return m.sendCompleteMessage(conn, clientConn)
}

*lastResourceVersion = rv
}
}

return nil
}

// createWrapperMessage creates a wrapper message for a cluster connection.
func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) struct {
ClusterID string `json:"clusterId"`
Path string `json:"path"`
Query string `json:"query"`
UserID string `json:"userId"`
Data string `json:"data"`
Binary bool `json:"binary"`
} {
wrapperMsg := struct {
ClusterID string `json:"clusterId"`
Path string `json:"path"`
Query string `json:"query"`
UserID string `json:"userId"`
Data string `json:"data"`
Binary bool `json:"binary"`
}{
// sendCompleteMessage sends a COMPLETE message to the client.
func (m *Multiplexer) sendCompleteMessage(conn *Connection, clientConn *websocket.Conn) error {
completeMsg := Message{
ClusterID: conn.ClusterID,
Path: conn.Path,
Query: conn.Query,
UserID: conn.UserID,
Binary: messageType == websocket.BinaryMessage,
Type: "COMPLETE",
}

if messageType == websocket.BinaryMessage {
wrapperMsg.Data = base64.StdEncoding.EncodeToString(message)
} else {
wrapperMsg.Data = string(message)
conn.writeMu.Lock()
defer conn.writeMu.Unlock()

return clientConn.WriteJSON(completeMsg)
}

// sendDataMessage sends the actual data message to the client.
func (m *Multiplexer) sendDataMessage(
conn *Connection,
clientConn *websocket.Conn,
messageType int,
message []byte,
) error {
dataMsg := m.createWrapperMessage(conn, messageType, message)

conn.writeMu.Lock()
defer conn.writeMu.Unlock()

if err := clientConn.WriteJSON(dataMsg); err != nil {
return err
}

return wrapperMsg
conn.mu.Lock()
conn.Status.LastMsg = time.Now()
conn.mu.Unlock()

return nil
}

// handleReadError handles errors that occur when reading a message from a cluster connection.
func (m *Multiplexer) handleReadError(conn *Connection, err error) {
conn.updateStatus(StateError, err)
logger.Log(
logger.LevelError,
map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID},
err,
"reading message from cluster",
)
// cleanupConnection performs cleanup for a connection.
func (m *Multiplexer) cleanupConnection(conn *Connection) {
conn.mu.Lock()
conn.closed = true
conn.mu.Unlock()

if conn.WSConn != nil {
conn.WSConn.Close()
}

m.mutex.Lock()
connKey := fmt.Sprintf("%s:%s:%s", conn.ClusterID, conn.Path, conn.UserID)
delete(m.connections, connKey)
m.mutex.Unlock()
}

// handleWriteError handles errors that occur when writing a message to a client connection.
func (m *Multiplexer) handleWriteError(conn *Connection, err error) {
conn.updateStatus(StateError, err)
logger.Log(
logger.LevelError,
map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID},
err,
"writing message to client",
)
// createWrapperMessage creates a wrapper message for a cluster connection.
func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) Message {
var data string
if messageType == websocket.BinaryMessage {
data = base64.StdEncoding.EncodeToString(message)
} else {
data = string(message)
}

return Message{
ClusterID: conn.ClusterID,
Path: conn.Path,
Query: conn.Query,
UserID: conn.UserID,
Data: data,
Binary: messageType == websocket.BinaryMessage,
Type: "DATA",
}
}

// cleanupConnections closes and removes all connections.
Expand Down Expand Up @@ -587,37 +660,42 @@ func (m *Multiplexer) getClusterConfig(clusterID string) (*rest.Config, error) {
}

// CloseConnection closes a specific connection based on its identifier.
//
//nolint:unparam
func (m *Multiplexer) CloseConnection(clusterID, path, userID string) error {
connKey := fmt.Sprintf("%s:%s:%s", clusterID, path, userID)

m.mutex.Lock()
defer m.mutex.Unlock()

conn, exists := m.connections[connKey]
if !exists {
return fmt.Errorf("connection not found for key: %s", connKey)
m.mutex.Unlock()
// Don't log error for non-existent connections during cleanup
return nil
}

// Signal the connection to close
close(conn.Done)
// Mark as closed before releasing the lock
conn.mu.Lock()
if conn.closed {
conn.mu.Unlock()
m.mutex.Unlock()
logger.Log(logger.LevelError, map[string]string{"clusterID": conn.ClusterID}, nil, "closing connection")

// Close the WebSocket connection
if conn.WSConn != nil {
if err := conn.WSConn.Close(); err != nil {
logger.Log(
logger.LevelError,
map[string]string{"clusterID": clusterID, "userID": userID},
err,
"closing WebSocket connection",
)
}
return nil
}

// Update the connection status
conn.updateStatus(StateClosed, nil)
conn.closed = true
conn.mu.Unlock()

// Remove the connection from the map
delete(m.connections, connKey)
m.mutex.Unlock()

// Close the Done channel and connections after removing from map
close(conn.Done)

if conn.WSConn != nil {
conn.WSConn.Close()
}

return nil
}
Expand Down
Loading

0 comments on commit d72ca38

Please sign in to comment.