Skip to content

Commit

Permalink
Merge pull request #184 from csinva/revert-183-master
Browse files Browse the repository at this point in the history
Revert "Update hierarchical_shrinkage, fix bugs, change attribute name"
  • Loading branch information
csinva authored Jul 28, 2023
2 parents 7ab6510 + d96c3f2 commit f221ba5
Show file tree
Hide file tree
Showing 8 changed files with 267 additions and 280 deletions.
45 changes: 18 additions & 27 deletions imodels/experimental/figs_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
from matplotlib import pyplot as plt
import sklearn
from sklearn import datasets
from sklearn import tree
from sklearn.base import BaseEstimator
Expand Down Expand Up @@ -73,22 +72,18 @@ def setattrs(self, **kwargs):
setattr(self, k, v)

def __str__(self):
try:
sklearn.utils.validation.check_is_fitted(self)
if self.split_or_linear == 'linear':
if self.is_root:
return f'X_{self.feature} * {self.value:0.3f} (Tree #{self.tree_num} linear root)'
else:
return f'X_{self.feature} * {self.value:0.3f} (linear)'
if self.split_or_linear == 'linear':
if self.is_root:
return f'X_{self.feature} * {self.value:0.3f} (Tree #{self.tree_num} linear root)'
else:
if self.is_root:
return f'X_{self.feature} <= {self.threshold:0.3f} (Tree #{self.tree_num} root)'
elif self.left is None and self.right is None:
return f'Val: {self.value[0][0]:0.3f} (leaf)'
else:
return f'X_{self.feature} <= {self.threshold:0.3f} (split)'
except ValueError:
return self.__class__.__name__
return f'X_{self.feature} * {self.value:0.3f} (linear)'
else:
if self.is_root:
return f'X_{self.feature} <= {self.threshold:0.3f} (Tree #{self.tree_num} root)'
elif self.left is None and self.right is None:
return f'Val: {self.value[0][0]:0.3f} (leaf)'
else:
return f'X_{self.feature} <= {self.threshold:0.3f} (split)'

def __repr__(self):
return self.__str__()
Expand Down Expand Up @@ -422,17 +417,13 @@ def _tree_to_str(self, root: Node, prefix=''):
pprefix)

def __str__(self):
try:
sklearn.utils.validation.check_is_fitted(self)
s = '------------\n' + \
'\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
if hasattr(self, 'feature_names_') and self.feature_names_ is not None:
for i in range(len(self.feature_names_))[::-1]:
s = s.replace(f'X_{i}', self.feature_names_[i])
return s
except ValueError:
return self.__class__.__name__

s = '------------\n' + \
'\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
if hasattr(self, 'feature_names_') and self.feature_names_ is not None:
for i in range(len(self.feature_names_))[::-1]:
s = s.replace(f'X_{i}', self.feature_names_[i])
return s

def predict(self, X):
if self.posthoc_ridge and self.weighted_model_: # note, during fitting don't use the weighted moel
X_feats = self._extract_tree_predictions(X)
Expand Down
19 changes: 7 additions & 12 deletions imodels/rule_list/corels_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import pandas as pd
import sklearn
from sklearn.preprocessing import KBinsDiscretizer

