Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
isnowfy committed May 12, 2013
1 parent 63fdbbf commit c9791d4
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 48 deletions.
13 changes: 11 additions & 2 deletions frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ def samples(self):
return self.d.keys()


class NormalProb(BaseProb):

def add(self, key, value):
if not self.exists(key):
self.d[key] = 0
self.d[key] += value
self.total += value


class AddOneProb(BaseProb):

def __init__(self):
Expand All @@ -37,9 +46,9 @@ def __init__(self):
def add(self, key, value):
self.total += value
if not self.exists(key):
self.d[key] = 0
self.d[key] = 1
self.total += 1
self.d[key] += value+1
self.d[key] += value


class GoodTuringProb(BaseProb):
Expand Down
11 changes: 8 additions & 3 deletions test_tnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@

import tnt

def main():
def getdata(filename='brown.txt'):
data = []
f = codecs.open('brown.txt', 'r', 'utf-8')
f = codecs.open(filename, 'r', 'utf-8')
for line in f:
if not line.strip():
line = line.strip()
if not line:
continue
words = map(lambda x: x.split('/'), line.split(' '))
data.append(words)
return data

def main():
data = getdata()
model = tnt.TnT()
model.train(data)
ret = [model.tag(map(lambda x:x[0], sent)) for sent in data]
Expand Down
89 changes: 46 additions & 43 deletions tnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Implementation of 'TnT - A Statisical Part of Speech Tagger'
'''

import heapq
from math import log

import frequency
Expand All @@ -15,81 +16,83 @@ def __init__(self, N=1000):
self.l1 = 0.0
self.l2 = 0.0
self.l3 = 0.0
self.status = {'BOS', 'EOS'}
self.status = set()
self.wd = frequency.AddOneProb()
self.uni = frequency.AddOneProb()
self.bi = frequency.AddOneProb()
self.tri = frequency.AddOneProb()
self.eos = frequency.AddOneProb()
self.eosd = frequency.AddOneProb()
self.uni = frequency.NormalProb()
self.bi = frequency.NormalProb()
self.tri = frequency.NormalProb()
self.word = {}
self.trans = {}

def _safe_div(self, v1, v2):
def tnt_div(self, v1, v2):
if v2 == 0:
return -1
return float(v1) / float(v2)
return 0
return float(v1)/v2

def train(self, data):
now = ['BOS', 'BOS']
for sentence in data:
self.bi.add(('BOS', 'BOS'), 1)
self.uni.add('BOS', 2)
for word, tag in sentence:
now.append(tag)
self.status.add(tag)
self.wd.add((tag, word), 1)
self.eos.add(tuple(now[1:]), 1)
self.eosd.add(tag, 1)
self.uni.add(tag, 1)
self.bi.add(tuple(now[1:]), 1)
self.tri.add(tuple(now), 1)
if word not in self.word:
self.word[word] = set()
self.word[word].add(tag)
now.pop(0)
self.eos.add(now[-1], 1)
self.eos.add((now[-1], 'EOS'), 1)
tl1 = 0.0
tl2 = 0.0
tl3 = 0.0
for now in self.tri.samples():
c1 = self._safe_div(self.tri.get(now)[1]-1,
self.bi.get(now[:2])[1]-1)
c2 = self._safe_div(self.bi.get(now[1:])[1]-1,
self.uni.get(now[1])[1]-1)
c3 = self._safe_div(self.uni.get(now[2])[1]-1,
self.uni.getsum()-1)
if c1 > c2 and c1 > c3:
tl1 += self.tri.get(now)[1]
elif c2 > c1 and c2 > c3:
tl2 += self.tri.get(now)[1]
elif c3 > c1 and c3 > c2:
c3 = self.tnt_div(self.tri.get(now)[1]-1, self.bi.get(now[:2])[1]-1)
c2 = self.tnt_div(self.bi.get(now[1:])[1]-1, self.uni.get(now[1])[1]-1)
c1 = self.tnt_div(self.uni.get(now[2])[1]-1, self.uni.getsum()-1)
if c3 >= c1 and c3 >= c2:
tl3 += self.tri.get(now)[1]
elif c1 == c2 and c1 > c3:
tl1 += self.tri.get(now)[1]/2.0
tl2 += self.tri.get(now)[1]/2.0
elif c2 == c3 and c2 > c1:
tl2 += self.tri.get(now)[1]/2.0
tl3 += self.tri.get(now)[1]/2.0
elif c1 == c3 and c1 > c2:
tl3 += self.tri.get(now)[1]/2.0
tl1 += self.tri.get(now)[1]/2.0
self.l1 = self._safe_div(tl1, tl1+tl2+tl3)
self.l2 = self._safe_div(tl2, tl1+tl2+tl3)
self.l3 = self._safe_div(tl3, tl1+tl2+tl3)
for s1 in self.status:
for s2 in self.status:
elif c2 >= c1 and c2 >= c3:
tl2 += self.tri.get(now)[1]
elif c1 >= c2 and c1 >= c3:
tl1 += self.tri.get(now)[1]
self.l1 = float(tl1)/(tl1+tl2+tl3)
self.l2 = float(tl2)/(tl1+tl2+tl3)
self.l3 = float(tl3)/(tl1+tl2+tl3)
for s1 in self.status|set(('BOS',)):
for s2 in self.status|set(('BOS',)):
for s3 in self.status:
uni = self.l3*self.uni.freq(s3)
bi = self.l2*self.bi.freq((s2, s3))
tri = self.l1*self.tri.freq((s1, s2, s3))
uni = self.l1*self.uni.freq(s3)
bi = self.tnt_div(self.l2*self.bi.get((s2, s3))[1],
self.uni.get(s2)[1])
tri = self.tnt_div(self.l3*self.tri.get((s1, s2, s3))[1],
self.bi.get((s1, s2))[1])
self.trans[(s1, s2, s3)] = log(uni+bi+tri)

def tag(self, data):
now = [(('BOS', 'BOS'), 0.0, [])]
for w in data:
stage = {}
for s in self.status:
wd = log(self.wd.get((s, w))[1])
samples = self.status
if w in self.word:
samples = self.word[w]
for s in samples:
wd = log(self.wd.get((s, w))[1])-log(self.uni.get(s)[1])
for pre in now:
p = pre[1]+wd+self.trans[(pre[0][0], pre[0][1], s)]
if (pre[0][1], s) not in stage or p > stage[(pre[0][1], s)][0]:
stage[(pre[0][1], s)] = (p, pre[2]+[s])
stage = map(lambda x: (x[0], x[1][0], x[1][1]), stage.items())
now = sorted(stage, key=lambda x:-x[1])[:self.N]
print len(now)
for cnt, item in enumerate(now):
now[cnt] = (item[0], item[1]+log(self.eos.freq(item[0][1])), item[2])
now = sorted(stage, key=lambda x:-x[1])[:self.N]
now = heapq.nlargest(self.N, stage, key=lambda x:x[1])
now = heapq.nlargest(1, stage,
key=lambda x: x[1]+\
log(self.eos.get((x[0][1], 'EOS'))[1])\
-log(self.eosd.get(x[0][1])[1]))
return zip(data, now[0][2])

0 comments on commit c9791d4

Please sign in to comment.