Skip to content

Commit

Permalink
add AutoEncoderDefender
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Nov 25, 2024
1 parent 9e51ade commit aeed176
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 1 deletion.
13 changes: 12 additions & 1 deletion experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ def test_attack_defense():
}
)

autoencoder_evasion_defense_config = ConfigPattern(
_class_name="AutoEncoderDefender",
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
_config_class="EvasionDefenseConfig",
_config_kwargs={
"hidden_dim": 300,
"bottleneck_dim": 100,
"reconstruction_loss_weight": 0.1,
}
)

fgsm_evasion_attack_config0 = ConfigPattern(
_class_name="FGSM",
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
Expand All @@ -253,7 +264,7 @@ def test_attack_defense():
# gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config)
# gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config)
gnn_model_manager.set_evasion_attacker(evasion_attack_config=fgsm_evasion_attack_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=distillation_evasion_defense_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=autoencoder_evasion_defense_config)

warnings.warn("Start training")
dataset.train_test_split()
Expand Down
5 changes: 5 additions & 0 deletions metainfo/evasion_defense_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
"DistillationDefender": {
"temperature": ["temperature", "float", 5.0, {"min": 1, "step": 0.01}, "?"]
},
"AutoEncoderDefender": {
"hidden_dim": ["hidden_dim", "int", 5, {"min": 3, "step": 1}, "?"],
"bottleneck_dim": ["bottleneck_dim", "int", 3, {"min": 1, "step": 1}, "?"],
"reconstruction_loss_weight": ["reconstruction_loss_weight", "float", 0.1, {"min": 0.0001, "step": 0.01}, "?"]
},
"AdvTraining": {
"attack_name": ["attack_name", "str", "FGSM", {}, "?"]
}
Expand Down
92 changes: 92 additions & 0 deletions src/defense/evasion_defense.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,95 @@ def post_batch(
outputs = model_manager.gnn(self.perturbed_gen_dataset.data.x, self.perturbed_gen_dataset.data.edge_index)
loss_loc = model_manager.loss_function(outputs, batch.y)
return {"loss": loss + loss_loc}


class SimpleAutoEncoder(
torch.nn.Module
):
def __init__(
self,
input_dim: int,
hidden_dim: int,
bottleneck_dim: int,
device: str = 'cpu'
):
"""
"""
super(SimpleAutoEncoder, self).__init__()
self.device = device
self.encoder = torch.nn.Sequential(
torch.nn.Linear(input_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, bottleneck_dim),
torch.nn.ReLU()
).to(self.device)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(bottleneck_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, input_dim)
).to(self.device)

def forward(
self,
x: torch.Tensor
):
x = x.to(self.device)
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded


class AutoEncoderDefender(
EvasionDefender
):
name = "AutoEncoderDefender"

def __init__(
self,
hidden_dim: int,
bottleneck_dim: int,
reconstruction_loss_weight=0.1,
):
"""
"""
super().__init__()
self.autoencoder = None
self.hidden_dim = hidden_dim
self.bottleneck_dim = bottleneck_dim
self.reconstruction_loss_weight = reconstruction_loss_weight

def post_batch(self, model_manager, batch, loss):
"""
"""
model_manager.gnn.eval()
if self.autoencoder is None:
self.init_autoencoder(batch.x)
self.autoencoder.train()
reconstructed_x = self.autoencoder(batch.x)
reconstruction_loss = torch.nn.functional.mse_loss(reconstructed_x, batch.x)
modified_loss = loss + self.reconstruction_loss_weight * reconstruction_loss.detach().clone()
autoencoder_optimizer = torch.optim.Adam(self.autoencoder.parameters(), lr=0.001)
autoencoder_optimizer.zero_grad()
reconstruction_loss.backward()
autoencoder_optimizer.step()
return {"loss": modified_loss}

def denoise_with_autoencoder(self, x):
"""
"""
self.autoencoder.eval()
with torch.no_grad():
x_denoised = self.autoencoder(x)
return x_denoised

def init_autoencoder(
self,
x
):
self.autoencoder = SimpleAutoEncoder(
input_dim=x.shape[1],
bottleneck_dim=self.bottleneck_dim,
hidden_dim=self.hidden_dim,
device=x.device
)

0 comments on commit aeed176

Please sign in to comment.