Skip to content

Commit

Permalink
add evasion_defense_pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Dec 3, 2024
1 parent 9195446 commit 0df1ec1
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 23 deletions.
2 changes: 1 addition & 1 deletion experiments/attack_defense_metric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def attack_defense_metrics():
# metrics_attack=[AttackMetric("ASR")],
# mask='test'
# )
adm.poison_defense_pipeline(
adm.evasion_defense_pipeline(
steps=steps_epochs,
save_model_flag=save_model_flag,
metrics_attack=[AttackMetric("ASR"), AttackMetric("AuccAttackDiff"),],
Expand Down
100 changes: 97 additions & 3 deletions src/models_builder/attack_defense_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,101 @@ def evasion_attack_pipeline(

return metrics_values

def evasion_defense_pipeline(
self,
metrics_attack: List,
metrics_defense: List,
steps: int,
save_model_flag: bool = True,
mask: Union[str, List[bool], torch.Tensor] = 'test',
) -> dict:
metrics_values = {}
if self.available_attacks["evasion"] and self.available_defense["evasion"]:
from models_builder.gnn_models import Metric
local_gen_dataset_copy = copy.deepcopy(self.gen_dataset)
self.set_clear_model()
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=False,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_clean = self.gnn_manager.run_model(
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)

self.gnn_manager.evasion_defense_flag = True
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=False,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_after_defense_only = self.gnn_manager.run_model(
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)

local_gen_dataset_copy = copy.deepcopy(self.gen_dataset)
self.gnn_manager.evasion_defense_flag = False
self.gnn_manager.evasion_attack_flag = True
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=False,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_after_attack_only = self.gnn_manager.run_model(
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)

self.gnn_manager.evasion_defense_flag = True
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_after_attack_and_defense = self.gnn_manager.run_model(
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)

metrics_attack_values, metrics_defense_values = self.evaluate_attack_defense(
y_predict_after_attack_only=y_predict_after_attack_only,
y_predict_clean=y_predict_clean,
y_predict_after_defense_only=y_predict_after_defense_only,
y_predict_after_attack_and_defense=y_predict_after_attack_and_defense,
metrics_attack=metrics_attack,
metrics_defense=metrics_defense,
mask=mask,
)
if save_model_flag:
self.save_metrics(
metrics_attack_values=metrics_attack_values,
metrics_defense_values=metrics_defense_values,
)
self.return_attack_defense_flags()
else:
warnings.warn(f"Evasion attack and defense is not available. Please set evasion attack for "
f"gnn_model_manager use def set_evasion_attacker")

return metrics_values

def poison_attack_pipeline(
self,
metrics_attack: List,
Expand Down Expand Up @@ -195,7 +290,7 @@ def poison_attack_pipeline(
)
self.return_attack_defense_flags()
else:
warnings.warn(f"Evasion attack is not available. Please set evasion attack for "
warnings.warn(f"Poison attack is not available. Please set evasion attack for "
f"gnn_model_manager use def set_evasion_attacker")

return metrics_values
Expand Down Expand Up @@ -290,7 +385,7 @@ def poison_defense_pipeline(
)
self.return_attack_defense_flags()
else:
warnings.warn(f"Evasion attack is not available. Please set evasion attack for "
warnings.warn(f"Poison attack and defense is not available. Please set evasion attack for "
f"gnn_model_manager use def set_evasion_attacker")

return metrics_values
Expand Down Expand Up @@ -360,7 +455,6 @@ def evaluate_attack_defense(
y_predict_after_attack_and_defense=y_predict_after_attack_and_defense,
y_true=y_true,
)
print("!!!! ", metrics_attack_values, metrics_defense_values)
return metrics_attack_values, metrics_defense_values

@staticmethod
Expand Down
39 changes: 20 additions & 19 deletions src/models_builder/attack_defense_metric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Union, List, Callable
from typing import Union, List, Callable, Any

import numpy as np
import sklearn
import torch


def asr(
y_predict_clean,
y_predict_after_attack_only,
y_predict_clean: Union[List, torch.Tensor, np.array],
y_predict_after_attack_only: Union[List, torch.Tensor, np.array],
**kwargs
):
if isinstance(y_predict_clean, torch.Tensor):
Expand All @@ -22,8 +23,8 @@ def asr(

# TODO Kirill, change for any classic metric
def aucc_change_attack(
y_predict_clean,
y_predict_after_attack_only,
y_predict_clean: Union[List, torch.Tensor, np.array],
y_predict_after_attack_only: Union[List, torch.Tensor, np.array],
y_true,
**kwargs
):
Expand All @@ -45,9 +46,9 @@ def aucc_change_attack(

# TODO Kirill, change for any classic metric
def aucc_change_defense_only(
y_predict_clean,
y_predict_after_defense_only,
y_true,
y_predict_clean: Union[List, torch.Tensor, np.array],
y_predict_after_defense_only: Union[List, torch.Tensor, np.array],
y_true: Union[List, torch.Tensor, np.array],
**kwargs
):
if isinstance(y_predict_clean, torch.Tensor):
Expand All @@ -68,9 +69,9 @@ def aucc_change_defense_only(

# TODO Kirill, change for any classic metric
def aucc_change_defense_with_attack(
y_predict_after_attack_only,
y_predict_after_attack_and_defense,
y_true,
y_predict_after_attack_only: Union[List, torch.Tensor, np.array],
y_predict_after_attack_and_defense: Union[List, torch.Tensor, np.array],
y_true: Union[List, torch.Tensor, np.array],
**kwargs
):
if isinstance(y_predict_after_attack_only, torch.Tensor):
Expand Down Expand Up @@ -105,9 +106,9 @@ def __init__(

def compute(
self,
y_predict_clean,
y_predict_after_attack_only,
y_true,
y_predict_clean: Union[List, torch.Tensor, np.array, None],
y_predict_after_attack_only: Union[List, torch.Tensor, np.array, None],
y_true: Union[List, torch.Tensor, np.array, None],
):
if self.name in AttackMetric.available_metrics:
return AttackMetric.available_metrics[self.name](
Expand Down Expand Up @@ -135,11 +136,11 @@ def __init__(

def compute(
self,
y_predict_clean,
y_predict_after_attack_only,
y_predict_after_defense_only,
y_predict_after_attack_and_defense,
y_true,
y_predict_clean: Union[List, torch.Tensor, np.array, None],
y_predict_after_attack_only: Union[List, torch.Tensor, np.array, None],
y_predict_after_defense_only: Union[List, torch.Tensor, np.array, None],
y_predict_after_attack_and_defense: Union[List, torch.Tensor, np.array, None],
y_true: Union[List, torch.Tensor, np.array, None],
):
if self.name in DefenseMetric.available_metrics:
return DefenseMetric.available_metrics[self.name](
Expand Down

0 comments on commit 0df1ec1

Please sign in to comment.