Skip to content

Commit

Permalink
Merge pull request #24 from abhhfcgjk/adv_training
Browse files Browse the repository at this point in the history
Adv training
  • Loading branch information
LukyanovKirillML authored Oct 17, 2024
2 parents f6d1b5b + f385047 commit b578576
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 23 deletions.
150 changes: 131 additions & 19 deletions experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings


from torch import device

from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
Expand All @@ -10,11 +11,12 @@
from src.aux.configs import ModelModificationConfig, ConfigPattern
from src.base.datasets_processing import DatasetManager
from src.models_builder.models_zoo import model_configs_zoo
from attacks.QAttack import qattack


def test_attack_defense():
# my_device = device('cuda' if is_available() else 'cpu')
my_device = device('cpu')

my_device = device('cuda' if torch.cuda.is_available() else 'cpu')

full_name = None

Expand Down Expand Up @@ -102,51 +104,85 @@ def test_attack_defense():
gnn_model_manager.gnn.to(my_device)
data = data.to(my_device)

# poison_attack_config = ConfigPattern(
# _class_name="RandomPoisonAttack",
# _import_path=POISON_ATTACK_PARAMETERS_PATH,
# _config_class="PoisonAttackConfig",
# _config_kwargs={
# "n_edges_percent": 0.1,
# }
# )

poison_attack_config = ConfigPattern(
_class_name="RandomPoisonAttack",
_class_name="MetaAttackFull",
_import_path=POISON_ATTACK_PARAMETERS_PATH,
_config_class="PoisonAttackConfig",
_config_kwargs={
"n_edges_percent": 0.1,
"num_nodes": dataset.dataset.x.shape[0]
}
)

# poison_defense_config = ConfigPattern(
# _class_name="BadRandomPoisonDefender",
# _import_path=POISON_DEFENSE_PARAMETERS_PATH,
# _config_class="PoisonDefenseConfig",
# poison_attack_config = ConfigPattern(
# _class_name="RandomPoisonAttack",
# _import_path=POISON_ATTACK_PARAMETERS_PATH,
# _config_class="PoisonAttackConfig",
# _config_kwargs={
# "n_edges_percent": 0.1,
# }
# )

poison_defense_config = ConfigPattern(
_class_name="EmptyPoisonDefender",
_class_name="GNNGuard",
_import_path=POISON_DEFENSE_PARAMETERS_PATH,
_config_class="PoisonDefenseConfig",
_config_kwargs={
"n_edges_percent": 0.1,
}
)


evasion_attack_config = ConfigPattern(
_class_name="FGSM",
_class_name="QAttack",
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
_config_class="EvasionAttackConfig",
_config_kwargs={
"epsilon": 0.01 * 1,
"population_size": 50,
"individual_size": 30,
"generations": 50,
"prob_cross": 0.5,
"prob_mutate": 0.02
}
)
# evasion_attack_config = ConfigPattern(
# _class_name="FGSM",
# _import_path=EVASION_ATTACK_PARAMETERS_PATH,
# _config_class="EvasionAttackConfig",
# _config_kwargs={
# "epsilon": 0.01 * 1,
# }
# )

# evasion_defense_config = ConfigPattern(
# _class_name="GradientRegularizationDefender",
# _import_path=EVASION_DEFENSE_PARAMETERS_PATH,
# _config_class="EvasionDefenseConfig",
# _config_kwargs={
# "regularization_strength": 0.1 * 10
# }
# )
evasion_defense_config = ConfigPattern(
_class_name="GradientRegularizationDefender",
_class_name="AdvTraining",
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
_config_class="EvasionDefenseConfig",
_config_kwargs={
"regularization_strength": 0.1 * 10
"attack_name": None,
"attack_config": evasion_attack_config # evasion_attack_config
}
)

gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config)
# gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config)
# gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config)
# gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)
gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)
# gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config)

warnings.warn("Start training")
Expand Down Expand Up @@ -611,9 +647,85 @@ def test_jaccard():
print("TEST", metric_loc)


def test_adv_training():
from defense.evasion_defense import AdvTraining

