From 75ea65325e4a9ba91f36879e217230d4f73a260c Mon Sep 17 00:00:00 2001 From: Ramith Hettiarachchi Date: Sat, 11 Mar 2023 00:00:14 -0500 Subject: [PATCH] organize metrics --- modules/dataloaders.py | 4 ++-- modules/test_utils.py | 38 +++++++++++++++++++------------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/modules/dataloaders.py b/modules/dataloaders.py index 017b7f9..1a9f398 100644 --- a/modules/dataloaders.py +++ b/modules/dataloaders.py @@ -10,7 +10,7 @@ import matplotlib.pyplot as plt -def get_bacteria_dataloaders(img_size, train_batch_size ,torch_seed=10, label_type = "class", balanced_mode = False, expand_channels = False, data_dir= '/n/holyscratch01/wadduwage_lab/D2NN_QPM_classification/datasets/bacteria_np'): +def get_bacteria_dataloaders(img_size, train_batch_size ,torch_seed=10, label_type = "class", balanced_mode = False, expand_channels = False, data_dir= '/n/holyscratch01/wadduwage_lab/ramith/bacteria_processed'): ''' Function to return train, validation QPM dataloaders Args: @@ -56,7 +56,7 @@ def get_bacteria_dataloaders(img_size, train_batch_size ,torch_seed=10, label_ty dataset_sizes = {'train': len(train_loader)*train_batch_size, 'val': len(val_loader)*32, 'test': len(test_loader)*128} return train_loader, val_loader, test_loader, dataset_sizes -def get_bacteria_eval_dataloaders(img_size, test_batch_size ,torch_seed=10, label_type = "class", expand_channels = False, data_dir= '/n/holyscratch01/wadduwage_lab/D2NN_QPM_classification/datasets/bacteria_np', isolate_class = False): +def get_bacteria_eval_dataloaders(img_size, test_batch_size ,torch_seed=10, label_type = "class", expand_channels = False, data_dir= '/n/holyscratch01/wadduwage_lab/ramith/bacteria_processed', isolate_class = False): ''' Function to return train, validation QPM dataloaders Args: diff --git a/modules/test_utils.py b/modules/test_utils.py index 133e6c3..2238aef 100644 --- a/modules/test_utils.py +++ b/modules/test_utils.py @@ -146,21 +146,21 @@ def test_model_in_groups(model, data, criterion, n_classes = 0, device = 'cpu', - if(n_classes == 2): ## Calculate *binary* classification metrics - test_accuracy = Accuracy(task="binary", average = None, num_classes = n_classes, compute_on_step=False).to(device) - - test_f1 = F1Score(task="binary", compute_on_step=False).to(device) - test_precision = Precision(task="binary", compute_on_step=False).to(device) - test_recall = Recall(task="binary", compute_on_step=False).to(device) - test_specificity = Specificity(task="binary", compute_on_step=False).to(device) - # test_auroc = AUROC(task="binary").to(device) - else: - test_accuracy = Accuracy(task="multiclass", average = None, num_classes = n_classes, compute_on_step=False).to(device) - - test_f1 = F1Score(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device) - test_precision = Precision(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device) - test_recall = Recall(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device) - test_specificity = Specificity(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device) + # if(n_classes == 2): ## Calculate *binary* classification metrics + # test_accuracy = Accuracy(task="binary", average = None, num_classes = n_classes, compute_on_step=False).to(device) + + # test_f1 = F1Score(task="binary", compute_on_step=False).to(device) + # test_precision = Precision(task="binary", compute_on_step=False).to(device) + # test_recall = Recall(task="binary", compute_on_step=False).to(device) + # test_specificity = Specificity(task="binary", compute_on_step=False).to(device) + # # test_auroc = AUROC(task="binary").to(device) + # else: + test_accuracy = Accuracy(task="multiclass", average = None, num_classes = n_classes, compute_on_step=False).to(device) + + test_f1 = F1Score(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device) + test_precision = Precision(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device) + test_recall = Recall(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device) + test_specificity = Specificity(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device) # test_auroc = MulticlassAUROC(num_classes = n_classes, average="macro").to(device) test_preds = torch.empty([0, ]) @@ -250,10 +250,10 @@ def test_model_in_groups(model, data, criterion, n_classes = 0, device = 'cpu', # print("Sklearn ROC AUC (ovo)", roc_auc_score(test_labels_.to(dtype = torch.int32), pred_probs, multi_class= 'ovo')) t_acc = test_accuracy.compute().tolist() - t_f1 = float(test_f1.compute()) if n_classes == 2 else test_f1.compute().tolist() - t_precision = float(test_precision.compute()) if n_classes == 2 else test_precision.compute().tolist() - t_recall = float(test_recall.compute()) if n_classes == 2 else test_recall.compute().tolist() - t_specificity = float(test_specificity.compute()) if n_classes == 2 else test_specificity.compute().tolist() + t_f1 = float(test_f1.compute()) if n_classes == 1 else test_f1.compute().tolist() + t_precision = float(test_precision.compute()) if n_classes == 1 else test_precision.compute().tolist() + t_recall = float(test_recall.compute()) if n_classes == 1 else test_recall.compute().tolist() + t_specificity = float(test_specificity.compute()) if n_classes == 1 else test_specificity.compute().tolist() print("test accuracy",t_acc) print("test f1",t_f1)