diff --git a/pkg/auditserver/server_test.go b/pkg/auditserver/server_test.go index 6d1db5f..3d95d19 100644 --- a/pkg/auditserver/server_test.go +++ b/pkg/auditserver/server_test.go @@ -3,12 +3,15 @@ package auditserver import ( "bytes" "encoding/json" - "github.com/panjf2000/gnet" - "github.com/redis/go-redis/v9" - "log/slog" + "io/ioutil" "net" "os" "testing" + "time" + + "github.com/expr-lang/expr" + "github.com/spf13/viper" + "log/slog" ) // mockConn is a mock implementation of gnet.Conn @@ -37,194 +40,216 @@ func (m *mockConn) Peek(n int) (buf []byte, err error) { return nil, nil } func (m *mockConn) Next(n int) (buf []byte, err error) { return nil, nil } func TestAuditServer_React(t *testing.T) { - tests := []struct { - name string - input AuditLog - expectedAction gnet.Action - expectedLog bool - }{ + // Create a temporary directory for log files + tempDir := t.TempDir() + + // Define rule group configurations + ruleGroupConfigs := []RuleGroupConfig{ { - name: "Valid KV update operation", - input: AuditLog{ - Type: "audit", - Time: "2023-07-31T12:34:56Z", - Auth: Auth{ - PolicyResults: struct { - Allowed bool `json:"allowed"` - }{Allowed: true}, - }, - Request: Request{ - Operation: "update", - MountType: "kv", - Path: "/secret/data/test", - }, - Response: Response{ - MountType: "kv", - }, - RemoteAddr: "192.168.1.1", + Name: "normal_operations", + Rules: []string{ + `Request.Operation in ["read", "update"] && Request.Path startsWith "secret/data/" && Auth.PolicyResults.Allowed == true`, }, - expectedAction: gnet.None, - expectedLog: true, - }, - { - name: "Valid KV create operation", - input: AuditLog{ - Type: "audit", - Time: "2023-07-31T12:34:56Z", - Auth: Auth{ - PolicyResults: struct { - Allowed bool `json:"allowed"` - }{Allowed: true}, - }, - Request: Request{ - Operation: "create", - MountType: "kv", - Path: "/secret/data/test", - }, - Response: Response{ - MountType: "kv", - }, - RemoteAddr: "192.168.1.1", + LogFile: LogFileConfig{ + FilePath: tempDir + "/normal_operations.log", + MaxSize: 1, + MaxBackups: 1, + MaxAge: 1, + Compress: false, }, - expectedAction: gnet.None, - expectedLog: true, }, { - name: "Valid KV delete operation", - input: AuditLog{ - Type: "audit", - Time: "2023-07-31T12:34:56Z", - Auth: Auth{ - PolicyResults: struct { - Allowed bool `json:"allowed"` - }{Allowed: true}, - }, - Request: Request{ - Operation: "delete", - MountType: "kv", - Path: "/secret/data/test", - }, - Response: Response{ - MountType: "kv", - }, - RemoteAddr: "192.168.1.1", + Name: "critical_events", + Rules: []string{ + `Request.Operation == "delete" && Auth.PolicyResults.Allowed == true`, + `Request.Path startsWith "secret/metadata/" && Auth.PolicyResults.Allowed == true`, }, - expectedAction: gnet.None, - expectedLog: true, - }, - { - name: "Non-KV operation", - input: AuditLog{ - Type: "audit", - Time: "2023-07-31T12:34:56Z", - Auth: Auth{ - PolicyResults: struct { - Allowed bool `json:"allowed"` - }{Allowed: true}, - }, - Request: Request{ - Operation: "update", - MountType: "transit", - Path: "/transit/keys/test", - }, - Response: Response{ - MountType: "transit", - }, - RemoteAddr: "192.168.1.1", + LogFile: LogFileConfig{ + FilePath: tempDir + "/critical_events.log", + MaxSize: 1, + MaxBackups: 1, + MaxAge: 1, + Compress: false, }, - expectedAction: gnet.Close, - expectedLog: false, }, + } + + // Initialize viper with the rule group configurations + viper.Set("rule_groups", ruleGroupConfigs) + + // Create the AuditServer + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})) + as := New(logger) + + tests := []struct { + name string + input AuditLog + expectedLogs map[string]bool // Map of log file names to whether they should contain the log + }{ { - name: "Disallowed operation", + name: "Normal operation - update", input: AuditLog{ - Type: "audit", - Time: "2023-07-31T12:34:56Z", + Type: "request", + Time: "2024-09-17T13:00:00Z", Auth: Auth{ + DisplayName: "user1", + Policies: []string{"default", "writer"}, PolicyResults: struct { - Allowed bool `json:"allowed"` - }{Allowed: false}, + Allowed bool `json:"allowed"` + GrantingPolicies []struct { + Name string `json:"name"` + NamespaceID string `json:"namespace_id"` + Type string `json:"type"` + } `json:"granting_policies"` + }{ + Allowed: true, + }, }, Request: Request{ Operation: "update", - MountType: "kv", - Path: "/secret/data/test", + Path: "secret/data/myapp/config", }, - Response: Response{ - MountType: "kv", - }, - RemoteAddr: "192.168.1.1", }, - expectedAction: gnet.Close, - expectedLog: false, + expectedLogs: map[string]bool{ + tempDir + "/normal_operations.log": true, + tempDir + "/critical_events.log": false, + }, }, + // Add more test cases as needed } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, &slog.HandlerOptions{Level: slog.LevelInfo})) - - server := New(logger, nil) - - inputJSON, err := json.Marshal(tt.input) + // Serialize the audit log to JSON + frame, err := json.Marshal(tt.input) if err != nil { - t.Fatalf("Failed to marshal input: %v", err) + t.Fatalf("Failed to marshal audit log: %v", err) } - _, action := server.React(inputJSON, &mockConn{}) - - t.Logf("Test case: %s", tt.name) - t.Logf("Input: %s", string(inputJSON)) - t.Logf("Action: %v", action) - t.Logf("Log buffer: %s", logBuffer.String()) + // Call React + as.React(frame, &mockConn{}) - if action != tt.expectedAction { - t.Errorf("Expected action %v, but got %v", tt.expectedAction, action) - } + // Give some time for the log to be written + time.Sleep(100 * time.Millisecond) - if tt.expectedLog { - if logBuffer.Len() == 0 { - t.Errorf("Expected log output, but got none") - } else { - var logEntry map[string]interface{} - err := json.Unmarshal(logBuffer.Bytes(), &logEntry) - if err != nil { - t.Fatalf("Failed to parse log output: %v", err) + // Check log files + for logFile, shouldContain := range tt.expectedLogs { + content, err := ioutil.ReadFile(logFile) + if err != nil { + if os.IsNotExist(err) && !shouldContain { + // File doesn't exist as expected + continue } + t.Fatalf("Failed to read log file '%s': %v", logFile, err) + } - expectedFields := []string{"operation", "path"} - for _, field := range expectedFields { - if _, ok := logEntry[field]; !ok { - t.Errorf("Expected '%s' field in log, but it was missing", field) - } + if shouldContain { + if !bytes.Contains(content, frame) { + t.Errorf("Expected log file '%s' to contain the audit log", logFile) + } + } else { + if len(content) > 0 { + t.Errorf("Expected log file '%s' to be empty", logFile) } } - } else if logBuffer.Len() > 0 { - t.Errorf("Expected no log output, but got: %s", logBuffer.String()) + } + + // Clean up log files for next test + for logFile := range tt.expectedLogs { + os.Remove(logFile) } }) } } func TestNew(t *testing.T) { - // Test with nil logger and nil publisher - server := New(nil, nil) + // Define rule group configurations + ruleGroupConfigs := []RuleGroupConfig{ + { + Name: "test_group", + Rules: []string{ + `Request.Operation == "update"`, + }, + LogFile: LogFileConfig{ + FilePath: "test.log", + MaxSize: 1, + MaxBackups: 1, + MaxAge: 1, + Compress: false, + }, + }, + } + + // Initialize viper with the rule group configurations + viper.Set("rule_groups", ruleGroupConfigs) + + // Test with nil logger + server := New(nil) if server.logger == nil { t.Errorf("Expected non-nil logger when initialized with nil") } - if server.publisher == nil { - t.Errorf("Expected non-nil publisher when initialized with nil") + + if len(server.ruleGroups) != len(ruleGroupConfigs) { + t.Errorf("Expected %d rule groups, got %d", len(ruleGroupConfigs), len(server.ruleGroups)) } - // Test with custom logger and publisher + // Test with custom logger customLogger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) - customPublisher := redis.NewClient(&redis.Options{Addr: "localhost:6379"}) - server = New(customLogger, customPublisher) + server = New(customLogger) if server.logger != customLogger { t.Errorf("Expected custom logger to be used") } - if server.publisher != customPublisher { - t.Errorf("Expected custom publisher to be used") +} + +func TestRuleGroup_shouldLog(t *testing.T) { + // Define a sample audit log + auditLog := &AuditLog{ + Type: "request", + Time: "2024-09-17T13:00:00Z", + Auth: Auth{ + DisplayName: "user1", + Policies: []string{"default", "writer"}, + PolicyResults: struct { + Allowed bool `json:"allowed"` + GrantingPolicies []struct { + Name string `json:"name"` + NamespaceID string `json:"namespace_id"` + Type string `json:"type"` + } `json:"granting_policies"` + }{ + Allowed: true, + }, + }, + Request: Request{ + Operation: "update", + Path: "secret/data/myapp/config", + }, + } + + // Compile a rule + ruleStr := `Request.Operation == "update" && Request.Path startsWith "secret/data/" && Auth.PolicyResults.Allowed == true` + program, err := expr.Compile(ruleStr, expr.Env(&AuditLog{})) + if err != nil { + t.Fatalf("Failed to compile rule: %v", err) + } + + // Create a RuleGroup + rg := &RuleGroup{ + Name: "test_group", + CompiledRules: []CompiledRule{ + {Program: program}, + }, + } + + // Test shouldLog + if !rg.shouldLog(auditLog) { + t.Errorf("Expected shouldLog to return true, got false") + } + + // Modify audit log to not match + auditLog.Request.Operation = "read" + + if rg.shouldLog(auditLog) { + t.Errorf("Expected shouldLog to return false, got true") } }