From 2042d5f6548b140fb894a342e766f98d05c56f23 Mon Sep 17 00:00:00 2001 From: serafim Date: Tue, 29 Oct 2024 00:13:43 +0300 Subject: [PATCH 1/5] explainer metrics draft --- experiments/interpretation_metrics_test.py | 116 ++++++++++++ .../GNNExplainer/torch_geom_our/out.py | 11 ++ src/explainers/explainer_metrics.py | 171 ++++++++++++++++++ src/explainers/explainers_manager.py | 43 ++++- 4 files changed, 339 insertions(+), 2 deletions(-) create mode 100644 experiments/interpretation_metrics_test.py create mode 100644 src/explainers/explainer_metrics.py diff --git a/experiments/interpretation_metrics_test.py b/experiments/interpretation_metrics_test.py new file mode 100644 index 0000000..03a8f1a --- /dev/null +++ b/experiments/interpretation_metrics_test.py @@ -0,0 +1,116 @@ +import random +import warnings + +import torch + +from aux.custom_decorators import timing_decorator +from aux.utils import EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, EXPLAINERS_INIT_PARAMETERS_PATH +from explainers.explainers_manager import FrameworkExplainersManager +from models_builder.gnn_models import FrameworkGNNModelManager, Metric +from src.aux.configs import ModelModificationConfig, ConfigPattern +from src.base.datasets_processing import DatasetManager +from src.models_builder.models_zoo import model_configs_zoo + + +@timing_decorator +def run_interpretation_test(): + full_name = ("single-graph", "Planetoid", 'Cora') + steps_epochs = 10 + save_model_flag = False + my_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + dataset, data, results_dataset_path = DatasetManager.get_by_full_name( + full_name=full_name, + dataset_ver_ind=0 + ) + gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn') + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + # "_config_class": "Config", + "_class_name": "Adam", + # "_import_path": OPTIMIZERS_PARAMETERS_PATH, + # "_class_import_info": ["torch.optim"], + "_config_kwargs": {}, + } + } + ) + gnn_model_manager = FrameworkGNNModelManager( + gnn=gnn, + dataset_path=results_dataset_path, + manager_config=manager_config, + modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs) + ) + gnn_model_manager.gnn.to(my_device) + data.x = data.x.float() + data = data.to(my_device) + + warnings.warn("Start training") + try: + raise FileNotFoundError() + except FileNotFoundError: + gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 + train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, + save_model_flag=save_model_flag, + metrics=[Metric("F1", mask='train', average=None)]) + + if train_test_split_path is not None: + dataset.save_train_test_mask(train_test_split_path) + train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ + :] + dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask + data.percent_train_class, data.percent_test_class = train_test_sizes + warnings.warn("Training was successful") + + metric_loc = gnn_model_manager.evaluate_model( + gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')]) + print(metric_loc) + + explainer_init_config = ConfigPattern( + _class_name="GNNExplainer(torch-geom)", + _import_path=EXPLAINERS_INIT_PARAMETERS_PATH, + _config_class="ExplainerInitConfig", + _config_kwargs={ + "epochs": 10 + } + ) + explainer_run_config = ConfigPattern( + _config_class="ExplainerRunConfig", + _config_kwargs={ + "mode": "local", + "kwargs": { + "_class_name": "GNNExplainer(torch-geom)", + "_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, + "_config_class": "Config", + "_config_kwargs": { + "element_idx": 5 + }, + } + } + ) + + explainer_GNNExpl = FrameworkExplainersManager( + init_config=explainer_init_config, + dataset=dataset, gnn_manager=gnn_model_manager, + explainer_name='GNNExplainer(torch-geom)', + ) + + num_explaining_nodes = 10 + node_indices = random.sample(range(dataset.data.x.shape[0]), num_explaining_nodes) + + # explainer_GNNExpl.explainer.pbar = ProgressBar(socket, "er", desc=f'{explainer_GNNExpl.explainer.name} explaining') + # explanation_metric = NodesExplainerMetric( + # model=explainer_GNNExpl.gnn, + # graph=explainer_GNNExpl.gen_dataset.data, + # explainer=explainer_GNNExpl.explainer + # ) + # res = explanation_metric.evaluate(node_indices) + explanation_metrics = explainer_GNNExpl.evaluate_metrics(node_indices) + print(explanation_metrics) + + +if __name__ == '__main__': + random.seed(11) + run_interpretation_test() diff --git a/src/explainers/GNNExplainer/torch_geom_our/out.py b/src/explainers/GNNExplainer/torch_geom_our/out.py index e0a10ce..cfd3d5f 100644 --- a/src/explainers/GNNExplainer/torch_geom_our/out.py +++ b/src/explainers/GNNExplainer/torch_geom_our/out.py @@ -102,6 +102,17 @@ def run(self, mode, kwargs, finalize=True): self.raw_explanation = self.explainer(self.x, self.edge_index, index=self.node_idx) self.pbar.close() + @finalize_decorator + def evaluate_tensor_graph(self, x, edge_index, node_idx, **kwargs): + self._run_mode = "local" + self.node_idx = node_idx + self.x = x + self.edge_index = edge_index + self.pbar.reset(total=self.epochs, mode=self._run_mode) + self.explainer.algorithm.pbar = self.pbar + self.raw_explanation = self.explainer(self.x, self.edge_index, index=self.node_idx, **kwargs) + self.pbar.close() + def _finalize(self): mode = self._run_mode assert mode == "local" diff --git a/src/explainers/explainer_metrics.py b/src/explainers/explainer_metrics.py new file mode 100644 index 0000000..bb6e467 --- /dev/null +++ b/src/explainers/explainer_metrics.py @@ -0,0 +1,171 @@ +import numpy as np +import torch +from torch_geometric.utils import subgraph + + +class NodesExplainerMetric: + def __init__(self, model, graph, explainer): + self.model = model + self.explainer = explainer + self.graph = graph + self.x = self.graph.x + self.edge_index = self.graph.edge_index + self.nodes_explanations = {} # explanations cache. node_ind -> explanation + self.dictionary = { + } + + def evaluate(self, target_nodes_indices): + num_targets = len(target_nodes_indices) + sparsity = 0 + stability = 0 + consistency = 0 + for node_ind in target_nodes_indices: + self.get_explanation(node_ind) + sparsity += self.calculate_sparsity(node_ind) + stability += self.calculate_stability(node_ind) + consistency += self.calculate_consistency(node_ind) + fidelity = self.calculate_fidelity(target_nodes_indices) + self.dictionary["sparsity"] = sparsity / num_targets + self.dictionary["stability"] = stability / num_targets + self.dictionary["consistency"] = consistency / num_targets + self.dictionary["fidelity"] = fidelity + return self.dictionary + + def calculate_fidelity(self, target_nodes_indices): + original_answer = self.model.get_answer(self.x, self.edge_index) + same_answers_count = 0 + for node_ind in target_nodes_indices: + node_explanation = self.get_explanation(node_ind) + new_x, new_edge_index, new_target_node = self.filter_graph_by_explanation( + self.x, self.edge_index, node_explanation, node_ind + ) + filtered_answer = self.model.get_answer(new_x, new_edge_index) + matched = filtered_answer[new_target_node] == original_answer[node_ind] + print(f"Processed fidelity calculation for node id {node_ind}. Matched: {matched}") + if matched: + same_answers_count += 1 + fidelity = same_answers_count / len(target_nodes_indices) + return fidelity + + def calculate_sparsity(self, node_ind): + explanation = self.get_explanation(node_ind) + sparsity = 1 - (len(explanation["data"]["nodes"]) + len(explanation["data"]["edges"])) / ( + len(self.x) + len(self.edge_index)) + return sparsity + + def calculate_stability(self, node_ind, feature_change_percent=0.05, node_removal_percent=0.05): + base_explanation = self.get_explanation(node_ind) + new_x, new_edge_index = self.perturb_graph( + self.x, self.edge_index, node_ind, feature_change_percent, node_removal_percent + ) + perturbed_explanation = self.calculate_explanation(new_x, new_edge_index, node_ind) + + base_explanation_vector, perturbed_explanation_vector = \ + NodesExplainerMetric.calculate_explanation_vectors(base_explanation, perturbed_explanation) + + return euclidean_distance(base_explanation_vector, perturbed_explanation_vector) + + def calculate_consistency(self, node_ind): + base_explanation = self.get_explanation(node_ind) + perturbed_explanation = self.calculate_explanation(self.x, self.edge_index, node_ind) + base_explanation_vector, perturbed_explanation_vector = \ + NodesExplainerMetric.calculate_explanation_vectors(base_explanation, perturbed_explanation) + return cosine_similarity(base_explanation_vector, perturbed_explanation_vector) + + def calculate_explanation(self, x, edge_index, node_idx, **kwargs): + self.explainer.evaluate_tensor_graph(x, edge_index, node_idx, **kwargs) + return self.explainer.explanation.dictionary + + def get_explanation(self, node_ind): + if node_ind in self.nodes_explanations: + node_explanation = self.nodes_explanations[node_ind] + else: + node_explanation = self.calculate_explanation(self.x, self.edge_index, node_ind) + self.nodes_explanations[node_ind] = node_explanation + print(f"Processed explanation calculation for node id {node_ind}.") + return node_explanation + + @staticmethod + def parse_explanation(explanation): + important_nodes = { + int(node): float(weight) for node, weight in explanation["data"]["nodes"].items() + } + important_edges = { + tuple(map(int, edge.split(','))): float(weight) + for edge, weight in explanation["data"]["edges"].items() + } + return important_nodes, important_edges + + @staticmethod + def filter_graph_by_explanation(x, edge_index, explanation, target_node): + important_nodes, important_edges = NodesExplainerMetric.parse_explanation(explanation) + all_important_nodes = set(important_nodes.keys()) + all_important_nodes.add(target_node) + for u, v in important_edges.keys(): + all_important_nodes.add(u) + all_important_nodes.add(v) + + important_node_indices = list(all_important_nodes) + node_mask = torch.zeros(x.size(0), dtype=torch.bool) + node_mask[important_node_indices] = True + + new_edge_index, new_edge_weight = subgraph(node_mask, edge_index, relabel_nodes=True) + new_x = x[node_mask] + new_target_node = important_node_indices.index(target_node) + return new_x, new_edge_index, new_target_node + + @staticmethod + def calculate_explanation_vectors(base_explanation, perturbed_explanation): + base_important_nodes, base_important_edges = NodesExplainerMetric.parse_explanation( + base_explanation + ) + perturbed_important_nodes, perturbed_important_edges = NodesExplainerMetric.parse_explanation( + perturbed_explanation + ) + union_nodes = set(base_important_nodes.keys()) | set(perturbed_important_nodes.keys()) + union_edges = set(base_important_edges.keys()) | set(perturbed_important_edges.keys()) + explain_vector_len = len(union_nodes) + len(union_edges) + base_explanation_vector = np.zeros(explain_vector_len) + perturbed_explanation_vector = np.zeros(explain_vector_len) + i = 0 + for expl_node_ind in union_nodes: + base_explanation_vector[i] = base_important_nodes.get(expl_node_ind, 0) + perturbed_explanation_vector[i] = perturbed_important_nodes.get(expl_node_ind, 0) + i += 1 + for expl_edge in union_edges: + base_explanation_vector[i] = base_important_edges.get(expl_edge, 0) + perturbed_explanation_vector[i] = perturbed_important_edges.get(expl_edge, 0) + i += 1 + return base_explanation_vector, perturbed_explanation_vector + + @staticmethod + def perturb_graph(x, edge_index, node_ind, feature_change_percent, node_removal_percent): + new_x = x.clone() + num_nodes = x.shape[0] + num_features = x.shape[1] + num_features_to_change = int(feature_change_percent * num_nodes * num_features) + indices = torch.randint(0, num_nodes * num_features, (num_features_to_change,), device=x.device) + new_x.view(-1)[indices] = 1.0 - new_x.view(-1)[indices] + + neighbors = edge_index[1][edge_index[0] == node_ind].unique() + num_nodes_to_remove = int(node_removal_percent * neighbors.shape[0]) + + if num_nodes_to_remove > 0: + nodes_to_remove = neighbors[ + torch.randperm(neighbors.size(0), device=edge_index.device)[:num_nodes_to_remove] + ] + mask = ~((edge_index[0] == node_ind).unsqueeze(1) & (edge_index[1].unsqueeze(0) == nodes_to_remove).any( + dim=0)) + new_edge_index = edge_index[:, mask] + else: + new_edge_index = edge_index + + return new_x, new_edge_index + + +def cosine_similarity(a, b): + return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + + +def euclidean_distance(a, b): + return np.linalg.norm(a - b) diff --git a/src/explainers/explainers_manager.py b/src/explainers/explainers_manager.py index 933771d..b9e3580 100644 --- a/src/explainers/explainers_manager.py +++ b/src/explainers/explainers_manager.py @@ -1,10 +1,10 @@ import json -from aux.configs import ExplainerInitConfig, ExplainerModificationConfig, ExplainerRunConfig, \ - CONFIG_CLASS_NAME, CONFIG_OBJ, ConfigPattern +from aux.configs import ExplainerInitConfig, ExplainerModificationConfig, CONFIG_OBJ, ConfigPattern from aux.declaration import Declare from aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH from explainers.explainer import Explainer, ProgressBar +from explainers.explainer_metrics import NodesExplainerMetric # TODO misha can we do it not manually? # Need to import all modules with subclasses of Explainer, otherwise python can't see them @@ -176,6 +176,45 @@ def conduct_experiment(self, run_config, socket=None): return result + def evaluate_metrics(self, target_nodes_indices, socket=None): + """ + Evaluates explanation metrics between given node indices + """ + # TODO: Refactor this method for framework design + self.explainer.pbar = ProgressBar( + socket, "er", desc=f'{self.explainer.name} explaining metrics calculation' + ) # progress bar + try: + print("Evaluating explanation metrics...") + if self.gen_dataset.is_multi(): + raise NotImplementedError("Explanation metrics for graph classification") + else: + explanation_metrics_calculator = NodesExplainerMetric( + model=self.gnn, + graph=self.gen_dataset.data, + explainer=self.explainer + ) + result = explanation_metrics_calculator.evaluate(target_nodes_indices) + print("Explanation metrics are ready") + + if socket: + # TODO: Handle this on frontend + socket.send("er", { + "status": "OK", + "explanation_metrics": result + }) + + # TODO what if save_explanation_flag=False? + if self.save_explanation_flag: + # self.save_explanation_metrics(run_config) + self.model_manager.save_model_executor() + except Exception as e: + if socket: + socket.send("er", {"status": "FAILED"}) + raise e + + return result + @staticmethod def available_explainers(gen_dataset, model_manager): """ Get a list of explainers applicable for current model and dataset. From a06a1952a6d2591ac9eea9dce34b3accee67780b Mon Sep 17 00:00:00 2001 From: mikhail Date: Tue, 29 Oct 2024 15:32:06 +0300 Subject: [PATCH 2/5] fix gnn_explainer work with features --- metainfo/explainers_init_parameters.json | 4 +-- .../GNNExplainer/torch_geom_our/out.py | 32 +++++++++++++------ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/metainfo/explainers_init_parameters.json b/metainfo/explainers_init_parameters.json index 66596f8..a4f93be 100644 --- a/metainfo/explainers_init_parameters.json +++ b/metainfo/explainers_init_parameters.json @@ -20,8 +20,8 @@ "GNNExplainer(torch-geom)": { "epochs": ["Epochs","int",100,{"min": 1},"The number of epochs to train"], "lr": ["Learn rate","float",0.01,{"min": 0, "step": 0.0001},"The learning rate to apply"], - "node_mask_type": ["Node mask","string","object",["None","object","common_attributes","attributes"],"The type of mask to apply on nodes"], - "edge_mask_type": ["Edge mask","string","object",["None","object","common_attributes","attributes"],"The type of mask to apply on edges"], + "node_mask_type": ["Node mask","string","object",["None","object","common_attributes"],"The type of mask to apply on nodes"], + "edge_mask_type": ["Edge mask","string","object",["None","object"],"The type of mask to apply on edges"], "mode": ["Mode","string","multiclass_classification",["binary_classification","multiclass_classification","regression"],"The mode of the model"], "return_type": ["Model return","string","log_probs",["raw","prob","log_probs"],"Denotes the type of output from model. Valid inputs are 'log_probs' (the model returns the logarithm of probabilities), 'prob' (the model returns probabilities), 'raw' (the model returns raw scores)"], "edge_size": ["edge_size","float",0.005,{"min": 0, "step": 0.001},""], diff --git a/src/explainers/GNNExplainer/torch_geom_our/out.py b/src/explainers/GNNExplainer/torch_geom_our/out.py index e0a10ce..fa6497d 100644 --- a/src/explainers/GNNExplainer/torch_geom_our/out.py +++ b/src/explainers/GNNExplainer/torch_geom_our/out.py @@ -111,13 +111,14 @@ def _finalize(self): self.explanation = AttributionExplanation( local=mode, - edges="continuous" if edge_mask is not None else False, - features="continuous" if node_mask is not None else False) + edges="continuous" if self.edge_mask_type=="object" else False, + nodes="continuous" if self.node_mask_type=="object" else False, + features="continuous" if self.node_mask_type=="common_attributes" else False) important_edges = {} important_nodes = {} + important_features = {} - # TODO What if edge_mask_type or node_mask_type is None, common_attributes, attributes? if self.edge_mask_type is not None and self.node_mask_type is not None: # Multi graphs check is not needed: the explanation format for @@ -136,21 +137,34 @@ def _finalize(self): important_edges[f"{edge[0]},{edge[1]}"] = format(imp, '.4f') # Nodes - num_nodes = node_mask.size(0) - assert num_nodes == self.x.size(0) + if self.node_mask_type=="object": + num_nodes = node_mask.size(0) + assert num_nodes == self.x.size(0) - for i in range(num_nodes): - imp = float(node_mask[i][0]) - if not imp < eps: - important_nodes[i] = format(imp, '.4f') + for i in range(num_nodes): + imp = float(node_mask[i][0]) + if not imp < eps: + important_nodes[i] = format(imp, '.4f') + + # Features + elif self.node_mask_type=="common_attributes": + num_features = node_mask.size(1) + assert num_features == self.x.size(1) + + for i in range(num_features): + imp = float(node_mask[0][i]) + if not imp < eps: + important_features[i] = format(imp, '.4f') if self.gen_dataset.is_multi(): important_edges = {self.graph_idx: important_edges} important_nodes = {self.graph_idx: important_nodes} + important_features = {self.graph_idx: important_features} # TODO Write functions with output threshold self.explanation.add_edges(important_edges) self.explanation.add_nodes(important_nodes) + self.explanation.add_features(important_features) # print(important_edges) # print(important_nodes) From a62a3e1da50a42f5e6f3d137e6fcfa6fd64013ab Mon Sep 17 00:00:00 2001 From: serafim Date: Tue, 29 Oct 2024 15:47:59 +0300 Subject: [PATCH 3/5] explainer metrics - metrics params --- experiments/interpretation_metrics_test.py | 9 ++- src/explainers/explainer_metrics.py | 68 ++++++++++++++++------ src/explainers/explainers_manager.py | 9 ++- 3 files changed, 63 insertions(+), 23 deletions(-) diff --git a/experiments/interpretation_metrics_test.py b/experiments/interpretation_metrics_test.py index 03a8f1a..8b01d9e 100644 --- a/experiments/interpretation_metrics_test.py +++ b/experiments/interpretation_metrics_test.py @@ -76,7 +76,7 @@ def run_interpretation_test(): "epochs": 10 } ) - explainer_run_config = ConfigPattern( + explainer_metrics_run_config = ConfigPattern( _config_class="ExplainerRunConfig", _config_kwargs={ "mode": "local", @@ -85,7 +85,10 @@ def run_interpretation_test(): "_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, "_config_class": "Config", "_config_kwargs": { - "element_idx": 5 + "stability_graph_perturbations_nums": 10, + "stability_feature_change_percent": 0.05, + "stability_node_removal_percent": 0.05, + "consistency_num_explanation_runs": 10 }, } } @@ -107,7 +110,7 @@ def run_interpretation_test(): # explainer=explainer_GNNExpl.explainer # ) # res = explanation_metric.evaluate(node_indices) - explanation_metrics = explainer_GNNExpl.evaluate_metrics(node_indices) + explanation_metrics = explainer_GNNExpl.evaluate_metrics(node_indices, explainer_metrics_run_config) print(explanation_metrics) diff --git a/src/explainers/explainer_metrics.py b/src/explainers/explainer_metrics.py index bb6e467..32e0eb5 100644 --- a/src/explainers/explainer_metrics.py +++ b/src/explainers/explainer_metrics.py @@ -4,12 +4,19 @@ class NodesExplainerMetric: - def __init__(self, model, graph, explainer): + def __init__(self, model, graph, explainer, kwargs_dict): self.model = model self.explainer = explainer self.graph = graph self.x = self.graph.x self.edge_index = self.graph.edge_index + self.kwargs_dict = { + "stability_graph_perturbations_nums": 10, + "stability_feature_change_percent": 0.05, + "stability_node_removal_percent": 0.05, + "consistency_num_explanation_runs": 10 + } + self.kwargs_dict.update(kwargs_dict) self.nodes_explanations = {} # explanations cache. node_ind -> explanation self.dictionary = { } @@ -22,8 +29,16 @@ def evaluate(self, target_nodes_indices): for node_ind in target_nodes_indices: self.get_explanation(node_ind) sparsity += self.calculate_sparsity(node_ind) - stability += self.calculate_stability(node_ind) - consistency += self.calculate_consistency(node_ind) + stability += self.calculate_stability( + node_ind, + graph_perturbations_nums=self.kwargs_dict["stability_graph_perturbations_nums"], + feature_change_percent=self.kwargs_dict["stability_feature_change_percent"], + node_removal_percent=self.kwargs_dict["stability_node_removal_percent"] + ) + consistency += self.calculate_consistency( + node_ind, + num_explanation_runs=self.kwargs_dict["consistency_num_explanation_runs"] + ) fidelity = self.calculate_fidelity(target_nodes_indices) self.dictionary["sparsity"] = sparsity / num_targets self.dictionary["stability"] = stability / num_targets @@ -53,27 +68,45 @@ def calculate_sparsity(self, node_ind): len(self.x) + len(self.edge_index)) return sparsity - def calculate_stability(self, node_ind, feature_change_percent=0.05, node_removal_percent=0.05): + def calculate_stability( + self, + node_ind, + graph_perturbations_nums=10, + feature_change_percent=0.05, + node_removal_percent=0.05 + ): base_explanation = self.get_explanation(node_ind) - new_x, new_edge_index = self.perturb_graph( - self.x, self.edge_index, node_ind, feature_change_percent, node_removal_percent - ) - perturbed_explanation = self.calculate_explanation(new_x, new_edge_index, node_ind) + stability = 0 + for _ in range(graph_perturbations_nums): + new_x, new_edge_index = self.perturb_graph( + self.x, self.edge_index, node_ind, feature_change_percent, node_removal_percent + ) + perturbed_explanation = self.calculate_explanation(new_x, new_edge_index, node_ind) + base_explanation_vector, perturbed_explanation_vector = \ + NodesExplainerMetric.calculate_explanation_vectors(base_explanation, perturbed_explanation) - base_explanation_vector, perturbed_explanation_vector = \ - NodesExplainerMetric.calculate_explanation_vectors(base_explanation, perturbed_explanation) + stability += euclidean_distance(base_explanation_vector, perturbed_explanation_vector) - return euclidean_distance(base_explanation_vector, perturbed_explanation_vector) + stability = stability / graph_perturbations_nums + return stability - def calculate_consistency(self, node_ind): - base_explanation = self.get_explanation(node_ind) - perturbed_explanation = self.calculate_explanation(self.x, self.edge_index, node_ind) - base_explanation_vector, perturbed_explanation_vector = \ - NodesExplainerMetric.calculate_explanation_vectors(base_explanation, perturbed_explanation) - return cosine_similarity(base_explanation_vector, perturbed_explanation_vector) + def calculate_consistency(self, node_ind, num_explanation_runs=10): + explanation = self.get_explanation(node_ind) + consistency = 0 + for _ in range(num_explanation_runs): + perturbed_explanation = self.calculate_explanation(self.x, self.edge_index, node_ind) + base_explanation_vector, perturbed_explanation_vector = \ + NodesExplainerMetric.calculate_explanation_vectors(explanation, perturbed_explanation) + consistency += cosine_similarity(base_explanation_vector, perturbed_explanation_vector) + explanation = perturbed_explanation + + consistency = consistency / num_explanation_runs + return consistency def calculate_explanation(self, x, edge_index, node_idx, **kwargs): + print(f"Processing explanation calculation for node id {node_idx}.") self.explainer.evaluate_tensor_graph(x, edge_index, node_idx, **kwargs) + print(f"Explanation calculation for node id {node_idx} completed.") return self.explainer.explanation.dictionary def get_explanation(self, node_ind): @@ -82,7 +115,6 @@ def get_explanation(self, node_ind): else: node_explanation = self.calculate_explanation(self.x, self.edge_index, node_ind) self.nodes_explanations[node_ind] = node_explanation - print(f"Processed explanation calculation for node id {node_ind}.") return node_explanation @staticmethod diff --git a/src/explainers/explainers_manager.py b/src/explainers/explainers_manager.py index b9e3580..cf93aae 100644 --- a/src/explainers/explainers_manager.py +++ b/src/explainers/explainers_manager.py @@ -176,11 +176,15 @@ def conduct_experiment(self, run_config, socket=None): return result - def evaluate_metrics(self, target_nodes_indices, socket=None): + def evaluate_metrics(self, target_nodes_indices, run_config=None, socket=None): """ Evaluates explanation metrics between given node indices """ # TODO: Refactor this method for framework design + if run_config: + params = getattr(getattr(run_config, CONFIG_OBJ).kwargs, CONFIG_OBJ).to_dict() + else: + params = {} self.explainer.pbar = ProgressBar( socket, "er", desc=f'{self.explainer.name} explaining metrics calculation' ) # progress bar @@ -192,7 +196,8 @@ def evaluate_metrics(self, target_nodes_indices, socket=None): explanation_metrics_calculator = NodesExplainerMetric( model=self.gnn, graph=self.gen_dataset.data, - explainer=self.explainer + explainer=self.explainer, + kwargs_dict=params ) result = explanation_metrics_calculator.evaluate(target_nodes_indices) print("Explanation metrics are ready") From 14034feb8897688c308286945d73532f9693095d Mon Sep 17 00:00:00 2001 From: mikhail Date: Tue, 29 Oct 2024 17:20:06 +0300 Subject: [PATCH 4/5] add raise --- metainfo/explainers_init_parameters.json | 4 ++-- .../GNNExplainer/torch_geom_our/out.py | 23 +++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/metainfo/explainers_init_parameters.json b/metainfo/explainers_init_parameters.json index a4f93be..0280f8e 100644 --- a/metainfo/explainers_init_parameters.json +++ b/metainfo/explainers_init_parameters.json @@ -20,8 +20,8 @@ "GNNExplainer(torch-geom)": { "epochs": ["Epochs","int",100,{"min": 1},"The number of epochs to train"], "lr": ["Learn rate","float",0.01,{"min": 0, "step": 0.0001},"The learning rate to apply"], - "node_mask_type": ["Node mask","string","object",["None","object","common_attributes"],"The type of mask to apply on nodes"], - "edge_mask_type": ["Edge mask","string","object",["None","object"],"The type of mask to apply on edges"], + "node_mask_type": ["Node mask","string","object",["None","object","common_attributes","attributes"],"The type of mask to apply on nodes"], + "edge_mask_type": ["Edge mask","string","common_attributes",["None","object","common_attributes","attributes"],"The type of mask to apply on edges"], "mode": ["Mode","string","multiclass_classification",["binary_classification","multiclass_classification","regression"],"The mode of the model"], "return_type": ["Model return","string","log_probs",["raw","prob","log_probs"],"Denotes the type of output from model. Valid inputs are 'log_probs' (the model returns the logarithm of probabilities), 'prob' (the model returns probabilities), 'raw' (the model returns raw scores)"], "edge_size": ["edge_size","float",0.005,{"min": 0, "step": 0.001},""], diff --git a/src/explainers/GNNExplainer/torch_geom_our/out.py b/src/explainers/GNNExplainer/torch_geom_our/out.py index d773af6..5595d5a 100644 --- a/src/explainers/GNNExplainer/torch_geom_our/out.py +++ b/src/explainers/GNNExplainer/torch_geom_our/out.py @@ -137,15 +137,18 @@ def _finalize(self): eps = 0.001 # Edges - num_edges = edge_mask.size(0) - assert num_edges == self.edge_index.size(1) - edges = self.edge_index + if self.edge_mask_type=="object": + num_edges = edge_mask.size(0) + assert num_edges == self.edge_index.size(1) + edges = self.edge_index - for i in range(num_edges): - imp = float(edge_mask[i]) - if not imp < eps: - edge = edges[0][i], edges[1][i] - important_edges[f"{edge[0]},{edge[1]}"] = format(imp, '.4f') + for i in range(num_edges): + imp = float(edge_mask[i]) + if not imp < eps: + edge = edges[0][i], edges[1][i] + important_edges[f"{edge[0]},{edge[1]}"] = format(imp, '.4f') + else: # if "common_attributes" or "attributes" + raise NotImplementedError(f"Edge mask type '{self.edge_mask_type}' is not yet implemented.") # Nodes if self.node_mask_type=="object": @@ -156,7 +159,6 @@ def _finalize(self): imp = float(node_mask[i][0]) if not imp < eps: important_nodes[i] = format(imp, '.4f') - # Features elif self.node_mask_type=="common_attributes": num_features = node_mask.size(1) @@ -166,6 +168,9 @@ def _finalize(self): imp = float(node_mask[0][i]) if not imp < eps: important_features[i] = format(imp, '.4f') + else: # if "attributes" + # TODO add functional if node_mask_type=="attributes" + raise NotImplementedError(f"Node mask type '{self.node_mask_type}' is not yet implemented.") if self.gen_dataset.is_multi(): important_edges = {self.graph_idx: important_edges} From 231326fb461998e9682f279a9d6f566a656e8182 Mon Sep 17 00:00:00 2001 From: mikhail Date: Tue, 29 Oct 2024 17:43:43 +0300 Subject: [PATCH 5/5] fix --- metainfo/explainers_init_parameters.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metainfo/explainers_init_parameters.json b/metainfo/explainers_init_parameters.json index 0280f8e..66596f8 100644 --- a/metainfo/explainers_init_parameters.json +++ b/metainfo/explainers_init_parameters.json @@ -21,7 +21,7 @@ "epochs": ["Epochs","int",100,{"min": 1},"The number of epochs to train"], "lr": ["Learn rate","float",0.01,{"min": 0, "step": 0.0001},"The learning rate to apply"], "node_mask_type": ["Node mask","string","object",["None","object","common_attributes","attributes"],"The type of mask to apply on nodes"], - "edge_mask_type": ["Edge mask","string","common_attributes",["None","object","common_attributes","attributes"],"The type of mask to apply on edges"], + "edge_mask_type": ["Edge mask","string","object",["None","object","common_attributes","attributes"],"The type of mask to apply on edges"], "mode": ["Mode","string","multiclass_classification",["binary_classification","multiclass_classification","regression"],"The mode of the model"], "return_type": ["Model return","string","log_probs",["raw","prob","log_probs"],"Denotes the type of output from model. Valid inputs are 'log_probs' (the model returns the logarithm of probabilities), 'prob' (the model returns probabilities), 'raw' (the model returns raw scores)"], "edge_size": ["edge_size","float",0.005,{"min": 0, "step": 0.001},""],