forked from SuperMedIntel/Medical-SAM-Adapter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
94 lines (74 loc) · 2.72 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
""" train and test dataset
author jundewu
"""
import os
import sys
import pickle
import cv2
from skimage import io
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
import pandas as pd
from skimage.transform import rotate
from utils import random_click
import random
from monai.transforms import LoadImaged, Randomizable,LoadImage
class ISIC2016(Dataset):
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
df = pd.read_csv(os.path.join(data_path, 'ISBI2016_ISIC_Part3B_' + mode + '_GroundTruth.csv'), encoding='gbk')
self.name_list = df.iloc[:,1].tolist()
self.label_list = df.iloc[:,2].tolist()
self.data_path = data_path
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.name_list)
def __getitem__(self, index):
# if self.mode == 'Training':
# point_label = random.randint(0, 1)
# inout = random.randint(0, 1)
# else:
# inout = 1
# point_label = 1
inout = 1
point_label = 1
"""Get the images"""
name = self.name_list[index]
img_path = os.path.join(self.data_path, name)
mask_name = self.label_list[index]
msk_path = os.path.join(self.data_path, mask_name)
img = Image.open(img_path).convert('RGB')
mask = Image.open(msk_path).convert('L')
# if self.mode == 'Training':
# label = 0 if self.label_list[index] == 'benign' else 1
# else:
# label = int(self.label_list[index])
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
pt = random_click(np.array(mask) / 255, point_label, inout)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'label': mask,
'p_label':point_label,
'pt':pt,
'image_meta_dict':image_meta_dict,
}