diff --git a/feature_importance/subgroup/subgroup_detection.py b/feature_importance/subgroup/subgroup_detection.py index e9fca6f..e435dcf 100644 --- a/feature_importance/subgroup/subgroup_detection.py +++ b/feature_importance/subgroup/subgroup_detection.py @@ -42,7 +42,7 @@ def detect_subgroups(mdi, rankings, num_clusters, p = 0.9, k = None, condensed_distance_matrix = squareform(distance_matrix) # perform hierarchical clustering - linkage_matrix = linkage(condensed_distance_matrix, method="ward") + linkage_matrix = linkage(condensed_distance_matrix, method=linkage_method) clustergrid = sns.clustermap(mdi, row_linkage=linkage_matrix, col_cluster=False, cmap='viridis', cbar_pos = (1, 0.2, 0.05, 0.5))