Skip to content

Commit

Permalink
make better files in attacks
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Nov 21, 2024
1 parent b326cb0 commit 0b82795
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 73 deletions.
23 changes: 19 additions & 4 deletions src/attacks/attack_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
from typing import Type

from base.datasets_processing import DatasetManager


class Attacker:
name = "Attacker"

def __init__(self):
def __init__(
self
):
pass

def attack(self, **kwargs):
def attack(
self,
**kwargs
):
pass

def attack_diff(self):
def attack_diff(
self
):
pass

@staticmethod
def check_availability(gen_dataset, model_manager):
def check_availability(
gen_dataset: DatasetManager,
model_manager: Type
):
return False


150 changes: 110 additions & 40 deletions src/attacks/evasion_attacks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Type, Union

import torch
import torch.nn.functional as F
import numpy as np

from attacks.attack_base import Attacker
from base.datasets_processing import DatasetManager

# Nettack imports
from src.attacks.nettack.nettack import Nettack
Expand All @@ -11,31 +14,50 @@
# PGD imports
from attacks.evasion_attacks_collection.pgd.utils import Projection, RandomSampling
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj, dense_to_sparse, k_hop_subgraph
from torch_geometric.utils import k_hop_subgraph
from tqdm import tqdm
from torch_geometric.nn import SGConv


class EvasionAttacker(Attacker):
def __init__(self, **kwargs):
class EvasionAttacker(
Attacker
):
def __init__(
self,
**kwargs
):
super().__init__()


class EmptyEvasionAttacker(EvasionAttacker):
class EmptyEvasionAttacker(
EvasionAttacker
):
name = "EmptyEvasionAttacker"

def attack(self, **kwargs):
def attack(
self,
**kwargs
):
pass


class FGSMAttacker(EvasionAttacker):
class FGSMAttacker(
EvasionAttacker
):
name = "FGSM"

def __init__(self, epsilon=0.1):
def __init__(
self,
epsilon: float = 0.1
):
super().__init__()
self.epsilon = epsilon

