From a22c1d140b1a7d4c624605bf03e39088d83d97d7 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Mon, 28 Oct 2024 19:34:27 +0300 Subject: [PATCH] fix load --- experiments/GNNExplainerGEOM_exp_example.py | 13 +++++----- src/aux/declaration.py | 14 ++++++++++- src/models_builder/gnn_models.py | 27 +++++++++++++++------ 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/experiments/GNNExplainerGEOM_exp_example.py b/experiments/GNNExplainerGEOM_exp_example.py index 6b5771c..becf7fe 100644 --- a/experiments/GNNExplainerGEOM_exp_example.py +++ b/experiments/GNNExplainerGEOM_exp_example.py @@ -23,13 +23,12 @@ def geom_GNNExplainer_test(): gcn2 = model_configs_zoo(dataset=dataset, model_name='gcn_gcn') - gnn_model_manager_config = ModelManagerConfig(**{ - "mask_features": [], - "optimizer": { - "_class_name": "Adam", - "_config_kwargs": {}, - } - }) + gnn_model_manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [] + } + ) steps_epochs = 200 gnn_model_manager = FrameworkGNNModelManager( diff --git a/src/aux/declaration.py b/src/aux/declaration.py index 504af88..7988911 100644 --- a/src/aux/declaration.py +++ b/src/aux/declaration.py @@ -183,6 +183,12 @@ def declare_model_by_config( GNNModelManager_hash: str, model_ver_ind: int, gnn_name: str, + mi_defense_hash: str, + evasion_defense_hash: str, + poison_defense_hash: str, + mi_attack_hash: str, + evasion_attack_hash: str, + poison_attack_hash: str, epochs=None, ): """ @@ -202,8 +208,14 @@ def declare_model_by_config( obj_info = { "gnn": gnn_name, "gnn_model_manager": GNNModelManager_hash, - "epochs": str(epochs), "model_ver_ind": str(model_ver_ind), + "poison_attacker": str(poison_attack_hash), + "poison_defender": str(poison_defense_hash), + "evasion_defender": str(evasion_defense_hash), + "mi_defender": str(mi_defense_hash), + "evasion_attacker": str(evasion_attack_hash), + "mi_attacker": str(mi_attack_hash), + "epochs": str(epochs), } path, files_paths = Declare.obj_info_to_path(previous_path=path, what_save=what_save, diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index 205bf7c..1fad619 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -15,7 +15,8 @@ from aux.configs import ModelManagerConfig, ModelModificationConfig, ModelConfig, CONFIG_CLASS_NAME from aux.data_info import UserCodeInfo -from aux.utils import import_by_name, all_subclasses, FRAMEWORK_PARAMETERS_PATH, model_managers_info_by_names_list, hash_data_sha256, \ +from aux.utils import import_by_name, all_subclasses, FRAMEWORK_PARAMETERS_PATH, model_managers_info_by_names_list, \ + hash_data_sha256, \ TECHNICAL_PARAMETER_KEY, IMPORT_INFO_KEY, OPTIMIZERS_PARAMETERS_PATH, FUNCTIONS_PARAMETERS_PATH from aux.declaration import Declare from explainers.explainer import ProgressBar @@ -247,7 +248,13 @@ class variables model_ver_ind=kwargs.get('model_ver_ind') if 'model_ver_ind' in kwargs else self.modification.model_ver_ind, epochs=self.modification.epochs, - gnn_name=self.gnn.get_hash() + gnn_name=self.gnn.get_hash(), + mi_defense_hash=self.mi_defense_config.hash_for_config(), + evasion_defense_hash=self.evasion_defense_config.hash_for_config(), + poison_defense_hash=self.poison_defense_config.hash_for_config(), + mi_attack_hash=self.mi_attack_config.hash_for_config(), + evasion_attack_hash=self.evasion_attack_config.hash_for_config(), + poison_attack_hash=self.poison_attack_config.hash_for_config(), ) path = model_dir_path / 'model' else: @@ -556,6 +563,12 @@ def from_model_path(model_path, dataset_path, **kwargs): epochs=int(model_path['epochs']) if model_path['epochs'] != 'None' else None, model_ver_ind=int(model_path['model_ver_ind']), gnn_name=model_path['gnn'], + poison_attack_hash=model_path['poison_attacker'], + poison_defense_hash=model_path['poison_defender'], + evasion_defense_hash=model_path['evasion_defender'], + mi_defense_hash=model_path['mi_defender'], + evasion_attack_hash=model_path['evasion_attacker'], + mi_attack_hash=model_path['mi_attacker'], ) gnn_mm_file = files_paths[1] @@ -1182,7 +1195,7 @@ class ProtGNNModelManager(FrameworkGNNModelManager): "_class_import_info": ["torch.optim"], "_config_kwargs": {}, }, - #FUNCTIONS_PARAMETERS_PATH, + # FUNCTIONS_PARAMETERS_PATH, "loss_function": { "_config_class": "Config", "_class_name": "CrossEntropyLoss", @@ -1201,7 +1214,7 @@ def __init__(self, gnn=None, dataset_path=None, **kwargs): _config_obj = getattr(self.manager_config, CONFIG_OBJ) self.clst = _config_obj.clst self.sep = _config_obj.sep - #lr = _config_obj.lr + # lr = _config_obj.lr self.early_stopping_marker = _config_obj.early_stopping self.proj_epochs = _config_obj.proj_epochs self.warm_epoch = _config_obj.warm_epoch @@ -1318,7 +1331,8 @@ def before_epoch(self, gen_dataset): train_ind = [n for n, x in enumerate(gen_dataset.train_mask) if x] # Prototype projection if cur_step > self.proj_epochs and cur_step % self.proj_epochs == 0: - self.prot_layer.projection(self.gnn, gen_dataset.dataset, train_ind, gen_dataset.dataset.data, thrsh=self.prot_thrsh) + self.prot_layer.projection(self.gnn, gen_dataset.dataset, train_ind, gen_dataset.dataset.data, + thrsh=self.prot_thrsh) self.gnn.train() for p in self.gnn.parameters(): p.requires_grad = True @@ -1346,7 +1360,6 @@ def after_epoch(self, gen_dataset): self.early_stop_count = 0 self.gnn.best_prots = self.prot_layer.prototype_graphs - def early_stopping(self, train_loss, gen_dataset, metrics, steps): step = self.modification.epochs if self.is_best: @@ -1355,4 +1368,4 @@ def early_stopping(self, train_loss, gen_dataset, metrics, steps): self.early_stop_count += 1 last_projection = (step % self.proj_epochs == 0 and step + self.proj_epochs >= steps) - return self.early_stop_count >= self.early_stopping_marker or last_projection \ No newline at end of file + return self.early_stop_count >= self.early_stopping_marker or last_projection