Skip to content

Commit

Permalink
hotfix binary classfication tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Ramith Hettiarachchi committed Mar 3, 2023
1 parent 79c18f4 commit 2beeb35
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions modules/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,16 +270,25 @@ def test_model_in_groups(model, data, criterion, n_classes = 0, device = 'cpu',
test_acc_table = wandb.Table(data=test_acc_data, columns=["class_name", "accuracy"])

## F1 Table
test_f1_data = [[name, prec] for (name, prec) in zip(class_names, t_f1)]
test_f1_table = wandb.Table(data=test_f1_data, columns=["class_name", "f1"])
if(n_classes != 2):
test_f1_data = [[name, prec] for (name, prec) in zip(class_names, t_f1)]
test_f1_table = wandb.Table(data=test_f1_data, columns=["class_name", "f1"])
else:
test_f1_table = wandb.Table()

## Precision Table
test_precision_data = [[name, prec] for (name, prec) in zip(class_names, t_precision)]
test_precision_table = wandb.Table(data=test_precision_data, columns=["class_name", "precision"])
if(n_classes != 2):
test_precision_data = [[name, prec] for (name, prec) in zip(class_names, t_precision)]
test_precision_table = wandb.Table(data=test_precision_data, columns=["class_name", "precision"])
else:
test_precision_table = wandb.Table()

## Recall Table
test_recall_data = [[name, prec] for (name, prec) in zip(class_names, t_recall)]
test_recall_table = wandb.Table(data=test_recall_data, columns=["class_name", "recall"])
if(n_classes != 2):
test_recall_data = [[name, prec] for (name, prec) in zip(class_names, t_recall)]
test_recall_table = wandb.Table(data=test_recall_data, columns=["class_name", "recall"])
else:
test_recall_table = wandb.Table()

test_confusion_matrix, saved_confmatrix = get_confusion_matrix(test_preds, test_labels, n_classes, class_names)

Expand Down

0 comments on commit 2beeb35

Please sign in to comment.