diff --git a/pkg/cmd/flags/policy.go b/pkg/cmd/flags/policy.go index 753d6b738918..4eea38e5a760 100644 --- a/pkg/cmd/flags/policy.go +++ b/pkg/cmd/flags/policy.go @@ -103,278 +103,246 @@ func PrepareFilterMapsFromPolicies(policies []k8s.PolicyInterface) (PolicyScopeM // CreatePolicies creates a Policies object from the scope and events maps. func CreatePolicies(policyScopeMap PolicyScopeMap, policyEventsMap PolicyEventMap, newBinary bool) ([]*policy.Policy, error) { - eventNamesToID := events.Core.NamesToIDs() - // remove internal events since they shouldn't be accessible by users - for event, id := range eventNamesToID { - if events.Core.GetDefinitionByID(id).IsInternal() { - delete(eventNamesToID, event) + policies := make([]*policy.Policy, 0, len(policyScopeMap)) + + for policyIdx, policyScope := range policyScopeMap { + policyEvents, ok := policyEventsMap[policyIdx] + if !ok { + return nil, InvalidFlagEmpty() + } + + pol, err := createSinglePolicy(policyIdx, policyScope, policyEvents, newBinary) + if err != nil { + return nil, err } + policies = append(policies, pol) } - policies := make([]*policy.Policy, 0, len(policyScopeMap)) - for policyIdx, policyScopeFilters := range policyScopeMap { - p := policy.NewPolicy() - p.ID = policyIdx - p.Name = policyScopeFilters.policyName - - for _, scopeFlag := range policyScopeFilters.scopeFlags { - // The filters which are more common (container, event, pid, set, uid) can be given using a prefix of them. - // Other filters should be given using their full name. - // To avoid collisions between filters that share the same prefix, put the filters which should have an exact match first! - if scopeFlag.scopeName == "comm" { - err := p.CommFilter.Parse(scopeFlag.operatorAndValues) - if err != nil { - return nil, err - } - continue + return policies, nil +} + +func createSinglePolicy(policyIdx int, policyScope policyScopes, policyEvents policyEvents, newBinary bool) (*policy.Policy, error) { + p := policy.NewPolicy() + p.ID = policyIdx + p.Name = policyScope.policyName + + if err := parseScopeFilters(p, policyScope.scopeFlags, newBinary); err != nil { + return nil, err + } + + if err := parseEventFilters(p, policyEvents.eventFlags); err != nil { + return nil, err + } + + return p, nil +} + +func parseScopeFilters(p *policy.Policy, scopeFlags []scopeFlag, newBinary bool) error { + for _, scopeFlag := range scopeFlags { + switch scopeFlag.scopeName { + case "comm": + if err := p.CommFilter.Parse(scopeFlag.operatorAndValues); err != nil { + return err } - if scopeFlag.scopeName == "exec" || scopeFlag.scopeName == "executable" || - scopeFlag.scopeName == "bin" || scopeFlag.scopeName == "binary" { - // TODO: Rename BinaryFilter to ExecutableFilter - err := p.BinaryFilter.Parse(scopeFlag.operatorAndValues) - if err != nil { - return nil, err - } - continue + case "exec", "executable", "bin", "binary": + if err := p.BinaryFilter.Parse(scopeFlag.operatorAndValues); err != nil { + return err } - if scopeFlag.scopeName == "container" { - if scopeFlag.operator == "not" { - err := p.ContFilter.Parse(scopeFlag.full) - if err != nil { - return nil, err - } - continue + case "container": + switch { + case scopeFlag.operator == "not": + if err := p.ContFilter.Parse(scopeFlag.full); err != nil { + return err } - if scopeFlag.operatorAndValues == "=new" { - err := p.NewContFilter.Parse("new") - if err != nil { - return nil, err - } - continue + case scopeFlag.operatorAndValues == "=new": + if err := p.NewContFilter.Parse("new"); err != nil { + return err } - if scopeFlag.operatorAndValues == "!=new" { - err := p.ContFilter.Parse(scopeFlag.scopeName) - if err != nil { - return nil, err - } - err = p.NewContFilter.Parse("!new") - if err != nil { - return nil, err - } - continue + case scopeFlag.operatorAndValues == "!=new": + if err := p.ContFilter.Parse(scopeFlag.scopeName); err != nil { + return err } - if scopeFlag.operator == "=" { - err := p.ContIDFilter.Parse(scopeFlag.operatorAndValues) - if err != nil { - return nil, err - } - continue + if err := p.NewContFilter.Parse("!new"); err != nil { + return err } - - err := p.ContFilter.Parse(scopeFlag.scopeName) - if err != nil { - return nil, err + case scopeFlag.operator == "=": + if err := p.ContIDFilter.Parse(scopeFlag.operatorAndValues); err != nil { + return err + } + default: + if err := p.ContFilter.Parse(scopeFlag.scopeName); err != nil { + return err } - continue } - if scopeFlag.scopeName == "mntns" { - if strings.ContainsAny(scopeFlag.operator, "<>") { - return nil, filters.InvalidExpression(scopeFlag.operatorAndValues) - } - err := p.MntNSFilter.Parse(scopeFlag.operatorAndValues) - if err != nil { - return nil, err - } - continue + case "mntns": + if strings.ContainsAny(scopeFlag.operator, "<>") { + return filters.InvalidExpression(scopeFlag.operatorAndValues) + } + if err := p.MntNSFilter.Parse(scopeFlag.operatorAndValues); err != nil { + return err } - if scopeFlag.scopeName == "pidns" { - if strings.ContainsAny(scopeFlag.operator, "<>") { - return nil, filters.InvalidExpression(scopeFlag.operatorAndValues) - } - err := p.PidNSFilter.Parse(scopeFlag.operatorAndValues) - if err != nil { - return nil, err - } - continue + case "pidns": + if strings.ContainsAny(scopeFlag.operator, "<>") { + return filters.InvalidExpression(scopeFlag.operatorAndValues) + } + if err := p.PidNSFilter.Parse(scopeFlag.operatorAndValues); err != nil { + return err } - if scopeFlag.scopeName == "tree" { - err := p.ProcessTreeFilter.Parse(scopeFlag.operatorAndValues) - if err != nil { - return nil, err - } - continue + case "tree": + if err := p.ProcessTreeFilter.Parse(scopeFlag.operatorAndValues); err != nil { + return err } - if scopeFlag.scopeName == "pid" { - if scopeFlag.operatorAndValues == "=new" { - if err := p.NewPidFilter.Parse("new"); err != nil { - return nil, err - } - continue + case "pid": + switch scopeFlag.operatorAndValues { + case "=new": + if err := p.NewPidFilter.Parse("new"); err != nil { + return err } - if scopeFlag.operatorAndValues == "!=new" { - if err := p.NewPidFilter.Parse("!new"); err != nil { - return nil, err - } - continue + case "!=new": + if err := p.NewPidFilter.Parse("!new"); err != nil { + return err } - err := p.PIDFilter.Parse(scopeFlag.operatorAndValues) - if err != nil { - return nil, err + default: + if err := p.PIDFilter.Parse(scopeFlag.operatorAndValues); err != nil { + return err } - continue } - if scopeFlag.scopeName == "uts" { - err := p.UTSFilter.Parse(scopeFlag.operatorAndValues) - if err != nil { - return nil, err - } - continue + case "uts": + if err := p.UTSFilter.Parse(scopeFlag.operatorAndValues); err != nil { + return err } - if scopeFlag.scopeName == "uid" { - err := p.UIDFilter.Parse(scopeFlag.operatorAndValues) - if err != nil { - return nil, err - } - continue + case "uid": + if err := p.UIDFilter.Parse(scopeFlag.operatorAndValues); err != nil { + return err } - if scopeFlag.scopeName == "follow" { - p.Follow = true - continue - } + case "follow": + p.Follow = true - return nil, InvalidScopeOptionError(scopeFlag.full, newBinary) + default: + return InvalidScopeOptionError(scopeFlag.full, newBinary) } + } + return nil +} - policyEvents, ok := policyEventsMap[policyIdx] - if !ok { - return nil, InvalidFlagEmpty() +func parseEventFilters(p *policy.Policy, eventFlags []eventFlag) error { + eventNamesToID := events.Core.NamesToIDs() + // remove internal events since they shouldn't be accessible by users + for event, id := range eventNamesToID { + if events.Core.GetDefinitionByID(id).IsInternal() { + delete(eventNamesToID, event) } + } - // map sets to events - setsToEvents := make(map[string][]events.ID) - for _, eventDefinition := range events.Core.GetDefinitions() { - for _, set := range eventDefinition.GetSets() { - setsToEvents[set] = append(setsToEvents[set], eventDefinition.GetID()) - } + // map sets to events + setsToEvents := make(map[string][]events.ID) + for _, eventDefinition := range events.Core.GetDefinitions() { + for _, set := range eventDefinition.GetSets() { + setsToEvents[set] = append(setsToEvents[set], eventDefinition.GetID()) } + } - excludedEvents := make([]string, 0) + excludedEvents := make([]string, 0) - for _, evtFlag := range policyEvents.eventFlags { - if evtFlag.eventOptionType == "" && evtFlag.operator == "-" { - // no event option type means that the flag contains only event names - // save excluded events/sets to be removed from the final events/sets - excludedEvents = append(excludedEvents, evtFlag.eventName) - continue - } + // Process event flags + for _, evtFlag := range eventFlags { + if evtFlag.eventOptionType == "" && evtFlag.operator == "-" { + excludedEvents = append(excludedEvents, evtFlag.eventName) + continue + } - eventIdToName := make(map[events.ID]string) - // handle event prefixes with wildcards - if strings.HasSuffix(evtFlag.eventName, "*") { - found := false - prefix := evtFlag.eventName[:len(evtFlag.eventName)-1] - for event, id := range eventNamesToID { - if strings.HasPrefix(event, prefix) { - eventIdToName[id] = event - found = true - } - } - if !found { - return nil, InvalidEventError(evtFlag.eventName) + eventIdToName := make(map[events.ID]string) + if strings.HasSuffix(evtFlag.eventName, "*") { + found := false + prefix := evtFlag.eventName[:len(evtFlag.eventName)-1] + for event, id := range eventNamesToID { + if strings.HasPrefix(event, prefix) { + eventIdToName[id] = event + found = true } - } else { - id, ok := eventNamesToID[evtFlag.eventName] + } + if !found { + return InvalidEventError(evtFlag.eventName) + } + } else { + id, ok := eventNamesToID[evtFlag.eventName] + if !ok { + // no matching event - maybe it is actually a set? + setEvents, ok := setsToEvents[evtFlag.eventName] if !ok { - // no matching event - maybe it is actually a set? - if setEvents, ok := setsToEvents[evtFlag.eventName]; ok { - // expand set to events - for _, id := range setEvents { - eventIdToName[id] = events.Core.GetDefinitionByID(id).GetName() - } - } else { - return nil, InvalidEventError(evtFlag.eventName) - } - } else { - eventIdToName[id] = evtFlag.eventName + return InvalidEventError(evtFlag.eventName) + } + for _, id := range setEvents { + eventIdToName[id] = events.Core.GetDefinitionByID(id).GetName() } + } else { + eventIdToName[id] = evtFlag.eventName } + } - for eventId, _ := range eventIdToName { - if _, ok := p.Rules[eventId]; !ok { - p.Rules[eventId] = policy.RuleData{ - EventID: eventId, - ScopeFilter: filters.NewScopeFilter(), - DataFilter: filters.NewDataFilter(), - RetFilter: filters.NewIntFilter(), - } + for eventId := range eventIdToName { + if _, ok := p.Rules[eventId]; !ok { + p.Rules[eventId] = policy.RuleData{ + EventID: eventId, + ScopeFilter: filters.NewScopeFilter(), + DataFilter: filters.NewDataFilter(), + RetFilter: filters.NewIntFilter(), } + } - if evtFlag.eventOptionType == "" { - // no event option type means that the flag contains only event names - continue - } + if evtFlag.eventOptionType == "" { + continue + } - if evtFlag.eventOptionType == "retval" { - err := p.Rules[eventId].RetFilter.Parse(evtFlag.operatorAndValues) - if err != nil { - return nil, err - } - continue + switch evtFlag.eventOptionType { + case "retval": + if err := p.Rules[eventId].RetFilter.Parse(evtFlag.operatorAndValues); err != nil { + return err } - - if evtFlag.eventOptionType == "scope" { - err := p.Rules[eventId].ScopeFilter.Parse(evtFlag.eventOptionName, evtFlag.operatorAndValues) - if err != nil { - return nil, err - } - continue + case "scope": + if err := p.Rules[eventId].ScopeFilter.Parse(evtFlag.eventOptionName, evtFlag.operatorAndValues); err != nil { + return err } - - if evtFlag.eventOptionType == "data" || evtFlag.eventOptionType == "args" { - err := p.Rules[eventId].DataFilter.Parse(eventId, evtFlag.eventOptionName, evtFlag.operatorAndValues) - if err != nil { - return nil, err - } - continue + case "data", "args": + if err := p.Rules[eventId].DataFilter.Parse(eventId, evtFlag.eventOptionName, evtFlag.operatorAndValues); err != nil { + return err } - - return nil, InvalidFilterFlagFormat(evtFlag.full) + default: + return InvalidFilterFlagFormat(evtFlag.full) } } + } - // if no events were specified, add all events from the default set - if len(p.Rules) == 0 { - for _, eventId := range setsToEvents["default"] { - if _, ok := p.Rules[eventId]; !ok { - p.Rules[eventId] = policy.RuleData{ - EventID: eventId, - ScopeFilter: filters.NewScopeFilter(), - DataFilter: filters.NewDataFilter(), - RetFilter: filters.NewIntFilter(), - } + // if no events were specified, add all events from the default set + if len(p.Rules) == 0 { + for _, eventId := range setsToEvents["default"] { + if _, ok := p.Rules[eventId]; !ok { + p.Rules[eventId] = policy.RuleData{ + EventID: eventId, + ScopeFilter: filters.NewScopeFilter(), + DataFilter: filters.NewDataFilter(), + RetFilter: filters.NewIntFilter(), } } } + } - // remove excluded events from the policy - for _, eventName := range excludedEvents { - if _, ok := eventNamesToID[eventName]; !ok { - return nil, InvalidEventExcludeError(eventName) - } - delete(p.Rules, eventNamesToID[eventName]) + // remove excluded events from the policy + for _, eventName := range excludedEvents { + if _, ok := eventNamesToID[eventName]; !ok { + return InvalidEventExcludeError(eventName) } - - policies = append(policies, p) + delete(p.Rules, eventNamesToID[eventName]) } - return policies, nil + return nil } diff --git a/pkg/cmd/flags/policy_test.go b/pkg/cmd/flags/policy_test.go index c11105fdfb3d..f5e286c17f47 100644 --- a/pkg/cmd/flags/policy_test.go +++ b/pkg/cmd/flags/policy_test.go @@ -1876,7 +1876,7 @@ func TestCreatePolicies(t *testing.T) { { testName: "invalid datafilter 1", evtFlags: []string{"open.data"}, - expectPolicyErr: filters.InvalidExpression("open."), + expectPolicyErr: filters.InvalidEventField(""), }, { testName: "invalid datafilter 2", @@ -1897,12 +1897,12 @@ func TestCreatePolicies(t *testing.T) { { testName: "invalid scope filter 1", evtFlags: []string{"open.scope"}, - expectPolicyErr: filters.InvalidExpression("open.scope"), + expectPolicyErr: filters.InvalidScopeField(""), }, { testName: "invalid scope filter 2", evtFlags: []string{"bla.scope.processName=ls"}, - expectPolicyErr: filters.InvalidEventName("bla"), + expectPolicyErr: InvalidEventError("bla"), }, { testName: "invalid scope filter 3", @@ -2040,10 +2040,6 @@ func TestCreatePolicies(t *testing.T) { testName: "wildcard filter", evtFlags: []string{"open*"}, }, - { - testName: "wildcard not filter", - evtFlags: []string{"-*"}, - }, { testName: "multiple filters", scopeFlags: []string{"uid<1", "mntns=5", "pidns!=3", "pid!=10", "comm=ps", "uts!=abc"},