Skip to content

Commit

Permalink
add gam test
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Aug 17, 2023
1 parent feb3bb5 commit d0b602b
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 43 deletions.
53 changes: 31 additions & 22 deletions imodels/algebraic/gam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,36 @@
from sklearn.metrics import accuracy_score, roc_auc_score
from tqdm import tqdm

import imodels


class TreeGAMClassifier(BaseEstimator):
"""Tree-based GAM classifier.
Uses cyclical boosting to fit a GAM with small trees.
Simplified version of the explainable boosting machine described in https://github.com/interpretml/interpret
Only works for binary classification.
"""

def __init__(
self,
max_leaf_nodes=3,
n_boosting_rounds=20,
n_boosting_rounds=100,
random_state=None,
):
self.max_leaf_nodes = max_leaf_nodes
self.random_state = random_state
self.n_boosting_rounds = n_boosting_rounds

def fit(self, X, y, sample_weight=None):
def fit(self, X, y, sample_weight=None, learning_rate=0.01):
X, y = check_X_y(X, y, accept_sparse=False, multi_output=False)
check_classification_targets(y)
sample_weight = _check_sample_weight(sample_weight, X, dtype=None)

# cycle through features and fit a tree to each one
ests = []
self.estimators_ = []
self.learning_rate = learning_rate
self.bias_ = np.mean(y)
residuals = y - self.bias_
for boosting_round in tqdm(range(self.n_boosting_rounds)):
for feature_num in range(X.shape[1]):
X_ = np.zeros_like(X)
Expand All @@ -50,38 +56,33 @@ def fit(self, X, y, sample_weight=None):
max_leaf_nodes=self.max_leaf_nodes,
random_state=self.random_state,
)
est.fit(X_, y, sample_weight=sample_weight)
if not est.tree_.feature[0] == feature_num:
# failed to split on this feature
est.fit(X_, residuals, sample_weight=sample_weight)
succesfully_split_on_feature = np.all(
(est.tree_.feature[0] == feature_num) | (est.tree_.feature[0] == -2)
)
if not succesfully_split_on_feature:
continue
ests.append(est)
y = y - est.predict(X)

self.est_ = GradientBoostingRegressor()
self.est_.fit(X, y)
self.est_.n_estimators_ = len(ests)
self.est_.estimators_ = np.array(ests).reshape(-1, 1)

self.estimators_.append(est)
residuals = residuals - self.learning_rate * est.predict(X)
return self

def predict_proba(self, X):
X = check_array(X, accept_sparse=False, dtype=None)
check_is_fitted(self)
probs1 = self.est_.predict(X)
probs1 = np.ones(X.shape[0]) * self.bias_
for est in self.estimators_:
probs1 += self.learning_rate * est.predict(X)
probs1 = np.clip(probs1, a_min=0, a_max=1)
return np.array([1 - probs1, probs1]).T

def predict(self, X):
X = check_array(X, accept_sparse=False, dtype=None)
check_is_fitted(self)
return (self.est_.predict(X) > 0.5).astype(int)
return np.argmax(self.predict_proba(X), axis=1)


if __name__ == "__main__":
breast = load_breast_cancer()
feature_names = list(breast.feature_names)
X, y = pd.DataFrame(breast.data, columns=feature_names), breast.target
X, y, feature_names = imodels.get_clean_dataset("heart")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
gam = TreeGAMClassifier(max_leaf_nodes=2)
gam = TreeGAMClassifier()
gam.fit(X_train, y_train)

# check roc auc score
Expand All @@ -90,3 +91,11 @@ def predict(self, X):
"train roc auc score:", roc_auc_score(y_train, gam.predict_proba(X_train)[:, 1])
)
print("test roc auc score:", roc_auc_score(y_test, y_pred))
print(
"accs",
accuracy_score(y_train, gam.predict(X_train)),
accuracy_score(y_test, gam.predict(X_test)),
"imb",
np.mean(y_train),
np.mean(y_test),
)
48 changes: 29 additions & 19 deletions tests/classification_binary_inputs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,72 @@
import numpy as np

