From 31a290fbd37c93c872088bbcf0539a361b960cd4 Mon Sep 17 00:00:00 2001 From: Sazonov_ISP Date: Thu, 5 Sep 2024 12:16:31 +0300 Subject: [PATCH] ProtGNN fix for new train. WIP --- src/models_builder/gnn_models.py | 50 +++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index 975aa5e..a459717 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -1152,14 +1152,14 @@ class ProtGNNModelManager(FrameworkGNNModelManager): _config_class="ModelManagerConfig", _config_kwargs={ "mask_features": [], - # "optimizer": { - # "_config_class": "Config", - # "_class_name": "Adam", - # "_import_path": OPTIMIZERS_PARAMETERS_PATH, - # "_class_import_info": ["torch.optim"], - # "_config_kwargs": {}, - # }, - # FUNCTIONS_PARAMETERS_PATH, + "optimizer": { + "_config_class": "Config", + "_class_name": "Adam", + "_import_path": OPTIMIZERS_PARAMETERS_PATH, + "_class_import_info": ["torch.optim"], + "_config_kwargs": {}, + }, + #FUNCTIONS_PARAMETERS_PATH, "loss_function": { "_config_class": "Config", "_class_name": "CrossEntropyLoss", @@ -1234,7 +1234,7 @@ def evaluate_model(self, gen_dataset, metrics): return metrics_values - def train_full(self, gen_dataset, steps=None, metrics=None): + def train_full(self, gen_dataset, steps=None, metrics=None, pbar=None): """ Train ProtGNN model for Graph classification """ @@ -1294,6 +1294,9 @@ def train_full(self, gen_dataset, steps=None, metrics=None): best_prots = prot_layer.prototype_graphs # data_indices = train_loader.dataset.indices for step in range(steps): + self.before_epoch(gen_dataset) + print("epoch", self.modification.epochs) + acc = [] precision = [] recall = [] @@ -1315,8 +1318,8 @@ def train_full(self, gen_dataset, steps=None, metrics=None): p.requires_grad = True for batch in train_loader: - min_distances = self.gnn.min_distances logits = self.gnn(batch.x, batch.edge_index, batch.batch) + min_distances = self.gnn.min_distances loss = self.loss_function(logits, batch.y) # cluster loss prot_layer.prototype_class_identity = prot_layer.prototype_class_identity @@ -1421,6 +1424,16 @@ def train_full(self, gen_dataset, steps=None, metrics=None): """ self.modification.epochs = step + 1 + self.after_epoch(gen_dataset) + early_stopping_flag = self.early_stopping(train_loss=np.average(loss_list), gen_dataset=gen_dataset, + metrics=metrics) + if self.socket: + self.report_results(train_loss=np.average(loss_list), gen_dataset=gen_dataset, + metrics=metrics) + pbar.update(1) + if early_stopping_flag: + break + print(f"The best validation accuracy is {best_acc}.") # report test msg # checkpoint = torch.load(os.path.join(ckpt_dir, f'{model_args.model_name}_best.pth')) @@ -1432,6 +1445,23 @@ def train_full(self, gen_dataset, steps=None, metrics=None): return best_acc + def train_complete(self, gen_dataset, steps=None, pbar=None, metrics=None, **kwargs): + print("TEST TEST TEST") + self.train_full(gen_dataset=gen_dataset, steps=steps, pbar=pbar, metrics=metrics) + # for _ in range(steps): + # self.before_epoch(gen_dataset) + # print("epoch", self.modification.epochs) + # train_loss = self.train_1_step(gen_dataset) + # self.after_epoch(gen_dataset) + # early_stopping_flag = self.early_stopping(train_loss=train_loss, gen_dataset=gen_dataset, + # metrics=metrics) + # if self.socket: + # self.report_results(train_loss=train_loss, gen_dataset=gen_dataset, + # metrics=metrics) + # pbar.update(1) + # if early_stopping_flag: + # break + def run_model(self, gen_dataset, mask='test', out='answers'): """ Run the model on a part of dataset specified with a mask.