Skip to content

Commit

Permalink
Add experiment pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeratt committed Dec 5, 2024
1 parent 1bf8166 commit 2a4d574
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 32 deletions.
112 changes: 80 additions & 32 deletions experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import torch
import copy

import warnings

from torch import device

from models_builder.attack_defense_manager import FrameworkAttackDefenseManager
from models_builder.attack_defense_metric import AttackMetric, DefenseMetric
from models_builder.models_utils import apply_decorator_to_graph_layers
from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
EVASION_DEFENSE_PARAMETERS_PATH
Expand All @@ -17,15 +20,17 @@
from defense.GNNGuard import gnnguard


def test_attack_defense():
def test_attack_defense(d='Cora', m='gin_2', a='fgsm'):
my_device = device('cuda' if torch.cuda.is_available() else 'cpu')

full_name = None

# full_name = ("multiple-graphs", "TUDataset", 'MUTAG')
# full_name = ("single-graph", "custom", 'karate')
full_name = ("single-graph", "Planetoid", 'Cora')
# full_name = ("single-graph", "Amazon", 'Photo')
if d == 'Cora':
full_name = ("single-graph", "Planetoid", 'Cora')
elif d == 'Photo':
full_name = ("single-graph", "Amazon", 'Photo')
# full_name = ("single-graph", "Planetoid", 'CiteSeer')
# full_name = ("multiple-graphs", "TUDataset", 'PROTEINS')

Expand Down Expand Up @@ -60,7 +65,12 @@ def test_attack_defense():

# print(data.train_mask)

gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn')
if m == 'gcn_2':
gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn')
elif m == 'gcn_3':
gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn_gcn')
elif m == 'gin_2':
gnn = model_configs_zoo(dataset=dataset, model_name='gin_gin')
# gnn = model_configs_zoo(dataset=dataset, model_name='gat_gcn_sage_gcn_gcn')
# gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn_lin')
# gnn = model_configs_zoo(dataset=dataset, model_name='test_gnn')
Expand Down Expand Up @@ -102,8 +112,8 @@ def test_attack_defense():
modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs)
)

save_model_flag = False
# save_model_flag = True
# save_model_flag = False
save_model_flag = True

# data.x = data.x.float()
gnn_model_manager.gnn.to(my_device)
Expand Down Expand Up @@ -250,7 +260,16 @@ def test_attack_defense():
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
_config_class="EvasionAttackConfig",
_config_kwargs={
"epsilon": 0.1 * 1,
"epsilon": 0.01,
}
)

fgsm_evasion_attack_config1 = ConfigPattern(
_class_name="FGSM",
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
_config_class="EvasionAttackConfig",
_config_kwargs={
"epsilon": 0.1,
}
)
at_evasion_defense_config = ConfigPattern(
Expand All @@ -259,40 +278,58 @@ def test_attack_defense():
_config_class="EvasionDefenseConfig",
_config_kwargs={
"attack_name": None,
"attack_config": fgsm_evasion_attack_config0
"attack_config": fgsm_evasion_attack_config1
}
)

# gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config)
# gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config)
# gnn_model_manager.set_evasion_attacker(evasion_attack_config=fgsm_evasion_attack_config)
# gnn_model_manager.set_evasion_defender(evasion_defense_config=autoencoder_evasion_defense_config)
if a == 'fgsm':
gnn_model_manager.set_evasion_attacker(evasion_attack_config=fgsm_evasion_attack_config)
elif a == 'nettack':
gnn_model_manager.set_evasion_attacker(evasion_attack_config=netattack_evasion_attack_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config)

warnings.warn("Start training")
dataset.train_test_split()

try:
raise FileNotFoundError()
# gnn_model_manager.load_model_executor()
except FileNotFoundError:
gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)])

if train_test_split_path is not None:
dataset.save_train_test_mask(train_test_split_path)
train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[
:]
dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask
data.percent_train_class, data.percent_test_class = train_test_sizes

warnings.warn("Training was successful")
for i in range(20):
adm = FrameworkAttackDefenseManager(
gen_dataset=copy.deepcopy(dataset),
gnn_manager=gnn_model_manager,
)
adm.evasion_defense_pipeline(
steps=steps_epochs,
save_model_flag=save_model_flag,
metrics_attack=[AttackMetric("ASR"), AttackMetric("AuccAttackDiff"),],
metrics_defense=[DefenseMetric("AuccDefenseCleanDiff"), DefenseMetric("AuccDefenseAttackDiff"), ],
mask='test'
)

