diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index ca7f642..72b6973 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -835,7 +835,7 @@ def train_1_step(self, gen_dataset): self.gnn.eval() return loss.cpu().detach().numpy().tolist() - def train_on_batch(self, batch, task_type): + def train_on_batch(self, batch, task_type=None): loss = None if task_type == "single-graph": self.optimizer.zero_grad()