forked from Pawandeep-prog/finetuned-gpt2-convai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ChatData.py
31 lines (23 loc) · 931 Bytes
/
ChatData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from torch.utils.data import Dataset
import json
class ChatData(Dataset):
def __init__(self, path:str, tokenizer):
self.data = json.load(open(path, "r"))
self.X = []
for i in self.data:
for j in i['dialog']:
self.X.append(j['text'])
for idx, i in enumerate(self.X):
try:
self.X[idx] = "<startofstring> "+i+" <bot>: "+self.X[idx+1]+" <endofstring>"
except:
break
self.X = self.X[:5000]
print(self.X[0])
self.X_encoded = tokenizer(self.X,max_length=40, truncation=True, padding="max_length", return_tensors="pt")
self.input_ids = self.X_encoded['input_ids']
self.attention_mask = self.X_encoded['attention_mask']
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return (self.input_ids[idx], self.attention_mask[idx])