From 1d18a7f46ba39e3254a74618922cba48ca19ca2a Mon Sep 17 00:00:00 2001 From: narumi Date: Fri, 6 Dec 2024 23:04:34 +0800 Subject: [PATCH] calculate feature importance --- pkg/ml/ensemble/iforest/forest.go | 12 ++++++++++++ pkg/ml/ensemble/iforest/tree.go | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/pkg/ml/ensemble/iforest/forest.go b/pkg/ml/ensemble/iforest/forest.go index ccdf99691..83675843f 100644 --- a/pkg/ml/ensemble/iforest/forest.go +++ b/pkg/ml/ensemble/iforest/forest.go @@ -105,3 +105,15 @@ func (f *IsolationForest) Predict(samples *mat.Dense) []int { } return predictions } + +func (f *IsolationForest) FeatureImportance(sample []float64) []float64 { + o := make([]float64, len(sample)) + + for _, tree := range f.Forest { + for i, c := range tree.FeatureImportance(sample) { + o[i] += float64(c) / float64(f.NumTrees) + } + } + + return o +} diff --git a/pkg/ml/ensemble/iforest/tree.go b/pkg/ml/ensemble/iforest/tree.go index bc9beb68a..99476714d 100644 --- a/pkg/ml/ensemble/iforest/tree.go +++ b/pkg/ml/ensemble/iforest/tree.go @@ -11,3 +11,21 @@ type TreeNode struct { func (node *TreeNode) IsLeaf() bool { return node.Left == nil && node.Right == nil } + +func (node *TreeNode) trackUsedFeatures(sample []float64, featureUsed []float64) []float64 { + if node.IsLeaf() { + return featureUsed + } + + featureUsed[node.SplitIndex]++ + + if sample[node.SplitIndex] < node.SplitValue { + return node.Left.trackUsedFeatures(sample, featureUsed) + } else { + return node.Right.trackUsedFeatures(sample, featureUsed) + } +} + +func (node *TreeNode) FeatureImportance(sample []float64) []float64 { + return node.trackUsedFeatures(sample, make([]float64, len(sample))) +}