Skip to content

Commit

Permalink
Added RMSE to regression script
Browse files Browse the repository at this point in the history
TODO: add arguments to wrappers
  • Loading branch information
YAY-C committed May 1, 2024
1 parent 0e402aa commit 8546f2b
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions qstack/regression/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from qstack.regression.kernel_utils import get_kernel, defaults, ParseKwargs
from qstack.tools import correct_num_threads

def regression(X, y, read_kernel=False, sigma=defaults.sigma, eta=defaults.eta, akernel=defaults.kernel, gkernel=defaults.gkernel, gdict=defaults.gdict, test_size=defaults.test_size, train_size=defaults.train_size, n_rep=defaults.n_rep, debug=0, ipywidget=None, save_pred=False):
def regression(X, y, read_kernel=False, sigma=defaults.sigma, eta=defaults.eta, akernel=defaults.kernel, gkernel=defaults.gkernel, gdict=defaults.gdict, test_size=defaults.test_size, train_size=defaults.train_size, n_rep=defaults.n_rep, debug=0, ipywidget=None, save_pred=False, rmse=False):
if read_kernel is False:
kernel = get_kernel(akernel, [gkernel, gdict])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=debug)
Expand Down Expand Up @@ -37,7 +37,10 @@ def regression(X, y, read_kernel=False, sigma=defaults.sigma, eta=defaults.eta,
Ks = Ks_all[:,train_idx]
alpha = scipy.linalg.solve(K, y_kf_train, assume_a='pos')
y_kf_predict = np.dot(Ks, alpha)
maes.append(np.mean(np.abs(y_test-y_kf_predict)))
if rmse:
maes.append(np.sqrt(np.mean((y_test-y_kf_predict)**2)))
else:
maes.append(np.mean(np.abs(y_test-y_kf_predict)))
r2_scores.append(r2_score(y_test, y_kf_predict))
if ipywidget != None : ipywidget.value += 1
maes_all.append((size_train, np.mean(maes), np.std(maes), np.mean(r2_scores)))
Expand Down

0 comments on commit 8546f2b

Please sign in to comment.