-
Notifications
You must be signed in to change notification settings - Fork 3
/
data.py
128 lines (114 loc) · 4.5 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# -*- coding: utf-8 -*-
import torch.utils.data as data
import torch
from config import *
import os
from PIL import Image
import random
class Fashion_attr_prediction(data.Dataset):
def __init__(self, type="train", transform=None, target_transform=None, crop=False, img_path=None):
self.transform = transform
self.target_transform = target_transform
self.crop = crop
# type_all = ["train", "test", "all", "triplet", "single"]
self.type = type
if type == "single":
print("Image path is ", img_path)
self.img_path = img_path
return
self.train_list = []
self.train_dict = {i: [] for i in range(CATEGORIES)}
self.test_list = []
self.all_list = []
self.bbox = dict()
self.anno = dict()
self.read_partition_category()
self.read_bbox()
def __len__(self):
if self.type == "all":
return len(self.all_list)
elif self.type == "train":
return len(self.train_list)
elif self.type == "test":
return len(self.test_list)
else:
return 1
def read_partition_category(self):
list_eval_partition = os.path.join(DATASET_BASE, r'Eval', r'list_eval_partition.txt')
list_category_img = os.path.join(DATASET_BASE, r'Anno', r'list_category_img.txt')
partition_pairs = self.read_lines(list_eval_partition)
category_img_pairs = self.read_lines(list_category_img)
for k, v in category_img_pairs:
v = int(v)
if v <= 20:
self.anno[k] = v - 1
for k, v in partition_pairs:
if k in self.anno:
#print (k,v,self.anno[k])
if v == "train":
self.train_list.append(k)
self.train_dict[self.anno[k]].append(k)
else:
# Test and Val
self.test_list.append(k)
self.all_list = self.test_list + self.train_list
random.shuffle(self.train_list)
random.shuffle(self.test_list)
random.shuffle(self.all_list)
#print (self.type)
#for i in range(20):
#print( len(self.train_dict[i]) )
def read_bbox(self):
list_bbox = os.path.join(DATASET_BASE, r'Anno', r'list_bbox.txt')
pairs = self.read_lines(list_bbox)
for k, x1, y1, x2, y2 in pairs:
self.bbox[k] = [x1, y1, x2, y2]
def read_lines(self, path):
with open(path) as fin:
lines = fin.readlines()[2:]
lines = list(filter(lambda x: len(x) > 0, lines))
pairs = list(map(lambda x: x.strip().split(), lines))
return pairs
def read_crop(self, img_path):
img_full_path = os.path.join(DATASET_BASE, img_path)
with open(img_full_path, 'rb') as f:
with Image.open(f) as img:
img = img.convert('RGB')
if self.crop:
x1, y1, x2, y2 = self.bbox[img_path]
if x1 < x2 <= img.size[0] and y1 < y2 <= img.size[1]:
img = img.crop((x1, y1, x2, y2))
return img
def __getitem__(self, index):
if self.type == "triplet":
img_path = self.train_list[index]
target = self.anno[img_path]
img_p = random.choice(self.train_dict[target])
img_n = random.choice(self.train_dict[random.choice(list(filter(lambda x: x != target, range(20))))])
img = self.read_crop(img_path)
img_p = self.read_crop(img_p)
img_n = self.read_crop(img_n)
if self.transform is not None:
img = self.transform(img)
img_p = self.transform(img_p)
img_n = self.transform(img_n)
return img, img_p, img_n
if self.type == "single":
img_path = self.img_path
img = self.read_crop(img_path)
if self.transform is not None:
img = self.transform(img)
return img
if self.type == "all":
img_path = self.all_list[index]
elif self.type == "train":
img_path = self.train_list[index]
else:
img_path = self.test_list[index]
target = self.anno[img_path]
img = self.read_crop(img_path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, img_path if self.type == "all" else target