-
Notifications
You must be signed in to change notification settings - Fork 280
/
ID3.py
147 lines (133 loc) · 5.42 KB
/
ID3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from pprint import pprint
from rich.console import Console
from rich.table import Table
from collections import Counter
import sys
import os
from pathlib import Path
sys.path.append(str(Path(os.path.abspath(__file__)).parent.parent))
from utils import argmax, information_gain
class ID3:
class Node:
def __init__(self, col, Y):
self.col = col
self.children = {}
self.cnt = Counter(Y)
self.label = self.cnt.most_common(1)[0][0]
def __init__(self, information_gain_threshold=0., verbose=False):
self.information_gain_threshold = information_gain_threshold
self.verbose = verbose
def build(self, X, Y, selected):
cur = self.Node(None, Y)
if self.verbose:
print("Cur selected columns:", selected)
print("Cur data:")
pprint(X)
print(Y)
split = False
# check if there is no attribute to choose
# or there is no need for spilt
if len(selected) != self.column_cnt and len(set(Y)) > 1:
left_columns = list(set(range(self.column_cnt)) - selected)
col_ind, best_information_gain = argmax(left_columns,
key=lambda col: information_gain(X, Y, col))
col = left_columns[col_ind]
# if this split is better than not splitting
if best_information_gain > self.information_gain_threshold:
if self.verbose:
print(f"Split by {col}th column")
split = True
cur.col = col
for val in set(x[col] for x in X):
ind = [x[col] == val for x in X]
child_X = [x for i, x in zip(ind, X) if i]
child_Y = [y for i, y in zip(ind, Y) if i]
cur.children[val] = self.build(child_X, child_Y, selected | {col})
if not split and self.verbose:
print("No split")
return cur
def query(self, root, x):
if root.col is None or x[root.col] not in root.children:
return root.label
return self.query(root.children[x[root.col]], x)
def fit(self, X, Y):
self.column_cnt = len(X[0])
self.root = self.build(X, Y, set())
def _predict(self, x):
return self.query(self.root, x)
def predict(self, X):
return [self._predict(x) for x in X]
if __name__ == "__main__":
console = Console(markup=False)
id3 = ID3(verbose=False)
# -------------------------- Example 1 ----------------------------------------
# unpruned decision tree predict correctly for all training data
print("Example 1:")
X = [
['青年', '否', '否', '一般'],
['青年', '否', '否', '好'],
['青年', '是', '否', '好'],
['青年', '是', '是', '一般'],
['青年', '否', '否', '一般'],
['老年', '否', '否', '一般'],
['老年', '否', '否', '好'],
['老年', '是', '是', '好'],
['老年', '否', '是', '非常好'],
['老年', '否', '是', '非常好'],
['老年', '否', '是', '非常好'],
['老年', '否', '是', '好'],
['老年', '是', '否', '好'],
['老年', '是', '否', '非常好'],
['老年', '否', '否', '一般'],
]
Y = ['否', '否', '是', '是', '否', '否', '否', '是', '是', '是', '是', '是', '是', '是', '否']
id3.fit(X, Y)
# show in table
pred = id3.predict(X)
table = Table('x', 'y', 'pred')
for x, y, y_hat in zip(X, Y, pred):
table.add_row(*map(str, [x, y, y_hat]))
console.print(table)
# -------------------------- Example 2 ----------------------------------------
# but unpruned decision tree doesn't generalize well for test data
print("Example 2:")
X = [
['青年', '否', '否', '一般'],
['青年', '否', '否', '好'],
['青年', '是', '是', '一般'],
['青年', '否', '否', '一般'],
['老年', '否', '否', '一般'],
['老年', '否', '否', '好'],
['老年', '是', '是', '好'],
['老年', '否', '是', '非常好'],
['老年', '否', '是', '非常好'],
['老年', '否', '是', '非常好'],
['老年', '否', '是', '好'],
['老年', '否', '否', '一般'],
]
Y = ['否', '否', '是', '否', '否', '否', '是', '是', '是', '是', '是', '否']
id3.fit(X, Y)
testX = [
['青年', '否', '否', '一般'],
['青年', '否', '否', '好'],
['青年', '是', '否', '好'],
['青年', '是', '是', '一般'],
['青年', '否', '否', '一般'],
['老年', '否', '否', '一般'],
['老年', '否', '否', '好'],
['老年', '是', '是', '好'],
['老年', '否', '是', '非常好'],
['老年', '否', '是', '非常好'],
['老年', '否', '是', '非常好'],
['老年', '否', '是', '好'],
['老年', '是', '否', '好'],
['老年', '是', '否', '非常好'],
['老年', '否', '否', '一般'],
]
testY = ['否', '否', '是', '是', '否', '否', '否', '是', '是', '是', '是', '是', '是', '是', '否']
# show in table
pred = id3.predict(testX)
table = Table('x', 'y', 'pred')
for x, y, y_hat in zip(testX, testY, pred):
table.add_row(*map(str, [x, y, y_hat]))
console.print(table)