Skip to content

Commit

Permalink
add deepcopy in attack_pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Dec 3, 2024
1 parent 255f851 commit a974200
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
3 changes: 2 additions & 1 deletion experiments/attack_defense_metric_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import warnings

import torch
Expand Down Expand Up @@ -140,7 +141,7 @@ def attack_defense_metrics():
# print(metric_loc)

adm = FrameworkAttackDefenseManager(
gen_dataset=dataset,
gen_dataset=copy.deepcopy(dataset),
gnn_manager=gnn_model_manager,
)
# adm.evasion_attack_pipeline(
Expand Down
23 changes: 13 additions & 10 deletions src/models_builder/attack_defense_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import json
import os
import warnings
Expand Down Expand Up @@ -89,14 +90,15 @@ def evasion_attack_pipeline(
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
from models_builder.gnn_models import Metric
local_gen_dataset_copy = copy.deepcopy(self.gen_dataset)
self.gnn_manager.train_model(
gen_dataset=self.gen_dataset,
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=save_model_flag,
save_model_flag=False,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_clean = self.gnn_manager.run_model(
gen_dataset=self.gen_dataset,
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)
Expand All @@ -105,17 +107,17 @@ def evasion_attack_pipeline(
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=self.gen_dataset,
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)]
)
self.gnn_manager.call_evasion_attack(
gen_dataset=self.gen_dataset,
gen_dataset=local_gen_dataset_copy,
mask=mask,
)
y_predict_attack = self.gnn_manager.run_model(
gen_dataset=self.gen_dataset,
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)
Expand Down Expand Up @@ -152,14 +154,15 @@ def poison_attack_pipeline(
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
from models_builder.gnn_models import Metric
local_gen_dataset_copy = copy.deepcopy(self.gen_dataset)
self.gnn_manager.train_model(
gen_dataset=self.gen_dataset,
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=False,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_clean = self.gnn_manager.run_model(
gen_dataset=self.gen_dataset,
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)
Expand All @@ -168,13 +171,13 @@ def poison_attack_pipeline(
self.gnn_manager.modification.epochs = 0
self.gnn_manager.gnn.reset_parameters()
self.gnn_manager.train_model(
gen_dataset=self.gen_dataset,
gen_dataset=local_gen_dataset_copy,
steps=steps,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)]
)
y_predict_attack = self.gnn_manager.run_model(
gen_dataset=self.gen_dataset,
gen_dataset=local_gen_dataset_copy,
mask=mask,
out='logits',
)
Expand Down

0 comments on commit a974200

Please sign in to comment.