From 2e4846dfd3ba3b16a69664825018ae2d63a2a36d Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Thu, 18 Jul 2024 14:39:02 +0300 Subject: [PATCH] add early_stopping --- src/models_builder/gnn_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index 72b6973..7653812 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -795,11 +795,16 @@ def train_complete(self, gen_dataset, steps=None, pbar=None, metrics=None, **kwa print("epoch", self.modification.epochs) train_loss = self.train_1_step(gen_dataset) self._after_epoch(gen_dataset) + self.early_stopping(train_loss=train_loss, gen_dataset=gen_dataset, + metrics=metrics) if self.socket: self.report_results(train_loss=train_loss, gen_dataset=gen_dataset, metrics=metrics) pbar.update(1) + def early_stopping(self, train_loss, gen_dataset, metrics): + pass + def train_1_step(self, gen_dataset): task_type = gen_dataset.domain() if self.mi_defender: