Skip to content

Commit

Permalink
add 3 metrics and add poison_defense_pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Dec 3, 2024
1 parent a974200 commit 9195446
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 15 deletions.
13 changes: 10 additions & 3 deletions experiments/attack_defense_metric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import device

from models_builder.attack_defense_manager import FrameworkAttackDefenseManager
from models_builder.attack_defense_metric import AttackMetric
from models_builder.attack_defense_metric import AttackMetric, DefenseMetric
from models_builder.models_utils import apply_decorator_to_graph_layers
from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
EVASION_DEFENSE_PARAMETERS_PATH
Expand Down Expand Up @@ -150,10 +150,17 @@ def attack_defense_metrics():
# metrics_attack=[AttackMetric("ASR")],
# mask='test'
# )
adm.poison_attack_pipeline(
# adm.poison_attack_pipeline(
# steps=steps_epochs,
# save_model_flag=save_model_flag,
# metrics_attack=[AttackMetric("ASR")],
# mask='test'
# )
adm.poison_defense_pipeline(
steps=steps_epochs,
save_model_flag=save_model_flag,
metrics_attack=[AttackMetric("ASR")],
metrics_attack=[AttackMetric("ASR"), AttackMetric("AuccAttackDiff"),],
metrics_defense=[DefenseMetric("AuccDefenseCleanDiff"), DefenseMetric("AuccDefenseAttackDiff"), ],
mask='test'
)

Expand Down
119 changes: 114 additions & 5 deletions src/models_builder/attack_defense_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,101 @@ def poison_attack_pipeline(

return metrics_values

def poison_defense_pipeline(
self,
metrics_attack: List,
metrics_defense: List,
steps: int,
save_model_flag: bool = True,
mask: Union[str, List[bool], torch.Tensor] = 'test',
) -> dict:
metrics_values = {}
if self.available_attacks["poison"] and self.available_defense["poison"]:
from models_builder.gnn_models import Metric
local_gen_dataset_copy = copy.deepcopy(self.gen_dataset)
self.set_clear_model()
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=False,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_clean = self.gnn_manager.run_model(
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)

self.gnn_manager.poison_defense_flag = True
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=False,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_after_defense_only = self.gnn_manager.run_model(
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)

local_gen_dataset_copy = copy.deepcopy(self.gen_dataset)
self.gnn_manager.poison_defense_flag = False
self.gnn_manager.poison_attack_flag = True
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=False,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_after_attack_only = self.gnn_manager.run_model(
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)

self.gnn_manager.poison_defense_flag = True
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_after_attack_and_defense = self.gnn_manager.run_model(
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)

metrics_attack_values, metrics_defense_values = self.evaluate_attack_defense(
y_predict_after_attack_only=y_predict_after_attack_only,
y_predict_clean=y_predict_clean,
y_predict_after_defense_only=y_predict_after_defense_only,
y_predict_after_attack_and_defense=y_predict_after_attack_and_defense,
metrics_attack=metrics_attack,
metrics_defense=metrics_defense,
mask=mask,
)
if save_model_flag:
self.save_metrics(
metrics_attack_values=metrics_attack_values,
metrics_defense_values=metrics_defense_values,
)
self.return_attack_defense_flags()
else:
warnings.warn(f"Evasion attack is not available. Please set evasion attack for "
f"gnn_model_manager use def set_evasion_attacker")

return metrics_values

def save_metrics(
self,
metrics_attack_values: Union[dict, None] = None,
Expand All @@ -221,9 +316,8 @@ def save_metrics(
new_dict=metrics_defense_values
)

@staticmethod
def evaluate_attack_defense(
# self,
self,
y_predict_clean: Union[List, torch.Tensor, np.array],
mask: Union[str, torch.Tensor],
y_predict_after_attack_only: Union[List, torch.Tensor, np.array, None] = None,
Expand All @@ -232,13 +326,26 @@ def evaluate_attack_defense(
metrics_attack: Union[List, None] = None,
metrics_defense: Union[List, None] = None,
):

try:
mask_tensor = {
'train': self.gen_dataset.train_mask.tolist(),
'val': self.gen_dataset.val_mask.tolist(),
'test': self.gen_dataset.test_mask.tolist(),
'all': [True] * len(self.gen_dataset.labels),
}[mask]
except KeyError:
assert isinstance(mask, torch.Tensor)
mask_tensor = mask
y_true = copy.deepcopy(self.gen_dataset.labels[mask_tensor])
metrics_attack_values = {mask: {}}
metrics_defense_values = {mask: {}}
if metrics_attack is not None and y_predict_after_attack_only is not None:
for metric in metrics_attack:
metrics_attack_values[mask][metric.name] = metric.compute(
y_predict_clean=y_predict_clean,
y_predict_after_attack_only=y_predict_after_attack_only
y_predict_after_attack_only=y_predict_after_attack_only,
y_true=y_true,
)
if (
metrics_defense is not None
Expand All @@ -248,10 +355,12 @@ def evaluate_attack_defense(
for metric in metrics_defense:
metrics_defense_values[mask][metric.name] = metric.compute(
y_predict_clean=y_predict_clean,
y_predict_after_attack_only=y_predict_after_attack_only,
y_predict_after_defense_only=y_predict_after_defense_only,
y_predict_after_attack_and_defense=y_predict_after_attack_and_defense
y_predict_after_attack_and_defense=y_predict_after_attack_and_defense,
y_true=y_true,
)

print("!!!! ", metrics_attack_values, metrics_defense_values)
return metrics_attack_values, metrics_defense_values

@staticmethod
Expand Down
89 changes: 82 additions & 7 deletions src/models_builder/attack_defense_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,79 @@ def asr(
return 1 - sklearn.metrics.accuracy_score(y_true=y_predict_clean, y_pred=y_predict_after_attack_only)


# TODO Kirill, change for any classic metric
def aucc_change_attack(
y_predict_clean,
y_predict_after_attack_only,
y_true,
**kwargs
):
if isinstance(y_predict_clean, torch.Tensor):
if y_predict_clean.dim() > 1:
y_predict_clean = y_predict_clean.argmax(dim=1)
y_predict_clean.cpu()
if isinstance(y_predict_after_attack_only, torch.Tensor):
if y_predict_after_attack_only.dim() > 1:
y_predict_after_attack_only = y_predict_after_attack_only.argmax(dim=1)
y_predict_after_attack_only.cpu()
if isinstance(y_true, torch.Tensor):
if y_true.dim() > 1:
y_true = y_true.argmax(dim=1)
y_true.cpu()
return (sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_clean) -
sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_after_attack_only))


# TODO Kirill, change for any classic metric
def aucc_change_defense_only(
y_predict_clean,
y_predict_after_defense_only,
y_true,
**kwargs
):
if isinstance(y_predict_clean, torch.Tensor):
if y_predict_clean.dim() > 1:
y_predict_clean = y_predict_clean.argmax(dim=1)
y_predict_clean.cpu()
if isinstance(y_predict_after_defense_only, torch.Tensor):
if y_predict_after_defense_only.dim() > 1:
y_predict_after_defense_only = y_predict_after_defense_only.argmax(dim=1)
y_predict_after_defense_only.cpu()
if isinstance(y_true, torch.Tensor):
if y_true.dim() > 1:
y_true = y_true.argmax(dim=1)
y_true.cpu()
return (sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_clean) -
sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_after_defense_only))


# TODO Kirill, change for any classic metric
def aucc_change_defense_with_attack(
y_predict_after_attack_only,
y_predict_after_attack_and_defense,
y_true,
**kwargs
):
if isinstance(y_predict_after_attack_only, torch.Tensor):
if y_predict_after_attack_only.dim() > 1:
y_predict_after_attack_only = y_predict_after_attack_only.argmax(dim=1)
y_predict_after_attack_only.cpu()
if isinstance(y_predict_after_attack_and_defense, torch.Tensor):
if y_predict_after_attack_and_defense.dim() > 1:
y_predict_after_attack_and_defense = y_predict_after_attack_and_defense.argmax(dim=1)
y_predict_after_attack_and_defense.cpu()
if isinstance(y_true, torch.Tensor):
if y_true.dim() > 1:
y_true = y_true.argmax(dim=1)
y_true.cpu()
return (sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_after_attack_and_defense) -
sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_predict_after_attack_only))


class AttackMetric:
available_metrics = {
'ASR': asr,
"ASR": asr,
"AuccAttackDiff": aucc_change_attack,
}

def __init__(
Expand All @@ -37,18 +107,22 @@ def compute(
self,
y_predict_clean,
y_predict_after_attack_only,
y_true,
):
if self.name in AttackMetric.available_metrics:
return AttackMetric.available_metrics[self.name](
y_predict_clean=y_predict_clean,
y_predict_after_attack_only=y_predict_after_attack_only,
y_true=y_true,
**self.kwargs
)
raise NotImplementedError()


class DefenseMetric:
available_metrics = {
"AuccDefenseCleanDiff": aucc_change_defense_only,
"AuccDefenseAttackDiff": aucc_change_defense_with_attack,
}

def __init__(
Expand All @@ -62,17 +136,18 @@ def __init__(
def compute(
self,
y_predict_clean,
y_predict_after_attack_only,
y_predict_after_defense_only,
y_predict_after_attack_and_defense,
y_true,
):
if self.name in AttackMetric.available_metrics:
return AttackMetric.available_metrics[self.name](
if self.name in DefenseMetric.available_metrics:
return DefenseMetric.available_metrics[self.name](
y_predict_clean=y_predict_clean,
y_predict_after_defense_only=y_predict_after_defense_only,
y_predict_after_attack_only=y_predict_after_attack_only,
y_predict_after_attack_and_defense=y_predict_after_attack_and_defense,
y_true=y_true,
**self.kwargs
)
raise NotImplementedError()



raise NotImplementedError(f"Metric {self.name} is not implemented")

0 comments on commit 9195446

Please sign in to comment.