Skip to content

Commit

Permalink
+
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Nov 25, 2024
1 parent aeed176 commit 6c694c3
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions src/defense/evasion_defense.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(
self,
hidden_dim: int,
bottleneck_dim: int,
reconstruction_loss_weight=0.1,
reconstruction_loss_weight: float = 0.1,
):
"""
"""
Expand All @@ -291,7 +291,12 @@ def __init__(
self.bottleneck_dim = bottleneck_dim
self.reconstruction_loss_weight = reconstruction_loss_weight

def post_batch(self, model_manager, batch, loss):
def post_batch(
self,
model_manager,
batch,
loss: torch.Tensor
) -> dict:
"""
"""
model_manager.gnn.eval()
Expand All @@ -307,7 +312,10 @@ def post_batch(self, model_manager, batch, loss):
autoencoder_optimizer.step()
return {"loss": modified_loss}

def denoise_with_autoencoder(self, x):
def denoise_with_autoencoder(
self,
x: torch.Tensor
) -> torch.Tensor:
"""
"""
self.autoencoder.eval()
Expand All @@ -317,12 +325,11 @@ def denoise_with_autoencoder(self, x):

def init_autoencoder(
self,
x
):
x: torch.Tensor
) -> None:
self.autoencoder = SimpleAutoEncoder(
input_dim=x.shape[1],
bottleneck_dim=self.bottleneck_dim,
hidden_dim=self.hidden_dim,
device=x.device
)

0 comments on commit 6c694c3

Please sign in to comment.