diff --git a/experiments/attack_defense_metric_test.py b/experiments/attack_defense_metric_test.py index 169b625..30f584a 100644 --- a/experiments/attack_defense_metric_test.py +++ b/experiments/attack_defense_metric_test.py @@ -143,7 +143,13 @@ def attack_defense_metrics(): gen_dataset=dataset, gnn_manager=gnn_model_manager, ) - adm.evasion_attack_pipeline( + # adm.evasion_attack_pipeline( + # steps=steps_epochs, + # save_model_flag=save_model_flag, + # metrics_attack=[AttackMetric("ASR")], + # mask='test' + # ) + adm.poison_attack_pipeline( steps=steps_epochs, save_model_flag=save_model_flag, metrics_attack=[AttackMetric("ASR")], diff --git a/src/models_builder/attack_defense_manager.py b/src/models_builder/attack_defense_manager.py index 5374f9f..934d1c7 100644 --- a/src/models_builder/attack_defense_manager.py +++ b/src/models_builder/attack_defense_manager.py @@ -124,6 +124,58 @@ def evasion_attack_pipeline( return metrics_values + def poison_attack_pipeline( + self, + metrics_attack, + steps: int, + save_model_flag: bool = True, + mask: Union[str, List[bool], torch.Tensor] = 'test', + ): + metrics_values = {} + if self.available_attacks["poison"]: + self.set_clear_model() + self.gnn_manager.modification.epochs = 0 + self.gnn_manager.gnn.reset_parameters() + from models_builder.gnn_models import Metric + self.gnn_manager.train_model( + gen_dataset=self.gen_dataset, + steps=steps, + save_model_flag=save_model_flag, + metrics=[Metric("F1", mask='train', average=None)] + ) + y_predict_clean = self.gnn_manager.run_model( + gen_dataset=self.gen_dataset, + mask=mask, + out='logits', + ) + + self.gnn_manager.poison_attack_flag = True + self.gnn_manager.modification.epochs = 0 + self.gnn_manager.gnn.reset_parameters() + self.gnn_manager.train_model( + gen_dataset=self.gen_dataset, + steps=steps, + save_model_flag=save_model_flag, + metrics=[Metric("F1", mask='train', average=None)] + ) + y_predict_attack = self.gnn_manager.run_model( + gen_dataset=self.gen_dataset, + mask=mask, + out='logits', + ) + metrics_values = self.evaluate_attack_defense( + y_predict_after_attack_only=y_predict_attack, + y_predict_clean=y_predict_clean, + metrics_attack=metrics_attack, + ) + self.return_attack_defense_flags() + + else: + warnings.warn(f"Evasion attack is not available. Please set evasion attack for " + f"gnn_model_manager use def set_evasion_attacker") + + return metrics_values + def evaluate_attack_defense( self, y_predict_clean,