Skip to content

Commit

Permalink
ProtGNN fix for new train
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeratt committed Sep 5, 2024
1 parent 31a290f commit e7053aa
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def setUp(self) -> None:
labeling='binary',
dataset_ver_ind=0)
)
self.gen_dataset_mg_small.train_test_split(percent_train_class=0.6, percent_test_class=0.4)
self.gen_dataset_mg_small.train_test_split(percent_train_class=0.6, percent_test_class=0.2)
self.results_dataset_path_mg_small = self.gen_dataset_mg_small.results_dir
self.default_config = ModelModificationConfig(
model_ver_ind=0,
Expand Down Expand Up @@ -124,7 +124,8 @@ def test_model_on_multiple_graph(self):
)

gnn_mm_mg_small.train_model(gen_dataset=self.gen_dataset_mg_small, steps=100,
metrics=[Metric("F1", mask='test')])
metrics=[Metric("F1", mask='val'),
Metric("F1", mask='test')])
metric_loc = gnn_mm_mg_small.evaluate_model(
gen_dataset=self.gen_dataset_mg_small, metrics=[Metric("F1", mask='test', average='macro')])
print(metric_loc)
Expand Down

0 comments on commit e7053aa

Please sign in to comment.