Skip to content

Commit

Permalink
make better files in defense
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Nov 21, 2024
1 parent 87854b3 commit b326cb0
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 39 deletions.
18 changes: 15 additions & 3 deletions src/defense/defense_base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
from typing import Type

from base.datasets_processing import DatasetManager


class Defender:
name = "Defender"

def __init__(self):
def __init__(
self
):
pass

def defense_diff(self):
def defense_diff(
self
):
pass

@staticmethod
def check_availability(gen_dataset, model_manager):
def check_availability(
gen_dataset: DatasetManager,
model_manager: Type
):
return False


84 changes: 63 additions & 21 deletions src/defense/evasion_defense.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Type

import torch

from defense.defense_base import Defender
from src.aux.utils import import_by_name
from src.aux.configs import ModelModificationConfig, ConfigPattern
from src.aux.configs import ModelModificationConfig, ConfigPattern, EvasionAttackConfig
from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
EVASION_DEFENSE_PARAMETERS_PATH
from attacks.evasion_attacks import FGSMAttacker
Expand All @@ -12,24 +14,43 @@
import copy


class EvasionDefender(Defender):
def __init__(self, **kwargs):
class EvasionDefender(
Defender
):
def __init__(
self,
**kwargs
):
super().__init__()

def pre_batch(self, **kwargs):
def pre_batch(
self,
**kwargs
):
pass

def post_batch(self, **kwargs):
def post_batch(
self,
**kwargs
):
pass


class EmptyEvasionDefender(EvasionDefender):
class EmptyEvasionDefender(
EvasionDefender
):
name = "EmptyEvasionDefender"

def pre_batch(self, **kwargs):
def pre_batch(
self,
**kwargs
):
pass

def post_batch(self, **kwargs):
def post_batch(
self,
**kwargs
):
pass


Expand All @@ -52,27 +73,39 @@ def post_batch(self, model_manager, batch, loss, **kwargs):


# TODO Kirill, add code in pre_batch
class QuantizationDefender(EvasionDefender):
class QuantizationDefender(
EvasionDefender
):
name = "QuantizationDefender"

def __init__(self, qbit=8):
def __init__(
self,
qbit: int = 8
):
super().__init__()
self.regularization_strength = qbit

def pre_batch(self, **kwargs):
def pre_batch(
self,
**kwargs
):
# TODO Kirill
pass


class DataWrap:
def __init__(self, batch) -> None:
self.data = batch
self.dataset = self


class AdvTraining(EvasionDefender):
class AdvTraining(
EvasionDefender
):
# TODO Kirill, rewrite
name = "AdvTraining"

def __init__(self, attack_name=None, attack_config=None, attack_type=None, device='cpu'):
def __init__(
self,
attack_name: str = None,
attack_config: EvasionAttackConfig = None,
attack_type: str = None,
device: str = 'cpu'
):
super().__init__()
assert device is not None, "Please specify 'device'!"
if not attack_config:
Expand Down Expand Up @@ -118,7 +151,11 @@ def __init__(self, attack_name=None, attack_config=None, attack_type=None, devic
else:
raise KeyError(f"There is no {self.attack_config._class_name} class")

def pre_batch(self, model_manager, batch):
def pre_batch(
self,
model_manager: Type,
batch
):
super().pre_batch(model_manager=model_manager, batch=batch)
self.perturbed_gen_dataset = data.Data()
self.perturbed_gen_dataset.data = copy.deepcopy(batch)
Expand All @@ -129,7 +166,12 @@ def pre_batch(self, model_manager, batch):
gen_dataset=self.perturbed_gen_dataset,
mask_tensor=self.perturbed_gen_dataset.data.train_mask)

def post_batch(self, model_manager, batch, loss) -> dict:
def post_batch(
self,
model_manager: Type,
batch,
loss: torch.Tensor
) -> dict:
super().post_batch(model_manager=model_manager, batch=batch, loss=loss)
# Output on perturbed data
outputs = model_manager.gnn(self.perturbed_gen_dataset.data.x, self.perturbed_gen_dataset.data.edge_index)
Expand Down
29 changes: 23 additions & 6 deletions src/defense/mi_defense.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
from defense.defense_base import Defender


class MIDefender(Defender):
def __init__(self, **kwargs):
class MIDefender(
Defender
):
def __init__(
self,
**kwargs
):
super().__init__()

def pre_batch(self, **kwargs):
def pre_batch(
self,
**kwargs
):
pass

def post_batch(self, **kwargs):
def post_batch(
self,
**kwargs
):
pass


class EmptyMIDefender(MIDefender):
name = "EmptyMIDefender"

def pre_batch(self, **kwargs):
def pre_batch(
self,
**kwargs
):
pass

def post_batch(self, **kwargs):
def post_batch(
self,
**kwargs
):
pass
42 changes: 33 additions & 9 deletions src/defense/poison_defense.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,43 @@
import numpy as np

from base.datasets_processing import DatasetManager
from defense.defense_base import Defender


class PoisonDefender(Defender):
def __init__(self, **kwargs):
class PoisonDefender(
Defender
):
def __init__(
self,
**kwargs
):
super().__init__()

def defense(self, **kwargs):
def defense(
self,
**kwargs
):
pass


class BadRandomPoisonDefender(PoisonDefender):
class BadRandomPoisonDefender(
PoisonDefender
):
name = "BadRandomPoisonDefender"

def __init__(self, n_edges_percent=0.1):
def __init__(
self,
n_edges_percent: float = 0.1
):
self.defense_diff = None

super().__init__()
self.n_edges_percent = n_edges_percent

def defense(self, gen_dataset):
def defense(
self,
gen_dataset: DatasetManager
) -> DatasetManager:
edge_index = gen_dataset.data.edge_index
random_indices = np.random.choice(
edge_index.shape[1],
Expand All @@ -35,12 +52,19 @@ def defense(self, gen_dataset):
self.defense_diff = edge_index_diff
return gen_dataset

def defense_diff(self):
def defense_diff(
self
):
return self.defense_diff


class EmptyPoisonDefender(PoisonDefender):
class EmptyPoisonDefender(
PoisonDefender
):
name = "EmptyPoisonDefender"

def defense(self, gen_dataset):
def defense(
self,
gen_dataset:DatasetManager
) -> DatasetManager:
return gen_dataset

0 comments on commit b326cb0

Please sign in to comment.