diff --git a/tests/test_modelgen.py b/tests/test_modelgen.py index 97a6be3..332079d 100644 --- a/tests/test_modelgen.py +++ b/tests/test_modelgen.py @@ -66,7 +66,7 @@ def test_cnn_metrics(self): metrics = ['accuracy', 'mae'] model = modelgen.generate_CNN_model((None, 20, 3), 2, [32, 32], 100, metrics=metrics) model_metrics = [m.name for m in model.metrics] - assert model_metrics == metrics or model_metrics == ['acc', 'mae'] + assert model_metrics == metrics or model_metrics == ['acc', 'mean_absolute_error'] def test_CNN_hyperparameters_nrlayers(self): """ Number of Conv layers from range [4, 4] should be 4. """