Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/ispras/GNN-AID into pgd_…
Browse files Browse the repository at this point in the history
…attack
  • Loading branch information
mishabounty committed Oct 29, 2024
2 parents 28896b9 + 781c01a commit 0deb805
Show file tree
Hide file tree
Showing 4 changed files with 417 additions and 21 deletions.
119 changes: 119 additions & 0 deletions experiments/interpretation_metrics_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import random
import warnings

import torch

from aux.custom_decorators import timing_decorator
from aux.utils import EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, EXPLAINERS_INIT_PARAMETERS_PATH
from explainers.explainers_manager import FrameworkExplainersManager
from models_builder.gnn_models import FrameworkGNNModelManager, Metric
from src.aux.configs import ModelModificationConfig, ConfigPattern
from src.base.datasets_processing import DatasetManager
from src.models_builder.models_zoo import model_configs_zoo


@timing_decorator
def run_interpretation_test():
full_name = ("single-graph", "Planetoid", 'Cora')
steps_epochs = 10
save_model_flag = False
my_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset, data, results_dataset_path = DatasetManager.get_by_full_name(
full_name=full_name,
dataset_ver_ind=0
)
gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn')
manager_config = ConfigPattern(
_config_class="ModelManagerConfig",
_config_kwargs={
"mask_features": [],
"optimizer": {
# "_config_class": "Config",
"_class_name": "Adam",
# "_import_path": OPTIMIZERS_PARAMETERS_PATH,
# "_class_import_info": ["torch.optim"],
"_config_kwargs": {},
}
}
)
gnn_model_manager = FrameworkGNNModelManager(
gnn=gnn,
dataset_path=results_dataset_path,
manager_config=manager_config,
modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs)
)
gnn_model_manager.gnn.to(my_device)
data.x = data.x.float()
data = data.to(my_device)

warnings.warn("Start training")
try:
raise FileNotFoundError()
except FileNotFoundError:
gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)])

if train_test_split_path is not None:
dataset.save_train_test_mask(train_test_split_path)
train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[
:]
dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask
data.percent_train_class, data.percent_test_class = train_test_sizes
warnings.warn("Training was successful")

metric_loc = gnn_model_manager.evaluate_model(
gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')])
print(metric_loc)

explainer_init_config = ConfigPattern(
_class_name="GNNExplainer(torch-geom)",
_import_path=EXPLAINERS_INIT_PARAMETERS_PATH,
_config_class="ExplainerInitConfig",
_config_kwargs={
"epochs": 10
}
)
explainer_metrics_run_config = ConfigPattern(
_config_class="ExplainerRunConfig",
_config_kwargs={
"mode": "local",
"kwargs": {
"_class_name": "GNNExplainer(torch-geom)",
"_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH,
"_config_class": "Config",
"_config_kwargs": {
"stability_graph_perturbations_nums": 10,
"stability_feature_change_percent": 0.05,
"stability_node_removal_percent": 0.05,
"consistency_num_explanation_runs": 10
},
}
}
)

explainer_GNNExpl = FrameworkExplainersManager(
init_config=explainer_init_config,
dataset=dataset, gnn_manager=gnn_model_manager,
explainer_name='GNNExplainer(torch-geom)',
)

num_explaining_nodes = 10
node_indices = random.sample(range(dataset.data.x.shape[0]), num_explaining_nodes)

# explainer_GNNExpl.explainer.pbar = ProgressBar(socket, "er", desc=f'{explainer_GNNExpl.explainer.name} explaining')
# explanation_metric = NodesExplainerMetric(
# model=explainer_GNNExpl.gnn,
# graph=explainer_GNNExpl.gen_dataset.data,
# explainer=explainer_GNNExpl.explainer
# )
# res = explanation_metric.evaluate(node_indices)
explanation_metrics = explainer_GNNExpl.evaluate_metrics(node_indices, explainer_metrics_run_config)
print(explanation_metrics)


if __name__ == '__main__':
random.seed(11)
run_interpretation_test()
68 changes: 49 additions & 19 deletions src/explainers/GNNExplainer/torch_geom_our/out.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ def run(self, mode, kwargs, finalize=True):
self.raw_explanation = self.explainer(self.x, self.edge_index, index=self.node_idx)
self.pbar.close()

@finalize_decorator
def evaluate_tensor_graph(self, x, edge_index, node_idx, **kwargs):
self._run_mode = "local"
self.node_idx = node_idx
self.x = x
self.edge_index = edge_index
self.pbar.reset(total=self.epochs, mode=self._run_mode)
self.explainer.algorithm.pbar = self.pbar
self.raw_explanation = self.explainer(self.x, self.edge_index, index=self.node_idx, **kwargs)
self.pbar.close()

def _finalize(self):
mode = self._run_mode
assert mode == "local"
Expand All @@ -111,46 +122,65 @@ def _finalize(self):

self.explanation = AttributionExplanation(
local=mode,
edges="continuous" if edge_mask is not None else False,
features="continuous" if node_mask is not None else False)
edges="continuous" if self.edge_mask_type=="object" else False,
nodes="continuous" if self.node_mask_type=="object" else False,
features="continuous" if self.node_mask_type=="common_attributes" else False)

