-
Notifications
You must be signed in to change notification settings - Fork 4
/
ngram_models.py
115 lines (103 loc) · 3.27 KB
/
ngram_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import pickle
import sys
import torch
class NGram():
def __init__(self, corpus, corpus_counts, type):
self.corpus = corpus
self.counts = corpus_counts
self.type = type
def prob(self, key, next):
"""
Args:
key (tuple): tuple of token ID's forming prior
next (int): probability of next token
"""
l = len(key)
if self.type == "bigram":
assert l == 1
key = key[0]
elif self.type == "trigram":
assert l == 2
elif self.type == "fourgram":
assert l == 3
elif self.type == "fivegram":
assert l == 4
elif self.type == "sixgram":
assert l == 5
elif self.type == "sevengram":
assert l == 6
count = 0
if key in self.corpus:
count = self.corpus[key].get(next, 0)
total = sum(self.corpus[key].values())
return count / total
else:
return -1
def ntd(self, key, vocab_size=32000):
"""
Args:
key (tuple): tuple of token ID's forming prior
Returns:
prob_tensor (torch.Tensor): (vocab_size, ) of full next token probabilities
"""
if key in self.corpus:
prob_tensor = torch.zeros(vocab_size)
total = sum(self.corpus[key].values())
for next_token in self.corpus[key]:
prob_tensor[next_token] = self.corpus[key][next_token] / total
return prob_tensor
else:
return None
def make_models(ckpt_path, bigram, trigram, fourgram, fivegram, sixgram, sevengram):
"""
Loads and returns a list correspoding to bigram to sevengram models, containing
the models that whose parameters are `True`. See below for expected corpus names.
Args:
ckpt_path (str): Location of ngram models
bigram-sevengram: Which models to load
Returns:
List of n-gram models
"""
models = []
if bigram:
print("Making bigram...")
with open(f"{ckpt_path}/b_d_final.pkl", "rb") as f:
bigram = pickle.load(f)
bigram_model = NGram(bigram, None, "bigram")
models.append(bigram_model)
print(sys.getsizeof(bigram))
if trigram:
print("Making trigram...")
with open(f"{ckpt_path}/t_d_final.pkl", "rb") as f:
trigram = pickle.load(f)
trigram_model = NGram(trigram, None, "trigram")
models.append(trigram_model)
print(sys.getsizeof(trigram))
if fourgram:
print("Making fourgram...")
with open(f"{ckpt_path}/fo_d_final.pkl", "rb") as f:
fourgram = pickle.load(f)
fourgram_model = NGram(fourgram, None, "fourgram")
models.append(fourgram_model)
print(sys.getsizeof(fourgram))
if fivegram:
print("Making fivegram...")
with open(f"{ckpt_path}/fi_d_final.pkl", "rb") as f:
fivegram = pickle.load(f)
fivegram_model = NGram(fivegram, None, "fivegram")
models.append(fivegram_model)
print(sys.getsizeof(fivegram))
if sixgram:
print("Making sixgram...")
with open(f"{ckpt_path}/si_d_final.pkl", "rb") as f:
sixgram = pickle.load(f)
sixgram_model = NGram(sixgram, None, "sixgram")
models.append(sixgram_model)
print(sys.getsizeof(sixgram))
if sevengram:
print("Making sevengram...")
with open(f"{ckpt_path}/se_d_final.pkl", "rb") as f:
sevengram = pickle.load(f)
sevengram_model = NGram(sevengram, None, "sevengram")
models.append(sevengram_model)
return models