my_device = device('cpu')
full_name = ("single-graph", "Planetoid", 'Cora')

dataset, data, results_dataset_path = DatasetManager.get_by_full_name(
full_name=full_name,
dataset_ver_ind=0
)
gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn')
manager_config = ConfigPattern(
_config_class="ModelManagerConfig",
_config_kwargs={
"mask_features": [],
"optimizer": {
# "_config_class": "Config",
"_class_name": "Adam",
# "_import_path": OPTIMIZERS_PARAMETERS_PATH,
# "_class_import_info": ["torch.optim"],
"_config_kwargs": {},
}
}
)
steps_epochs = 200
gnn_model_manager = FrameworkGNNModelManager(
gnn=gnn,
dataset_path=results_dataset_path,
manager_config=manager_config,
modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs)
)
save_model_flag = False
gnn_model_manager.gnn.to(my_device)
data = data.to(my_device)

evasion_defense_config = ConfigPattern(
_class_name="AdvTraining",
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
_config_class="EvasionDefenseConfig",
_config_kwargs={
# "num_nodes": dataset.dataset.x.shape[0]
}
)
from defense.evasion_defense import EvasionDefender
from src.aux.utils import all_subclasses
print([e.name for e in all_subclasses(EvasionDefender)])
gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config)

warnings.warn("Start training")
dataset.train_test_split(percent_train_class=0.1)

try:
raise FileNotFoundError()
# gnn_model_manager.load_model_executor()
except FileNotFoundError:
gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)])

if train_test_split_path is not None:
dataset.save_train_test_mask(train_test_split_path)
train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[
:]
dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask
data.percent_train_class, data.percent_test_class = train_test_sizes

warnings.warn("Training was successful")

metric_loc = gnn_model_manager.evaluate_model(
gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'),
Metric("Accuracy", mask='test')])
print(metric_loc)

if __name__ == '__main__':
import random
random.seed(10)
test_attack_defense()
torch.manual_seed(5000)
#test_meta()
#test_qattack()
#test_attack_defense()
test_jaccard()
# test_adv_training()
# test_gnnguard()
# test_jaccard()
5 changes: 2 additions & 3 deletions metainfo/evasion_attack_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,5 @@
"generations" : ["Generations", "int", 50, {"min": 0, "step": 1}, "Number of generations for genetic algorithm"],
"prob_cross": ["Probability for crossover", "float", 0.5, {"min": 0, "max": 1, "step": 0.01}, "Probability of crossover between two genes"],
"prob_mutate": ["Probability for mutation", "float", 0.02, {"min": 0, "max": 1, "step": 0.01}, "Probability of gene mutation"]
}
}

}
}
3 changes: 3 additions & 0 deletions metainfo/evasion_defense_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
},
"QuantizationDefender": {
"qbit": ["qbit", "int", 8, {"min": 1, "step": 1}, "?"]
},
"AdvTraining": {
"attack_name": ["attack_name", "str", "FGSM", {}, "?"]
}
}

5 changes: 5 additions & 0 deletions metainfo/poison_defense_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
"BadRandomPoisonDefender": {
"n_edges_percent": ["n_edges_percent", "float", 0.1, {"min": 0.0001, "step": 0.01}, "?"]
},
"GNNGuard": {
"lr": ["lr", "float", 0.01, {"min": 0.0001, "step": 0.005}, "?"],
"attention": ["attention", "bool", true, {}, "?"],
"drop": ["drop", "bool", true, {}, "?"]
},
"JaccardDefender": {
"threshold": ["Edge Threshold", "float", 0.35, {"min": 0, "max": 1, "step": 0.01}, "Jaccard index threshold for dropping edges"]
}
Expand Down
82 changes: 82 additions & 0 deletions src/defense/evasion_defense.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import torch

from defense.defense_base import Defender
from src.aux.utils import import_by_name
from src.aux.configs import ModelModificationConfig, ConfigPattern
from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
EVASION_DEFENSE_PARAMETERS_PATH
from attacks.evasion_attacks import FGSMAttacker
from attacks.QAttack import qattack
from torch_geometric import data

