-
Notifications
You must be signed in to change notification settings - Fork 0
/
DataProvider.py
81 lines (69 loc) · 2.8 KB
/
DataProvider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from torch.utils.data import Dataset, DataLoader
import Transforms_utils
import config
import warnings
warnings.filterwarnings('ignore')
import os
import scipy
from scipy import io
import skimage
import matplotlib.pyplot as plt
import pdb
from skimage.transform import resize
from skimage.segmentation import slic
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage import color
import pdb
import PIL
from sklearn import cluster
import pdb
import glob
class DataProvider(Dataset):
def __init__(self,transformList=None,data='val'):
self.data=data
if data=='train':
self.image_paths=config.data_ade.trainData
self.gt_map_paths=config.data_ade.train_pixel_map
if data=='val':
self.image_paths=config.data_ade.valData
self.gt_map_paths=config.data_ade.val_pixel_map
self.image_name_list=glob.glob(self.image_paths+'/*.jpg')
self.transformList=transformList
def __len__(self):
return len(self.image_name_list)
def __getitem__(self,idx):
im_name=self.image_name_list[idx]
gt_name=self.gt_map_paths+'/'+im_name.split(os.sep)[-1][:-3]+'png'
image=PIL.Image.open(im_name).convert('RGB')
gt=PIL.Image.open(gt_name)
sample={'image':image,'gt':gt,'segments':None}
if self.transformList:
trans= Transforms_utils.transformImage(sample)
sample=trans(self.transformList)
return sample
class DataProviderUtil():
def __init__(self, num_spix=500,batch_size=3):
self.transformList=[
Transforms_utils.RandomHorizontalFlip(p=0.5),
Transforms_utils.RandomVerticalFlip(p=0.5),
Transforms_utils.Rescale(256),
Transforms_utils.RandomCrop(224),
Transforms_utils.imagespixels(n_segments=num_spix,sigma=2),
Transforms_utils.ToTensor()
]
self.getData={'train': DataLoader(DataProvider(transformList=self.transformList,data='train'), batch_size=batch_size, shuffle=True, num_workers=8,pin_memory=True),
'val': DataLoader(DataProvider(transformList=self.transformList,data='val'), batch_size= batch_size ,shuffle=True, num_workers=8,pin_memory=True)}
if __name__=='__main__':
inst=DataProviderUtil()
for i,sample in enumerate(inst.getData['train']):
a=sample['image'][0]
b=sample['gt'][0]
segments=sample['segments'][0]
plt.subplot(1,3,1)
plt.imshow(a.numpy().transpose(1,2,0))
plt.subplot(1,3,2)
plt.imshow(b)
plt.subplot(1,3,3)
plt.imshow(mark_boundaries(a.numpy().transpose(1,2,0),segments.numpy()))
plt.show()