From 55822e8679f674928ddb4bd4b1a6835d6ae474ae Mon Sep 17 00:00:00 2001 From: serafim Date: Sat, 21 Dec 2024 17:37:32 +0300 Subject: [PATCH] working 3 --- experiments/interpretation_metrics_test.py | 41 +++++++++++++--------- src/explainers/explainer_metrics.py | 14 ++++---- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/experiments/interpretation_metrics_test.py b/experiments/interpretation_metrics_test.py index e3bacbd..4886299 100644 --- a/experiments/interpretation_metrics_test.py +++ b/experiments/interpretation_metrics_test.py @@ -60,7 +60,7 @@ def explainer_run_config_for_node(explainer_name, node_ind, explainer_kwargs=Non ) @timing_decorator -def run_interpretation_test(dataset_full_name, model_name): +def run_interpretation_test(explainer_name, dataset_full_name, model_name): steps_epochs = 10 num_explaining_nodes = 1 explaining_metrics_params = { @@ -80,9 +80,8 @@ def run_interpretation_test(dataset_full_name, model_name): explainer_kwargs_by_explainer_name = { 'GNNExplainer(torch-geom)': {}, 'SubgraphX': {"max_nodes": 5}, + 'Zorro': {}, } - explainer_name = 'SubgraphX' - # explainer_name = 'GNNExplainer(torch-geom)' dataset_key_name = "_".join(dataset_full_name) metrics_path = root_dir / "experiments" / "explainers_metrics" dataset_metrics_path = metrics_path / f"{model_name}_{dataset_key_name}_{explainer_name}_metrics.json" @@ -95,6 +94,7 @@ def run_interpretation_test(dataset_full_name, model_name): restart_experiment = False if restart_experiment: + node_indices = random.sample(range(dataset.data.x.shape[0]), num_explaining_nodes) result_dict = { "num_nodes": num_explaining_nodes, @@ -129,6 +129,7 @@ def run_interpretation_test(dataset_full_name, model_name): for experiment_name, calculate_fn in experiment_name_to_experiment: if experiment_name not in result_dict: print(f"Calculation of explanation metrics with defence: {experiment_name} started.") + explaining_metrics_params["experiment_name"] = experiment_name metrics = calculate_fn( explainer_name, steps_epochs, @@ -185,8 +186,9 @@ def calculate_unprotected_metrics( warnings.warn("Start training") try: + print("Loading model executor") gnn_model_manager.load_model_executor() - print("Loaded model.") + print("Loaded model") except FileNotFoundError: gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, @@ -231,7 +233,7 @@ def calculate_jaccard_defence_metrics( node_id_to_explainer_run_config, model_name ): - save_model_flag = False + save_model_flag = True device = torch.device('cpu') data, results_dataset_path = dataset.data, dataset.results_dir @@ -271,9 +273,9 @@ def calculate_jaccard_defence_metrics( gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config) warnings.warn("Start training") try: - raise FileNotFoundError print("Loading model executor") gnn_model_manager.load_model_executor() + print("Loaded model") except FileNotFoundError: print("Training started.") gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 @@ -319,7 +321,7 @@ def calculate_adversial_defence_metrics( node_id_to_explainer_run_config, model_name ): - save_model_flag = False + save_model_flag = True device = torch.device('cpu') data, results_dataset_path = dataset.data, dataset.results_dir @@ -373,9 +375,9 @@ def calculate_adversial_defence_metrics( warnings.warn("Start training") try: - raise FileNotFoundError + print("Loading model executor") gnn_model_manager.load_model_executor() - print("Loaded model.") + print("Loaded model") except FileNotFoundError: gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, @@ -420,7 +422,7 @@ def calculate_gnnguard_defence_metrics( node_id_to_explainer_run_config, model_name ): - save_model_flag = False + save_model_flag = True device = torch.device('cpu') data, results_dataset_path = dataset.data, dataset.results_dir @@ -464,9 +466,9 @@ def calculate_gnnguard_defence_metrics( warnings.warn("Start training") try: - raise FileNotFoundError + print("Loading model executor") gnn_model_manager.load_model_executor() - print("Loaded model.") + print("Loaded model") except FileNotFoundError: gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, @@ -504,6 +506,13 @@ def calculate_gnnguard_defence_metrics( if __name__ == '__main__': # random.seed(777) + + explainers = [ + # 'GNNExplainer(torch-geom)', + # 'SubgraphX', + "Zorro", + ] + models = [ # 'gcn_gcn', 'gat_gat', @@ -513,9 +522,9 @@ def calculate_gnnguard_defence_metrics( ("single-graph", "Planetoid", 'Cora'), # ("single-graph", "Amazon", 'Photo'), ] - - for dataset_full_name in datasets: - for model_name in models: - run_interpretation_test(dataset_full_name, model_name) + for explainer in explainers: + for dataset_full_name in datasets: + for model_name in models: + run_interpretation_test(explainer, dataset_full_name, model_name) # dataset_full_name = ("single-graph", "Amazon", 'Photo') # run_interpretation_test(dataset_full_name) diff --git a/src/explainers/explainer_metrics.py b/src/explainers/explainer_metrics.py index c091487..32cf4b8 100644 --- a/src/explainers/explainer_metrics.py +++ b/src/explainers/explainer_metrics.py @@ -5,7 +5,7 @@ import numpy as np import torch -from torch_geometric.utils import subgraph +from torch_geometric.utils import subgraph, k_hop_subgraph from aux.configs import ConfigPattern from aux.custom_decorators import timing_decorator @@ -102,16 +102,16 @@ def calculate_fidelity(self, target_nodes_indices): @timing_decorator def calculate_sparsity(self, node_ind): explanation = self.get_explanations(node_ind)[0] + num_hops = self.model.get_num_hops() + local_subset, local_edge_index, _, _ = k_hop_subgraph(node_ind, num_hops, self.edge_index, relabel_nodes=False) num = 0 den = 0 - # TODO: fix me by NeighborLoader if explanation["data"]["nodes"]: num += len(explanation["data"]["nodes"]) - den += self.x.shape[0] + den += local_subset.shape[0] if explanation["data"]["edges"]: num += len(explanation["data"]["edges"]) - den += self.edge_index.shape[1] - + den += local_edge_index.shape[1] sparsity = 1 - num / den print(f"Sparsity calculation for node id {node_ind} completed.") return sparsity @@ -154,11 +154,11 @@ def calculate_stability( @timing_decorator def calculate_consistency(self, node_ind, num_explanation_runs=10): print(f"Consistency calculation for node id {node_ind} started.") - explanations = self.get_explanations(node_ind, num_explanations=num_explanation_runs+1) + explanations = self.get_explanations(node_ind, num_explanations=num_explanation_runs + 1) explanation = explanations[0] consistency = [] for ind in range(num_explanation_runs): - perturbed_explanation = explanations[ind+1] + perturbed_explanation = explanations[ind + 1] base_explanation_vector, perturbed_explanation_vector = \ NodesExplainerMetric.calculate_explanation_vectors(explanation, perturbed_explanation) consistency += [cosine_similarity(base_explanation_vector, perturbed_explanation_vector)]