diff --git a/metainfo/explainers_init_parameters.json b/metainfo/explainers_init_parameters.json index 66596f8..a4f93be 100644 --- a/metainfo/explainers_init_parameters.json +++ b/metainfo/explainers_init_parameters.json @@ -20,8 +20,8 @@ "GNNExplainer(torch-geom)": { "epochs": ["Epochs","int",100,{"min": 1},"The number of epochs to train"], "lr": ["Learn rate","float",0.01,{"min": 0, "step": 0.0001},"The learning rate to apply"], - "node_mask_type": ["Node mask","string","object",["None","object","common_attributes","attributes"],"The type of mask to apply on nodes"], - "edge_mask_type": ["Edge mask","string","object",["None","object","common_attributes","attributes"],"The type of mask to apply on edges"], + "node_mask_type": ["Node mask","string","object",["None","object","common_attributes"],"The type of mask to apply on nodes"], + "edge_mask_type": ["Edge mask","string","object",["None","object"],"The type of mask to apply on edges"], "mode": ["Mode","string","multiclass_classification",["binary_classification","multiclass_classification","regression"],"The mode of the model"], "return_type": ["Model return","string","log_probs",["raw","prob","log_probs"],"Denotes the type of output from model. Valid inputs are 'log_probs' (the model returns the logarithm of probabilities), 'prob' (the model returns probabilities), 'raw' (the model returns raw scores)"], "edge_size": ["edge_size","float",0.005,{"min": 0, "step": 0.001},""], diff --git a/src/explainers/GNNExplainer/torch_geom_our/out.py b/src/explainers/GNNExplainer/torch_geom_our/out.py index e0a10ce..fa6497d 100644 --- a/src/explainers/GNNExplainer/torch_geom_our/out.py +++ b/src/explainers/GNNExplainer/torch_geom_our/out.py @@ -111,13 +111,14 @@ def _finalize(self): self.explanation = AttributionExplanation( local=mode, - edges="continuous" if edge_mask is not None else False, - features="continuous" if node_mask is not None else False) + edges="continuous" if self.edge_mask_type=="object" else False, + nodes="continuous" if self.node_mask_type=="object" else False, + features="continuous" if self.node_mask_type=="common_attributes" else False) important_edges = {} important_nodes = {} + important_features = {} - # TODO What if edge_mask_type or node_mask_type is None, common_attributes, attributes? if self.edge_mask_type is not None and self.node_mask_type is not None: # Multi graphs check is not needed: the explanation format for @@ -136,21 +137,34 @@ def _finalize(self): important_edges[f"{edge[0]},{edge[1]}"] = format(imp, '.4f') # Nodes - num_nodes = node_mask.size(0) - assert num_nodes == self.x.size(0) + if self.node_mask_type=="object": + num_nodes = node_mask.size(0) + assert num_nodes == self.x.size(0) - for i in range(num_nodes): - imp = float(node_mask[i][0]) - if not imp < eps: - important_nodes[i] = format(imp, '.4f') + for i in range(num_nodes): + imp = float(node_mask[i][0]) + if not imp < eps: + important_nodes[i] = format(imp, '.4f') + + # Features + elif self.node_mask_type=="common_attributes": + num_features = node_mask.size(1) + assert num_features == self.x.size(1) + + for i in range(num_features): + imp = float(node_mask[0][i]) + if not imp < eps: + important_features[i] = format(imp, '.4f') if self.gen_dataset.is_multi(): important_edges = {self.graph_idx: important_edges} important_nodes = {self.graph_idx: important_nodes} + important_features = {self.graph_idx: important_features} # TODO Write functions with output threshold self.explanation.add_edges(important_edges) self.explanation.add_nodes(important_nodes) + self.explanation.add_features(important_features) # print(important_edges) # print(important_nodes)