Skip to content

Commit

Permalink
+
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeratt committed Oct 21, 2024
1 parent 6ab6b5d commit bfecf18
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 17 deletions.
2 changes: 1 addition & 1 deletion experiments/EAttack_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test():
'run_config': explainer_run_config,
'mode': 'local',
'attack_inds': attack_inds,
'random_rewire': False
'random_rewire': True
}
)

Expand Down
55 changes: 46 additions & 9 deletions experiments/Explain_Defense_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import torch
import numpy as np
import warnings
import copy
import json

from dig.sslgraph.dataset import get_node_dataset
from pyscf.fci.cistring import gen_des_str_index
from torch import device
from tqdm import tqdm

from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
EVASION_DEFENSE_PARAMETERS_PATH
Expand All @@ -17,6 +20,8 @@
from src.base.datasets_processing import DatasetManager
from src.models_builder.models_zoo import model_configs_zoo

from explainers.explainer import ProgressBar

from aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, \
EXPLAINERS_GLOBAL_RUN_PARAMETERS_PATH
from explainers.explainers_manager import FrameworkExplainersManager
Expand All @@ -27,7 +32,8 @@

def test():
from attacks.EAttack.eattack_attack import EAttack
#from defense.JaccardDefense import
from defense.JaccardDefense import jaccard_def
from defense.evasion_defense import AdvTraining

my_device = device('cpu')

Expand Down Expand Up @@ -62,18 +68,24 @@ def test():

gnn_model_manager.gnn.to(my_device)

poison_defense_config = ConfigPattern(
_class_name="JaccardDefender",
_import_path=POISON_DEFENSE_PARAMETERS_PATH,
_config_class="PoisonDefenseConfig",
# poison_defense_config = ConfigPattern(
# _class_name="JaccardDefense",
# _import_path=POISON_DEFENSE_PARAMETERS_PATH,
# _config_class="PoisonDefenseConfig",
# _config_kwargs={
# }
# )

evasion_defense_config = ConfigPattern(
_class_name="AdvTraining",
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
_config_class="EvasionDefenseConfig",
_config_kwargs={
}
)

gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config)



#gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config)

num_steps = 100
gnn_model_manager.train_model(gen_dataset=dataset,
Expand Down Expand Up @@ -135,8 +147,31 @@ def test():
# }
# )


# node_inds = np.arange(dataset.dataset.data.x.shape[0])
# explain_node_size = int((0.01 * len(node_inds)))
# explaind_inds = np.random.choice(node_inds, explain_node_size)

explaind_inds = [0, 1, 4, 5, 6, 8, 11, 13, 14, 15, 16, 17, 18, 22, 24, 25, 26, 29, 1862, 2130]

init_kwargs = getattr(explainer_init_config, CONFIG_OBJ).to_dict()
explainer = GNNExplainer(gen_dataset=dataset, model=gnn_model_manager.gnn, device=my_device, **init_kwargs)

mode = getattr(explainer_run_config, CONFIG_OBJ).mode
params = getattr(getattr(explainer_run_config, CONFIG_OBJ).kwargs, CONFIG_OBJ).to_dict()

explanations = []
for n in tqdm(explaind_inds):
params['element_idx'] = n
explainer.pbar = ProgressBar(None, "er", desc=f'{explainer.name} explaining')
explainer.run(mode, params, finalize=True)
explanations.append(copy.deepcopy(explainer.explanation))
# with open(f"/home/sazonov/PycharmProjects/GNN-AID/experiments/results/expl_{n}.json", "w") as fout:
# json.dump(explainer.explanation.)
out = {int(explaind_inds[i]): explanations[i].dictionary['data']['edges'] for i in range(len(explaind_inds))}
with open(f"/home/sazonov/PycharmProjects/GNN-AID/experiments/results/expl_No_Def.json", "w") as fout:
json.dump(out, fout)

# explainer = SubgraphXExplainer(gen_dataset=dataset, model=gnn_model_manager.gnn, device=my_device, **init_kwargs)
# explainer = ZorroExplainer(gen_dataset=dataset, model=gnn_model_manager.gnn, device=my_device, **init_kwargs)

Expand Down Expand Up @@ -186,4 +221,6 @@ def test():


if __name__ == "__main__":
# torch.manual_seed(1000)
# np.random.seed(1000)
test()
35 changes: 35 additions & 0 deletions experiments/Explain_Defese_Results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import json
import numpy as np

explaind_inds = [0, 1, 4, 5, 6, 8, 11, 13, 14, 15, 16, 17, 18, 22, 24, 25, 26, 29, 1862, 2130]

# path = f"expl_No_Def"
# for n in explaind_inds:
# for f in range(5):
# with open(path + f"_{f}.json", "r") as fin:
# res = json.load(fin)
# node_res = res[n]

def iou(s1, s2):
if not len(s1.union(s2)):
return 0
return len(s1.intersection(s2)) / len(s1.union(s2))

path_def = "/home/sazonov/PycharmProjects/GNN-AID/experiments/results/expl_Def_1.json"
path_no_def = "/home/sazonov/PycharmProjects/GNN-AID/experiments/results/v1/expl_No_Def_1.json"

with open(path_def, 'r') as fin:
def_data = json.load(fin)

with open(path_no_def, 'r') as fin:
no_def_data = json.load(fin)

no_def_data_set = {k: set(v.keys()) for k, v in no_def_data.items()}
def_data_set = {k: set(v.keys()) for k, v in def_data.items()}
out = []
for k in no_def_data.keys():
out.append(iou(no_def_data_set[k], def_data_set[k]))
print(iou(no_def_data_set[k], def_data_set[k]))

out = np.array(out)
print(np.mean(out), np.std(out))
2 changes: 1 addition & 1 deletion metainfo/explainers_init_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},

