-
Notifications
You must be signed in to change notification settings - Fork 0
/
simple_folder_dataloader.py
93 lines (80 loc) · 3.33 KB
/
simple_folder_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os, sys
import importlib
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import cv2
from glob import glob
class SimpleFolderDataset(Dataset):
"""
Images and labels are arranged in different folders,
with image name for correspondence
(location after sorted should be consistent):
dataroot/
imgs/
000.jpg
001.jpg
lbls/
000_lbl.png
001_lbl.png
...
Returns:
{"img": image, "label": label}
"""
def __init__(self, opt_dataset) -> None:
super().__init__()
# parse used arguments, explicit parsing is easier for debug
self.dataroot_img = opt_dataset['dataroot_img']
self.dataroot_lbl = opt_dataset['dataroot_lbl']
self.phase = opt_dataset['phase']
self.img_exts = opt_dataset['img_exts'] # list, 'cause input may be different formats
self.lbl_exts = opt_dataset['lbl_exts']
augment_opt = opt_dataset['augment']
augment_type = augment_opt.pop('augment_type')
if self.phase == 'train':
self.augment = importlib.import_module(f'data_augment.{augment_type}').train_augment(**augment_opt)
elif self.phase == 'valid':
self.augment = importlib.import_module(f'data_augment.{augment_type}').val_augment(**augment_opt)
# collect all images and labels recursively under given folders
img_paths = list()
for img_ext in self.img_exts:
img_paths += list(glob(os.path.join(self.dataroot_img, f'*.{img_ext}')))
self.img_paths = sorted(img_paths)
lbl_paths = list()
for lbl_ext in self.lbl_exts:
lbl_paths += list(glob(os.path.join(self.dataroot_lbl, f'*.{lbl_ext}')))
self.lbl_paths = sorted(lbl_paths)
def __getitem__(self, index):
cur_img_path = self.img_paths[index]
cur_lbl_path = self.lbl_paths[index]
cur_img = cv2.imread(cur_img_path)
# TODO: currently only support 1-channel int map
# should support color label map handling also
cur_lbl = cv2.imread(cur_lbl_path, cv2.IMREAD_UNCHANGED)
img_lbl_aug = self.augment(image=cur_img, mask=cur_lbl)
img_aug, lbl_aug = img_lbl_aug['image'], img_lbl_aug['mask'].to(torch.int64)
# simple dataset cannot cover mixup/contrast etc. which need 2 or more images to return
output_dict = {
"img" : img_aug,
"label": lbl_aug,
"img_path": cur_img_path,
"lbl_path": cur_lbl_path
}
return output_dict
def __len__(self):
return len(self.img_paths)
def SimpleFolderDataloader(opt_dataloader):
phase = opt_dataloader['phase']
if phase == 'train':
batch_size = opt_dataloader['batch_size']
num_workers = opt_dataloader['num_workers']
shuffle = True
elif phase == 'valid':
batch_size = 1
num_workers = 0
shuffle = False
folder_dataset = SimpleFolderDataset(opt_dataloader)
dataloader = DataLoader(folder_dataset, batch_size=batch_size, pin_memory=True, \
drop_last=True, shuffle=shuffle, num_workers=num_workers)
return dataloader