From 39379b5b7f1171def5dc9832f1b53b167789f307 Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Fri, 20 Oct 2023 18:22:23 +0200 Subject: [PATCH] Add balanced models. --- params.yaml | 9 ++++++++- scripts/evaluate_metrics.py | 6 ++---- src/re_classwise_shapley/model.py | 11 ++++++----- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/params.yaml b/params.yaml index 74b69a1e..182cd506 100644 --- a/params.yaml +++ b/params.yaml @@ -156,29 +156,36 @@ datasets: models: logistic_regression: + model: logistic_regression solver: liblinear logistic_regression_balanced: + model: logistic_regression solver: liblinear class_weight: balanced gradient_boosting_classifier: + model: gradient_boosting_classifier n_estimators: 40 min_samples_split: 6 max_depth: 2 knn: + model: knn n_neighbors: 5 weights: uniform svm: + model: svm kernel: rbf svm_balanced: + model: svm kernel: rbf class_weight: balanced - mlp: {} + mlp: + model: mlp valuation_methods: random: diff --git a/scripts/evaluate_metrics.py b/scripts/evaluate_metrics.py index 7047c49e..8661e4e3 100644 --- a/scripts/evaluate_metrics.py +++ b/scripts/evaluate_metrics.py @@ -97,10 +97,8 @@ def _evaluate_metrics( / str(repetition_id) / valuation_method_name ) - if ( - False - and os.path.exists(output_dir / f"{metric_name}.csv") - and os.path.exists(output_dir / f"{metric_name}.curve.csv") + if os.path.exists(output_dir / f"{metric_name}.csv") and os.path.exists( + output_dir / f"{metric_name}.curve.csv" ): return logger.info(f"Metric data exists in '{output_dir}'. Skipping...") diff --git a/src/re_classwise_shapley/model.py b/src/re_classwise_shapley/model.py index d1a79091..b2bd6491 100644 --- a/src/re_classwise_shapley/model.py +++ b/src/re_classwise_shapley/model.py @@ -33,26 +33,27 @@ def instantiate_model( The instantiated model. """ random_state = np.random.RandomState(seed) + model = model_kwargs.pop("model") - if model_name == "gradient_boosting_classifier": + if model == "gradient_boosting_classifier": model = make_pipeline( GradientBoostingClassifier(**model_kwargs, random_state=random_state) ) - elif model_name == "logistic_regression": + elif model == "logistic_regression": model = make_pipeline( StandardScaler(), LogisticRegression(**model_kwargs, random_state=random_state), ) - elif model_name == "knn": + elif model == "knn": model = make_pipeline( StandardScaler(), KNeighborsClassifier(**model_kwargs), ) - elif model_name == "svm": + elif model == "svm": model = make_pipeline( StandardScaler(), SVC(**model_kwargs, random_state=random_state) ) - elif model_name == "mlp": + elif model == "mlp": model = make_pipeline( StandardScaler(), MLPClassifier(**model_kwargs, random_state=random_state) )