Skip to content

Commit

Permalink
add class AttackMetric and asr meric
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Dec 3, 2024
1 parent bc4a85c commit 2c8bbf6
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 45 deletions.
53 changes: 30 additions & 23 deletions experiments/attack_defense_metric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import device

from models_builder.attack_defense_manager import FrameworkAttackDefenseManager
from models_builder.attack_defense_metric import AttackMetric
from models_builder.models_utils import apply_decorator_to_graph_layers
from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
EVASION_DEFENSE_PARAMETERS_PATH
Expand Down Expand Up @@ -114,34 +115,40 @@ def attack_defense_metrics():
warnings.warn("Start training")
dataset.train_test_split()

try:
# raise FileNotFoundError()
gnn_model_manager.load_model_executor()
dataset = gnn_model_manager.load_train_test_split(dataset)
except FileNotFoundError:
gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)])

if train_test_split_path is not None:
dataset.save_train_test_mask(train_test_split_path)
train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[
:]
dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask
data.percent_train_class, data.percent_test_class = train_test_sizes

warnings.warn("Training was successful")

metric_loc = gnn_model_manager.evaluate_model(
gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'),
Metric("Accuracy", mask='test')])
print(metric_loc)
# try:
# # raise FileNotFoundError()
# gnn_model_manager.load_model_executor()
# dataset = gnn_model_manager.load_train_test_split(dataset)
# except FileNotFoundError:
# gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
# train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs,
# save_model_flag=save_model_flag,
# metrics=[Metric("F1", mask='train', average=None)])
#
# if train_test_split_path is not None:
# dataset.save_train_test_mask(train_test_split_path)
# train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[
# :]
# dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask
# data.percent_train_class, data.percent_test_class = train_test_sizes
#
# warnings.warn("Training was successful")
#
# metric_loc = gnn_model_manager.evaluate_model(
# gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'),
# Metric("Accuracy", mask='test')])
# print(metric_loc)

adm = FrameworkAttackDefenseManager(
gen_dataset=dataset,
gnn_manager=gnn_model_manager,
)
adm.evasion_attack_pipeline(
steps=steps_epochs,
save_model_flag=save_model_flag,
metrics_attack=[AttackMetric("ASR")],
mask='test'
)


if __name__ == '__main__':
Expand Down
54 changes: 41 additions & 13 deletions src/models_builder/attack_defense_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import warnings
from typing import Type
from typing import Type, Union, List

import torch

from base.datasets_processing import GeneralDataset

