From 6c694c37917d4da5a5b9f71d6ebb469ef7d3cc90 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Mon, 25 Nov 2024 15:53:35 +0300 Subject: [PATCH] + --- src/defense/evasion_defense.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/defense/evasion_defense.py b/src/defense/evasion_defense.py index 003c7b2..52a9a0f 100644 --- a/src/defense/evasion_defense.py +++ b/src/defense/evasion_defense.py @@ -281,7 +281,7 @@ def __init__( self, hidden_dim: int, bottleneck_dim: int, - reconstruction_loss_weight=0.1, + reconstruction_loss_weight: float = 0.1, ): """ """ @@ -291,7 +291,12 @@ def __init__( self.bottleneck_dim = bottleneck_dim self.reconstruction_loss_weight = reconstruction_loss_weight - def post_batch(self, model_manager, batch, loss): + def post_batch( + self, + model_manager, + batch, + loss: torch.Tensor + ) -> dict: """ """ model_manager.gnn.eval() @@ -307,7 +312,10 @@ def post_batch(self, model_manager, batch, loss): autoencoder_optimizer.step() return {"loss": modified_loss} - def denoise_with_autoencoder(self, x): + def denoise_with_autoencoder( + self, + x: torch.Tensor + ) -> torch.Tensor: """ """ self.autoencoder.eval() @@ -317,12 +325,11 @@ def denoise_with_autoencoder(self, x): def init_autoencoder( self, - x - ): + x: torch.Tensor + ) -> None: self.autoencoder = SimpleAutoEncoder( input_dim=x.shape[1], bottleneck_dim=self.bottleneck_dim, hidden_dim=self.hidden_dim, device=x.device ) -