diff --git a/test_tnt.py b/test_tnt.py index 9387887..1bee931 100644 --- a/test_tnt.py +++ b/test_tnt.py @@ -11,7 +11,6 @@ def main(): if not line.strip(): continue words = map(lambda x: x.split('/'), line.split(' ')) - words = map(lambda x: (x[1], x[0]), words) data.append(words) model = tnt.TnT() model.train(data) @@ -21,7 +20,7 @@ def main(): for c1, sent in enumerate(data): for c2, wd in enumerate(sent): total += 1 - if wd[0] != ret[c1][c2][0]: + if wd[1] != ret[c1][c2][1]: error += 1 print 'total: %d, error: %d, precision: %f' % (total, error, float(error).total) diff --git a/tnt.py b/tnt.py index b9bdffb..bfc39d0 100755 --- a/tnt.py +++ b/tnt.py @@ -11,7 +11,7 @@ class TnT(object): def __init__(self, N=1000): - self.N = 1000 + self.N = N self.l1 = 0.0 self.l2 = 0.0 self.l3 = 0.0 @@ -70,6 +70,7 @@ def train(self, data): def tag(self, data): now = [(('BOS', 'BOS'), 0.0, [])] + print self.status for w in data: stage = {} for pre in now: @@ -85,4 +86,4 @@ def tag(self, data): stage = map(lambda x: (x[0], x[1][0], x[1][1]), stage.items()) now = sorted(stage, key=lambda x:-x[1])[:self.N] print len(now) - return zip(now[2], data) + return zip(data, now[0][2])