def attack(self, model_manager, gen_dataset, mask_tensor):
def attack(
self,
model_manager: Type,
gen_dataset: DatasetManager,
mask_tensor: torch.Tensor
):
gen_dataset.data.x.requires_grad = True
output = model_manager.gnn(gen_dataset.data.x, gen_dataset.data.edge_index, gen_dataset.data.batch)
loss = model_manager.loss_function(output[mask_tensor],
Expand All @@ -49,16 +71,20 @@ def attack(self, model_manager, gen_dataset, mask_tensor):
return gen_dataset


class PGDAttacker(EvasionAttacker):
class PGDAttacker(
EvasionAttacker
):
name = "PGD"

def __init__(self,
is_feature_attack=False,
element_idx=0,
epsilon=0.5,
learning_rate=0.001,
num_iterations=100,
num_rand_trials=100):
def __init__(
self,
is_feature_attack: bool = False,
element_idx: int = 0,
epsilon: float = 0.5,
learning_rate: float = 0.001,
num_iterations: int = 100,
num_rand_trials: int = 100
):

super().__init__()
self.attack_diff = None
Expand All @@ -69,13 +95,22 @@ def __init__(self,
self.num_iterations = num_iterations
self.num_rand_trials = num_rand_trials

def attack(self, model_manager, gen_dataset, mask_tensor):
def attack(
self,
model_manager: Type,
gen_dataset: DatasetManager,
mask_tensor: torch.Tensor
) -> None:
if gen_dataset.is_multi():
self._attack_on_graph(model_manager, gen_dataset)
else:
self._attack_on_node(model_manager, gen_dataset)

def _attack_on_node(self, model_manager, gen_dataset):
def _attack_on_node(
self,
model_manager: Type,
gen_dataset: DatasetManager
) -> None:
node_idx = self.element_idx

edge_index = gen_dataset.data.edge_index
Expand Down Expand Up @@ -118,7 +153,11 @@ def _attack_on_node(self, model_manager, gen_dataset):
else: # structure attack
pass

def _attack_on_graph(self, model_manager, gen_dataset):
def _attack_on_graph(
self,
model_manager: Type,
gen_dataset: DatasetManager
):
graph_idx = self.element_idx

edge_index = gen_dataset.dataset[graph_idx].edge_index
Expand Down Expand Up @@ -149,21 +188,26 @@ def _attack_on_graph(self, model_manager, gen_dataset):
else: # structure attack
pass

def attack_diff(self):
def attack_diff(
self
):
return self.attack_diff


class NettackEvasionAttacker(EvasionAttacker):
class NettackEvasionAttacker(
EvasionAttacker
):
name = "NettackEvasionAttacker"

def __init__(self,
node_idx=0,
n_perturbations=None,
perturb_features=True,
perturb_structure=True,
direct=True,
n_influencers=0
):
def __init__(
self,
node_idx: int = 0,
n_perturbations: Union[int, None] = None,
perturb_features: bool = True,
perturb_structure: bool = True,
direct: bool = True,
n_influencers: int = 0
):

super().__init__()
self.attack_diff = None
Expand All @@ -174,7 +218,12 @@ def __init__(self,
self.direct = direct
self.n_influencers = n_influencers

def attack(self, model_manager, gen_dataset, mask_tensor):
def attack(
self,
model_manager: Type,
gen_dataset: DatasetManager,
mask_tensor: torch.Tensor
) -> DatasetManager:
# Prepare
data = gen_dataset.data
_A_obs, _X_obs, _z_obs = data_to_csr_matrix(data)
Expand Down Expand Up @@ -222,11 +271,17 @@ def attack(self, model_manager, gen_dataset, mask_tensor):

return gen_dataset

def attack_diff(self):
def attack_diff(
self
):
return self.attack_diff

@staticmethod
def _evasion(gen_dataset, feature_perturbations, structure_perturbations):
def _evasion(
gen_dataset: DatasetManager,
feature_perturbations,
structure_perturbations
):
cleaned_feat_pert = list(filter(None, feature_perturbations))
if cleaned_feat_pert: # list is not empty
x = gen_dataset.data.x.clone()
Expand All @@ -243,17 +298,27 @@ def _evasion(gen_dataset, feature_perturbations, structure_perturbations):
# add edges
for edge in cleaned_struct_pert:
edge_index = torch.cat((edge_index,
torch.tensor((edge[0], edge[1]), dtype=torch.int32).to(torch.int64).unsqueeze(1)), dim=1)
torch.tensor((edge[0], edge[1]), dtype=torch.int32).to(torch.int64).unsqueeze(
1)), dim=1)
edge_index = torch.cat((edge_index,
torch.tensor((edge[1], edge[0]), dtype=torch.int32).to(torch.int64).unsqueeze(1)), dim=1)
torch.tensor((edge[1], edge[0]), dtype=torch.int32).to(torch.int64).unsqueeze(
1)), dim=1)

gen_dataset.data.edge_index = edge_index

class NettackGroupEvasionAttacker(EvasionAttacker):

class NettackGroupEvasionAttacker(
EvasionAttacker
):
name = "NettackGroupEvasionAttacker"
def __init__(self,node_idxs, **kwargs):

def __init__(
self,
node_idxs: list,
**kwargs
):
super().__init__()
self.node_idxs = node_idxs # kwargs.get("node_idxs")
self.node_idxs = node_idxs # kwargs.get("node_idxs")
assert isinstance(self.node_idxs, list)
self.n_perturbations = kwargs.get("n_perturbations")
self.perturb_features = kwargs.get("perturb_features")
Expand All @@ -262,8 +327,13 @@ def __init__(self,node_idxs, **kwargs):
self.n_influencers = kwargs.get("n_influencers")
self.attacker = NettackEvasionAttacker(0, **kwargs)

def attack(self, model_manager, gen_dataset, mask_tensor):
def attack(
self,
model_manager: Type,
gen_dataset: DatasetManager,
mask_tensor: torch.Tensor
) -> DatasetManager:
for node_idx in self.node_idxs:
self.attacker.node_idx = node_idx
gen_dataset = self.attacker.attack(model_manager, gen_dataset, mask_tensor)
return gen_dataset
return gen_dataset
18 changes: 14 additions & 4 deletions src/attacks/mi_attacks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from attacks.attack_base import Attacker


class MIAttacker(Attacker):
def __init__(self, **kwargs):
class MIAttacker(
Attacker
):
def __init__(
self,
**kwargs
):
super().__init__()


class EmptyMIAttacker(MIAttacker):
class EmptyMIAttacker(
MIAttacker
):
name = "EmptyMIAttacker"

def attack(self, **kwargs):
def attack(
self,
**kwargs
):
pass
Loading

0 comments on commit 0b82795

Please sign in to comment.