Expand Down Expand Up @@ -69,45 +71,71 @@ def return_attack_defense_flags(self):
def evasion_attack_pipeline(
self,
metrics_attack,
model_metrics,
steps: int,
save_model_flag: bool = True,
mask: Union[str, List[bool], torch.Tensor] = 'test',
):
metrics_values = {}
if self.available_attacks["evasion"]:
self.set_clear_model()
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
from models_builder.gnn_models import Metric
self.gnn_manager.train_model(
gen_dataset=self.gen_dataset,
steps=steps,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)]
)
metric_clean_model = self.gnn_manager.evaluate_model(
y_predict_clean = self.gnn_manager.run_model(
gen_dataset=self.gen_dataset,
metrics=model_metrics
mask=mask,
out='logits',
)

self.gnn_manager.evasion_attack_flag = True
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=self.gen_dataset,
steps=steps,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)]
)
metric_evasion_attack_only = self.gnn_manager.evaluate_model(
self.gnn_manager.call_evasion_attack(
gen_dataset=self.gen_dataset,
metrics=model_metrics
mask=mask,
)
y_predict_attack = self.gnn_manager.run_model(
gen_dataset=self.gen_dataset,
mask=mask,
out='logits',
)
metrics_values = self.evaluate_attack_defense(
y_predict_after_attack_only=y_predict_attack,
y_predict_clean=y_predict_clean,
metrics_attack=metrics_attack,
)
# TODO Kirill
# metrics_values = evaluate_attacks(
# metric_clean_model,
# metric_evasion_attack_only,
# metrics_attack=metrics_attack
# )
self.return_attack_defense_flags()
pass

else:
warnings.warn(f"Evasion attack is not available. Please set evasion attack for "
f"gnn_model_manager use def set_evasion_attacker")

return metrics_values

def evaluate_attack_defense(
self,
y_predict_clean,
y_predict_after_attack_only=None,
y_predict_after_defense_only=None,
y_predict_after_attack_and_defense=None,
metrics_attack=None,
metrics_defense=None,
):
metrics_attack_values = {}
if metrics_attack is not None and y_predict_after_attack_only is not None:
for metric in metrics_attack:
metrics_attack_values[metric.name] = metric.compute(y_predict_clean, y_predict_after_attack_only)

return metrics_attack_values
51 changes: 51 additions & 0 deletions src/models_builder/attack_defense_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Union, List, Callable

import sklearn
import torch


def asr(
y_predict_clean,
y_predict_after_attack_only,
**kwargs
):
if isinstance(y_predict_clean, torch.Tensor):
if y_predict_clean.dim() > 1:
y_predict_clean = y_predict_clean.argmax(dim=1)
y_predict_clean.cpu()
if isinstance(y_predict_after_attack_only, torch.Tensor):
if y_predict_after_attack_only.dim() > 1:
y_predict_after_attack_only = y_predict_after_attack_only.argmax(dim=1)
y_predict_after_attack_only.cpu()
print("ASR ", 1 - sklearn.metrics.accuracy_score(y_true=y_predict_clean, y_pred=y_predict_after_attack_only))
return 1 - sklearn.metrics.accuracy_score(y_true=y_predict_clean, y_pred=y_predict_after_attack_only)


class AttackMetric:
available_metrics = {
'ASR': asr,
}

def __init__(
self,
name: str,
**kwargs
):
self.name = name
self.kwargs = kwargs

def compute(
self,
metrics_clean_model,
metrics_after_attack
):
if self.name in AttackMetric.available_metrics:
return AttackMetric.available_metrics[self.name](
metrics_clean_model,
metrics_after_attack,
**self.kwargs
)
raise NotImplementedError()



47 changes: 38 additions & 9 deletions src/models_builder/gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,15 +951,15 @@ def train_on_batch_full(
batch,
task_type: str = None
) -> torch.Tensor:
if self.mi_defender:
if self.mi_defender and self.mi_defense_flag:
self.mi_defender.pre_batch()
if self.evasion_defender:
if self.evasion_defender and self.evasion_defense_flag:
self.evasion_defender.pre_batch(model_manager=self, batch=batch)
loss = self.train_on_batch(batch=batch, task_type=task_type)
if self.mi_defender:
if self.mi_defender and self.mi_defense_flag:
self.mi_defender.post_batch()
evasion_defender_dict = None
if self.evasion_defender:
if self.evasion_defender and self.evasion_defense_flag:
evasion_defender_dict = self.evasion_defender.post_batch(
model_manager=self, batch=batch, loss=loss,
)
Expand Down Expand Up @@ -1093,12 +1093,12 @@ def train_model(
:param metrics: list of metrics to measure at each step or at the end of training
:param socket: socket to use for sending data to frontend
"""
if self.poison_attacker:
if self.poison_attacker and self.poison_attack_flag:
loc = self.poison_attacker.attack(gen_dataset=gen_dataset)
if loc is not None:
gen_dataset = loc

if self.poison_defender:
if self.poison_defender and self.poison_defense_flag:
loc = self.poison_defender.defense(gen_dataset=gen_dataset)
if loc is not None:
gen_dataset = loc
Expand Down Expand Up @@ -1256,18 +1256,47 @@ def evaluate_model(
except KeyError:
assert isinstance(mask, torch.Tensor)
mask_tensor = mask
if self.evasion_attacker:
self.evasion_attacker.attack(model_manager=self, gen_dataset=gen_dataset, mask_tensor=mask_tensor)
if self.evasion_attacker and self.evasion_attack_flag:
self.call_evasion_attack(
gen_dataset=gen_dataset,
mask=mask,
)
metrics_values[mask] = {}
y_pred = self.run_model(gen_dataset, mask=mask)
y_true = gen_dataset.labels[mask_tensor]

for metric in ms:
metrics_values[mask][metric.name] = metric.compute(y_pred, y_true)
# metrics_values[mask][metric.name] = MetricManager.compute(metric, y_pred, y_true)
if self.mi_attacker and self.mi_attack_flag:
self.call_mi_attack()
return metrics_values

def call_evasion_attack(
self,
gen_dataset: GeneralDataset,
mask: Union[str, List[bool], torch.Tensor] = 'test',
):
if self.evasion_attacker:
try:
mask_tensor = {
'train': gen_dataset.train_mask.tolist(),
'val': gen_dataset.val_mask.tolist(),
'test': gen_dataset.test_mask.tolist(),
'all': [True] * len(gen_dataset.labels),
}[mask]
except KeyError:
assert isinstance(mask, torch.Tensor)
mask_tensor = mask
self.evasion_attacker.attack(
model_manager=self,
gen_dataset=gen_dataset,
mask_tensor=mask_tensor
)

def call_mi_attack(self):
if self.mi_attacker:
self.mi_attacker.attack()
return metrics_values

def compute_stats_data(
self,
Expand Down

0 comments on commit 2c8bbf6

Please sign in to comment.