Skip to content

Commit

Permalink
add poison_attack_pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Dec 3, 2024
1 parent 2c8bbf6 commit 6b51f7d
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
8 changes: 7 additions & 1 deletion experiments/attack_defense_metric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,13 @@ def attack_defense_metrics():
gen_dataset=dataset,
gnn_manager=gnn_model_manager,
)
adm.evasion_attack_pipeline(
# adm.evasion_attack_pipeline(
# steps=steps_epochs,
# save_model_flag=save_model_flag,
# metrics_attack=[AttackMetric("ASR")],
# mask='test'
# )
adm.poison_attack_pipeline(
steps=steps_epochs,
save_model_flag=save_model_flag,
metrics_attack=[AttackMetric("ASR")],
Expand Down
52 changes: 52 additions & 0 deletions src/models_builder/attack_defense_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,58 @@ def evasion_attack_pipeline(

return metrics_values

def poison_attack_pipeline(
self,
metrics_attack,
steps: int,
save_model_flag: bool = True,
mask: Union[str, List[bool], torch.Tensor] = 'test',
):
metrics_values = {}
if self.available_attacks["poison"]:
self.set_clear_model()
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
from models_builder.gnn_models import Metric
self.gnn_manager.train_model(
gen_dataset=self.gen_dataset,
steps=steps,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_clean = self.gnn_manager.run_model(
gen_dataset=self.gen_dataset,
mask=mask,
out='logits',
)

self.gnn_manager.poison_attack_flag = True
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=self.gen_dataset,
steps=steps,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_attack = self.gnn_manager.run_model(
gen_dataset=self.gen_dataset,
mask=mask,
out='logits',
)
metrics_values = self.evaluate_attack_defense(
y_predict_after_attack_only=y_predict_attack,
y_predict_clean=y_predict_clean,
metrics_attack=metrics_attack,
)
self.return_attack_defense_flags()

else:
warnings.warn(f"Evasion attack is not available. Please set evasion attack for "
f"gnn_model_manager use def set_evasion_attacker")

return metrics_values

def evaluate_attack_defense(
self,
y_predict_clean,
Expand Down

0 comments on commit 6b51f7d

Please sign in to comment.