Skip to content

Commit

Permalink
+
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeratt committed Oct 28, 2024
1 parent d34224a commit 84bad82
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 46 deletions.
101 changes: 58 additions & 43 deletions experiments/EAttack_experiment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Experiment on attacking GNN via GNNExplainer's explanations
"""
import copy

import torch
import numpy as np
Expand All @@ -9,6 +10,7 @@
from dig.sslgraph.dataset import get_node_dataset
from pyscf.fci.cistring import gen_des_str_index
from torch import device
from tqdm import tqdm

from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
EVASION_DEFENSE_PARAMETERS_PATH
Expand Down Expand Up @@ -80,8 +82,31 @@ def test():
print(f"BEFORE ATTACK\nAccuracy on train: {acc_train}. Accuracy on test: {acc_test}")
# print(f"Accuracy on test: {acc_test}")

explainer_init_config = ConfigPattern(
_class_name="GNNExplainer(torch-geom)",
_import_path=EXPLAINERS_INIT_PARAMETERS_PATH,
_config_class="ExplainerInitConfig",
_config_kwargs={
"node_mask_type": "attributes"
}
)
explainer_run_config = ConfigPattern(
_config_class="ExplainerRunConfig",
_config_kwargs={
"mode": "local",
"kwargs": {
"_class_name": "GNNExplainer(torch-geom)",
"_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH,
"_config_class": "Config",
"_config_kwargs": {

},
}
}
)

# explainer_init_config = ConfigPattern(
# _class_name="GNNExplainer(torch-geom)",
# _class_name="SubgraphX",
# _import_path=EXPLAINERS_INIT_PARAMETERS_PATH,
# _config_class="ExplainerInitConfig",
# _config_kwargs={
Expand All @@ -92,7 +117,7 @@ def test():
# _config_kwargs={
# "mode": "local",
# "kwargs": {
# "_class_name": "GNNExplainer(torch-geom)",
# "_class_name": "SubgraphX",
# "_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH,
# "_config_class": "Config",
# "_config_kwargs": {
Expand All @@ -103,7 +128,7 @@ def test():
# )

# explainer_init_config = ConfigPattern(
# _class_name="SubgraphX",
# _class_name="PGMExplainer",
# _import_path=EXPLAINERS_INIT_PARAMETERS_PATH,
# _config_class="ExplainerInitConfig",
# _config_kwargs={
Expand All @@ -114,7 +139,7 @@ def test():
# _config_kwargs={
# "mode": "local",
# "kwargs": {
# "_class_name": "SubgraphX",
# "_class_name": "PGMExplainer",
# "_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH,
# "_config_class": "Config",
# "_config_kwargs": {
Expand All @@ -124,33 +149,11 @@ def test():
# }
# )

explainer_init_config = ConfigPattern(
_class_name="PGMExplainer",
_import_path=EXPLAINERS_INIT_PARAMETERS_PATH,
_config_class="ExplainerInitConfig",
_config_kwargs={
}
)
explainer_run_config = ConfigPattern(
_config_class="ExplainerRunConfig",
_config_kwargs={
"mode": "local",
"kwargs": {
"_class_name": "PGMExplainer",
"_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH,
"_config_class": "Config",
"_config_kwargs": {

},
}
}
)

init_kwargs = getattr(explainer_init_config, CONFIG_OBJ).to_dict()
# explainer = GNNExplainer(gen_dataset=dataset, model=gnn_model_manager.gnn, device=my_device, **init_kwargs)
explainer = GNNExplainer(gen_dataset=dataset, model=gnn_model_manager.gnn, device=my_device, **init_kwargs)
# explainer = SubgraphXExplainer(gen_dataset=dataset, model=gnn_model_manager.gnn, device=my_device, **init_kwargs)
# explainer = ZorroExplainer(gen_dataset=dataset, model=gnn_model_manager.gnn, device=my_device, **init_kwargs)
explainer = PGMExplainer(gen_dataset=dataset, model=gnn_model_manager.gnn, device=my_device, **init_kwargs)
# explainer = PGMExplainer(gen_dataset=dataset, model=gnn_model_manager.gnn, device=my_device, **init_kwargs)

# node_inds = np.arange(dataset.dataset.data.x.shape[0])
# dataset = gen_dataset.dataset.data[mask_tensor]
Expand All @@ -169,7 +172,7 @@ def test():
if u not in adj_list[v]:
adj_list[v].append(u)
node_inds = [n for n in adj_list.keys() if len(adj_list[n]) > 1]
attacked_node_size = int((0.002 * len(node_inds)))
attacked_node_size = int((0.04 * len(node_inds)))
attack_inds = np.random.choice(node_inds, attacked_node_size)

evasion_attack_config = ConfigPattern(
Expand All @@ -185,25 +188,37 @@ def test():
}
)

gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)
dataset_copy = copy.deepcopy(dataset)

mask = Metric.create_mask_by_target_list(y_true=dataset.labels, target_list=attack_inds)
succ_attack = 0

# explainer_GNNExpl.conduct_experiment(explainer_run_config)
for i in tqdm(attack_inds):

# Evaluate model
mask = Metric.create_mask_by_target_list(y_true=dataset.labels, target_list=[i])

acc_attack = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask=mask)])[mask]['Accuracy']
print(f"AFTER ATTACK\nAccuracy: {acc_attack}")
evasion_attack_config = ConfigPattern(
_class_name="EAttack",
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
_config_class="EvasionAttackConfig",
_config_kwargs={
'explainer': explainer,
'run_config': explainer_run_config,
'mode': 'local',
'attack_inds': [i],
'random_rewire': True
}
)

# acc_train = gnn_model_manager.evaluate_model(gen_dataset=dataset,
# metrics=[Metric("Accuracy", mask='train')])['train']['Accuracy']
#
#
# acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset,
# metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
# print(f"AFTER ATTACK\nAccuracy on train: {acc_train}. Accuracy on test: {acc_test}")
gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)

acc_attack = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask=mask)])[mask]['Accuracy']

succ_attack += acc_attack
# print(f"AFTER ATTACK\nAccuracy: {acc_attack}")

dataset = copy.deepcopy(dataset_copy)
print(f"ACCURACY ON ATTACKED: {succ_attack / len(attack_inds)}")



Expand Down
2 changes: 1 addition & 1 deletion metainfo/evasion_attack_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"random_rewire": ["Random rewire", "bool", false, {}, "Rewire based on random, not on explanation (for comparison)"],
"attack_features": ["Attack features", "bool", false, {}, "Whether features to be attacked or not"],
"attack_edges": ["Attack edges", "bool", true, {}, "Whether edges to be attacked or not"],
"edge_mode": ["Edge attack type", "string", "add", ["remove", "add", "rewire"], "What to do with edges: remove or add or rewire (add one and remove another)"],
"edge_mode": ["Edge attack type", "string", "rewire", ["remove", "add", "rewire"], "What to do with edges: remove or add or rewire (add one and remove another)"],
"features_mode": ["Feature attack type", "string", "reverse", ["reverse","drop"], "What to do with features: drop or reverse (binary)"]
},
"QAttack": {
Expand Down
12 changes: 10 additions & 2 deletions src/attacks/EAttack/experimental_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,23 @@ def attack(self, model_manager, gen_dataset, mask_tensor):
edge_index_set.add((u, new_node[0]))
cnt += 1
elif self.edge_mode == 'rewire':
max_rewire = self.max_rewire
for (u, v) in zip(edge_index[0], edge_index[1]):
if u != n and v != n and f"{u},{v}" in explanations[i]['edges'].keys():
edge_index_set.discard((u, v))
edge_index_set.discard((v, u))
if (u, n) not in edge_index_set:
cnt += 1
edge_index_set.add((u, n))
# edge_index_set.add((u, n))
# edge_index_set.add((n, u))
max_rewire -= 1
elif (v, n) not in edge_index_set:
cnt += 1
edge_index_set.add((v, n))
# edge_index_set.add((v, n))
# edge_index_set.add((n, v))
max_rewire -= 1
if max_rewire <= 0:
break

# Update dataset edges
edge_index_new = [[], []]
Expand Down

0 comments on commit 84bad82

Please sign in to comment.