From e7053aa2ff7b29bfa3281d79a9ca076213384229 Mon Sep 17 00:00:00 2001 From: Sazonov_ISP Date: Thu, 5 Sep 2024 14:04:22 +0300 Subject: [PATCH] ProtGNN fix for new train --- tests/models_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models_test.py b/tests/models_test.py index f123d06..8217376 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -72,7 +72,7 @@ def setUp(self) -> None: labeling='binary', dataset_ver_ind=0) ) - self.gen_dataset_mg_small.train_test_split(percent_train_class=0.6, percent_test_class=0.4) + self.gen_dataset_mg_small.train_test_split(percent_train_class=0.6, percent_test_class=0.2) self.results_dataset_path_mg_small = self.gen_dataset_mg_small.results_dir self.default_config = ModelModificationConfig( model_ver_ind=0, @@ -124,7 +124,8 @@ def test_model_on_multiple_graph(self): ) gnn_mm_mg_small.train_model(gen_dataset=self.gen_dataset_mg_small, steps=100, - metrics=[Metric("F1", mask='test')]) + metrics=[Metric("F1", mask='val'), + Metric("F1", mask='test')]) metric_loc = gnn_mm_mg_small.evaluate_model( gen_dataset=self.gen_dataset_mg_small, metrics=[Metric("F1", mask='test', average='macro')]) print(metric_loc)