diff --git a/internal/scanner/ast/.snapshots/TestExpectedRules b/internal/scanner/ast/.snapshots/TestExpectedRules index 44a9a6598..eb5c1c998 100644 --- a/internal/scanner/ast/.snapshots/TestExpectedRules +++ b/internal/scanner/ast/.snapshots/TestExpectedRules @@ -19,7 +19,7 @@ children: id: 2 range: 3:3 - 5:6 expectedrules: - - 5 + - rule1 children: - type: '"def"' id: 3 diff --git a/internal/scanner/ast/tree/builder.go b/internal/scanner/ast/tree/builder.go index 9ac5ba9b3..7e81e6f6f 100644 --- a/internal/scanner/ast/tree/builder.go +++ b/internal/scanner/ast/tree/builder.go @@ -114,12 +114,9 @@ func (builder *Builder) AddDisabledRules(sitterNode *sitter.Node, rules []*rules func (builder *Builder) addExpectedRulesForNode(nodeID int, rules []*ruleset.Rule) { node := &builder.nodes[nodeID] - if node.expectedRuleIndices == nil { - node.expectedRuleIndices = bitset.New(uint(builder.ruleCount)) - } for _, rule := range rules { - node.expectedRuleIndices.Set(uint(rule.Index())) + node.expectedRules = append(node.expectedRules, rule.ID()) } } diff --git a/internal/scanner/ast/tree/tree.go b/internal/scanner/ast/tree/tree.go index 0260d7b3f..3b9c74d13 100644 --- a/internal/scanner/ast/tree/tree.go +++ b/internal/scanner/ast/tree/tree.go @@ -30,7 +30,7 @@ type Node struct { children, dataflowSources, aliasOf []*Node - expectedRuleIndices *bitset.BitSet + expectedRules []string disabledRuleIndices *bitset.BitSet // FIXME: remove the need for this sitterNode *sitter.Node @@ -61,6 +61,10 @@ func (tree *Tree) NodeFromSitter(sitterNode *sitter.Node) *Node { return tree.sitterToNode[sitterNode] } +func (tree *Tree) Nodes() []Node { + return tree.nodes +} + func (node *Node) Tree() *Tree { return node.tree } @@ -134,12 +138,8 @@ func (node *Node) AliasOf() []*Node { return node.aliasOf } -func (node *Node) RuleExpected(index int) bool { - if node.expectedRuleIndices == nil { - return false - } - - return node.expectedRuleIndices.Test(uint(index)) +func (node *Node) ExpectedRules() []string { + return node.expectedRules } func (node *Node) RuleDisabled(index int) bool { @@ -167,7 +167,7 @@ type nodeDump struct { AliasOf []int `yaml:"alias_of,omitempty"` Queries []int `yaml:",omitempty"` DisabledRules []int `yaml:",omitempty"` - ExpectedRules []int `yaml:",omitempty"` + ExpectedRules []string `yaml:",omitempty"` Children []nodeDump `yaml:",omitempty"` } @@ -199,12 +199,10 @@ func (node *Node) dumpValue() nodeDump { } } - var expectedRules []int - if node.expectedRuleIndices != nil { - for i := 0; i < int(node.expectedRuleIndices.Len()); i++ { - if node.expectedRuleIndices.Test(uint(i)) { - expectedRules = append(expectedRules, i) - } + var expectedRules []string + if len(node.expectedRules) > 0 { + for _, expectedRule := range node.expectedRules { + expectedRules = append(expectedRules, expectedRule) } } diff --git a/internal/scanner/detectors/customrule/filters/filters_test.go b/internal/scanner/detectors/customrule/filters/filters_test.go index 4889cd967..85987004f 100644 --- a/internal/scanner/detectors/customrule/filters/filters_test.go +++ b/internal/scanner/detectors/customrule/filters/filters_test.go @@ -51,19 +51,6 @@ func (context *MockDetectorContext) Scan( panic("unreachable") } -func (context *MockDetectorContext) ScanExpected( - rootNode *tree.Node, - rule *ruleset.Rule, - traversalStrategy traversalstrategy.Strategy, -) ([]*detectortypes.Detection, error) { - if context.scan != nil { - return context.scan(rootNode, rule, traversalStrategy) - } - - Fail("MockDetectorContext.scan called but no scan function was set") - panic("unreachable") -} - func (filter *MockFilter) Evaluate( detectorContext detectortypes.Context, patternVariables variableshape.Values, diff --git a/internal/scanner/detectors/types/types.go b/internal/scanner/detectors/types/types.go index a02a62f43..65109762d 100644 --- a/internal/scanner/detectors/types/types.go +++ b/internal/scanner/detectors/types/types.go @@ -14,11 +14,6 @@ type Detection struct { type Context interface { Filename() string - ScanExpected( - rootNode *tree.Node, - rule *ruleset.Rule, - traversalStrategy traversalstrategy.Strategy, - ) ([]*Detection, error) Scan( rootNode *tree.Node, rule *ruleset.Rule, diff --git a/internal/scanner/languagescanner/languagescanner.go b/internal/scanner/languagescanner/languagescanner.go index c1119f907..0b5fbf1a2 100644 --- a/internal/scanner/languagescanner/languagescanner.go +++ b/internal/scanner/languagescanner/languagescanner.go @@ -109,30 +109,27 @@ func (scanner *Scanner) Scan( ) detections, err := scanner.evaluateRules(ruleScanner, cache, tree) - // FIXME: Check if we are need to eval the tests or not - expectedDetections, _ := scanner.evaluateTests(ruleScanner, cache, tree) + expectedDetections, _ := scanner.ExpectedDetections(tree) return detections, expectedDetections, err } -func (scanner *Scanner) evaluateTests( - ruleScanner *rulescanner.Scanner, - cache *cache.Cache, - tree *tree.Tree, -) ([]*detectortypes.Detection, error) { +func (scanner *Scanner) ExpectedDetections(tree *tree.Tree) ([]*detectortypes.Detection, error) { var detections []*detectortypes.Detection - for _, rule := range scanner.ruleSet.Rules() { - if rule.Type() != ruleset.RuleTypeTopLevel { - continue + nodes := tree.Nodes() + for i := range tree.Nodes() { + node := &nodes[i] + if len(node.ExpectedRules()) > 0 { + for _, expectedRule := range node.ExpectedRules() { + rule, _ := scanner.ruleSet.RuleByID(expectedRule) + detections = append(detections, []*detectortypes.Detection{ + { + RuleID: rule.ID(), + MatchNode: node, + }, + }...) + } } - - cache.Clear() - expectedDetections, err := ruleScanner.ScanExpected(tree.RootNode(), rule, traversalstrategy.NestedStrict) - if err != nil { - return nil, err - } - - detections = append(detections, expectedDetections...) } return detections, nil diff --git a/internal/scanner/rulescanner/rulescanner.go b/internal/scanner/rulescanner/rulescanner.go index 9c0cfa3a0..1df49dd54 100644 --- a/internal/scanner/rulescanner/rulescanner.go +++ b/internal/scanner/rulescanner/rulescanner.go @@ -95,58 +95,10 @@ func (scanner *Scanner) Scan( return detections, nil } -func (scanner *Scanner) ScanExpected( - rootNode *tree.Node, - rule *ruleset.Rule, - traversalStrategy traversalstrategy.Strategy, -) ( - []*detectortypes.Detection, - error, -) { - var detections []*detectortypes.Detection - if err := traversalStrategy.Traverse(scanner.traversalCache, rootNode, func(node *tree.Node) (bool, error) { - if scanner.ctx.Err() != nil { - return false, scanner.ctx.Err() - } - - result, err := scanner.detectExpectedAtNode(rule, node) - if result == nil || err != nil { - return false, err - } - - detections = append(detections, result.Detections...) - return result.Expected, nil - }); err != nil { - return nil, err - } - - return detections, nil -} - func (scanner *Scanner) Filename() string { return scanner.filename } -func (scanner *Scanner) detectExpectedAtNode(rule *ruleset.Rule, node *tree.Node) (*detectorset.Result, error) { - if log.Trace().Enabled() { - log.Trace().Msgf("detect expected at node start: %s at %s", rule.ID(), node.Debug()) - } - - if node.RuleExpected(rule.Index()) { - return &detectorset.Result{ - Detections: []*detectortypes.Detection{ - { - RuleID: rule.ID(), - MatchNode: node, - }, - }, - Expected: true, - }, nil - } - - return nil, nil -} - func (scanner *Scanner) detectAtNode(rule *ruleset.Rule, node *tree.Node) (*detectorset.Result, error) { if log.Trace().Enabled() { log.Trace().Msgf("detect at node start: %s at %s", rule.ID(), node.Debug()) diff --git a/internal/scanner/ruleset/ruleset.go b/internal/scanner/ruleset/ruleset.go index 1713bd8b0..abf36ce2b 100644 --- a/internal/scanner/ruleset/ruleset.go +++ b/internal/scanner/ruleset/ruleset.go @@ -128,6 +128,10 @@ func getRuleType(triggerRuleIDs set.Set[string], settingsRule *settings.Rule) Ru } } +func (set *Set) RuleByIndex(idx uint64) (*Rule, error) { + return set.Rules()[idx], nil +} + func (set *Set) RuleByID(id string) (*Rule, error) { rule, exists := set.rulesByID[id] if !exists {