from imodels.rule_list.greedy_rule_list import GreedyRuleListClassifier
Expand Down Expand Up @@ -234,18 +233,14 @@ def _traverse_rule(self, X: np.ndarray, y: np.ndarray, feature_names: List[str],
self.str_print = str_print

def __str__(self):
try:
sklearn.utils.validation.check_is_fitted(self)
if corels_supported:
if self.str_print is not None:
return 'OptimalRuleList:\n\n' + self.str_print
else:
return 'OptimalRuleList:\n\n' + self.rl_.__str__()
if corels_supported:
if self.str_print is not None:
return 'OptimalRuleList:\n\n' + self.str_print
else:
return super().__str__()
except ValueError:
return self.__class__.__name__
return 'OptimalRuleList:\n\n' + self.rl_.__str__()
else:
return super().__str__()

def _get_complexity(self):
return sum([len(corule['antecedents']) for corule in self.rl_.rules])

Expand Down
70 changes: 37 additions & 33 deletions imodels/rule_list/greedy_rule_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from copy import deepcopy

import numpy as np
import sklearn
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_array, check_is_fitted
Expand Down Expand Up @@ -141,43 +140,48 @@ def predict(self, X):
X = check_array(X)
return np.argmax(self.predict_proba(X), axis=1)

"""
def __str__(self):
# s = ''
# for rule in self.rules_:
# s += f"mean {rule['val'].round(3)} ({rule['num_pts']} pts)\n"
# if 'col' in rule:
# s += f"if {rule['col']} >= {rule['cutoff']} then {rule['val_right'].round(3)} ({rule['num_pts_right']} pts)\n"
# return s
"""

def __str__(self):
'''Print out the list in a nice way
'''
try:
sklearn.utils.validation.check_is_fitted(self)
s = '> ------------------------------\n> Greedy Rule List\n> ------------------------------\n'

def red(s):
# return f"\033[91m{s}\033[00m"
return s

def cyan(s):
# return f"\033[96m{s}\033[00m"
return s

def rule_name(rule):
if rule['flip']:
return '~' + rule['col']
return rule['col']

# rule = self.rules_[0]
# s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
for rule in self.rules_:
s += u'\u2193\n' + f"{cyan((100 * rule['val']).round(2))}% risk ({rule['num_pts']} pts)\n"
# s += f"\t{'Else':>45} => {cyan((100 * rule['val']).round(2)):>6}% IwI ({rule['val'] * rule['num_pts']:.0f}/{rule['num_pts']} pts)\n"
if 'col' in rule:
# prefix = f"if {rule['col']} >= {rule['cutoff']}"
prefix = f"if {rule_name(rule)}"
val = f"{100 * rule['val_right'].round(3)}"
s += f"\t{prefix} ==> {red(val)}% risk ({rule['num_pts_right']} pts)\n"
# rule = self.rules_[-1]
# s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
s = '> ------------------------------\n> Greedy Rule List\n> ------------------------------\n'

def red(s):
# return f"\033[91m{s}\033[00m"
return s

def cyan(s):
# return f"\033[96m{s}\033[00m"
return s
except ValueError:
return self.__class__.__name__


def rule_name(rule):
if rule['flip']:
return '~' + rule['col']
return rule['col']

# rule = self.rules_[0]
# s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
for rule in self.rules_:
s += u'\u2193\n' + f"{cyan((100 * rule['val']).round(2))}% risk ({rule['num_pts']} pts)\n"
# s += f"\t{'Else':>45} => {cyan((100 * rule['val']).round(2)):>6}% IwI ({rule['val'] * rule['num_pts']:.0f}/{rule['num_pts']} pts)\n"
if 'col' in rule:
# prefix = f"if {rule['col']} >= {rule['cutoff']}"
prefix = f"if {rule_name(rule)}"
val = f"{100 * rule['val_right'].round(3)}"
s += f"\t{prefix} ==> {red(val)}% risk ({rule['num_pts_right']} pts)\n"
# rule = self.rules_[-1]
# s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
return s

######## HERE ONWARDS CUSTOM SPLITTING (DEPRECATED IN FAVOR OF SKLEARN STUMP) ########
######################################################################################
def _find_best_split(self, x, y):
Expand Down
9 changes: 2 additions & 7 deletions imodels/rule_set/brs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from numpy.random import random
from pandas import read_csv
from scipy.sparse import csc_matrix
import sklearn
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils.multiclass import check_classification_targets
Expand Down Expand Up @@ -193,12 +192,8 @@ def fit(self, X, y, feature_names: list = None, init=[], verbose=False):
return self

def __str__(self):
try:
sklearn.utils.validation.check_is_fitted(self)
return ' '.join(str(r) for r in self.rules_)
except ValueError:
return self.__class__.__name__

return ' '.join(str(r) for r in self.rules_)

def predict(self, X):
check_is_fitted(self)
if isinstance(X, np.ndarray):
Expand Down
17 changes: 6 additions & 11 deletions imodels/rule_set/rule_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import pandas as pd
import scipy
from scipy.special import softmax
import sklearn
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.base import TransformerMixin
from sklearn.utils.multiclass import unique_labels
Expand Down Expand Up @@ -243,16 +242,12 @@ def visualize(self, decimals=2):
return rules[['rule', 'coef']].round(decimals)

def __str__(self):
try:
sklearn.utils.validation.check_is_fitted(self)
s = '> ------------------------------\n'
s += '> RuleFit:\n'
s += '> \tPredictions are made by summing the coefficients of each rule\n'
s += '> ------------------------------\n'
return s + self.visualize().to_string(index=False) + '\n'
except ValueError:
return self.__class__.__name__

s = '> ------------------------------\n'
s += '> RuleFit:\n'
s += '> \tPredictions are made by summing the coefficients of each rule\n'
s += '> ------------------------------\n'
return s + self.visualize().to_string(index=False) + '\n'

def _extract_rules(self, X, y) -> List[str]:
return extract_rulefit(X, y,
feature_names=self.feature_placeholders,
Expand Down
34 changes: 13 additions & 21 deletions imodels/tree/cart_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This is just a simple wrapper around sklearn decisiontree
# https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

import sklearn
from sklearn.tree import DecisionTreeClassifier, export_text, DecisionTreeRegressor
from imodels.util.arguments import check_fit_arguments

Expand Down Expand Up @@ -49,18 +48,15 @@ def _set_complexity(self):
self.complexity_ = compute_tree_complexity(self.tree_)

def __str__(self):
try:
sklearn.utils.validation.check_is_fitted(self)
s = '> ------------------------------\n'
s += '> Greedy CART Tree:\n'
s += '> \tPrediction is made by looking at the value in the appropriate leaf of the tree\n'
s += '> ------------------------------' + '\n'
if hasattr(self, 'feature_names') and self.feature_names is not None:
return s + export_text(self, feature_names=self.feature_names, show_weights=True)
else:
return s + export_text(self, show_weights=True)
except ValueError:
return self.__class__.__name__
s = '> ------------------------------\n'
s += '> Greedy CART Tree:\n'
s += '> \tPrediction is made by looking at the value in the appropriate leaf of the tree\n'
s += '> ------------------------------' + '\n'
if hasattr(self, 'feature_names') and self.feature_names is not None:
return s + export_text(self, feature_names=self.feature_names, show_weights=True)
else:
return s + export_text(self, show_weights=True)


class GreedyTreeRegressor(DecisionTreeRegressor):
"""Wrapper around sklearn greedy tree regressor
Expand Down Expand Up @@ -102,11 +98,7 @@ def _set_complexity(self):
self.complexity_ = compute_tree_complexity(self.tree_)

def __str__(self):
try:
sklearn.utils.validation.check_is_fitted(self)
if hasattr(self, 'feature_names') and self.feature_names is not None:
return 'GreedyTree:\n' + export_text(self, feature_names=self.feature_names, show_weights=True)
else:
return 'GreedyTree:\n' + export_text(self, show_weights=True)
except ValueError:
return self.__class__.__name__
if hasattr(self, 'feature_names') and self.feature_names is not None:
return 'GreedyTree:\n' + export_text(self, feature_names=self.feature_names, show_weights=True)
else:
return 'GreedyTree:\n' + export_text(self, show_weights=True)
47 changes: 20 additions & 27 deletions imodels/tree/figs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import pandas as pd
from scipy.special import expit
import sklearn
from sklearn import datasets
from sklearn import tree
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
Expand Down Expand Up @@ -52,17 +51,13 @@ def setattrs(self, **kwargs):
setattr(self, k, v)

def __str__(self):
try:
sklearn.utils.validation.check_is_fitted(self)
if self.is_root:
return f'X_{self.feature} <= {self.threshold:0.3f} (Tree #{self.tree_num} root)'
elif self.left is None and self.right is None:
return f'Val: {self.value[0][0]:0.3f} (leaf)'
else:
return f'X_{self.feature} <= {self.threshold:0.3f} (split)'
except ValueError:
return self.__class__.__name__

if self.is_root:
return f'X_{self.feature} <= {self.threshold:0.3f} (Tree #{self.tree_num} root)'
elif self.left is None and self.right is None:
return f'Val: {self.value[0][0]:0.3f} (leaf)'
else:
return f'X_{self.feature} <= {self.threshold:0.3f} (split)'

def print_root(self, y):
try:
one_count = pd.Series(y).value_counts()[1.0]
Expand All @@ -77,6 +72,8 @@ def print_root(self, y):
else:
return f'X_{self.feature} <= {self.threshold:0.3f}' + one_proportion

def __repr__(self):
return self.__str__()


class FIGS(BaseEstimator):
Expand Down Expand Up @@ -414,21 +411,17 @@ def _tree_to_str_with_data(self, X, y, root: Node, prefix=''):
self._tree_to_str_with_data(X[~left], y[~left], root.right, pprefix))

def __str__(self):
try:
sklearn.utils.validation.check_is_fitted(self)
s = '> ------------------------------\n'
s += '> FIGS-Fast Interpretable Greedy-Tree Sums:\n'
s += '> \tPredictions are made by summing the "Val" reached by traversing each tree.\n'
s += '> \tFor classifiers, a sigmoid function is then applied to the sum.\n'
s += '> ------------------------------\n'
s += '\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
if hasattr(self, 'feature_names_') and self.feature_names_ is not None:
for i in range(len(self.feature_names_))[::-1]:
s = s.replace(f'X_{i}', self.feature_names_[i])
return s
except ValueError:
return self.__class__.__name__

s = '> ------------------------------\n'
s += '> FIGS-Fast Interpretable Greedy-Tree Sums:\n'
s += '> \tPredictions are made by summing the "Val" reached by traversing each tree.\n'
s += '> \tFor classifiers, a sigmoid function is then applied to the sum.\n'
s += '> ------------------------------\n'
s += '\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
if hasattr(self, 'feature_names_') and self.feature_names_ is not None:
for i in range(len(self.feature_names_))[::-1]:
s = s.replace(f'X_{i}', self.feature_names_[i])
return s

def print_tree(self, X, y, feature_names=None):
s = '------------\n' + \
'\n\t+\n'.join([self._tree_to_str_with_data(X, y, t)
Expand Down
Loading

0 comments on commit f221ba5

Please sign in to comment.