important_edges = {}
important_nodes = {}
important_features = {}

# TODO What if edge_mask_type or node_mask_type is None, common_attributes, attributes?
if self.edge_mask_type is not None and self.node_mask_type is not None:

# Multi graphs check is not needed: the explanation format for
# graph classification and node classification is the same
eps = 0.001

# Edges
num_edges = edge_mask.size(0)
assert num_edges == self.edge_index.size(1)
edges = self.edge_index

for i in range(num_edges):
imp = float(edge_mask[i])
if not imp < eps:
edge = edges[0][i], edges[1][i]
important_edges[f"{edge[0]},{edge[1]}"] = format(imp, '.4f')
if self.edge_mask_type=="object":
num_edges = edge_mask.size(0)
assert num_edges == self.edge_index.size(1)
edges = self.edge_index

for i in range(num_edges):
imp = float(edge_mask[i])
if not imp < eps:
edge = edges[0][i], edges[1][i]
important_edges[f"{edge[0]},{edge[1]}"] = format(imp, '.4f')
else: # if "common_attributes" or "attributes"
raise NotImplementedError(f"Edge mask type '{self.edge_mask_type}' is not yet implemented.")

# Nodes
num_nodes = node_mask.size(0)
assert num_nodes == self.x.size(0)

for i in range(num_nodes):
imp = float(node_mask[i][0])
if not imp < eps:
important_nodes[i] = format(imp, '.4f')
if self.node_mask_type=="object":
num_nodes = node_mask.size(0)
assert num_nodes == self.x.size(0)

for i in range(num_nodes):
imp = float(node_mask[i][0])
if not imp < eps:
important_nodes[i] = format(imp, '.4f')
# Features
elif self.node_mask_type=="common_attributes":
num_features = node_mask.size(1)
assert num_features == self.x.size(1)

for i in range(num_features):
imp = float(node_mask[0][i])
if not imp < eps:
important_features[i] = format(imp, '.4f')
else: # if "attributes"
# TODO add functional if node_mask_type=="attributes"
raise NotImplementedError(f"Node mask type '{self.node_mask_type}' is not yet implemented.")

if self.gen_dataset.is_multi():
important_edges = {self.graph_idx: important_edges}
important_nodes = {self.graph_idx: important_nodes}
important_features = {self.graph_idx: important_features}

# TODO Write functions with output threshold
self.explanation.add_edges(important_edges)
self.explanation.add_nodes(important_nodes)
self.explanation.add_features(important_features)

# print(important_edges)
# print(important_nodes)
Expand Down
Loading

0 comments on commit 0deb805

Please sign in to comment.