-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
98 lines (79 loc) · 4.14 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import spacy
from torchtext.functional import to_tensor
from torchtext.datasets import multi30k, Multi30k
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
# Patch Multi30K dataset, PyTorch's Google Drive link is no longer active
multi30k.URL['train'] = 'https://raw.githubusercontent.com/tanjeffreyz/pytorch-multi30k/main/training.tar.gz'
multi30k.URL['valid'] = 'https://raw.githubusercontent.com/tanjeffreyz/pytorch-multi30k/main/validation.tar.gz'
multi30k.URL['test'] = 'https://raw.githubusercontent.com/tanjeffreyz/pytorch-multi30k/main/mmt16_task1_test.tar.gz'
multi30k.MD5['test'] = 'd914ec964e2c5f0534e5cdd3926cd2fe628d591dad9423c3ae953d93efdb27a6'
class Dataset:
pipelines = {
'en': spacy.load('en_core_web_sm'),
'de': spacy.load('de_core_news_sm')
}
def __init__(self,
language_pair,
sos_token='<sos>',
eos_token='<eos>',
unk_token='<unk>',
pad_token='<pad>',
columns=('source', 'target'),
batch_size=64):
self.src_lang, self.trg_lang = language_pair
assert self.src_lang in Dataset.pipelines, 'Unrecognized source language'
assert self.trg_lang in Dataset.pipelines, 'Unrecognized target language'
self.sos_token = sos_token
self.eos_token = eos_token
self.unk_token = unk_token
self.pad_token = pad_token
# Load datasets
train_data, valid_data, test_data = Multi30k(root='data', language_pair=language_pair)
# Build vocabs from dataset
self.src_vocab = build_vocab_from_iterator(
map(lambda x: x[0], train_data.map(self.tokenize)),
specials=[self.sos_token, self.eos_token, self.unk_token, self.pad_token]
)
self.src_vocab.set_default_index(self.src_vocab[self.unk_token])
self.trg_vocab = build_vocab_from_iterator(
map(lambda x: x[1], train_data.map(self.tokenize)),
specials=[self.sos_token, self.eos_token, self.unk_token, self.pad_token]
)
self.trg_vocab.set_default_index(self.trg_vocab[self.unk_token])
# Tokenize, encode, and batch the data
train_processed = self.process(train_data, batch_size, columns)
valid_processed = self.process(valid_data, batch_size, columns)
test_processed = self.process(test_data, batch_size, columns)
self.train_loader = DataLoader(train_processed, batch_size=None, shuffle=True)
self.valid_loader = DataLoader(valid_processed, batch_size=None, shuffle=False)
self.test_loader = DataLoader(test_processed, batch_size=None, shuffle=False)
def process(self, data, batch_size, columns):
"""Applies the preprocessing pipeline to DATA."""
return (
data
.map(self.tokenize)
.map(self.encode)
.batch(batch_size)
.rows2columnar(columns)
.map(self.pad)
)
def tokenize(self, pair):
"""Splits source and target sentences into tokens based on the configured language pair"""
src, trg = pair
src_tokens = [self.sos_token] + [x.text.lower() for x in Dataset.pipelines[self.src_lang].tokenizer(src)] + [self.eos_token]
trg_tokens = [self.sos_token] + [y.text.lower() for y in Dataset.pipelines[self.trg_lang].tokenizer(trg)] + [self.eos_token]
return src_tokens, trg_tokens
def encode(self, pair):
"""Replaces tokens with their corresponding vocabulary indices"""
src, trg = pair
return self.src_vocab.lookup_indices(src), self.trg_vocab.lookup_indices(trg)
def pad(self, batch):
"""
Pads all sequences in the batch to have same length. Pads using the EOS token because if <eos> is reached, model
should insist sequence has finished and output <eos> even if prompted further.
"""
# Separate source and target sequences again
batch['source'] = to_tensor(batch['source'], padding_value=self.src_vocab[self.pad_token])
batch['target'] = to_tensor(batch['target'], padding_value=self.trg_vocab[self.pad_token])
return batch