-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
99 lines (81 loc) · 3.33 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import cv2
import torch
import pickle
import numpy as np
from PIL import Image
import albumentations as A
from models import Transformer
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from albumentations.pytorch import ToTensorV2
from utils.cam_helper import get_swin_cam
from pytorch_grad_cam.utils.image import show_cam_on_image
class TCGADataset4Inf(Dataset):
def __init__(self, data, image_size=1024):
self.img = np.concatenate([data['train']['x_path'], data['test']['x_path']])
self.grade = np.concatenate([data['train']['grade'], data['test']['grade']])
self.num_classes = len(set(self.grade))
self.test_transform = A.Compose([
A.Resize(height=image_size, width=image_size),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
def __len__(self):
return len(self.grade)
def __getitem__(self, index):
img = np.array(Image.open(self.img[index]).convert('RGB'))
grade = torch.tensor(self.grade[index]).long()
img = self.test_transform(image=img)['image']
return img, grade
def save_img(img, cam, root, label, idx, suffix):
# define the path
label = label.detach().cpu().numpy()[0]
path = os.path.join(root, f"grade_{label}")
os.makedirs(path, exist_ok=True)
# to numpy
img = img.permute(0, 2, 3, 1).detach().cpu().numpy()[0]
img = (img - np.min(img)) / np.ptp(img)
cam = cam.detach().cpu().numpy()[0]
pixel_mask = cam > 0.5
pos_region = img * pixel_mask[..., np.newaxis]
neg_region = img * ~pixel_mask[..., np.newaxis]
img_with_cam = show_cam_on_image(img, cam, use_rgb=False)
# save the image
cv2.imwrite(os.path.join(path, f"{idx}_cam_{suffix}.jpg"), img_with_cam)
cv2.imwrite(os.path.join(path, f"{idx}_pos_{suffix}.jpg"), np.uint8(255 * pos_region))
cv2.imwrite(os.path.join(path, f"{idx}_neg_{suffix}.jpg"), np.uint8(255 * neg_region))
cv2.imwrite(os.path.join(path, f"{idx}_img.jpg"), np.uint8(255 * img))
def main():
torch.manual_seed(42)
np.random.seed(42)
save_path = "./results/Ours"
os.makedirs(save_path, exist_ok=True)
# load data file
data_cv = pickle.load(open("./dataset/my_split_dropGradeNaN.pkl", 'rb'))
data_cv_split = data_cv['splits'][0]
# dataset
dataset = TCGADataset4Inf(data_cv_split)
loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
# model init
model = Transformer(image_size=1024, num_classes=dataset.num_classes, pretrained="WinKawaks/vit-tiny-patch16-224", patch_size=16)
# load model
checkpoint = torch.load("./weights/fold_0.pth")
model.load_state_dict(checkpoint)
model = model.cuda()
# inference
idx = 0
for img, grade in loader:
print(f"\rProcessing {idx}th image", end='', flush=True)
img, grade = img.cuda(non_blocking=True), grade.cuda(non_blocking=True)
cam = get_swin_cam(model, img, grade, smooth=True)
_, pred, _ = model(img)
pred = F.softmax(pred, dim=1)
pred = pred.argmax(dim=1)
suffix = pred.eq(grade).item()
save_img(img, cam, save_path, grade, idx, suffix)
idx += 1
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
main()