metric_loc = gnn_model_manager.evaluate_model(
gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'),
Metric("Accuracy", mask='test')])
print(metric_loc)
#
# try:
# raise FileNotFoundError()
# # gnn_model_manager.load_model_executor()
# except FileNotFoundError:
# gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
# train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs,
# save_model_flag=save_model_flag,
# metrics=[Metric("F1", mask='train', average=None)])
#
# if train_test_split_path is not None:
# dataset.save_train_test_mask(train_test_split_path)
# train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[
# :]
# dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask
# data.percent_train_class, data.percent_test_class = train_test_sizes
#
# warnings.warn("Training was successful")
#
# metric_loc = gnn_model_manager.evaluate_model(
# gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'),
# Metric("Accuracy", mask='test')])
# print(metric_loc)


def test_meta():
Expand Down Expand Up @@ -1005,11 +1042,22 @@ def test_pgd():
print(f"After PGD attack on graph (MUTAG dataset): {info_after_pgd_attack_on_graph}")



def exp_pipeline():
dataset_grid = ['Photo', 'Cora']
model_grid = ['gcn_2', 'gcn_3', 'gin_2']
attack_grid = ['fgsm', 'nettack']
for d in dataset_grid:
for m in model_grid:
for a in attack_grid:
test_attack_defense(d, m, a)

if __name__ == '__main__':
import random

random.seed(10)
test_attack_defense()
#random.seed(10)
#test_attack_defense()
exp_pipeline()
# torch.manual_seed(5000)
# test_gnnguard()
# test_jaccard()
Expand Down
158 changes: 158 additions & 0 deletions src/models_builder/models_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,111 @@ def model_configs_zoo(
)
)

gin_gin = FrameworkGNNConstructor(
model_config=ModelConfig(
structure=ModelStructureConfig(
[
{
'label': 'n',
'layer': {
'layer_name': 'GINConv',
'layer_kwargs': None,
'gin_seq': [
{
'layer': {
'layer_name': 'Linear',
'layer_kwargs': {
'in_features': dataset.num_node_features,
'out_features': 16,
},
},
'batchNorm': {
'batchNorm_name': 'BatchNorm1d',
'batchNorm_kwargs': {
'num_features': 16,
'eps': 1e-05,
}
},
'activation': {
'activation_name': 'ReLU',
'activation_kwargs': None,
},
},
{
'layer': {
'layer_name': 'Linear',
'layer_kwargs': {
'in_features': 16,
'out_features': 16,
},
},
'batchNorm': {
'batchNorm_name': 'BatchNorm1d',
'batchNorm_kwargs': {
'num_features': 16,
'eps': 1e-05,
}
},
'activation': {
'activation_name': 'ReLU',
'activation_kwargs': None,
},
},
],
},
'activation': {
'activation_name': 'ReLU',
'activation_kwargs': None,
},
},

{
'label': 'n',
'layer': {
'layer_name': 'GINConv',
'layer_kwargs': None,
'gin_seq': [
{
'layer': {
'layer_name': 'Linear',
'layer_kwargs': {
'in_features': 16,
'out_features': 16,
},
},
'batchNorm': {
'batchNorm_name': 'BatchNorm1d',
'batchNorm_kwargs': {
'num_features': 16,
'eps': 1e-05,
}
},
'activation': {
'activation_name': 'ReLU',
'activation_kwargs': None,
},
},
{
'layer': {
'layer_name': 'Linear',
'layer_kwargs': {
'in_features': 16,
'out_features': dataset.num_classes,
},
},
},
],
},
'activation': {
'activation_name': 'LogSoftmax',
'activation_kwargs': None,
},
},
]
)
)
)

gat_gat = FrameworkGNNConstructor(
model_config=ModelConfig(
structure=ModelStructureConfig(
Expand Down Expand Up @@ -414,6 +519,59 @@ def model_configs_zoo(
)
)

gcn_gcn_gcn = FrameworkGNNConstructor(
model_config=ModelConfig(
structure=ModelStructureConfig(
[
{
'label': 'n',
'layer': {
'layer_name': 'GCNConv',
'layer_kwargs': {
'in_channels': dataset.num_node_features,
'out_channels': 16,
},
},
'activation': {
'activation_name': 'ReLU',
'activation_kwargs': None,
},
},

{
'label': 'n',
'layer': {
'layer_name': 'GCNConv',
'layer_kwargs': {
'in_channels': 16,
'out_channels': 16,
},
},
'activation': {
'activation_name': 'ReLU',
'activation_kwargs': None,
},
},

{
'label': 'n',
'layer': {
'layer_name': 'GCNConv',
'layer_kwargs': {
'in_channels': 16,
'out_channels': dataset.num_classes,
},
},
'activation': {
'activation_name': 'LogSoftmax',
'activation_kwargs': None,
},
},
]
)
)
)

gcn_gcn_no_self_loops = FrameworkGNNConstructor(
model_config=ModelConfig(
structure=ModelStructureConfig(
Expand Down

0 comments on commit 2a4d574

Please sign in to comment.