-
Notifications
You must be signed in to change notification settings - Fork 0
/
DecisionTreeClassifier.js
88 lines (78 loc) · 2.53 KB
/
DecisionTreeClassifier.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import { Matrix } from './indexMatrix.js';
import Tree from './TreeNode.js';
const defaultOptions = {
gainFunction: 'gini',
splitFunction: 'mean',
minNumSamples: 3,
maxDepth: Infinity,
gainThreshold: 0.01,
};
export class DecisionTreeClassifier {
/**
* Create new Decision Tree Classifier with CART implementation with the given options
* @param {object} options
* @param {string} [options.gainFunction="gini"] - gain function to get the best split, "gini" the only one supported.
* @param {string} [options.splitFunction="mean"] - given two integers from a split feature, get the value to split, "mean" the only one supported.
* @param {number} [options.minNumSamples=3] - minimum number of samples to create a leaf node to decide a class.
* @param {number} [options.maxDepth=Infinity] - Max depth of the tree.
* @param {object} model - for load purposes.
* @constructor
*/
constructor(options, model) {
if (options === true) {
this.options = model.options;
this.root = new Tree(model.options);
this.root.setNodeParameters(model.root);
} else {
this.options = Object.assign({}, defaultOptions, options);
this.options.kind = 'classifier';
}
}
/**
* Train the decision tree with the given training set and labels.
* @param {Matrix|MatrixTransposeView|Array} trainingSet
* @param {Array} trainingLabels
*/
train(trainingSet, trainingLabels) {
this.root = new Tree(this.options);
trainingSet = Matrix.checkMatrix(trainingSet);
this.root.train(trainingSet, trainingLabels, 0, null);
}
/**
* Predicts the output given the matrix to predict.
* @param {Matrix|MatrixTransposeView|Array} toPredict
* @return {Array} predictions
*/
predict(toPredict) {
toPredict = Matrix.checkMatrix(toPredict);
let predictions = new Array(toPredict.rows);
for (let i = 0; i < toPredict.rows; ++i) {
predictions[i] = this.root
.classify(toPredict.getRow(i))
.maxRowIndex(0)[1];
}
return predictions;
}
/**
* Export the current model to JSON.
* @return {object} - Current model.
*/
toJSON() {
return {
options: this.options,
root: this.root,
name: 'DTClassifier',
};
}
/**
* Load a Decision tree classifier with the given model.
* @param {object} model
* @return {DecisionTreeClassifier}
*/
static load(model) {
if (model.name !== 'DTClassifier') {
throw new RangeError(`Invalid model: ${model.name}`);
}
return new DecisionTreeClassifier(true, model);
}
}