from imodels import (
OptimalRuleListClassifier, OptimalTreeClassifier, FPLassoClassifier, FPSkopeClassifier)
OptimalRuleListClassifier,
OptimalTreeClassifier,
FPLassoClassifier,
FPSkopeClassifier,
TreeGAMClassifier,
)


class TestClassClassificationBinary:
'''Tests simple classification for different models. Note: still doesn't test all the models!
'''
"""Tests simple classification for different models. Note: still doesn't test all the models!"""

def setup(self):
np.random.seed(13)
random.seed(13)
self.n = 40
self.p = 2
self.X_classification_binary = (np.random.randn(self.n, self.p) > 0).astype(int)

# y = x0 > 0
self.y_classification_binary = (self.X_classification_binary[:, 0] > 0).astype(int)
self.y_classification_binary = (self.X_classification_binary[:, 0] > 0).astype(
int
)

# flip labels for last few
self.y_classification_binary[-2:] = 1 - self.y_classification_binary[-2:]

def test_classification_binary(self):
'''Test imodels on basic binary classification task
'''
"""Test imodels on basic binary classification task"""
for model_type in [
OptimalRuleListClassifier, OptimalTreeClassifier,
FPLassoClassifier, FPSkopeClassifier,
OptimalRuleListClassifier,
OptimalTreeClassifier,
FPLassoClassifier,
FPSkopeClassifier,
TreeGAMClassifier,
]:

init_kwargs = {}
if model_type == FPSkopeClassifier:
init_kwargs['recall_min'] = 0.5
init_kwargs["recall_min"] = 0.5
if model_type == TreeGAMClassifier:
init_kwargs["n_boosting_rounds"] = 10
m = model_type(**init_kwargs)

X = self.X_classification_binary
m.fit(X, self.y_classification_binary)

# test predict()
preds = m.predict(X) # > 0.5).astype(int)
assert preds.size == self.n, 'predict() yields right size'
assert preds.size == self.n, "predict() yields right size"

# test preds_proba()
if model_type not in {OptimalRuleListClassifier, OptimalTreeClassifier}:
preds_proba = m.predict_proba(X)
assert len(preds_proba.shape) == 2, 'preds_proba has 2 columns'
assert preds_proba.shape[1] == 2, 'preds_proba has 2 columns'
assert np.max(preds_proba) < 1.1, 'preds_proba has no values over 1'
assert (np.argmax(preds_proba, axis=1) == preds).all(), ("predict_proba and "
"predict correspond")
assert len(preds_proba.shape) == 2, "preds_proba has 2 columns"
assert preds_proba.shape[1] == 2, "preds_proba has 2 columns"
assert np.max(preds_proba) < 1.1, "preds_proba has no values over 1"
assert (np.argmax(preds_proba, axis=1) == preds).all(), (
"predict_proba and " "predict correspond"
)

# test acc
acc_train = np.mean(preds == self.y_classification_binary)
# print(type(m), m, 'final acc', acc_train)
assert acc_train > 0.8, 'acc greater than 0.8'
assert acc_train > 0.8, "acc greater than 0.8"


if __name__ == '__main__':
if __name__ == "__main__":
t = TestClassClassificationBinary()
t.setup()
t.test_classification_binary()
6 changes: 4 additions & 2 deletions tests/classification_continuous_inputs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ def test_classification_binary(self):
OneRClassifier, SlipperClassifier,
GreedyTreeClassifier, OptimalTreeClassifier,
C45TreeClassifier, FIGSClassifier,
# TreeGAMClassifier,
TreeGAMClassifier,
]: # IRFClassifier, SLIMClassifier, BayesianRuleSetClassifier,

init_kwargs = {}
if model_type == SkopeRulesClassifier or model_type == FPSkopeClassifier:
init_kwargs['random_state'] = 0
init_kwargs['max_samples_features'] = 1.
if model_type == SlipperClassifier:
elif model_type == SlipperClassifier:
init_kwargs['n_estimators'] = 1
elif model_type == TreeGAMClassifier:
init_kwargs['n_boosting_rounds'] = 10
m = model_type(**init_kwargs)

X = self.X_classification_binary
Expand Down

0 comments on commit d0b602b

Please sign in to comment.