-
Notifications
You must be signed in to change notification settings - Fork 280
/
HMM.py
84 lines (71 loc) · 2.7 KB
/
HMM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import numpy as np
from functools import partial
import sys
from pathlib import Path
from rich.console import Console
from rich.table import Table
sys.path.append(str(Path(os.path.abspath(__file__)).parent.parent))
from utils import *
sys.path.append(str(Path(os.path.abspath(__file__)).parent.parent / '10.HMM'))
from BaumWelch import baum_welch
from Viterbi import viterbi
class HMM:
def __init__(self, state_size, observation_size, max_iteration=2000, verbose=False, epsilon=1e-8):
self.max_iteration = max_iteration
self.verbose = verbose
self.state_size = state_size
self.observation_size = observation_size
self.epsilon = epsilon
def fit(self, X):
"""
When there is no label in the training data,
HMM uses baum-welch for training.
Otherwise just counting the probability will be fine (not implemented here)
"""
self.state2state, self.state2observation, self.initial_state = \
baum_welch(X, self.state_size, self.observation_size, self.epsilon, self.max_iteration)
def predict(self, X):
"""HMM uses viterbi for predicting"""
Y = np.zeros_like(X)
Y = np.apply_along_axis(
partial(viterbi, self.state2state, self.state2observation, self.initial_state), -1, X)
return Y
if __name__ == '__main__':
def demonstrate(X, testX, desc):
console = Console(markup=False)
vocab = set(X.flatten())
vocab_size = len(vocab)
word2num = {word: num for num, word in enumerate(vocab)}
f_word2num = np.vectorize(lambda word: word2num[word])
numX, num_testX = map(f_word2num, (X, testX))
hmm = HMM(4, vocab_size)
hmm.fit(numX)
pred = hmm.predict(num_testX)
# show in table
print(desc)
table = Table()
for x, p in zip(testX, pred):
table.add_row(*map(str, x))
table.add_row(*map(str, p))
console.print(table)
# ---------------------- Example 1 --------------------------------------------
X = np.array([s.split() for s in
['i am good .',
'i am bad .',
'you are good .',
'you are bad .',
'it is good .',
'it is bad .',
]
])
testX = X
demonstrate(X, testX, "Example 1")
# ---------------------- Example 2 --------------------------------------------
testX = np.array([s.split() for s in
['you is good .',
'i are bad .',
'it are good .']
])
testX = np.concatenate([X, testX])
demonstrate(X, testX, "Example 2")