Skip to content

Commit

Permalink
fix gnn_explainer work with features
Browse files Browse the repository at this point in the history
  • Loading branch information
mishabounty committed Oct 29, 2024
1 parent 4884213 commit a06a195
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
4 changes: 2 additions & 2 deletions metainfo/explainers_init_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
"GNNExplainer(torch-geom)": {
"epochs": ["Epochs","int",100,{"min": 1},"The number of epochs to train"],
"lr": ["Learn rate","float",0.01,{"min": 0, "step": 0.0001},"The learning rate to apply"],
"node_mask_type": ["Node mask","string","object",["None","object","common_attributes","attributes"],"The type of mask to apply on nodes"],
"edge_mask_type": ["Edge mask","string","object",["None","object","common_attributes","attributes"],"The type of mask to apply on edges"],
"node_mask_type": ["Node mask","string","object",["None","object","common_attributes"],"The type of mask to apply on nodes"],
"edge_mask_type": ["Edge mask","string","object",["None","object"],"The type of mask to apply on edges"],
"mode": ["Mode","string","multiclass_classification",["binary_classification","multiclass_classification","regression"],"The mode of the model"],
"return_type": ["Model return","string","log_probs",["raw","prob","log_probs"],"Denotes the type of output from model. Valid inputs are 'log_probs' (the model returns the logarithm of probabilities), 'prob' (the model returns probabilities), 'raw' (the model returns raw scores)"],
"edge_size": ["edge_size","float",0.005,{"min": 0, "step": 0.001},""],
Expand Down
32 changes: 23 additions & 9 deletions src/explainers/GNNExplainer/torch_geom_our/out.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,14 @@ def _finalize(self):

self.explanation = AttributionExplanation(
local=mode,
edges="continuous" if edge_mask is not None else False,
features="continuous" if node_mask is not None else False)
edges="continuous" if self.edge_mask_type=="object" else False,
nodes="continuous" if self.node_mask_type=="object" else False,
features="continuous" if self.node_mask_type=="common_attributes" else False)

important_edges = {}
important_nodes = {}
important_features = {}

# TODO What if edge_mask_type or node_mask_type is None, common_attributes, attributes?
if self.edge_mask_type is not None and self.node_mask_type is not None:

# Multi graphs check is not needed: the explanation format for
Expand All @@ -136,21 +137,34 @@ def _finalize(self):
important_edges[f"{edge[0]},{edge[1]}"] = format(imp, '.4f')

# Nodes
num_nodes = node_mask.size(0)
assert num_nodes == self.x.size(0)
if self.node_mask_type=="object":
num_nodes = node_mask.size(0)
assert num_nodes == self.x.size(0)

for i in range(num_nodes):
imp = float(node_mask[i][0])
if not imp < eps:
important_nodes[i] = format(imp, '.4f')
for i in range(num_nodes):
imp = float(node_mask[i][0])
if not imp < eps:
important_nodes[i] = format(imp, '.4f')

# Features
elif self.node_mask_type=="common_attributes":
num_features = node_mask.size(1)
assert num_features == self.x.size(1)

for i in range(num_features):
imp = float(node_mask[0][i])
if not imp < eps:
important_features[i] = format(imp, '.4f')

if self.gen_dataset.is_multi():
important_edges = {self.graph_idx: important_edges}
important_nodes = {self.graph_idx: important_nodes}
important_features = {self.graph_idx: important_features}

# TODO Write functions with output threshold
self.explanation.add_edges(important_edges)
self.explanation.add_nodes(important_nodes)
self.explanation.add_features(important_features)

# print(important_edges)
# print(important_nodes)
Expand Down

0 comments on commit a06a195

Please sign in to comment.