-
Notifications
You must be signed in to change notification settings - Fork 76
/
BayesNet.py
228 lines (194 loc) · 8.73 KB
/
BayesNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from typing import List, Tuple, Dict
import networkx as nx
import matplotlib.pyplot as plt
from pgmpy.readwrite import XMLBIFReader
import math
import itertools
import pandas as pd
from copy import deepcopy
class BayesNet:
def __init__(self) -> None:
# initialize graph structure
self.structure = nx.DiGraph()
# LOADING FUNCTIONS ------------------------------------------------------------------------------------------------
def create_bn(self, variables: List[str], edges: List[Tuple[str, str]], cpts: Dict[str, pd.DataFrame]) -> None:
"""
Creates the BN according to the python objects passed in.
:param variables: List of names of the variables.
:param edges: List of the directed edges.
:param cpts: Dictionary of conditional probability tables.
"""
# add nodes
[self.add_var(v, cpt=cpts[v]) for v in variables]
# add edges
[self.add_edge(e) for e in edges]
# check for cycles
if not nx.is_directed_acyclic_graph(self.structure):
raise Exception('The provided graph is not acyclic.')
def load_from_bifxml(self, file_path: str) -> None:
"""
Load a BayesNet from a file in BIFXML file format. See description of BIFXML here:
http://www.cs.cmu.edu/afs/cs/user/fgcozman/www/Research/InterchangeFormat/
:param file_path: Path to the BIFXML file.
"""
# Read and parse the bifxml file
with open(file_path) as f:
bn_file = f.read()
bif_reader = XMLBIFReader(string=bn_file)
# load cpts
cpts = {}
# iterating through vars
for key, values in bif_reader.get_values().items():
values = values.transpose().flatten()
n_vars = int(math.log2(len(values)))
worlds = [list(i) for i in itertools.product([False, True], repeat=n_vars)]
# create empty array
cpt = []
# iterating through worlds within a variable
for i in range(len(values)):
# add the probability to each possible world
worlds[i].append(values[i])
cpt.append(worlds[i])
# determine column names
columns = bif_reader.get_parents()[key]
columns.reverse()
columns.append(key)
columns.append('p')
cpts[key] = pd.DataFrame(cpt, columns=columns)
# load vars
variables = bif_reader.get_variables()
# load edges
edges = bif_reader.get_edges()
self.create_bn(variables, edges, cpts)
# METHODS THAT MIGHT ME USEFUL -------------------------------------------------------------------------------------
def get_children(self, variable: str) -> List[str]:
"""
Returns the children of the variable in the graph.
:param variable: Variable to get the children from
:return: List of children
"""
return [c for c in self.structure.successors(variable)]
def get_cpt(self, variable: str) -> pd.DataFrame:
"""
Returns the conditional probability table of a variable in the BN.
:param variable: Variable of which the CPT should be returned.
:return: Conditional probability table of 'variable' as a pandas DataFrame.
"""
try:
return self.structure.nodes[variable]['cpt']
except KeyError:
raise Exception('Variable not in the BN')
def get_all_variables(self) -> List[str]:
"""
Returns a list of all variables in the structure.
:return: list of all variables.
"""
return [n for n in self.structure.nodes]
def get_all_cpts(self) -> Dict[str, pd.DataFrame]:
"""
Returns a dictionary of all cps in the network indexed by the variable they belong to.
:return: Dictionary of all CPTs
"""
cpts = {}
for var in self.get_all_variables():
cpts[var] = self.get_cpt(var)
return cpts
def get_interaction_graph(self):
"""
Returns a networkx.Graph as interaction graph of the current BN.
:return: The interaction graph based on the factors of the current BN.
"""
# Create the graph and add all variables
int_graph = nx.Graph()
[int_graph.add_node(var) for var in self.get_all_variables()]
# connect all variables with an edge which are mentioned in a CPT together
for var in self.get_all_variables():
involved_vars = list(self.get_cpt(var).columns)[:-1]
for i in range(len(involved_vars)-1):
for j in range(i+1, len(involved_vars)):
if not int_graph.has_edge(involved_vars[i], involved_vars[j]):
int_graph.add_edge(involved_vars[i], involved_vars[j])
return int_graph
@staticmethod
def get_compatible_instantiations_table(instantiation: pd.Series, cpt: pd.DataFrame):
"""
Get all the entries of a CPT which are compatible with the instantiation.
:param instantiation: a series of assignments as tuples. E.g.: pd.Series({"A": True, "B": False})
:param cpt: cpt to be filtered
:return: table with compatible instantiations and their probability value
"""
var_names = instantiation.index.values
var_names = [v for v in var_names if v in cpt.columns] # get rid of excess variables names
compat_indices = cpt[var_names] == instantiation[var_names].values
compat_indices = [all(x[1]) for x in compat_indices.iterrows()]
compat_instances = cpt.loc[compat_indices]
return compat_instances
def update_cpt(self, variable: str, cpt: pd.DataFrame) -> None:
"""
Replace the conditional probability table of a variable.
:param variable: Variable to be modified
:param cpt: new CPT
"""
self.structure.nodes[variable]["cpt"] = cpt
@staticmethod
def reduce_factor(instantiation: pd.Series, cpt: pd.DataFrame) -> pd.DataFrame:
"""
Creates and returns a new factor in which all probabilities which are incompatible with the instantiation
passed to the method to 0.
:param instantiation: a series of assignments as tuples. E.g.: pd.Series({"A": True, "B": False})
:param cpt: cpt to be reduced
:return: cpt with their original probability value and zero probability for incompatible instantiations
"""
var_names = instantiation.index.values
var_names = [v for v in var_names if v in cpt.columns] # get rid of excess variables names
if len(var_names) > 0: # only reduce the factor if the evidence appears in it
new_cpt = deepcopy(cpt)
incompat_indices = cpt[var_names] != instantiation[var_names].values
incompat_indices = [any(x[1]) for x in incompat_indices.iterrows()]
new_cpt.loc[incompat_indices, 'p'] = 0.0
return new_cpt
else:
return cpt
def draw_structure(self) -> None:
"""
Visualize structure of the BN.
"""
nx.draw(self.structure, with_labels=True, node_size=3000)
plt.show()
# BASIC HOUSEKEEPING METHODS ---------------------------------------------------------------------------------------
def add_var(self, variable: str, cpt: pd.DataFrame) -> None:
"""
Add a variable to the BN.
:param variable: variable to be added.
:param cpt: conditional probability table of the variable.
"""
if variable in self.structure.nodes:
raise Exception('Variable already exists.')
else:
self.structure.add_node(variable, cpt=cpt)
def add_edge(self, edge: Tuple[str, str]) -> None:
"""
Add a directed edge to the BN.
:param edge: Tuple of the directed edge to be added (e.g. ('A', 'B')).
:raises Exception: If added edge introduces a cycle in the structure.
"""
if edge in self.structure.edges:
raise Exception('Edge already exists.')
else:
self.structure.add_edge(edge[0], edge[1])
# check for cycles
if not nx.is_directed_acyclic_graph(self.structure):
self.structure.remove_edge(edge[0], edge[1])
raise ValueError('Edge would make graph cyclic.')
def del_var(self, variable: str) -> None:
"""
Delete a variable from the BN.
:param variable: Variable to be deleted.
"""
self.structure.remove_node(variable)
def del_edge(self, edge: Tuple[str, str]) -> None:
"""
Delete an edge form the structure of the BN.
:param edge: Edge to be deleted (e.g. ('A', 'B')).
"""
self.structure.remove_edge(edge[0], edge[1])