Skip to content

Commit

Permalink
Update test_grid_func_ge.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Barry57 authored Oct 26, 2024
1 parent 440a463 commit 7da036d
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions pytest/test_grid_func_ge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from GENetLib.grid_func_ge import grid_func_ge

def test_grid_func_ge():
func_continuous = sim_data_func(n=100, m=30, ytype='Continuous', seed=123)
grid_func_ge_res = grid_func_ge(func_continuous['y'], func_continuous['z'], func_continuous['location'], func_continuous['X'], 'Continuous', 'Bspline', num_hidden_layers=2, nodes_hidden_layer=[100, 10], Learning_Rate2=[0.008, 0.009, 0.01], L2=[0.002, 0.003, 0.004, 0.005, 0.006], Learning_Rate1=[0.02, 0.03, 0.04, 0.05], L=[0.05, 0.06, 0.07, 0.08], Num_Epochs=100, nbasis1=7, params1=4, Bsplines=15, norder1=4, model=None, split_type=0, ratio=[7,3], plot_res=True)
func_continuous = sim_data_func(n=50, m=30, ytype='Continuous', seed=123)
grid_func_ge_res = grid_func_ge(func_continuous['y'], func_continuous['z'], func_continuous['location'],
func_continuous['X'], 'Continuous', 'Bspline', num_hidden_layers=2, nodes_hidden_layer=[20, 5],
Learning_Rate2=[0.009], L2=[0.004],
Learning_Rate1=[0.04], L=[0.05, 0.06],
Num_Epochs=1, nbasis1=7, params1=4, Bsplines=15, norder1=4, model=None,
split_type=0, ratio=[7,3], plot_res=False)
assert grid_func_ge_res is not None

0 comments on commit 7da036d

Please sign in to comment.