-
Notifications
You must be signed in to change notification settings - Fork 3
/
data_loader.py
71 lines (61 loc) · 2.31 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# -*- coding: utf-8 -*-
"""
Created on Wed Nov 27 22:54:19 2019
@author: aithlab
"""
import os
import pickle
import torch
class MNISTStroke(torch.utils.data.Dataset):
def __init__(self, root_path='./', train=True):
self.mode = 'train' if train else 'test'
if not self.preprocessingcCheck(root_path):
data = pickle.load(open(os.path.join(root_path, "processed/%s"%self.mode), "rb"))
self.data = self.preprocessing(data, root_path)
else:
data_path = os.path.join(root_path, "sequenced/%s.pkl" %self.mode)
with open(data_path, 'rb') as f:
self.data = pickle.load(f)
def __getitem__(self, idx):
data = self.data[idx]['input']
seq_len = len(self.data[idx]['input'])
label = self.data[idx]['label']
return data, seq_len, label
def preprocessingcCheck(self, path):
return os.path.isdir(os.path.join(path, "sequenced"))
def preprocessing(self, data, root_path):
new_data = dict()
for iter_, (num, data_) in enumerate(data.items()):
temp = []
pt = torch.zeros([1,2])
for dx, dy, eos, eod in data_['input']:
pt = pt + torch.cat([dx.view(1,1), dy.view(1,1)], dim=1)
temp.append(pt)
if eod:
break
temp = torch.cat(temp,0)
new_data[num] = {'input':temp, 'label':data_['label']}
print('\rData Processing...(%5d/%5d)' %(iter_,len(data)),
end='' if iter_+1 < len(data) else '\n')
with open(os.path.join(root_path, "sequenced", '%s.pkl'%self.mode), 'wb') as f:
pickle.dump(new_data, f)
return new_data
def __len__(self):
return len(self.data)
#def collate_fn(batch):
# return tuple(zip(*batch))
def collate_fn(batch):
'''
Padds batch of variable length
note: it converts things ToTensor manually here since the ToTensor transform
assume it takes in images rather than arbitrary tensors.
'''
## padd
seqs = [torch.Tensor(t[0]) for t in batch]
seq_lens = torch.tensor([t[1] for t in batch])
labels = [t[2] for t in batch]
seqs = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True)
seqs = torch.nn.utils.rnn.pack_padded_sequence(seqs, seq_lens,
batch_first=True,
enforce_sorted=False)
return seqs, seq_lens, labels