Skip to content

Commit

Permalink
working 3
Browse files Browse the repository at this point in the history
  • Loading branch information
SerBorka committed Dec 21, 2024
1 parent 74ac43b commit 55822e8
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 23 deletions.
41 changes: 25 additions & 16 deletions experiments/interpretation_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def explainer_run_config_for_node(explainer_name, node_ind, explainer_kwargs=Non
)

@timing_decorator
def run_interpretation_test(dataset_full_name, model_name):
def run_interpretation_test(explainer_name, dataset_full_name, model_name):
steps_epochs = 10
num_explaining_nodes = 1
explaining_metrics_params = {
Expand All @@ -80,9 +80,8 @@ def run_interpretation_test(dataset_full_name, model_name):
explainer_kwargs_by_explainer_name = {
'GNNExplainer(torch-geom)': {},
'SubgraphX': {"max_nodes": 5},
'Zorro': {},
}
explainer_name = 'SubgraphX'
# explainer_name = 'GNNExplainer(torch-geom)'
dataset_key_name = "_".join(dataset_full_name)
metrics_path = root_dir / "experiments" / "explainers_metrics"
dataset_metrics_path = metrics_path / f"{model_name}_{dataset_key_name}_{explainer_name}_metrics.json"
Expand All @@ -95,6 +94,7 @@ def run_interpretation_test(dataset_full_name, model_name):

restart_experiment = False
if restart_experiment:

node_indices = random.sample(range(dataset.data.x.shape[0]), num_explaining_nodes)
result_dict = {
"num_nodes": num_explaining_nodes,
Expand Down Expand Up @@ -129,6 +129,7 @@ def run_interpretation_test(dataset_full_name, model_name):
for experiment_name, calculate_fn in experiment_name_to_experiment:
if experiment_name not in result_dict:
print(f"Calculation of explanation metrics with defence: {experiment_name} started.")
explaining_metrics_params["experiment_name"] = experiment_name
metrics = calculate_fn(
explainer_name,
steps_epochs,
Expand Down Expand Up @@ -185,8 +186,9 @@ def calculate_unprotected_metrics(

warnings.warn("Start training")
try:
print("Loading model executor")
gnn_model_manager.load_model_executor()
print("Loaded model.")
print("Loaded model")
except FileNotFoundError:
gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs,
Expand Down Expand Up @@ -231,7 +233,7 @@ def calculate_jaccard_defence_metrics(
node_id_to_explainer_run_config,
model_name
):
save_model_flag = False
save_model_flag = True
device = torch.device('cpu')

data, results_dataset_path = dataset.data, dataset.results_dir
Expand Down Expand Up @@ -271,9 +273,9 @@ def calculate_jaccard_defence_metrics(
gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config)
warnings.warn("Start training")
try:
raise FileNotFoundError
print("Loading model executor")
gnn_model_manager.load_model_executor()
print("Loaded model")
except FileNotFoundError:
print("Training started.")
gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
Expand Down Expand Up @@ -319,7 +321,7 @@ def calculate_adversial_defence_metrics(
node_id_to_explainer_run_config,
model_name
):
save_model_flag = False
save_model_flag = True
device = torch.device('cpu')

data, results_dataset_path = dataset.data, dataset.results_dir
Expand Down Expand Up @@ -373,9 +375,9 @@ def calculate_adversial_defence_metrics(

warnings.warn("Start training")
try:
raise FileNotFoundError
print("Loading model executor")
gnn_model_manager.load_model_executor()
print("Loaded model.")
print("Loaded model")
except FileNotFoundError:
gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs,
Expand Down Expand Up @@ -420,7 +422,7 @@ def calculate_gnnguard_defence_metrics(
node_id_to_explainer_run_config,
model_name
):
save_model_flag = False
save_model_flag = True
device = torch.device('cpu')

data, results_dataset_path = dataset.data, dataset.results_dir
Expand Down Expand Up @@ -464,9 +466,9 @@ def calculate_gnnguard_defence_metrics(

warnings.warn("Start training")
try:
raise FileNotFoundError
print("Loading model executor")
gnn_model_manager.load_model_executor()
print("Loaded model.")
print("Loaded model")
except FileNotFoundError:
gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs,
Expand Down Expand Up @@ -504,6 +506,13 @@ def calculate_gnnguard_defence_metrics(

if __name__ == '__main__':
# random.seed(777)

explainers = [
# 'GNNExplainer(torch-geom)',
# 'SubgraphX',
"Zorro",
]

models = [
# 'gcn_gcn',
'gat_gat',
Expand All @@ -513,9 +522,9 @@ def calculate_gnnguard_defence_metrics(
("single-graph", "Planetoid", 'Cora'),
# ("single-graph", "Amazon", 'Photo'),
]

for dataset_full_name in datasets:
for model_name in models:
run_interpretation_test(dataset_full_name, model_name)
for explainer in explainers:
for dataset_full_name in datasets:
for model_name in models:
run_interpretation_test(explainer, dataset_full_name, model_name)
# dataset_full_name = ("single-graph", "Amazon", 'Photo')
# run_interpretation_test(dataset_full_name)
14 changes: 7 additions & 7 deletions src/explainers/explainer_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import torch
from torch_geometric.utils import subgraph
from torch_geometric.utils import subgraph, k_hop_subgraph

from aux.configs import ConfigPattern
from aux.custom_decorators import timing_decorator
Expand Down Expand Up @@ -102,16 +102,16 @@ def calculate_fidelity(self, target_nodes_indices):
@timing_decorator
def calculate_sparsity(self, node_ind):
explanation = self.get_explanations(node_ind)[0]
num_hops = self.model.get_num_hops()
local_subset, local_edge_index, _, _ = k_hop_subgraph(node_ind, num_hops, self.edge_index, relabel_nodes=False)
num = 0
den = 0
# TODO: fix me by NeighborLoader
if explanation["data"]["nodes"]:
num += len(explanation["data"]["nodes"])
den += self.x.shape[0]
den += local_subset.shape[0]
if explanation["data"]["edges"]:
num += len(explanation["data"]["edges"])
den += self.edge_index.shape[1]

den += local_edge_index.shape[1]
sparsity = 1 - num / den
print(f"Sparsity calculation for node id {node_ind} completed.")
return sparsity
Expand Down Expand Up @@ -154,11 +154,11 @@ def calculate_stability(
@timing_decorator
def calculate_consistency(self, node_ind, num_explanation_runs=10):
print(f"Consistency calculation for node id {node_ind} started.")
explanations = self.get_explanations(node_ind, num_explanations=num_explanation_runs+1)
explanations = self.get_explanations(node_ind, num_explanations=num_explanation_runs + 1)
explanation = explanations[0]
consistency = []
for ind in range(num_explanation_runs):
perturbed_explanation = explanations[ind+1]
perturbed_explanation = explanations[ind + 1]
base_explanation_vector, perturbed_explanation_vector = \
NodesExplainerMetric.calculate_explanation_vectors(explanation, perturbed_explanation)
consistency += [cosine_similarity(base_explanation_vector, perturbed_explanation_vector)]
Expand Down

0 comments on commit 55822e8

Please sign in to comment.