From 2beeb35f9fda418982495aa11c0d637a0c204c50 Mon Sep 17 00:00:00 2001 From: Ramith Hettiarachchi Date: Fri, 3 Mar 2023 04:45:26 -0500 Subject: [PATCH] hotfix binary classfication tasks --- modules/test_utils.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/modules/test_utils.py b/modules/test_utils.py index fdb9e2d..133e6c3 100644 --- a/modules/test_utils.py +++ b/modules/test_utils.py @@ -270,16 +270,25 @@ def test_model_in_groups(model, data, criterion, n_classes = 0, device = 'cpu', test_acc_table = wandb.Table(data=test_acc_data, columns=["class_name", "accuracy"]) ## F1 Table - test_f1_data = [[name, prec] for (name, prec) in zip(class_names, t_f1)] - test_f1_table = wandb.Table(data=test_f1_data, columns=["class_name", "f1"]) + if(n_classes != 2): + test_f1_data = [[name, prec] for (name, prec) in zip(class_names, t_f1)] + test_f1_table = wandb.Table(data=test_f1_data, columns=["class_name", "f1"]) + else: + test_f1_table = wandb.Table() ## Precision Table - test_precision_data = [[name, prec] for (name, prec) in zip(class_names, t_precision)] - test_precision_table = wandb.Table(data=test_precision_data, columns=["class_name", "precision"]) + if(n_classes != 2): + test_precision_data = [[name, prec] for (name, prec) in zip(class_names, t_precision)] + test_precision_table = wandb.Table(data=test_precision_data, columns=["class_name", "precision"]) + else: + test_precision_table = wandb.Table() ## Recall Table - test_recall_data = [[name, prec] for (name, prec) in zip(class_names, t_recall)] - test_recall_table = wandb.Table(data=test_recall_data, columns=["class_name", "recall"]) + if(n_classes != 2): + test_recall_data = [[name, prec] for (name, prec) in zip(class_names, t_recall)] + test_recall_table = wandb.Table(data=test_recall_data, columns=["class_name", "recall"]) + else: + test_recall_table = wandb.Table() test_confusion_matrix, saved_confmatrix = get_confusion_matrix(test_preds, test_labels, n_classes, class_names)