From 9ec13bfba8d6180090f8650dcae3aced0a641349 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Tue, 23 Jul 2024 18:15:25 +0300 Subject: [PATCH] fix rename problem --- src/models_builder/gnn_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index 9636f67..975aa5e 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -803,9 +803,9 @@ def train_1_step(self, gen_dataset): def train_on_batch(self, batch, task_type=None): if self.mi_defender: - self.mi_defender.pre_epoch() + self.mi_defender.pre_batch() if self.evasion_defender: - self.evasion_defender.pre_epoch(model_manager=self, batch=batch) + self.evasion_defender.pre_batch(model_manager=self, batch=batch) loss = None if task_type == "single-graph": self.optimizer.zero_grad() @@ -840,7 +840,7 @@ def train_on_batch(self, batch, task_type=None): else: raise ValueError("Unsupported task type") if self.mi_defender: - self.mi_defender.post_epoch() + self.mi_defender.post_batch() evasion_defender_dict = None if self.evasion_defender: evasion_defender_dict = self.evasion_defender.post_batch(