-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_helpers.py
26 lines (22 loc) · 1.05 KB
/
data_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import pickle
import numpy as np
def removeUnk(x, n_words):
return [[1 if w >= n_words else w for w in sen] for sen in x]
def loadData(path, n, is_train=True):
with open(path, 'rb') as file:
dataText, dataLabel, dataId = pickle.load(file)
if is_train:
trnText, trnLabel, trnID = dataText[0], dataLabel[0], dataId[0]
devText, devLabel, tstID = dataText[1], dataLabel[1], dataId[1]
trnText = removeUnk(trnText, n)
devText = removeUnk(devText, n)
return [trnText, trnLabel, trnID], [devText, devLabel, tstID]
else:
tstText, tstLabel, tstID = dataText[2], dataLabel[2], dataId[2]
tstText = removeUnk(tstText, n)
print('load --> id -- {}, label -- {}, text -- {}'.format(len(tstID), len(tstLabel), len(tstText)))
return [tstText, tstLabel, tstID]
def prepareData(text, label):
length = [len(s) for s in text]
labels = np.array(label).astype('int32')
return [np.array(text), labels, np.array(length)]