From 2c8bbf67eb3ea9c02b604ecd8fa9a9b055d24048 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Tue, 3 Dec 2024 13:15:16 +0300 Subject: [PATCH] add class AttackMetric and asr meric --- experiments/attack_defense_metric_test.py | 53 ++++++++++--------- src/models_builder/attack_defense_manager.py | 54 +++++++++++++++----- src/models_builder/attack_defense_metric.py | 51 ++++++++++++++++++ src/models_builder/gnn_models.py | 47 +++++++++++++---- 4 files changed, 160 insertions(+), 45 deletions(-) create mode 100644 src/models_builder/attack_defense_metric.py diff --git a/experiments/attack_defense_metric_test.py b/experiments/attack_defense_metric_test.py index c469cad..169b625 100644 --- a/experiments/attack_defense_metric_test.py +++ b/experiments/attack_defense_metric_test.py @@ -4,6 +4,7 @@ from torch import device from models_builder.attack_defense_manager import FrameworkAttackDefenseManager +from models_builder.attack_defense_metric import AttackMetric from models_builder.models_utils import apply_decorator_to_graph_layers from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \ EVASION_DEFENSE_PARAMETERS_PATH @@ -114,34 +115,40 @@ def attack_defense_metrics(): warnings.warn("Start training") dataset.train_test_split() - try: - # raise FileNotFoundError() - gnn_model_manager.load_model_executor() - dataset = gnn_model_manager.load_train_test_split(dataset) - except FileNotFoundError: - gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 - train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, - save_model_flag=save_model_flag, - metrics=[Metric("F1", mask='train', average=None)]) - - if train_test_split_path is not None: - dataset.save_train_test_mask(train_test_split_path) - train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ - :] - dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask - data.percent_train_class, data.percent_test_class = train_test_sizes - - warnings.warn("Training was successful") - - metric_loc = gnn_model_manager.evaluate_model( - gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'), - Metric("Accuracy", mask='test')]) - print(metric_loc) + # try: + # # raise FileNotFoundError() + # gnn_model_manager.load_model_executor() + # dataset = gnn_model_manager.load_train_test_split(dataset) + # except FileNotFoundError: + # gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 + # train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, + # save_model_flag=save_model_flag, + # metrics=[Metric("F1", mask='train', average=None)]) + # + # if train_test_split_path is not None: + # dataset.save_train_test_mask(train_test_split_path) + # train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ + # :] + # dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask + # data.percent_train_class, data.percent_test_class = train_test_sizes + # + # warnings.warn("Training was successful") + # + # metric_loc = gnn_model_manager.evaluate_model( + # gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'), + # Metric("Accuracy", mask='test')]) + # print(metric_loc) adm = FrameworkAttackDefenseManager( gen_dataset=dataset, gnn_manager=gnn_model_manager, ) + adm.evasion_attack_pipeline( + steps=steps_epochs, + save_model_flag=save_model_flag, + metrics_attack=[AttackMetric("ASR")], + mask='test' + ) if __name__ == '__main__': diff --git a/src/models_builder/attack_defense_manager.py b/src/models_builder/attack_defense_manager.py index c18e9f8..5374f9f 100644 --- a/src/models_builder/attack_defense_manager.py +++ b/src/models_builder/attack_defense_manager.py @@ -1,5 +1,7 @@ import warnings -from typing import Type +from typing import Type, Union, List + +import torch from base.datasets_processing import GeneralDataset @@ -69,13 +71,15 @@ def return_attack_defense_flags(self): def evasion_attack_pipeline( self, metrics_attack, - model_metrics, steps: int, save_model_flag: bool = True, + mask: Union[str, List[bool], torch.Tensor] = 'test', ): metrics_values = {} if self.available_attacks["evasion"]: self.set_clear_model() + self.gnn_manager.modification.epochs = 0 + self.gnn_manager.gnn.reset_parameters() from models_builder.gnn_models import Metric self.gnn_manager.train_model( gen_dataset=self.gen_dataset, @@ -83,31 +87,55 @@ def evasion_attack_pipeline( save_model_flag=save_model_flag, metrics=[Metric("F1", mask='train', average=None)] ) - metric_clean_model = self.gnn_manager.evaluate_model( + y_predict_clean = self.gnn_manager.run_model( gen_dataset=self.gen_dataset, - metrics=model_metrics + mask=mask, + out='logits', ) + self.gnn_manager.evasion_attack_flag = True + self.gnn_manager.modification.epochs = 0 + self.gnn_manager.gnn.reset_parameters() self.gnn_manager.train_model( gen_dataset=self.gen_dataset, steps=steps, save_model_flag=save_model_flag, metrics=[Metric("F1", mask='train', average=None)] ) - metric_evasion_attack_only = self.gnn_manager.evaluate_model( + self.gnn_manager.call_evasion_attack( gen_dataset=self.gen_dataset, - metrics=model_metrics + mask=mask, + ) + y_predict_attack = self.gnn_manager.run_model( + gen_dataset=self.gen_dataset, + mask=mask, + out='logits', + ) + metrics_values = self.evaluate_attack_defense( + y_predict_after_attack_only=y_predict_attack, + y_predict_clean=y_predict_clean, + metrics_attack=metrics_attack, ) - # TODO Kirill - # metrics_values = evaluate_attacks( - # metric_clean_model, - # metric_evasion_attack_only, - # metrics_attack=metrics_attack - # ) self.return_attack_defense_flags() - pass + else: warnings.warn(f"Evasion attack is not available. Please set evasion attack for " f"gnn_model_manager use def set_evasion_attacker") return metrics_values + + def evaluate_attack_defense( + self, + y_predict_clean, + y_predict_after_attack_only=None, + y_predict_after_defense_only=None, + y_predict_after_attack_and_defense=None, + metrics_attack=None, + metrics_defense=None, + ): + metrics_attack_values = {} + if metrics_attack is not None and y_predict_after_attack_only is not None: + for metric in metrics_attack: + metrics_attack_values[metric.name] = metric.compute(y_predict_clean, y_predict_after_attack_only) + + return metrics_attack_values diff --git a/src/models_builder/attack_defense_metric.py b/src/models_builder/attack_defense_metric.py new file mode 100644 index 0000000..4fd2181 --- /dev/null +++ b/src/models_builder/attack_defense_metric.py @@ -0,0 +1,51 @@ +from typing import Union, List, Callable + +import sklearn +import torch + + +def asr( + y_predict_clean, + y_predict_after_attack_only, + **kwargs +): + if isinstance(y_predict_clean, torch.Tensor): + if y_predict_clean.dim() > 1: + y_predict_clean = y_predict_clean.argmax(dim=1) + y_predict_clean.cpu() + if isinstance(y_predict_after_attack_only, torch.Tensor): + if y_predict_after_attack_only.dim() > 1: + y_predict_after_attack_only = y_predict_after_attack_only.argmax(dim=1) + y_predict_after_attack_only.cpu() + print("ASR ", 1 - sklearn.metrics.accuracy_score(y_true=y_predict_clean, y_pred=y_predict_after_attack_only)) + return 1 - sklearn.metrics.accuracy_score(y_true=y_predict_clean, y_pred=y_predict_after_attack_only) + + +class AttackMetric: + available_metrics = { + 'ASR': asr, + } + + def __init__( + self, + name: str, + **kwargs + ): + self.name = name + self.kwargs = kwargs + + def compute( + self, + metrics_clean_model, + metrics_after_attack + ): + if self.name in AttackMetric.available_metrics: + return AttackMetric.available_metrics[self.name]( + metrics_clean_model, + metrics_after_attack, + **self.kwargs + ) + raise NotImplementedError() + + + diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index 09c0490..411f21a 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -951,15 +951,15 @@ def train_on_batch_full( batch, task_type: str = None ) -> torch.Tensor: - if self.mi_defender: + if self.mi_defender and self.mi_defense_flag: self.mi_defender.pre_batch() - if self.evasion_defender: + if self.evasion_defender and self.evasion_defense_flag: self.evasion_defender.pre_batch(model_manager=self, batch=batch) loss = self.train_on_batch(batch=batch, task_type=task_type) - if self.mi_defender: + if self.mi_defender and self.mi_defense_flag: self.mi_defender.post_batch() evasion_defender_dict = None - if self.evasion_defender: + if self.evasion_defender and self.evasion_defense_flag: evasion_defender_dict = self.evasion_defender.post_batch( model_manager=self, batch=batch, loss=loss, ) @@ -1093,12 +1093,12 @@ def train_model( :param metrics: list of metrics to measure at each step or at the end of training :param socket: socket to use for sending data to frontend """ - if self.poison_attacker: + if self.poison_attacker and self.poison_attack_flag: loc = self.poison_attacker.attack(gen_dataset=gen_dataset) if loc is not None: gen_dataset = loc - if self.poison_defender: + if self.poison_defender and self.poison_defense_flag: loc = self.poison_defender.defense(gen_dataset=gen_dataset) if loc is not None: gen_dataset = loc @@ -1256,8 +1256,11 @@ def evaluate_model( except KeyError: assert isinstance(mask, torch.Tensor) mask_tensor = mask - if self.evasion_attacker: - self.evasion_attacker.attack(model_manager=self, gen_dataset=gen_dataset, mask_tensor=mask_tensor) + if self.evasion_attacker and self.evasion_attack_flag: + self.call_evasion_attack( + gen_dataset=gen_dataset, + mask=mask, + ) metrics_values[mask] = {} y_pred = self.run_model(gen_dataset, mask=mask) y_true = gen_dataset.labels[mask_tensor] @@ -1265,9 +1268,35 @@ def evaluate_model( for metric in ms: metrics_values[mask][metric.name] = metric.compute(y_pred, y_true) # metrics_values[mask][metric.name] = MetricManager.compute(metric, y_pred, y_true) + if self.mi_attacker and self.mi_attack_flag: + self.call_mi_attack() + return metrics_values + + def call_evasion_attack( + self, + gen_dataset: GeneralDataset, + mask: Union[str, List[bool], torch.Tensor] = 'test', + ): + if self.evasion_attacker: + try: + mask_tensor = { + 'train': gen_dataset.train_mask.tolist(), + 'val': gen_dataset.val_mask.tolist(), + 'test': gen_dataset.test_mask.tolist(), + 'all': [True] * len(gen_dataset.labels), + }[mask] + except KeyError: + assert isinstance(mask, torch.Tensor) + mask_tensor = mask + self.evasion_attacker.attack( + model_manager=self, + gen_dataset=gen_dataset, + mask_tensor=mask_tensor + ) + + def call_mi_attack(self): if self.mi_attacker: self.mi_attacker.attack() - return metrics_values def compute_stats_data( self,