Skip to content

Commit

Permalink
refactor: make it less costly to find expected rules
Browse files Browse the repository at this point in the history
  • Loading branch information
cfabianski committed Nov 21, 2023
1 parent 3cd041f commit 7a8f040
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 103 deletions.
2 changes: 1 addition & 1 deletion internal/scanner/ast/.snapshots/TestExpectedRules
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ children:
id: 2
range: 3:3 - 5:6
expectedrules:
- 5
- rule1
children:
- type: '"def"'
id: 3
Expand Down
5 changes: 1 addition & 4 deletions internal/scanner/ast/tree/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,9 @@ func (builder *Builder) AddDisabledRules(sitterNode *sitter.Node, rules []*rules

func (builder *Builder) addExpectedRulesForNode(nodeID int, rules []*ruleset.Rule) {
node := &builder.nodes[nodeID]
if node.expectedRuleIndices == nil {
node.expectedRuleIndices = bitset.New(uint(builder.ruleCount))
}

for _, rule := range rules {
node.expectedRuleIndices.Set(uint(rule.Index()))
node.expectedRules = append(node.expectedRules, rule.ID())
}
}

Expand Down
26 changes: 12 additions & 14 deletions internal/scanner/ast/tree/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type Node struct {
children,
dataflowSources,
aliasOf []*Node
expectedRuleIndices *bitset.BitSet
expectedRules []string
disabledRuleIndices *bitset.BitSet
// FIXME: remove the need for this
sitterNode *sitter.Node
Expand Down Expand Up @@ -61,6 +61,10 @@ func (tree *Tree) NodeFromSitter(sitterNode *sitter.Node) *Node {
return tree.sitterToNode[sitterNode]
}

func (tree *Tree) Nodes() []Node {
return tree.nodes
}

func (node *Node) Tree() *Tree {
return node.tree
}
Expand Down Expand Up @@ -134,12 +138,8 @@ func (node *Node) AliasOf() []*Node {
return node.aliasOf
}

func (node *Node) RuleExpected(index int) bool {
if node.expectedRuleIndices == nil {
return false
}

return node.expectedRuleIndices.Test(uint(index))
func (node *Node) ExpectedRules() []string {
return node.expectedRules
}

func (node *Node) RuleDisabled(index int) bool {
Expand Down Expand Up @@ -167,7 +167,7 @@ type nodeDump struct {
AliasOf []int `yaml:"alias_of,omitempty"`
Queries []int `yaml:",omitempty"`
DisabledRules []int `yaml:",omitempty"`
ExpectedRules []int `yaml:",omitempty"`
ExpectedRules []string `yaml:",omitempty"`
Children []nodeDump `yaml:",omitempty"`
}

Expand Down Expand Up @@ -199,12 +199,10 @@ func (node *Node) dumpValue() nodeDump {
}
}

var expectedRules []int
if node.expectedRuleIndices != nil {
for i := 0; i < int(node.expectedRuleIndices.Len()); i++ {
if node.expectedRuleIndices.Test(uint(i)) {
expectedRules = append(expectedRules, i)
}
var expectedRules []string
if len(node.expectedRules) > 0 {
for _, expectedRule := range node.expectedRules {

Check failure on line 204 in internal/scanner/ast/tree/tree.go

View workflow job for this annotation

GitHub Actions / lint

S1011: should replace loop with `expectedRules = append(expectedRules, node.expectedRules...)` (gosimple)
expectedRules = append(expectedRules, expectedRule)
}
}

Expand Down
13 changes: 0 additions & 13 deletions internal/scanner/detectors/customrule/filters/filters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,6 @@ func (context *MockDetectorContext) Scan(
panic("unreachable")
}

func (context *MockDetectorContext) ScanExpected(
rootNode *tree.Node,
rule *ruleset.Rule,
traversalStrategy traversalstrategy.Strategy,
) ([]*detectortypes.Detection, error) {
if context.scan != nil {
return context.scan(rootNode, rule, traversalStrategy)
}

Fail("MockDetectorContext.scan called but no scan function was set")
panic("unreachable")
}

func (filter *MockFilter) Evaluate(
detectorContext detectortypes.Context,
patternVariables variableshape.Values,
Expand Down
5 changes: 0 additions & 5 deletions internal/scanner/detectors/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ type Detection struct {

type Context interface {
Filename() string
ScanExpected(
rootNode *tree.Node,
rule *ruleset.Rule,
traversalStrategy traversalstrategy.Strategy,
) ([]*Detection, error)
Scan(
rootNode *tree.Node,
rule *ruleset.Rule,
Expand Down
33 changes: 15 additions & 18 deletions internal/scanner/languagescanner/languagescanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,30 +109,27 @@ func (scanner *Scanner) Scan(
)

detections, err := scanner.evaluateRules(ruleScanner, cache, tree)
// FIXME: Check if we are need to eval the tests or not
expectedDetections, _ := scanner.evaluateTests(ruleScanner, cache, tree)
expectedDetections, _ := scanner.ExpectedDetections(tree)

return detections, expectedDetections, err
}

func (scanner *Scanner) evaluateTests(
ruleScanner *rulescanner.Scanner,
cache *cache.Cache,
tree *tree.Tree,
) ([]*detectortypes.Detection, error) {
func (scanner *Scanner) ExpectedDetections(tree *tree.Tree) ([]*detectortypes.Detection, error) {
var detections []*detectortypes.Detection
for _, rule := range scanner.ruleSet.Rules() {
if rule.Type() != ruleset.RuleTypeTopLevel {
continue
nodes := tree.Nodes()
for i := range tree.Nodes() {
node := &nodes[i]
if len(node.ExpectedRules()) > 0 {
for _, expectedRule := range node.ExpectedRules() {
rule, _ := scanner.ruleSet.RuleByID(expectedRule)
detections = append(detections, []*detectortypes.Detection{
{
RuleID: rule.ID(),
MatchNode: node,
},
}...)
}
}

cache.Clear()
expectedDetections, err := ruleScanner.ScanExpected(tree.RootNode(), rule, traversalstrategy.NestedStrict)
if err != nil {
return nil, err
}

detections = append(detections, expectedDetections...)
}

return detections, nil
Expand Down
48 changes: 0 additions & 48 deletions internal/scanner/rulescanner/rulescanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,58 +95,10 @@ func (scanner *Scanner) Scan(
return detections, nil
}

func (scanner *Scanner) ScanExpected(
rootNode *tree.Node,
rule *ruleset.Rule,
traversalStrategy traversalstrategy.Strategy,
) (
[]*detectortypes.Detection,
error,
) {
var detections []*detectortypes.Detection
if err := traversalStrategy.Traverse(scanner.traversalCache, rootNode, func(node *tree.Node) (bool, error) {
if scanner.ctx.Err() != nil {
return false, scanner.ctx.Err()
}

result, err := scanner.detectExpectedAtNode(rule, node)
if result == nil || err != nil {
return false, err
}

detections = append(detections, result.Detections...)
return result.Expected, nil
}); err != nil {
return nil, err
}

return detections, nil
}

func (scanner *Scanner) Filename() string {
return scanner.filename
}

func (scanner *Scanner) detectExpectedAtNode(rule *ruleset.Rule, node *tree.Node) (*detectorset.Result, error) {
if log.Trace().Enabled() {
log.Trace().Msgf("detect expected at node start: %s at %s", rule.ID(), node.Debug())
}

if node.RuleExpected(rule.Index()) {
return &detectorset.Result{
Detections: []*detectortypes.Detection{
{
RuleID: rule.ID(),
MatchNode: node,
},
},
Expected: true,
}, nil
}

return nil, nil
}

func (scanner *Scanner) detectAtNode(rule *ruleset.Rule, node *tree.Node) (*detectorset.Result, error) {
if log.Trace().Enabled() {
log.Trace().Msgf("detect at node start: %s at %s", rule.ID(), node.Debug())
Expand Down
4 changes: 4 additions & 0 deletions internal/scanner/ruleset/ruleset.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ func getRuleType(triggerRuleIDs set.Set[string], settingsRule *settings.Rule) Ru
}
}

func (set *Set) RuleByIndex(idx uint64) (*Rule, error) {
return set.Rules()[idx], nil
}

func (set *Set) RuleByID(id string) (*Rule, error) {
rule, exists := set.rulesByID[id]
if !exists {
Expand Down

0 comments on commit 7a8f040

Please sign in to comment.