Skip to content

Commit

Permalink
fixed bug:todo :refix
Browse files Browse the repository at this point in the history
  • Loading branch information
YAY-C committed Oct 20, 2023
1 parent 54660d0 commit 8270b92
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions qstack/regression/cross-validate_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def cv(X, y,
sigma=defaults.sigmaarr, eta=defaults.etaarr, gkernel=defaults.gkernel, gdict=defaults.gdict,
sigmaarr=defaults.sigmaarr, etaarr=defaults.etaarr, gkernel=defaults.gkernel, gdict=defaults.gdict,
akernel=defaults.kernel, test_size=defaults.test_size, train_size=defaults.train_size, splits=defaults.splits,
printlevel=0, adaptive=False, read_kernel=False, ipywidget=None, n_rep=defaults.n_rep, save=False, preffix='uknown'):
hyper_runs = []
Expand All @@ -23,13 +23,17 @@ def cv(X, y,
if bar > 0:
progress = utt.add_progressbar(max_value=n_rep)
for seed,n in zip(seeds, range(n_rep)):
error = hyperparameters(X, y, read_kernel=False, sigma=sigma, eta=eta, akernel=akernel, test_size=test_size, splits=splits, printlevel=printlevel, adaptive=adaptive, debug=seed)
error = hyperparameters(X, y, read_kernel=False, sigma=sigmaarr, eta=etaarr, akernel=akernel, test_size=test_size, splits=splits, printlevel=printlevel, adaptive=adaptive, debug=seed)
mae, stdev, eta, sigma = zip(*error)
hyper_runs.append(list(zip([n]*len(error), error, stdev, eta, sigma)))
maes_all = regression(X, y, read_kernel=False, sigma=sigma[-1], eta=eta[-1], akernel=akernel, test_size=test_size, train_size=train_size, n_rep=1, debug=seed)
ind = np.argsort(error[:,3])
error = error[ind]
ind = np.argsort(error[:,2])
error = error[ind]
hyper_runs.append(error)
lc_runs.append(maes_all)
if bar > 0:
progress.update(n)
progress.update(n+1)
lc_runs = np.array(lc_runs)
hyper_runs = np.array(hyper_runs, dtype=object)
lc = list(zip(lc_runs[:,:,0].mean(axis=0), lc_runs[:,:,1].mean(axis=0), lc_runs[:,:,1].std(axis=0), lc_runs[:,:,3].mean(axis=0)))
Expand Down Expand Up @@ -76,7 +80,7 @@ def main():
selected = np.loadtxt(args.f_select, dtype=int)
X = X[selected]
y = y[selected]
final = cv(X, y, sigma=args.sigma, eta=args.eta, akernel=args.akernel, test_size=args.test_size, splits=args.splits, printlevel=args.printlevel, adaptive=args.adaptive, train_size=args.train_size, n_rep=args.n_rep, preffix=args.nameout, save=args.save_all)
final = cv(X, y, sigmaarr=args.sigma, etaarr=args.eta, akernel=args.akernel, test_size=args.test_size, splits=args.splits, printlevel=args.printlevel, adaptive=args.adaptive, train_size=args.train_size, n_rep=args.n_rep, preffix=args.nameout, save=args.save_all)
print(final)
np.savetxt(args.nameout, final)

Expand Down

0 comments on commit 8270b92

Please sign in to comment.