Skip to content

Commit

Permalink
Add balanced models.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Oct 20, 2023
1 parent bc8df48 commit 39379b5
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
9 changes: 8 additions & 1 deletion params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,29 +156,36 @@ datasets:

models:
logistic_regression:
model: logistic_regression
solver: liblinear

logistic_regression_balanced:
model: logistic_regression
solver: liblinear
class_weight: balanced

gradient_boosting_classifier:
model: gradient_boosting_classifier
n_estimators: 40
min_samples_split: 6
max_depth: 2

knn:
model: knn
n_neighbors: 5
weights: uniform

svm:
model: svm
kernel: rbf

svm_balanced:
model: svm
kernel: rbf
class_weight: balanced

mlp: {}
mlp:
model: mlp

valuation_methods:
random:
Expand Down
6 changes: 2 additions & 4 deletions scripts/evaluate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,8 @@ def _evaluate_metrics(
/ str(repetition_id)
/ valuation_method_name
)
if (
False
and os.path.exists(output_dir / f"{metric_name}.csv")
and os.path.exists(output_dir / f"{metric_name}.curve.csv")
if os.path.exists(output_dir / f"{metric_name}.csv") and os.path.exists(
output_dir / f"{metric_name}.curve.csv"
):
return logger.info(f"Metric data exists in '{output_dir}'. Skipping...")

Expand Down
11 changes: 6 additions & 5 deletions src/re_classwise_shapley/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,27 @@ def instantiate_model(
The instantiated model.
"""
random_state = np.random.RandomState(seed)
model = model_kwargs.pop("model")

if model_name == "gradient_boosting_classifier":
if model == "gradient_boosting_classifier":
model = make_pipeline(
GradientBoostingClassifier(**model_kwargs, random_state=random_state)
)
elif model_name == "logistic_regression":
elif model == "logistic_regression":
model = make_pipeline(
StandardScaler(),
LogisticRegression(**model_kwargs, random_state=random_state),
)
elif model_name == "knn":
elif model == "knn":
model = make_pipeline(
StandardScaler(),
KNeighborsClassifier(**model_kwargs),
)
elif model_name == "svm":
elif model == "svm":
model = make_pipeline(
StandardScaler(), SVC(**model_kwargs, random_state=random_state)
)
elif model_name == "mlp":
elif model == "mlp":
model = make_pipeline(
StandardScaler(), MLPClassifier(**model_kwargs, random_state=random_state)
)
Expand Down

0 comments on commit 39379b5

Please sign in to comment.