diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index eb8cf4e..8b386d5 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -232,6 +232,17 @@ def test_attack_defense(): } ) + autoencoder_evasion_defense_config = ConfigPattern( + _class_name="AutoEncoderDefender", + _import_path=EVASION_DEFENSE_PARAMETERS_PATH, + _config_class="EvasionDefenseConfig", + _config_kwargs={ + "hidden_dim": 300, + "bottleneck_dim": 100, + "reconstruction_loss_weight": 0.1, + } + ) + fgsm_evasion_attack_config0 = ConfigPattern( _class_name="FGSM", _import_path=EVASION_ATTACK_PARAMETERS_PATH, @@ -253,7 +264,7 @@ def test_attack_defense(): # gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config) # gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config) gnn_model_manager.set_evasion_attacker(evasion_attack_config=fgsm_evasion_attack_config) - gnn_model_manager.set_evasion_defender(evasion_defense_config=distillation_evasion_defense_config) + gnn_model_manager.set_evasion_defender(evasion_defense_config=autoencoder_evasion_defense_config) warnings.warn("Start training") dataset.train_test_split() diff --git a/metainfo/evasion_defense_parameters.json b/metainfo/evasion_defense_parameters.json index e9ff771..da30af1 100644 --- a/metainfo/evasion_defense_parameters.json +++ b/metainfo/evasion_defense_parameters.json @@ -10,6 +10,11 @@ "DistillationDefender": { "temperature": ["temperature", "float", 5.0, {"min": 1, "step": 0.01}, "?"] }, + "AutoEncoderDefender": { + "hidden_dim": ["hidden_dim", "int", 5, {"min": 3, "step": 1}, "?"], + "bottleneck_dim": ["bottleneck_dim", "int", 3, {"min": 1, "step": 1}, "?"], + "reconstruction_loss_weight": ["reconstruction_loss_weight", "float", 0.1, {"min": 0.0001, "step": 0.01}, "?"] + }, "AdvTraining": { "attack_name": ["attack_name", "str", "FGSM", {}, "?"] } diff --git a/src/defense/evasion_defense.py b/src/defense/evasion_defense.py index 2eea79e..003c7b2 100644 --- a/src/defense/evasion_defense.py +++ b/src/defense/evasion_defense.py @@ -234,3 +234,95 @@ def post_batch( outputs = model_manager.gnn(self.perturbed_gen_dataset.data.x, self.perturbed_gen_dataset.data.edge_index) loss_loc = model_manager.loss_function(outputs, batch.y) return {"loss": loss + loss_loc} + + +class SimpleAutoEncoder( + torch.nn.Module +): + def __init__( + self, + input_dim: int, + hidden_dim: int, + bottleneck_dim: int, + device: str = 'cpu' + ): + """ + """ + super(SimpleAutoEncoder, self).__init__() + self.device = device + self.encoder = torch.nn.Sequential( + torch.nn.Linear(input_dim, hidden_dim), + torch.nn.ReLU(), + torch.nn.Linear(hidden_dim, bottleneck_dim), + torch.nn.ReLU() + ).to(self.device) + self.decoder = torch.nn.Sequential( + torch.nn.Linear(bottleneck_dim, hidden_dim), + torch.nn.ReLU(), + torch.nn.Linear(hidden_dim, input_dim) + ).to(self.device) + + def forward( + self, + x: torch.Tensor + ): + x = x.to(self.device) + encoded = self.encoder(x) + decoded = self.decoder(encoded) + return decoded + + +class AutoEncoderDefender( + EvasionDefender +): + name = "AutoEncoderDefender" + + def __init__( + self, + hidden_dim: int, + bottleneck_dim: int, + reconstruction_loss_weight=0.1, + ): + """ + """ + super().__init__() + self.autoencoder = None + self.hidden_dim = hidden_dim + self.bottleneck_dim = bottleneck_dim + self.reconstruction_loss_weight = reconstruction_loss_weight + + def post_batch(self, model_manager, batch, loss): + """ + """ + model_manager.gnn.eval() + if self.autoencoder is None: + self.init_autoencoder(batch.x) + self.autoencoder.train() + reconstructed_x = self.autoencoder(batch.x) + reconstruction_loss = torch.nn.functional.mse_loss(reconstructed_x, batch.x) + modified_loss = loss + self.reconstruction_loss_weight * reconstruction_loss.detach().clone() + autoencoder_optimizer = torch.optim.Adam(self.autoencoder.parameters(), lr=0.001) + autoencoder_optimizer.zero_grad() + reconstruction_loss.backward() + autoencoder_optimizer.step() + return {"loss": modified_loss} + + def denoise_with_autoencoder(self, x): + """ + """ + self.autoencoder.eval() + with torch.no_grad(): + x_denoised = self.autoencoder(x) + return x_denoised + + def init_autoencoder( + self, + x + ): + self.autoencoder = SimpleAutoEncoder( + input_dim=x.shape[1], + bottleneck_dim=self.bottleneck_dim, + hidden_dim=self.hidden_dim, + device=x.device + ) +