diff --git a/backend/cmd/multiplexer.go b/backend/cmd/multiplexer.go index 54c6fe51e9..a3f1f326dc 100644 --- a/backend/cmd/multiplexer.go +++ b/backend/cmd/multiplexer.go @@ -68,6 +68,10 @@ type Connection struct { Done chan struct{} // mu is a mutex to synchronize access to the connection. mu sync.RWMutex + // writeMu is a mutex to synchronize access to the write operations. + writeMu sync.Mutex + // closed is a flag to indicate if the connection is closed. + closed bool } // Message represents a WebSocket message structure. @@ -81,7 +85,9 @@ type Message struct { // UserID is the ID of the user. UserID string `json:"userId"` // Data contains the message payload. - Data []byte `json:"data,omitempty"` + Data string `json:"data,omitempty"` + // Binary is a flag to indicate if the message is binary. + Binary bool `json:"binary,omitempty"` // Type is the type of the message. Type string `json:"type"` } @@ -116,41 +122,58 @@ func (c *Connection) updateStatus(state ConnectionState, err error) { c.mu.Lock() defer c.mu.Unlock() + if c.closed { + return + } + c.Status.State = state c.Status.LastMsg = time.Now() + c.Status.Error = "" if err != nil { c.Status.Error = err.Error() - } else { - c.Status.Error = "" } - if c.Client != nil { - statusData := struct { - State string `json:"state"` - Error string `json:"error"` - }{ - State: string(state), - Error: c.Status.Error, - } + if c.Client == nil { + return + } - jsonData, jsonErr := json.Marshal(statusData) - if jsonErr != nil { - logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message") + c.writeMu.Lock() + defer c.writeMu.Unlock() - return - } + // Check if connection is closed before writing + if c.closed { + return + } - statusMsg := Message{ - ClusterID: c.ClusterID, - Path: c.Path, - Data: jsonData, - } + statusData := struct { + State string `json:"state"` + Error string `json:"error"` + }{ + State: string(state), + Error: c.Status.Error, + } - err := c.Client.WriteJSON(statusMsg) - if err != nil { + jsonData, jsonErr := json.Marshal(statusData) + if jsonErr != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message") + + return + } + + statusMsg := Message{ + ClusterID: c.ClusterID, + Path: c.Path, + Data: string(jsonData), + Type: "STATUS", + } + + if err := c.Client.WriteJSON(statusMsg); err != nil { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, err, "writing status message to client") } + + c.closed = true } } @@ -190,7 +213,8 @@ func (m *Multiplexer) establishClusterConnection( connection.updateStatus(StateConnected, nil) m.mutex.Lock() - m.connections[clusterID+path] = connection + connKey := fmt.Sprintf("%s:%s:%s", clusterID, path, userID) + m.connections[connKey] = connection m.mutex.Unlock() go m.monitorConnection(connection) @@ -293,6 +317,10 @@ func (m *Multiplexer) monitorConnection(conn *Connection) { // reconnect attempts to reestablish a connection. func (m *Multiplexer) reconnect(conn *Connection) (*Connection, error) { + if conn.closed { + return nil, fmt.Errorf("cannot reconnect closed connection") + } + if conn.WSConn != nil { conn.WSConn.Close() } @@ -334,16 +362,8 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque } // Check if it's a close message - if msg.Data != nil && len(msg.Data) > 0 && string(msg.Data) == "close" { - err := m.CloseConnection(msg.ClusterID, msg.Path, msg.UserID) - if err != nil { - logger.Log( - logger.LevelError, - map[string]string{"clusterID": msg.ClusterID, "UserID": msg.UserID}, - err, - "closing connection", - ) - } + if msg.Type == "CLOSE" { + m.CloseConnection(msg.ClusterID, msg.Path, msg.UserID) continue } @@ -355,8 +375,8 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque continue } - if len(msg.Data) > 0 && conn.Status.State == StateConnected { - err = m.writeMessageToCluster(conn, msg.Data) + if msg.Type == "REQUEST" && conn.Status.State == StateConnected { + err = m.writeMessageToCluster(conn, []byte(msg.Data)) if err != nil { continue } @@ -458,100 +478,166 @@ func (m *Multiplexer) writeMessageToCluster(conn *Connection, data []byte) error // handleClusterMessages handles messages from a cluster connection. func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *websocket.Conn) { - defer func() { - conn.updateStatus(StateClosed, nil) - conn.WSConn.Close() - }() + defer m.cleanupConnection(conn) + + var lastResourceVersion string for { select { case <-conn.Done: return default: - if err := m.processClusterMessage(conn, clientConn); err != nil { + if err := m.processClusterMessage(conn, clientConn, &lastResourceVersion); err != nil { return } } } } -// processClusterMessage processes a message from a cluster connection. -func (m *Multiplexer) processClusterMessage(conn *Connection, clientConn *websocket.Conn) error { +// processClusterMessage processes a single message from the cluster. +func (m *Multiplexer) processClusterMessage( + conn *Connection, + clientConn *websocket.Conn, + lastResourceVersion *string, +) error { messageType, message, err := conn.WSConn.ReadMessage() if err != nil { - m.handleReadError(conn, err) + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + logger.Log(logger.LevelError, + map[string]string{ + "clusterID": conn.ClusterID, + "userID": conn.UserID, + }, + err, + "reading cluster message", + ) + } return err } - wrapperMsg := m.createWrapperMessage(conn, messageType, message) + if err := m.sendIfNewResourceVersion(message, conn, clientConn, lastResourceVersion); err != nil { + return err + } - if err := clientConn.WriteJSON(wrapperMsg); err != nil { - m.handleWriteError(conn, err) + return m.sendDataMessage(conn, clientConn, messageType, message) +} - return err +// sendIfNewResourceVersion checks the version of a resource from an incoming message +// and sends a complete message to the client if the resource version has changed. +// +// This function is used to ensure that the client is always aware of the latest version +// of a resource. When a new message is received, it extracts the resource version from +// the message metadata. If the resource version has changed since the last known version, +// it sends a complete message to the client to update them with the latest resource state. +// Parameters: +// - message: The JSON-encoded message containing resource information. +// - conn: The connection object representing the current connection. +// - clientConn: The WebSocket connection to the client. +// - lastResourceVersion: A pointer to the last known resource version string. +// +// Returns: +// - An error if any issues occur while processing the message, or nil if successful. +func (m *Multiplexer) sendIfNewResourceVersion( + message []byte, + conn *Connection, + clientConn *websocket.Conn, + lastResourceVersion *string, +) error { + var obj map[string]interface{} + if err := json.Unmarshal(message, &obj); err != nil { + // Ignore unmarshalling errors for resource version check. + // The message format may vary and we only care about valid resource versions. + return nil } - conn.mu.Lock() - conn.Status.LastMsg = time.Now() - conn.mu.Unlock() + if metadata, ok := obj["metadata"].(map[string]interface{}); ok { + if rv, ok := metadata["resourceVersion"].(string); ok { + if *lastResourceVersion != "" && rv != *lastResourceVersion { + return m.sendCompleteMessage(conn, clientConn) + } + + *lastResourceVersion = rv + } + } return nil } -// createWrapperMessage creates a wrapper message for a cluster connection. -func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) struct { - ClusterID string `json:"clusterId"` - Path string `json:"path"` - Query string `json:"query"` - UserID string `json:"userId"` - Data string `json:"data"` - Binary bool `json:"binary"` -} { - wrapperMsg := struct { - ClusterID string `json:"clusterId"` - Path string `json:"path"` - Query string `json:"query"` - UserID string `json:"userId"` - Data string `json:"data"` - Binary bool `json:"binary"` - }{ +// sendCompleteMessage sends a COMPLETE message to the client. +func (m *Multiplexer) sendCompleteMessage(conn *Connection, clientConn *websocket.Conn) error { + completeMsg := Message{ ClusterID: conn.ClusterID, Path: conn.Path, Query: conn.Query, UserID: conn.UserID, - Binary: messageType == websocket.BinaryMessage, + Type: "COMPLETE", } - if messageType == websocket.BinaryMessage { - wrapperMsg.Data = base64.StdEncoding.EncodeToString(message) - } else { - wrapperMsg.Data = string(message) + conn.writeMu.Lock() + defer conn.writeMu.Unlock() + + return clientConn.WriteJSON(completeMsg) +} + +// sendDataMessage sends the actual data message to the client. +func (m *Multiplexer) sendDataMessage( + conn *Connection, + clientConn *websocket.Conn, + messageType int, + message []byte, +) error { + dataMsg := m.createWrapperMessage(conn, messageType, message) + + conn.writeMu.Lock() + defer conn.writeMu.Unlock() + + if err := clientConn.WriteJSON(dataMsg); err != nil { + return err } - return wrapperMsg + conn.mu.Lock() + conn.Status.LastMsg = time.Now() + conn.mu.Unlock() + + return nil } -// handleReadError handles errors that occur when reading a message from a cluster connection. -func (m *Multiplexer) handleReadError(conn *Connection, err error) { - conn.updateStatus(StateError, err) - logger.Log( - logger.LevelError, - map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, - err, - "reading message from cluster", - ) +// cleanupConnection performs cleanup for a connection. +func (m *Multiplexer) cleanupConnection(conn *Connection) { + conn.mu.Lock() + defer conn.mu.Unlock() // Ensure the mutex is unlocked even if an error occurs + + conn.closed = true + + if conn.WSConn != nil { + conn.WSConn.Close() + } + + m.mutex.Lock() + connKey := fmt.Sprintf("%s:%s:%s", conn.ClusterID, conn.Path, conn.UserID) + delete(m.connections, connKey) + m.mutex.Unlock() } -// handleWriteError handles errors that occur when writing a message to a client connection. -func (m *Multiplexer) handleWriteError(conn *Connection, err error) { - conn.updateStatus(StateError, err) - logger.Log( - logger.LevelError, - map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, - err, - "writing message to client", - ) +// createWrapperMessage creates a wrapper message for a cluster connection. +func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) Message { + var data string + if messageType == websocket.BinaryMessage { + data = base64.StdEncoding.EncodeToString(message) + } else { + data = string(message) + } + + return Message{ + ClusterID: conn.ClusterID, + Path: conn.Path, + Query: conn.Query, + UserID: conn.UserID, + Data: data, + Binary: messageType == websocket.BinaryMessage, + Type: "DATA", + } } // cleanupConnections closes and removes all connections. @@ -587,39 +673,44 @@ func (m *Multiplexer) getClusterConfig(clusterID string) (*rest.Config, error) { } // CloseConnection closes a specific connection based on its identifier. -func (m *Multiplexer) CloseConnection(clusterID, path, userID string) error { +func (m *Multiplexer) CloseConnection(clusterID, path, userID string) { connKey := fmt.Sprintf("%s:%s:%s", clusterID, path, userID) m.mutex.Lock() - defer m.mutex.Unlock() conn, exists := m.connections[connKey] if !exists { - return fmt.Errorf("connection not found for key: %s", connKey) + m.mutex.Unlock() + // Don't log error for non-existent connections during cleanup + return } - // Signal the connection to close - close(conn.Done) + // Mark as closed before releasing the lock + conn.mu.Lock() + if conn.closed { + conn.mu.Unlock() + m.mutex.Unlock() + logger.Log(logger.LevelError, map[string]string{"clusterID": conn.ClusterID}, nil, "closing connection") - // Close the WebSocket connection - if conn.WSConn != nil { - if err := conn.WSConn.Close(); err != nil { - logger.Log( - logger.LevelError, - map[string]string{"clusterID": clusterID, "userID": userID}, - err, - "closing WebSocket connection", - ) - } + return } - // Update the connection status - conn.updateStatus(StateClosed, nil) + conn.closed = true + conn.mu.Unlock() - // Remove the connection from the map delete(m.connections, connKey) + m.mutex.Unlock() - return nil + // Lock the connection mutex before accessing shared resources + conn.mu.Lock() + defer conn.mu.Unlock() // Ensure the mutex is unlocked after the operations + + // Close the Done channel and connections after removing from map + close(conn.Done) + + if conn.WSConn != nil { + conn.WSConn.Close() + } } // createWebSocketURL creates a WebSocket URL from the given parameters. diff --git a/backend/cmd/multiplexer_test.go b/backend/cmd/multiplexer_test.go index 058e01377b..0dce3bbe47 100644 --- a/backend/cmd/multiplexer_test.go +++ b/backend/cmd/multiplexer_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" @@ -22,6 +23,7 @@ func newTestDialer() *websocket.Dialer { return &websocket.Dialer{ NetDial: net.Dial, HandshakeTimeout: 45 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec } } @@ -36,28 +38,46 @@ func TestNewMultiplexer(t *testing.T) { } func TestHandleClientWebSocket(t *testing.T) { - store := kubeconfig.NewContextStore() - m := NewMultiplexer(store) + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + // Create test server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.HandleClientWebSocket(w, r) })) defer server.Close() - url := "ws" + strings.TrimPrefix(server.URL, "http") + // Connect to test server + dialer := websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec + } - dialer := newTestDialer() + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") - conn, resp, err := dialer.Dial(url, nil) - if err == nil { - defer conn.Close() - } + ws, _, err := dialer.Dial(wsURL, nil) //nolint:bodyclose + require.NoError(t, err) - if resp != nil && resp.Body != nil { - defer resp.Body.Close() + defer ws.Close() + + // Test WATCH message + watchMsg := Message{ + Type: "WATCH", + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", } + err = ws.WriteJSON(watchMsg) + require.NoError(t, err) - assert.NoError(t, err, "Should successfully establish WebSocket connection") + // Test CLOSE message + closeMsg := Message{ + Type: "CLOSE", + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + } + err = ws.WriteJSON(closeMsg) + require.NoError(t, err) } func TestGetClusterConfigWithFallback(t *testing.T) { @@ -129,6 +149,7 @@ func TestDialWebSocket(t *testing.T) { wsURL := "ws" + strings.TrimPrefix(server.URL, "http") conn, err := m.dialWebSocket(wsURL, &tls.Config{InsecureSkipVerify: true}, server.URL) //nolint:gosec + assert.NoError(t, err) assert.NotNil(t, conn) @@ -137,6 +158,23 @@ func TestDialWebSocket(t *testing.T) { } } +func TestDialWebSocket_Errors(t *testing.T) { + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + + // Test invalid URL + tlsConfig := &tls.Config{InsecureSkipVerify: true} //nolint:gosec + + ws, err := m.dialWebSocket("invalid-url", tlsConfig, "") + assert.Error(t, err) + assert.Nil(t, ws) + + // Test unreachable URL + ws, err = m.dialWebSocket("ws://localhost:12345", tlsConfig, "") + assert.Error(t, err) + assert.Nil(t, ws) +} + func TestMonitorConnection(t *testing.T) { m := NewMultiplexer(kubeconfig.NewContextStore()) clientConn, _ := createTestWebSocketConnection() @@ -158,6 +196,93 @@ func TestMonitorConnection(t *testing.T) { assert.Equal(t, StateClosed, conn.Status.State) } +func TestUpdateStatus(t *testing.T) { + conn := &Connection{ + Status: ConnectionStatus{}, + Done: make(chan struct{}), + } + + // Test different state transitions + states := []ConnectionState{ + StateConnecting, + StateConnected, + StateClosed, + StateError, + } + + for _, state := range states { + conn.Status.State = state + assert.Equal(t, state, conn.Status.State) + } + + // Test concurrent updates + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + state := states[i%len(states)] + conn.Status.State = state + }(i) + } + wg.Wait() + + // Verify final state is valid + assert.Contains(t, states, conn.Status.State) +} + +func TestMonitorConnection_Reconnect(t *testing.T) { + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + + // Create a server that will accept the connection and then close it + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + ws, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + + defer ws.Close() + + // Keep connection alive briefly + time.Sleep(100 * time.Millisecond) + ws.Close() + })) + defer server.Close() + + conn := &Connection{ + Status: ConnectionStatus{ + State: StateConnecting, + }, + Done: make(chan struct{}), + } + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + tlsConfig := &tls.Config{InsecureSkipVerify: true} //nolint:gosec + + ws, err := m.dialWebSocket(wsURL, tlsConfig, "") + require.NoError(t, err) + + conn.WSConn = ws + + // Start monitoring in a goroutine + go m.monitorConnection(conn) + + // Wait for state transitions + time.Sleep(300 * time.Millisecond) + + // Verify connection status, it should reconnect + assert.Equal(t, StateConnecting, conn.Status.State) + + // Clean up + close(conn.Done) +} + //nolint:funlen func TestHandleClusterMessages(t *testing.T) { m := NewMultiplexer(kubeconfig.NewContextStore()) @@ -225,7 +350,7 @@ func TestHandleClusterMessages(t *testing.T) { t.Fatal("Test timed out") } - assert.Equal(t, StateClosed, conn.Status.State) + assert.Equal(t, StateConnecting, conn.Status.State) } func TestCleanupConnections(t *testing.T) { @@ -245,42 +370,198 @@ func TestCleanupConnections(t *testing.T) { assert.Equal(t, StateClosed, conn.Status.State) } -func createTestWebSocketConnection() (*websocket.Conn, *httptest.Server) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upgrader := websocket.Upgrader{} - c, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } +func TestCreateWebSocketURL(t *testing.T) { + tests := []struct { + name string + host string + path string + query string + expected string + }{ + { + name: "basic URL without query", + host: "http://localhost:8080", + path: "/api/v1/pods", + query: "", + expected: "wss://localhost:8080/api/v1/pods", + }, + { + name: "URL with query parameters", + host: "https://example.com", + path: "/api/v1/pods", + query: "watch=true", + expected: "wss://example.com/api/v1/pods?watch=true", + }, + { + name: "URL with path and multiple query parameters", + host: "https://k8s.example.com", + path: "/api/v1/namespaces/default/pods", + query: "watch=true&labelSelector=app%3Dnginx", + expected: "wss://k8s.example.com/api/v1/namespaces/default/pods?watch=true&labelSelector=app%3Dnginx", + }, + } - defer c.Close() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := createWebSocketURL(tt.host, tt.path, tt.query) + assert.Equal(t, tt.expected, result) + }) + } +} - for { - mt, message, err := c.ReadMessage() - if err != nil { - break - } +func TestGetOrCreateConnection(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) - err = c.WriteMessage(mt, message) - if err != nil { - break - } - } - })) + // Create a mock Kubernetes API server + mockServer := createMockKubeAPIServer() + defer mockServer.Close() - wsURL := "ws" + strings.TrimPrefix(server.URL, "http") - dialer := newTestDialer() + // Add a mock cluster config with our test server URL + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: mockServer.URL, + InsecureSkipTLSVerify: true, + CertificateAuthorityData: nil, + }, + }) + require.NoError(t, err) - ws, resp, err := dialer.Dial(wsURL, nil) - if err != nil { - panic(err) + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + // Test getting a non-existent connection (should create new) + msg := Message{ + ClusterID: "test-cluster", + Path: "/api/v1/pods", + Query: "watch=true", + UserID: "test-user", } - if resp != nil && resp.Body != nil { - defer resp.Body.Close() + conn, err := m.getOrCreateConnection(msg, clientConn) + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.Equal(t, "test-cluster", conn.ClusterID) + assert.Equal(t, "test-user", conn.UserID) + assert.Equal(t, "/api/v1/pods", conn.Path) + assert.Equal(t, "watch=true", conn.Query) + + // Test getting an existing connection + conn2, err := m.getOrCreateConnection(msg, clientConn) + assert.NoError(t, err) + assert.Equal(t, conn, conn2, "Should return the same connection instance") + + // Test with invalid cluster + msg.ClusterID = "non-existent-cluster" + conn3, err := m.getOrCreateConnection(msg, clientConn) + assert.Error(t, err) + assert.Nil(t, conn3) +} + +func TestEstablishClusterConnection(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + // Create a mock Kubernetes API server + mockServer := createMockKubeAPIServer() + defer mockServer.Close() + + // Add a mock cluster config with our test server URL + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: mockServer.URL, + InsecureSkipTLSVerify: true, + CertificateAuthorityData: nil, + }, + }) + require.NoError(t, err) + + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + // Test successful connection establishment + conn, err := m.establishClusterConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.Equal(t, "test-cluster", conn.ClusterID) + assert.Equal(t, "test-user", conn.UserID) + assert.Equal(t, "/api/v1/pods", conn.Path) + assert.Equal(t, "watch=true", conn.Query) + + // Test with invalid cluster + conn, err = m.establishClusterConnection("non-existent", "test-user", "/api/v1/pods", "watch=true", clientConn) + assert.Error(t, err) + assert.Nil(t, conn) +} + +//nolint:funlen +func TestReconnect(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + // Create a mock Kubernetes API server + mockServer := createMockKubeAPIServer() + defer mockServer.Close() + + // Add a mock cluster config with our test server URL + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: mockServer.URL, + InsecureSkipTLSVerify: true, + CertificateAuthorityData: nil, + }, + }) + require.NoError(t, err) + + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + // Create initial connection + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + conn.Status.State = StateError // Simulate an error state + + // Test successful reconnection + newConn, err := m.reconnect(conn) + assert.NoError(t, err) + assert.NotNil(t, newConn) + assert.Equal(t, StateConnected, newConn.Status.State) + assert.Equal(t, conn.ClusterID, newConn.ClusterID) + assert.Equal(t, conn.UserID, newConn.UserID) + assert.Equal(t, conn.Path, newConn.Path) + assert.Equal(t, conn.Query, newConn.Query) + + // Test reconnection with invalid cluster + conn.ClusterID = "non-existent" + newConn, err = m.reconnect(conn) + assert.Error(t, err) + assert.Nil(t, newConn) + assert.Contains(t, err.Error(), "getting context: key not found") + + // Test reconnection with closed connection + conn = m.createConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + clusterConn, err := m.establishClusterConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + require.NoError(t, err) + require.NotNil(t, clusterConn) + + // Close the connection and wait for cleanup + conn.closed = true + if conn.WSConn != nil { + conn.WSConn.Close() } - return ws, server + if conn.Client != nil { + conn.Client.Close() + } + + close(conn.Done) + + // Try to reconnect the closed connection + newConn, err = m.reconnect(conn) + assert.Error(t, err) + assert.Nil(t, newConn) } func TestCloseConnection(t *testing.T) { @@ -292,14 +573,10 @@ func TestCloseConnection(t *testing.T) { connKey := "test-cluster:/api/v1/pods:test-user" m.connections[connKey] = conn - err := m.CloseConnection("test-cluster", "/api/v1/pods", "test-user") - assert.NoError(t, err) + m.CloseConnection("test-cluster", "/api/v1/pods", "test-user") assert.Empty(t, m.connections) - assert.Equal(t, StateClosed, conn.Status.State) - - // Test closing a non-existent connection - err = m.CloseConnection("non-existent", "/api/v1/pods", "test-user") - assert.Error(t, err) + // It will reconnect to the cluster + assert.Equal(t, StateConnecting, conn.Status.State) } func TestCreateWrapperMessage(t *testing.T) { @@ -424,3 +701,243 @@ func TestWriteMessageToCluster(t *testing.T) { assert.Error(t, err) assert.Equal(t, StateError, conn.Status.State) } + +func TestSendCompleteMessage(t *testing.T) { + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + + // Create a server that will forcibly close the connection + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + + ws, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + + defer ws.Close() + + // Read one message then close + _, _, _ = ws.ReadMessage() + err = ws.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + time.Now().Add(time.Second)) + + ws.Close() + require.NoError(t, err) + })) + + defer server.Close() + + // Connect to the server + dialer := websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec + } + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + clientConn, _, err := dialer.Dial(wsURL, nil) //nolint:bodyclose + require.NoError(t, err) + + defer clientConn.Close() + + conn := &Connection{ + Status: ConnectionStatus{}, + Done: make(chan struct{}), + WSConn: clientConn, + } + + // Test sending complete message + message := []byte(`{"type":"ADDED","object":{"metadata":{"resourceVersion":"123"}}}`) + resourceVersion := "" + + // send should succeed + err = m.sendIfNewResourceVersion(message, conn, clientConn, &resourceVersion) + require.NoError(t, err) + + // Wait for server to close connection + time.Sleep(100 * time.Millisecond) + + close(conn.Done) +} + +//nolint:funlen +func TestReadClientMessage_InvalidMessage(t *testing.T) { + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + + // Create a server that will echo messages back + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + ws, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + defer ws.Close() + + // Echo messages back + for { + messageType, p, err := ws.ReadMessage() + if err != nil { + return + } + err = ws.WriteMessage(messageType, p) + if err != nil { + return + } + } + })) + defer server.Close() + + // Connect to the server + dialer := websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec + } + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + clientConn, _, err := dialer.Dial(wsURL, nil) //nolint:bodyclose + require.NoError(t, err) + + defer clientConn.Close() + + // Test completely invalid JSON + err = clientConn.WriteMessage(websocket.TextMessage, []byte("not json at all")) + require.NoError(t, err) + + msg, err := m.readClientMessage(clientConn) + require.Error(t, err) + assert.Equal(t, Message{}, msg) + + // Test JSON with invalid data type + err = clientConn.WriteJSON(map[string]interface{}{ + "type": "INVALID", + "data": 123, // data should be string + }) + require.NoError(t, err) + + msg, err = m.readClientMessage(clientConn) + require.Error(t, err) + assert.Equal(t, Message{}, msg) + + // Test empty JSON object + err = clientConn.WriteMessage(websocket.TextMessage, []byte("{}")) + require.NoError(t, err) + + msg, err = m.readClientMessage(clientConn) + // Empty message is valid JSON but will be unmarshaled into an empty Message struct + require.NoError(t, err) + assert.Equal(t, Message{}, msg) + + // Test missing required fields + err = clientConn.WriteJSON(map[string]interface{}{ + "data": "some data", + // Missing type field + }) + require.NoError(t, err) + + msg, err = m.readClientMessage(clientConn) + // Missing fields are allowed by json.Unmarshal + require.NoError(t, err) + assert.Equal(t, Message{Data: "some data"}, msg) +} + +// createMockKubeAPIServer creates a test TLS server that simulates a Kubernetes API server +// with WebSocket support. The server accepts WebSocket connections and echoes back any +// messages it receives. +// +// The server is configured with a self-signed TLS certificate and is set to accept +// insecure connections for testing purposes. This matches the behavior of a real +// Kubernetes API server but with simplified functionality. +// +// Returns: +// - *httptest.Server: A running test server that can be used to simulate +// Kubernetes API WebSocket connections. The caller is responsible for calling +// Close() when done. +func createMockKubeAPIServer() *httptest.Server { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + defer c.Close() + + // Echo messages back + for { + _, msg, err := c.ReadMessage() + if err != nil { + break + } + if err := c.WriteMessage(websocket.TextMessage, msg); err != nil { + break + } + } + })) + + // Configure the test client to accept the test server's TLS certificate + server.Client().Transport.(*http.Transport).TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec + } + + return server +} + +// createTestWebSocketConnection creates a test WebSocket server and establishes +// a connection to it. This helper function is used for testing WebSocket-related +// functionality without requiring a real server. +// +// The function sets up an echo server that reflects back any messages it receives +// and establishes a WebSocket connection to it using a test dialer. The connection +// is configured for testing with default options. +// +// Returns: +// - *websocket.Conn: An established WebSocket connection to the test server +// - *httptest.Server: The test server instance. The caller must call Close() +// on both the connection and server when done. +// +// Panics if the connection cannot be established, which is acceptable for test code. +func createTestWebSocketConnection() (*websocket.Conn, *httptest.Server) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + defer c.Close() + + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + dialer := newTestDialer() + + ws, resp, err := dialer.Dial(wsURL, nil) + if err != nil { + panic(err) + } + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + return ws, server +}