import copy

class EvasionDefender(Defender):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -52,3 +60,77 @@ def __init__(self, qbit=8):

def pre_batch(self, **kwargs):
pass


class DataWrap:
def __init__(self, batch) -> None:
self.data = batch
self.dataset = self

class AdvTraining(EvasionDefender):
name = "AdvTraining"

def __init__(self, attack_name=None, attack_config=None, attack_type=None, device='cpu'):
super().__init__()
assert device is not None, "Please specify 'device'!"
if not attack_config:
# build default config
assert attack_name is not None
if attack_type == "POISON":
self.attack_type = "POISON"
PARAM_PATH = POISON_ATTACK_PARAMETERS_PATH
else:
self.attack_type = "EVASION"
PARAM_PATH = EVASION_ATTACK_PARAMETERS_PATH
attack_config = ConfigPattern(
_class_name=attack_name,
_import_path=PARAM_PATH,
_config_class="EvasionAttackConfig",
_config_kwargs={}
)
self.attack_config = attack_config
if self.attack_config._class_name == "FGSM":
self.attack_type = "EVASION"
# get attack params
self.epsilon = self.attack_config._config_kwargs.epsilon
# set attacker
self.attacker = FGSMAttacker(self.epsilon)
elif self.attack_config._class_name == "QAttack":
self.attack_type = "EVASION"
# get attack params
self.population_size = self.attack_config._config_kwargs["population_size"]
self.individual_size = self.attack_config._config_kwargs["individual_size"]
self.generations = self.attack_config._config_kwargs["generations"]
self.prob_cross = self.attack_config._config_kwargs["prob_cross"]
self.prob_mutate = self.attack_config._config_kwargs["prob_mutate"]
# set attacker
self.attacker = qattack.QAttacker(self.population_size, self.individual_size,
self.generations, self.prob_cross,
self.prob_mutate)
elif self.attack_config._class_name == "MetaAttackFull":
# from attacks.poison_attacks_collection.metattack import meta_gradient_attack
# self.attack_type = "POISON"
# self.num_nodes = self.attack_config._config_kwargs["num_nodes"]
# self.attacker = meta_gradient_attack.MetaAttackFull(num_nodes=self.num_nodes)
pass
else:
raise KeyError(f"There is no {self.attack_config._class_name} class")

def pre_batch(self, model_manager, batch):
super().pre_batch(model_manager=model_manager, batch=batch)
self.perturbed_gen_dataset = data.Data()
self.perturbed_gen_dataset.data = copy.deepcopy(batch)
self.perturbed_gen_dataset.dataset = self.perturbed_gen_dataset.data
self.perturbed_gen_dataset.dataset.data = self.perturbed_gen_dataset.data
if self.attack_type == "EVASION":
self.perturbed_gen_dataset = self.attacker.attack(model_manager=model_manager,
gen_dataset=self.perturbed_gen_dataset,
mask_tensor=self.perturbed_gen_dataset.data.train_mask)


def post_batch(self, model_manager, batch, loss) -> dict:
super().post_batch(model_manager=model_manager, batch=batch, loss=loss)
# Output on perturbed data
outputs = model_manager.gnn(self.perturbed_gen_dataset.data.x, self.perturbed_gen_dataset.data.edge_index)
loss_loc = model_manager.loss_function(outputs, batch.y)
return {"loss": loss + loss_loc}
2 changes: 1 addition & 1 deletion src/models_builder/gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def set_poison_defender(self, poison_defense_config=None, poison_defense_name: s
self.poison_defense_name = poison_defense_name
poison_defense_kwargs = getattr(self.poison_defense_config, CONFIG_OBJ).to_dict()

name_klass = {e.name: e for e in PoisonDefender.__subclasses__()}
name_klass = {e.name: e for e in all_subclasses(PoisonDefender)}
klass = name_klass[self.poison_defense_name]
self.poison_defender = klass(
# device=self.device,
Expand Down
Loading

0 comments on commit b578576

Please sign in to comment.