diff --git a/pytest/test_func_ge.py b/pytest/test_func_ge.py index 8491812..e38efcd 100644 --- a/pytest/test_func_ge.py +++ b/pytest/test_func_ge.py @@ -6,7 +6,7 @@ def test_func_ge(): func_continuous = sim_data_func(n=500, m=30, ytype='Continuous', seed=123) func_binary = sim_data_func(n=500, m=30, ytype='Binary', seed=123) - func_survival = sim_data_func(n=500, m=30, ytype='Binary', seed=123) + func_survival = sim_data_func(n=500, m=30, ytype='Survival', seed=123) func_ge_res_1 = func_ge(func_continuous['y'], func_continuous['z'], func_continuous['location'], func_continuous['X'], 'Continuous', 'Bspline', num_hidden_layers=2, nodes_hidden_layer=[100,10], Learning_Rate2=0.035, L2=0.01, Learning_Rate1=0.02, L=0.01, Num_Epochs=50, nbasis1=5, params1=4, Bsplines=5, norder1=4,