Skip to content

Commit

Permalink
tmp comment work and tests with Prot
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Jul 18, 2024
1 parent 146f328 commit b7bd4df
Showing 1 changed file with 39 additions and 38 deletions.
77 changes: 39 additions & 38 deletions tests/explainers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ def setUp(self) -> None:
save_model_flag=False,
metrics=[Metric("F1", mask='test')])

gin3_lin2_prot_mg_small = model_configs_zoo(
dataset=dataset_mg_small, model_name='gin_gin_gin_lin_lin_prot')
# TODO Kirill, tmp comment work and tests with Prot
# gin3_lin2_prot_mg_small = model_configs_zoo(
# dataset=dataset_mg_small, model_name='gin_gin_gin_lin_lin_prot')
gin3_lin1_mg_mutag = model_configs_zoo(
dataset=dataset_mg_mutag, model_name='gin_gin_gin_lin')

Expand Down Expand Up @@ -155,14 +156,14 @@ def setUp(self) -> None:
}
)

self.prot_gnn_mm_mg_small = ProtGNNModelManager(
gnn=gin3_lin2_prot_mg_small, dataset_path=results_dataset_path_mg_small,
# manager_config=gin3_lin2_mg_small_manager_config,
)
# self.prot_gnn_mm_mg_small = ProtGNNModelManager(
# gnn=gin3_lin2_prot_mg_small, dataset_path=results_dataset_path_mg_small,
# # manager_config=gin3_lin2_mg_small_manager_config,
# )
# TODO Misha use as training params: clst=clst, sep=sep, save_thrsh=save_thrsh, lr=lr

best_acc = self.prot_gnn_mm_mg_small.train_model(
gen_dataset=gen_dataset_mg_small, steps=100, metrics=[])
# best_acc = self.prot_gnn_mm_mg_small.train_model(
# gen_dataset=gen_dataset_mg_small, steps=100, metrics=[])

gin3_lin2_mg_small = model_configs_zoo(
dataset=gen_dataset_mg_small, model_name='gin_gin_gin_lin_lin')
Expand Down Expand Up @@ -326,36 +327,36 @@ def test_Zorro(self):
)
explainer_Zorro.conduct_experiment(explainer_run_config)

def test_ProtGNN(self):
warnings.warn("Start ProtGNN")
explainer_init_config = ConfigPattern(
_class_name="ProtGNN",
_import_path=EXPLAINERS_INIT_PARAMETERS_PATH,
_config_class="ExplainerInitConfig",
_config_kwargs={
}
)
explainer_run_config = ConfigPattern(
_config_class="ExplainerRunConfig",
_config_kwargs={
"mode": "global",
"kwargs": {
"_class_name": "ProtGNN",
"_import_path": EXPLAINERS_GLOBAL_RUN_PARAMETERS_PATH,
"_config_class": "Config",
"_config_kwargs": {

},
}
}
)
explainer_Prot = FrameworkExplainersManager(
init_config=explainer_init_config,
dataset=self.dataset_mg_small, gnn_manager=self.prot_gnn_mm_mg_small,
explainer_name='ProtGNN',
)

explainer_Prot.conduct_experiment(explainer_run_config)
# def test_ProtGNN(self):
# warnings.warn("Start ProtGNN")
# explainer_init_config = ConfigPattern(
# _class_name="ProtGNN",
# _import_path=EXPLAINERS_INIT_PARAMETERS_PATH,
# _config_class="ExplainerInitConfig",
# _config_kwargs={
# }
# )
# explainer_run_config = ConfigPattern(
# _config_class="ExplainerRunConfig",
# _config_kwargs={
# "mode": "global",
# "kwargs": {
# "_class_name": "ProtGNN",
# "_import_path": EXPLAINERS_GLOBAL_RUN_PARAMETERS_PATH,
# "_config_class": "Config",
# "_config_kwargs": {
#
# },
# }
# }
# )
# explainer_Prot = FrameworkExplainersManager(
# init_config=explainer_init_config,
# dataset=self.dataset_mg_small, gnn_manager=self.prot_gnn_mm_mg_small,
# explainer_name='ProtGNN',
# )
#
# explainer_Prot.conduct_experiment(explainer_run_config)

def test_GNNExpl_PYG_SG(self):
warnings.warn("Start GNNExplainer_PYG")
Expand Down

0 comments on commit b7bd4df

Please sign in to comment.