"GNNExplainer(torch-geom)": {
"epochs": ["Epochs","int",200,{"min": 1},"The number of epochs to train"],
"epochs": ["Epochs","int",100,{"min": 1},"The number of epochs to train"],
"lr": ["Learn rate","float",0.01,{"min": 0, "step": 0.0001},"The learning rate to apply"],
"node_mask_type": ["Node mask","string","object",["None","object","common_attributes","attributes"],"The type of mask to apply on nodes"],
"edge_mask_type": ["Edge mask","string","object",["None","object","common_attributes","attributes"],"The type of mask to apply on edges"],
Expand Down
2 changes: 1 addition & 1 deletion metainfo/poison_defense_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"drop": ["drop", "bool", true, {}, "?"]
},
"JaccardDefender": {
"threshold": ["Edge Threshold", "float", 0.35, {"min": 0, "max": 1, "step": 0.01}, "Jaccard index threshold for dropping edges"]
"threshold": ["Edge Threshold", "float", 0.1, {"min": 0, "max": 1, "step": 0.01}, "Jaccard index threshold for dropping edges"]
}
}

19 changes: 15 additions & 4 deletions src/attacks/EAttack/eattack_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def attack(self, model_manager, gen_dataset, mask_tensor):
self.attack_inds = np.random.choice(node_inds, self.attacked_node_size)

# get explanations
if self.random_rewire:
if False:
# get random explanation
for i in tqdm(range(len(self.attack_inds))):
edge_index = gen_dataset.dataset.data.edge_index.tolist()
Expand Down Expand Up @@ -84,12 +84,14 @@ def attack(self, model_manager, gen_dataset, mask_tensor):
edge_index = gen_dataset.dataset.data.edge_index.tolist()
edge_index_set = set([(u, v) for u, v in zip(edge_index[0], edge_index[1])])
neighbours = {n: set() for n in self.attack_inds}
# neighbours_list = list(neighbours)
# hop_2 = {{n: set() for n in self.attack_inds}}
neighbours_list = list(neighbours)
for u, v in zip(edge_index[0], edge_index[1]):
if u in neighbours.keys():
neighbours[u].add(v)
elif v in neighbours.keys():
neighbours[v].add(u)

for i, n in enumerate(self.attack_inds):
max_rewire = self.max_rewire
important_edges = sorted(list(explanations[i].dictionary['data']['edges'].items()), key=lambda x: x[1], reverse=True)
Expand All @@ -108,10 +110,19 @@ def attack(self, model_manager, gen_dataset, mask_tensor):
else:
continue
if max_rewire:
neighbours_list = list(neighbours[n])
if self.random_rewire:
hop_2 = []
for u, v in zip(edge_index[0], edge_index[1]):
if u in neighbours[n] and v != n:
hop_2.append((u, v))
elif v in neighbours[n] and u != n:
hop_2.append((u, v))
rewire_node, neigh_node = random.sample(hop_2, 1)[0]
#neighbours_list = list(neighbours[n])
sample = random.sample(neighbours_list, 2)
new_neigh = sample[0] if sample[0] != neigh_node else sample[1]
edge_index_set.remove((u, v))
#edge_index_set.remove((u, v))
edge_index_set.remove((rewire_node, neigh_node))
edge_index_set.add((rewire_node, new_neigh))
max_rewire -= 1
edge_index_new = [[],[]]
Expand Down
2 changes: 1 addition & 1 deletion src/defense/evasion_defense.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, attack_name=None, attack_config=None, attack_type=None, devic
if self.attack_config._class_name == "FGSM":
self.attack_type = "EVASION"
# get attack params
self.epsilon = self.attack_config._config_kwargs.epsilon
self.epsilon = self.attack_config._config_kwargs['epsilon']
# set attacker
self.attacker = FGSMAttacker(self.epsilon)
elif self.attack_config._class_name == "QAttack":
Expand Down

0 comments on commit bfecf18

Please sign in to comment.