diff --git a/experiments/attack_defense_metric_test.py b/experiments/attack_defense_metric_test.py index e2bd79e..408b77e 100644 --- a/experiments/attack_defense_metric_test.py +++ b/experiments/attack_defense_metric_test.py @@ -5,7 +5,7 @@ from torch import device from models_builder.attack_defense_manager import FrameworkAttackDefenseManager -from models_builder.attack_defense_metric import AttackMetric +from models_builder.attack_defense_metric import AttackMetric, DefenseMetric from models_builder.models_utils import apply_decorator_to_graph_layers from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \ EVASION_DEFENSE_PARAMETERS_PATH @@ -150,10 +150,17 @@ def attack_defense_metrics(): # metrics_attack=[AttackMetric("ASR")], # mask='test' # ) - adm.poison_attack_pipeline( + # adm.poison_attack_pipeline( + # steps=steps_epochs, + # save_model_flag=save_model_flag, + # metrics_attack=[AttackMetric("ASR")], + # mask='test' + # ) + adm.poison_defense_pipeline( steps=steps_epochs, save_model_flag=save_model_flag, - metrics_attack=[AttackMetric("ASR")], + metrics_attack=[AttackMetric("ASR"), AttackMetric("AuccAttackDiff"),], + metrics_defense=[DefenseMetric("AuccDefenseCleanDiff"), DefenseMetric("AuccDefenseAttackDiff"), ], mask='test' ) diff --git a/src/models_builder/attack_defense_manager.py b/src/models_builder/attack_defense_manager.py index 4e0f98d..4be401c 100644 --- a/src/models_builder/attack_defense_manager.py +++ b/src/models_builder/attack_defense_manager.py @@ -200,6 +200,101 @@ def poison_attack_pipeline( return metrics_values + def poison_defense_pipeline( + self, + metrics_attack: List, + metrics_defense: List, + steps: int, + save_model_flag: bool = True, + mask: Union[str, List[bool], torch.Tensor] = 'test', + ) -> dict: + metrics_values = {} + if self.available_attacks["poison"] and self.available_defense["poison"]: + from models_builder.gnn_models import Metric + local_gen_dataset_copy = copy.deepcopy(self.gen_dataset) + self.set_clear_model() + self.gnn_manager.modification.epochs = 0 + self.gnn_manager.gnn.reset_parameters() + self.gnn_manager.train_model( + gen_dataset=local_gen_dataset_copy, + steps=steps, + save_model_flag=False, + metrics=[Metric("F1", mask='train', average=None)] + ) + y_predict_clean = self.gnn_manager.run_model( + gen_dataset=local_gen_dataset_copy, + mask=mask, + out='logits', + ) + + self.gnn_manager.poison_defense_flag = True + self.gnn_manager.modification.epochs = 0 + self.gnn_manager.gnn.reset_parameters() + self.gnn_manager.train_model( + gen_dataset=local_gen_dataset_copy, + steps=steps, + save_model_flag=False, + metrics=[Metric("F1", mask='train', average=None)] + ) + y_predict_after_defense_only = self.gnn_manager.run_model( + gen_dataset=local_gen_dataset_copy, + mask=mask, + out='logits', + ) + + local_gen_dataset_copy = copy.deepcopy(self.gen_dataset) + self.gnn_manager.poison_defense_flag = False + self.gnn_manager.poison_attack_flag = True + self.gnn_manager.modification.epochs = 0 + self.gnn_manager.gnn.reset_parameters() + self.gnn_manager.train_model( + gen_dataset=local_gen_dataset_copy, + steps=steps, + save_model_flag=False, + metrics=[Metric("F1", mask='train', average=None)] + ) + y_predict_after_attack_only = self.gnn_manager.run_model( + gen_dataset=local_gen_dataset_copy, + mask=mask, + out='logits', + ) + + self.gnn_manager.poison_defense_flag = True + self.gnn_manager.modification.epochs = 0 + self.gnn_manager.gnn.reset_parameters() + self.gnn_manager.train_model( + gen_dataset=local_gen_dataset_copy, + steps=steps, + save_model_flag=save_model_flag, + metrics=[Metric("F1", mask='train', average=None)] + ) + y_predict_after_attack_and_defense = self.gnn_manager.run_model( + gen_dataset=local_gen_dataset_copy, + mask=mask, + out='logits', + ) + + metrics_attack_values, metrics_defense_values = self.evaluate_attack_defense( + y_predict_after_attack_only=y_predict_after_attack_only, + y_predict_clean=y_predict_clean, + y_predict_after_defense_only=y_predict_after_defense_only, + y_predict_after_attack_and_defense=y_predict_after_attack_and_defense, + metrics_attack=metrics_attack, + metrics_defense=metrics_defense, + mask=mask, + ) + if save_model_flag: + self.save_metrics( + metrics_attack_values=metrics_attack_values, + metrics_defense_values=metrics_defense_values, + ) + self.return_attack_defense_flags() + else: + warnings.warn(f"Evasion attack is not available. Please set evasion attack for " + f"gnn_model_manager use def set_evasion_attacker") + + return metrics_values + def save_metrics( self, metrics_attack_values: Union[dict, None] = None, @@ -221,9 +316,8 @@ def save_metrics( new_dict=metrics_defense_values ) - @staticmethod def evaluate_attack_defense( - # self, + self, y_predict_clean: Union[List, torch.Tensor, np.array], mask: Union[str, torch.Tensor], y_predict_after_attack_only: Union[List, torch.Tensor, np.array, None] = None, @@ -232,13 +326,26 @@ def evaluate_attack_defense( metrics_attack: Union[List, None] = None, metrics_defense: Union[List, None] = None, ): + + try: + mask_tensor = { + 'train': self.gen_dataset.train_mask.tolist(), + 'val': self.gen_dataset.val_mask.tolist(), + 'test': self.gen_dataset.test_mask.tolist(), + 'all': [True] * len(self.gen_dataset.labels), + }[mask] + except KeyError: + assert isinstance(mask, torch.Tensor) + mask_tensor = mask + y_true = copy.deepcopy(self.gen_dataset.labels[mask_tensor]) metrics_attack_values = {mask: {}} metrics_defense_values = {mask: {}} if metrics_attack is not None and y_predict_after_attack_only is not None: for metric in metrics_attack: metrics_attack_values[mask][metric.name] = metric.compute( y_predict_clean=y_predict_clean, - y_predict_after_attack_only=y_predict_after_attack_only + y_predict_after_attack_only=y_predict_after_attack_only, + y_true=y_true, ) if ( metrics_defense is not None @@ -248,10 +355,12 @@ def evaluate_attack_defense( for metric in metrics_defense: metrics_defense_values[mask][metric.name] = metric.compute( y_predict_clean=y_predict_clean, + y_predict_after_attack_only=y_predict_after_attack_only, y_predict_after_defense_only=y_predict_after_defense_only, - y_predict_after_attack_and_defense=y_predict_after_attack_and_defense + y_predict_after_attack_and_defense=y_predict_after_attack_and_defense, + y_true=y_true, ) - + print("!!!! ", metrics_attack_values, metrics_defense_values) return metrics_attack_values, metrics_defense_values @staticmethod diff --git a/src/models_builder/attack_defense_metric.py b/src/models_builder/attack_defense_metric.py index 27da020..43cfbee 100644 --- a/src/models_builder/attack_defense_metric.py +++ b/src/models_builder/attack_defense_metric.py @@ -20,9 +20,79 @@ def asr( return 1 - sklearn.metrics.accuracy_score(y_true=y_predict_clean, y_pred=y_predict_after_attack_only) +# TODO Kirill, change for any classic metric +def aucc_change_attack( + y_predict_clean, + y_predict_after_attack_only, + y_true, + **kwargs +): + if isinstance(y_predict_clean, torch.Tensor): + if y_predict_clean.dim() > 1: + y_predict_clean = y_predict_clean.argmax(dim=1) + y_predict_clean.cpu() + if isinstance(y_predict_after_attack_only, torch.Tensor): + if y_predict_after_attack_only.dim() > 1: + y_predict_after_attack_only = y_predict_after_attack_only.argmax(dim=1) + y_predict_after_attack_only.cpu() + if isinstance(y_true, torch.Tensor): + if y_true.dim() > 1: + y_true = y_true.argmax(dim=1) + y_true.cpu() + return (sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_clean) - + sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_after_attack_only)) + + +# TODO Kirill, change for any classic metric +def aucc_change_defense_only( + y_predict_clean, + y_predict_after_defense_only, + y_true, + **kwargs +): + if isinstance(y_predict_clean, torch.Tensor): + if y_predict_clean.dim() > 1: + y_predict_clean = y_predict_clean.argmax(dim=1) + y_predict_clean.cpu() + if isinstance(y_predict_after_defense_only, torch.Tensor): + if y_predict_after_defense_only.dim() > 1: + y_predict_after_defense_only = y_predict_after_defense_only.argmax(dim=1) + y_predict_after_defense_only.cpu() + if isinstance(y_true, torch.Tensor): + if y_true.dim() > 1: + y_true = y_true.argmax(dim=1) + y_true.cpu() + return (sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_clean) - + sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_after_defense_only)) + + +# TODO Kirill, change for any classic metric +def aucc_change_defense_with_attack( + y_predict_after_attack_only, + y_predict_after_attack_and_defense, + y_true, + **kwargs +): + if isinstance(y_predict_after_attack_only, torch.Tensor): + if y_predict_after_attack_only.dim() > 1: + y_predict_after_attack_only = y_predict_after_attack_only.argmax(dim=1) + y_predict_after_attack_only.cpu() + if isinstance(y_predict_after_attack_and_defense, torch.Tensor): + if y_predict_after_attack_and_defense.dim() > 1: + y_predict_after_attack_and_defense = y_predict_after_attack_and_defense.argmax(dim=1) + y_predict_after_attack_and_defense.cpu() + if isinstance(y_true, torch.Tensor): + if y_true.dim() > 1: + y_true = y_true.argmax(dim=1) + y_true.cpu() + return (sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_after_attack_and_defense) - + sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_after_attack_only)) + + class AttackMetric: available_metrics = { - 'ASR': asr, + "ASR": asr, + "AuccAttackDiff": aucc_change_attack, } def __init__( @@ -37,11 +107,13 @@ def compute( self, y_predict_clean, y_predict_after_attack_only, + y_true, ): if self.name in AttackMetric.available_metrics: return AttackMetric.available_metrics[self.name]( y_predict_clean=y_predict_clean, y_predict_after_attack_only=y_predict_after_attack_only, + y_true=y_true, **self.kwargs ) raise NotImplementedError() @@ -49,6 +121,8 @@ def compute( class DefenseMetric: available_metrics = { + "AuccDefenseCleanDiff": aucc_change_defense_only, + "AuccDefenseAttackDiff": aucc_change_defense_with_attack, } def __init__( @@ -62,17 +136,18 @@ def __init__( def compute( self, y_predict_clean, + y_predict_after_attack_only, y_predict_after_defense_only, y_predict_after_attack_and_defense, + y_true, ): - if self.name in AttackMetric.available_metrics: - return AttackMetric.available_metrics[self.name]( + if self.name in DefenseMetric.available_metrics: + return DefenseMetric.available_metrics[self.name]( y_predict_clean=y_predict_clean, y_predict_after_defense_only=y_predict_after_defense_only, + y_predict_after_attack_only=y_predict_after_attack_only, y_predict_after_attack_and_defense=y_predict_after_attack_and_defense, + y_true=y_true, **self.kwargs ) - raise NotImplementedError() - - - + raise NotImplementedError(f"Metric {self.name} is not implemented")