-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Artur Jurgas
committed
Jul 11, 2023
0 parents
commit 9b16e96
Showing
28 changed files
with
4,013 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
libhaa.egg-info | ||
**/__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
This repository contains the library used in the master thesis "Deep learning algorithms for quality assessment of whole slide images". | ||
|
||
Contained it the base library for augmentation as of publication. | ||
|
||
Additionally code for the classification module and segmentation is included. | ||
|
||
For easy installation and reproduction of results, `libhaa` library is installable by running | ||
|
||
```bash | ||
pip install -e . | ||
``` | ||
|
||
For segmentation inference we provide the weights in the `Releases` github panel. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
config = {'seed': 2021, | ||
'trainer': { | ||
'max_epochs': 1000, | ||
# 'gpus': 1, | ||
'accumulate_grad_batches': 1, | ||
# 'progress_bar_refresh_rate': 1, | ||
# 'fast_dev_run': False, | ||
# 'num_sanity_val_steps': 0, | ||
# 'resume_from_checkpoint': None, | ||
# 'default_root_dir': "/home/jarartur/Workspaces/HAAv2/temp/pl_logs", | ||
'log_every_n_steps': 50 | ||
}, | ||
'data': { | ||
'dataset_name': 'haav2', | ||
'batch_size': 256, | ||
'img_size': [224, 224], | ||
'num_workers': 14, | ||
# 'train_data_csv': '/home/jarartur/Workspaces/HAAv2/classification/tests.csv', | ||
'val_data_csv': '/net/pr2/projects/plgrid/plggmiadl/arjurgas/Datasets/HAAv2/ResNet_classification/bp_annot_val.csv', | ||
# 'val_data_csv': '/net/pr2/projects/plgrid/plggmiadl/arjurgas/Datasets/HAAv2/ResNet_classification/acr_val.csv', | ||
# 'test_data_csv': '/net/pr2/projects/plgrid/plggmiadl/arjurgas/Datasets/HAAv2/ResNet_classification/anh_annot.csv', | ||
# 'test_data_csv': '/net/pr2/projects/plgrid/plggmiadl/arjurgas/Datasets/HAAv2/ResNet_classification/acr_test_annot.csv', | ||
'test_data_csv': '/net/pr2/projects/plgrid/plggmiadl/arjurgas/Datasets/HAAv2/ResNet_classification/bp_annot_test.csv', | ||
# 'test_data_csv': '/net/pr2/projects/plgrid/plggmiadl/arjurgas/Datasets/HAAv2/ResNet_classification/acr_test.csv', | ||
}, | ||
'model':{ | ||
# 'backbone_init': { | ||
# 'model': 'efficientnet_v2_s_in21k', | ||
# 'nclass': 0, # do not change this | ||
# 'pretrained': True, | ||
# }, | ||
'optimizer_init':{ | ||
# 'class_path': 'torch.optim.SGD', | ||
'class_path': 'torch.optim.AdamW', | ||
'init_args': { | ||
'lr': 0.01, | ||
# 'momentum': 0.95, | ||
# 'weight_decay': 0.0005 | ||
} | ||
}, | ||
'lr_scheduler_init':{ | ||
# 'class_path': 'torch.optim.lr_scheduler.CosineAnnealingLR', | ||
'class_path': 'torch.optim.lr_scheduler.ExponentialLR', | ||
'init_args':{ | ||
# 'T_max': 0 # no need to change this | ||
'gamma': 0.97 | ||
}, | ||
'step': 'epoch' | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from typing import Tuple, Type, Any | ||
|
||
from pytorch_lightning import LightningDataModule | ||
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS | ||
|
||
import torch | ||
from torchvision import transforms | ||
from torch.utils.data import DataLoader, Dataset | ||
|
||
from torchvision.models import ResNet50_Weights | ||
|
||
import pandas as pd | ||
from PIL import Image | ||
|
||
class HAA_Dataset(Dataset): | ||
def __init__(self, root: str, transform: Type[Any]): | ||
super().__init__() | ||
self.root = root | ||
self.transform = transform | ||
|
||
self.data = pd.read_csv(root) | ||
self.targets = [] | ||
|
||
def __len__(self): | ||
return self.data.shape[0] | ||
|
||
def __getitem__(self, idx): | ||
file_path, *labels = self.data.iloc[idx, :].to_list() | ||
img = Image.open(file_path) | ||
|
||
img = self.transform(img) | ||
labels = torch.tensor(labels, dtype=torch.float32) | ||
|
||
return img, labels | ||
|
||
class HAA_DataModule(LightningDataModule): | ||
def __init__(self, | ||
dataset_name: str, | ||
img_size: Tuple[int, int] = (224, 224), | ||
batch_size: int = 2, | ||
num_workers: int = 4, | ||
train_data_csv: str = 'data.csv', | ||
val_data_csv: str = 'data.csv', | ||
test_data_csv: str = 'data.csv',): | ||
""" | ||
Base Data Module | ||
:arg | ||
Dataset: Enter Dataset | ||
batch_size: Enter batch size | ||
num_workers: Enter number of workers | ||
size: Enter resized image | ||
data_root: Enter root data folder name | ||
valid_ratio: Enter valid dataset ratio | ||
""" | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.dataset = HAA_Dataset | ||
self.train_transform, self.test_transform = self.get_transforms() | ||
|
||
def prepare_data(self) -> None: | ||
data = pd.read_csv(self.hparams.train_data_csv) | ||
file_path, *labels = data.iloc[0, :].to_list() | ||
num_samples = data.shape[0] | ||
self.num_classes = len(labels) | ||
self.num_step = num_samples // self.hparams.batch_size | ||
|
||
self.class_quant = [] | ||
num_samples_nonempty = 0 | ||
for i in range(self.num_classes): | ||
num = data.iloc[:, i+1].sum() | ||
self.class_quant += [num] | ||
num_samples_nonempty += num | ||
self.class_weights = [1-(quant/num_samples_nonempty) for quant in self.class_quant] | ||
|
||
print('-' * 50) | ||
print('* {} dataset class num: {}'.format(self.hparams.dataset_name, self.num_classes)) | ||
print('* {} dataset class quantity: {}'.format(self.hparams.dataset_name, self.class_quant)) | ||
print('* {} dataset class weights: {}'.format(self.hparams.dataset_name, self.class_weights)) | ||
print('* {} dataset len: {}'.format(self.hparams.dataset_name, num_samples)) | ||
print('-' * 50) | ||
|
||
def setup(self, stage: str = None): | ||
if stage in (None, 'fit'): | ||
self.train_ds = self.dataset(root=self.hparams.train_data_csv, transform=self.train_transform) | ||
self.valid_ds = self.dataset(root=self.hparams.val_data_csv, transform=self.test_transform) | ||
|
||
elif stage in (None, 'test', 'predict'): | ||
self.test_ds = self.dataset(root=self.hparams.test_data_csv, transform=self.test_transform) | ||
|
||
def train_dataloader(self) -> TRAIN_DATALOADERS: | ||
return DataLoader(self.train_ds, batch_size=self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers) | ||
|
||
def val_dataloader(self) -> EVAL_DATALOADERS: | ||
return DataLoader(self.valid_ds, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers) | ||
|
||
def test_dataloader(self) -> EVAL_DATALOADERS: | ||
return DataLoader(self.test_ds, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers) | ||
|
||
def predict_dataloader(self) -> EVAL_DATALOADERS: | ||
return DataLoader(self.test_ds, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers) | ||
|
||
def get_transforms(self): | ||
|
||
# # as we will use pretrained net on imagenet: | ||
# mean = [0.485, 0.456, 0.406] | ||
# std = [0.229, 0.224, 0.225] | ||
|
||
size = self.hparams.img_size | ||
|
||
train = ResNet50_Weights.DEFAULT.transforms() | ||
test = ResNet50_Weights.DEFAULT.transforms() | ||
|
||
# train = transforms.Compose([ | ||
# transforms.Resize(size), | ||
# transforms.Pad(4, padding_mode='reflect'), | ||
# transforms.RandomCrop(size), | ||
# transforms.RandomHorizontalFlip(), | ||
# transforms.ToTensor(), | ||
# transforms.Normalize(mean=mean, std=std) | ||
# ]) | ||
# test = transforms.Compose([ | ||
# transforms.Resize(size), | ||
# transforms.CenterCrop(size), | ||
# transforms.ToTensor(), | ||
# transforms.Normalize(mean=mean, std=std) | ||
# ]) | ||
return train, test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from pytorch_lightning import LightningModule | ||
from pytorch_lightning.cli import instantiate_class | ||
from torchmetrics import MetricCollection, Accuracy, Recall, F1Score | ||
from torchmetrics.classification import MultilabelROC, MultilabelAUROC | ||
|
||
from torchvision.models import resnet50, ResNet50_Weights | ||
|
||
import matplotlib.pyplot as plt | ||
from palettable.tableau import TableauMedium_10 | ||
import mpltex | ||
|
||
|
||
import os | ||
import numpy as np | ||
|
||
class BaseVisionSystem(LightningModule): | ||
def __init__(self, num_labels: int, num_step: int, class_weights: list, max_epochs: int, | ||
optimizer_init: dict, lr_scheduler_init: dict): | ||
""" Define base vision classification system | ||
backbone_init: dict, | ||
:arg | ||
backbone_init: feature extractor | ||
num_labels: number of class of dataset | ||
num_step: number of step | ||
gpus: gpus id | ||
max_epoch: max number of epoch | ||
optimizer_init: optimizer class path and init args | ||
lr_scheduler_init: learning rate scheduler class path and init args | ||
""" | ||
super().__init__() | ||
self.save_hyperparameters() | ||
|
||
# step 2. define model | ||
# self.backbone = torch.hub.load('hankyul2/EfficientNetV2-pytorch', **backbone_init) | ||
self.backbone = resnet50(weights=ResNet50_Weights.DEFAULT) | ||
self.fc = nn.Linear(self.backbone.fc.in_features, num_labels) | ||
self.backbone.fc = nn.Sequential() # quick and dirty hack to keep compatible with other models | ||
# and have separate fc layer | ||
# and have lower lr for the rest of the model | ||
# self.fc = nn.Linear(self.backbone.out_channels, num_labels) | ||
|
||
# step 2.5 freeze the model | ||
self.backbone.requires_grad_(False) | ||
# self.backbone.layer4[-1].conv1.requires_grad_(True) | ||
# self.backbone.layer4[-1].bn1.requires_grad_(True) | ||
# self.backbone.layer4[-1].conv2.requires_grad_(True) | ||
# self.backbone.layer4[-1].bn2.requires_grad_(True) | ||
# self.backbone.layer4[-1].conv3.requires_grad_(True) | ||
# self.backbone.layer4[-1].bn3.requires_grad_(True) | ||
# self.backbone.layer4[-1].relu.requires_grad_(True) | ||
# Im not sure if needed but just in case not to break the grad flow | ||
self.backbone.fc.requires_grad_(True) | ||
self.fc.requires_grad_(True) | ||
# self.backbone.blocks[-1].requires_grad_(False) | ||
# self.fc.requires_grad_(True) | ||
|
||
print('-'*50) | ||
for name, param in self.backbone.named_parameters(): | ||
print(name,param.requires_grad) | ||
print('-'*50) | ||
|
||
# step 3. define lr tools (optimizer, lr scheduler) | ||
self.optimizer_init_config = optimizer_init | ||
self.lr_scheduler_init_config = lr_scheduler_init | ||
|
||
class_weights = torch.tensor(self.hparams.class_weights, dtype=torch.float32) | ||
self.criterion = nn.BCEWithLogitsLoss(weight=class_weights) | ||
|
||
# step 4. define metric | ||
metrics = MetricCollection({ | ||
'accuracy': Accuracy(task='multilabel', num_labels=num_labels), | ||
'recall': Recall(task='multilabel', num_labels=num_labels), | ||
'f1': F1Score(task='multilabel', num_labels=num_labels), | ||
}) | ||
self.train_metric = metrics.clone(prefix='train/') | ||
self.valid_metric = metrics.clone(prefix='valid/') | ||
self.test_metric = metrics.clone(prefix='test/') | ||
|
||
def forward(self, x): | ||
return self.fc(self.backbone(x)) | ||
|
||
def on_test_start(self): | ||
self.roc = MultilabelROC(num_labels=self.hparams.num_labels, thresholds=None) | ||
self.auc = MultilabelAUROC(num_labels=self.hparams.num_labels, average='none') | ||
self.preds_accum = [] | ||
self.targets_accum = [] | ||
|
||
def on_test_end(self): | ||
preds = torch.cat(self.preds_accum) | ||
targets = torch.cat(self.targets_accum) | ||
fpr, tpr, thresholds = self.roc(preds, targets.int()) | ||
auc = self.auc(preds, targets.int()) | ||
save_path = self.logger.log_dir | ||
num_labels = self.hparams.num_labels | ||
|
||
print('-'*50) | ||
print(f'fpr: {fpr[0].shape}') | ||
print(f'tpr: {tpr[0].shape}') | ||
print('-'*50) | ||
|
||
labels = ['air', 'dust', 'tissue', 'ink', 'marker', 'focus'] | ||
|
||
lowest_shape = 10_000_000 | ||
for f in fpr: | ||
print(f.shape[0]) | ||
if lowest_shape > f.shape[0]: | ||
lowest_shape = f.shape[0] | ||
|
||
fpr = [f[:lowest_shape] for f in fpr] | ||
tpr = [t[:lowest_shape] for t in tpr] | ||
|
||
fpr_mean = torch.stack(fpr).mean(dim=0) | ||
tpr_mean = torch.stack(tpr).mean(dim=0) | ||
auc_mean = torch.mean(auc) | ||
|
||
plot(fpr, tpr, save_path, num_labels, labels, auc, auc_mean, fpr_mean, tpr_mean) | ||
|
||
def training_step(self, batch, batch_idx, optimizer_idx=None): | ||
return self.shared_step(batch, self.train_metric, 'train', add_dataloader_idx=False) | ||
|
||
def validation_step(self, batch, batch_idx, dataloader_idx=None): | ||
return self.shared_step(batch, self.valid_metric, 'valid', add_dataloader_idx=True) | ||
|
||
def test_step(self, batch, batch_idx, dataloader_idx=None): | ||
return self.shared_step(batch, self.test_metric, 'test', add_dataloader_idx=True) | ||
|
||
def shared_step(self, batch, metric, mode, add_dataloader_idx): | ||
x, y = batch | ||
loss, y_hat = self.compute_loss(x, y) if mode == 'train' else self.compute_loss_eval(x, y) | ||
self.log_dict({f'{mode}/loss': loss}, add_dataloader_idx=add_dataloader_idx) | ||
|
||
metrics = metric(y_hat, y) | ||
self.log_dict(metrics, add_dataloader_idx=add_dataloader_idx, prog_bar=True) | ||
|
||
if mode == 'valid': | ||
self.log("hp_metric", metrics['valid/f1']) | ||
|
||
if mode == 'test': | ||
self.preds_accum += [y_hat.detach().cpu()] | ||
self.targets_accum += [y.detach().cpu()] | ||
|
||
return loss | ||
|
||
def compute_loss(self, x, y): | ||
return self.compute_loss_eval(x, y) | ||
|
||
def compute_loss_eval(self, x, y): | ||
y_hat = self.fc(self.backbone(x)) | ||
loss = self.criterion(y_hat, y) | ||
return loss, y_hat | ||
|
||
def configure_optimizers(self): | ||
optimizer = instantiate_class([ | ||
{'params': self.backbone.parameters(), 'lr': self.optimizer_init_config['init_args']['lr'] * 0.1}, | ||
{'params': self.fc.parameters()}, | ||
], self.optimizer_init_config) | ||
|
||
lr_scheduler = { | ||
'scheduler': instantiate_class(optimizer, self.update_and_get_lr_scheduler_config()), | ||
'interval': self.lr_scheduler_init_config['step'] | ||
} | ||
return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler} | ||
|
||
def update_and_get_lr_scheduler_config(self): | ||
if 'T_max' in self.lr_scheduler_init_config['init_args']: | ||
self.lr_scheduler_init_config['init_args']['T_max'] = self.hparams.num_step * self.hparams.max_epochs | ||
return self.lr_scheduler_init_config | ||
|
||
@mpltex.acs_decorator | ||
def plot(fpr, tpr, save_path, num_labels, labels, auc, auc_mean, fpr_mean, tpr_mean): | ||
|
||
fig, ax = plt.subplots(1, 1, figsize=(5.3, 3)) | ||
linestyles = mpltex.linestyle_generator(colors = TableauMedium_10.mpl_colors) | ||
|
||
for i in range(num_labels): | ||
ax.plot(fpr[i], tpr[i], linewidth=1, label=f'$ {labels[i]}-{auc[i]:.3f} $', **next(linestyles), markevery=1500) | ||
|
||
ax.plot(fpr_mean, tpr_mean, linewidth=3, label=f'$ all\ classes-{auc_mean:.3f} $', markevery=1500) | ||
|
||
ax.set_xlabel("$ False\ Positive\ Rate $") | ||
ax.set_ylabel("$ True\ Positive\ Rate $") | ||
# plt.title("ROC") | ||
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left') | ||
# ax.legend() | ||
fig.tight_layout(pad=0.1) | ||
fig.savefig(os.path.join(save_path, "ROC.png")) |
Oops, something went wrong.