diff --git a/tests/models_test.py b/tests/models_test.py index 12dc78e..c4e117e 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -115,6 +115,26 @@ def test_combo_model_on_single_graph(self): sg_example_model_path = gnn_model_manager_sg_example.model_path_info() / 'model' gnn_model_manager_sg_example.load_model_executor(path=sg_example_model_path) + def test_combo_model_differ_acts_on_single_graph(self): + gat_gcn_sage_gcn_gcn = model_configs_zoo(dataset=self.gen_dataset_sg_example, model_name="gat_gcn_sage_gcn_gcn") + + gnn_model_manager_sg_example = FrameworkGNNModelManager( + gnn=gat_gcn_sage_gcn_gcn, + dataset_path=self.results_dataset_path_sg_example, + modification=self.default_config, + manager_config=self.manager_config, + ) + + gnn_model_manager_sg_example.train_model(gen_dataset=self.gen_dataset_sg_example, steps=50, + save_model_flag=True, + metrics=[Metric("F1", mask='test')]) + metric_loc = gnn_model_manager_sg_example.evaluate_model( + gen_dataset=self.gen_dataset_sg_example, metrics=[Metric("F1", mask='test', )]) + print(metric_loc) + + sg_example_model_path = gnn_model_manager_sg_example.model_path_info() / 'model' + gnn_model_manager_sg_example.load_model_executor(path=sg_example_model_path) + def test_model_on_multiple_graph(self): gin3_lin2_mg_small = model_configs_zoo(dataset=self.gen_dataset_mg_small, model_name='gin_gin_gin_lin_lin')