Skip to content

Commit

Permalink
save mse_val_ in tree gam
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Aug 25, 2023
1 parent 8179aab commit a257cd2
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 152 deletions.
26 changes: 11 additions & 15 deletions imodels/algebraic/tree_gam.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def fit(self, X, y, sample_weight=None):
sample_weight_val,
)

self.mse_val_ = self._calc_mse(X_val, y_val, sample_weight_val)

return self

def _marginal_fit(
Expand Down Expand Up @@ -174,10 +176,7 @@ def _cyclic_boost(
"""Apply cyclic boosting, storing trees in self.estimators_"""

residuals_train = y_train - self.predict_proba(X_train)[:, 1]
mse_val = np.average(
np.square(y_val - self.predict_proba(X_val)[:, 1]),
weights=sample_weight_val,
)
mse_val = self._calc_mse(X_val, y_val, sample_weight_val)
for _ in range(self.n_boosting_rounds):
boosting_round_ests = []
boosting_round_mses = []
Expand All @@ -203,9 +202,8 @@ def _cyclic_boost(
if self.boosting_strategy == "cyclic":
residuals_train = residuals_train_new
elif self.boosting_strategy == "greedy":
mse_train_new = np.average(
np.square(residuals_train_new),
weights=sample_weight_train,
mse_train_new = self._calc_mse(
X_train, y_train, sample_weight_train
)
# don't add each estimator for greedy
boosting_round_ests.append(deepcopy(self.estimators_.pop()))
Expand All @@ -219,18 +217,12 @@ def _cyclic_boost(
)

# early stopping if validation error does not decrease
mse_val_new = np.average(
np.square(y_val - self.predict_proba(X_val)[:, 1]),
weights=sample_weight_val,
)
mse_val_new = self._calc_mse(X_val, y_val, sample_weight_val)
if mse_val_new >= mse_val:
self.mse_val = mse_val
return
else:
mse_val = mse_val_new

self.mse_val = mse_val

def predict_proba(self, X):
X = check_array(X, accept_sparse=False, dtype=None)
check_is_fitted(self)
Expand All @@ -248,7 +240,11 @@ def predict(self, X):
elif isinstance(self, ClassifierMixin):
return np.argmax(self.predict_proba(X), axis=1)


def _calc_mse(self, X, y, sample_weight=None):
return np.average(
np.square(y - self.predict_proba(X)[:, 1]),
weights=sample_weight,
)


class TreeGAMRegressor(TreeGAM, RegressorMixin):
Expand Down
Loading

0 comments on commit a257cd2

Please sign in to comment.