diff --git a/flepimop/gempyor_pkg/tests/seir/test_seir.py b/flepimop/gempyor_pkg/tests/seir/test_seir.py index 16bbd390c..9e4aa5cc2 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_seir.py +++ b/flepimop/gempyor_pkg/tests/seir/test_seir.py @@ -3,6 +3,8 @@ import pytest import warnings import shutil +from random import randint +import pandas as pd import pathlib import pyarrow as pa @@ -15,6 +17,43 @@ DATA_DIR = os.path.dirname(__file__) + "/data" os.chdir(os.path.dirname(__file__)) +def test_neg_params(): + modinf = model_info.ModelInfo( + config=config, + nslots=1, + seir_modifiers_scenario="None", + write_csv=False, + ) + + modinf = model_info.ModelInfo() + parameter_names = modinf.parameters.pnames + dates = modinf.dates + subpop_names = modinf.subpop_pop + + # Test case 1: no negative params + test_array1 = np.zeros((len(parameter_names)-1, len(dates)-1, len(subpop_names)-1)) + + # Test case 2: randomized negative params + test_array2 = np.zeros((len(parameter_names)-1,len(dates)-1,len(subpop_names)-1)) + for _ in range(5): + test_array2[randint(0,len(parameter_names)-1)][randint(0,len(dates)-1)][randint(0,len(subpop_names)-1)] = -1 + + # Test case 3: set negative params with intentional redundancy + test_array3 = np.zeros((len(parameter_names),len(dates),len(subpop_names))) + randint_first_dim = randint(0,len(parameter_names)-1) + randint_second_dim = randint(0,len(dates)-2) + randint_third_dim = randint(0,len(subpop_names)-1) + test_array3[0][0][0] = -1 + test_array3[randint(0,len(parameter_names)),randint(0,len(dates)):,randint(0,len(subpop_names))] = -1 + test_array3[randint_first_dim][randint_second_dim][randint_third_dim] = -1 + test_array3[randint_first_dim][randint_second_dim+1][randint_third_dim] = -1 + + seir.neg_params(test_array1, parameter_names, dates, subpop_names) # NoError + + with pytest.raises(ValueError): + assert seir.neg_params(test_array2, parameter_names, dates, subpop_names) # ValueError + assert seir.neg_params(test_array3, parameter_names, dates, subpop_names) # ValueError + def test_check_values(): os.chdir(os.path.dirname(__file__))