Skip to content

Commit

Permalink
fix load
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Oct 28, 2024
1 parent c64e606 commit a22c1d1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
13 changes: 6 additions & 7 deletions experiments/GNNExplainerGEOM_exp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ def geom_GNNExplainer_test():

gcn2 = model_configs_zoo(dataset=dataset, model_name='gcn_gcn')

gnn_model_manager_config = ModelManagerConfig(**{
"mask_features": [],
"optimizer": {
"_class_name": "Adam",
"_config_kwargs": {},
}
})
gnn_model_manager_config = ConfigPattern(
_config_class="ModelManagerConfig",
_config_kwargs={
"mask_features": []
}
)

steps_epochs = 200
gnn_model_manager = FrameworkGNNModelManager(
Expand Down
14 changes: 13 additions & 1 deletion src/aux/declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def declare_model_by_config(
GNNModelManager_hash: str,
model_ver_ind: int,
gnn_name: str,
mi_defense_hash: str,
evasion_defense_hash: str,
poison_defense_hash: str,
mi_attack_hash: str,
evasion_attack_hash: str,
poison_attack_hash: str,
epochs=None,
):
"""
Expand All @@ -202,8 +208,14 @@ def declare_model_by_config(
obj_info = {
"gnn": gnn_name,
"gnn_model_manager": GNNModelManager_hash,
"epochs": str(epochs),
"model_ver_ind": str(model_ver_ind),
"poison_attacker": str(poison_attack_hash),
"poison_defender": str(poison_defense_hash),
"evasion_defender": str(evasion_defense_hash),
"mi_defender": str(mi_defense_hash),
"evasion_attacker": str(evasion_attack_hash),
"mi_attacker": str(mi_attack_hash),
"epochs": str(epochs),
}

path, files_paths = Declare.obj_info_to_path(previous_path=path, what_save=what_save,
Expand Down
27 changes: 20 additions & 7 deletions src/models_builder/gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from aux.configs import ModelManagerConfig, ModelModificationConfig, ModelConfig, CONFIG_CLASS_NAME
from aux.data_info import UserCodeInfo
from aux.utils import import_by_name, all_subclasses, FRAMEWORK_PARAMETERS_PATH, model_managers_info_by_names_list, hash_data_sha256, \
from aux.utils import import_by_name, all_subclasses, FRAMEWORK_PARAMETERS_PATH, model_managers_info_by_names_list, \
hash_data_sha256, \
TECHNICAL_PARAMETER_KEY, IMPORT_INFO_KEY, OPTIMIZERS_PARAMETERS_PATH, FUNCTIONS_PARAMETERS_PATH
from aux.declaration import Declare
from explainers.explainer import ProgressBar
Expand Down Expand Up @@ -247,7 +248,13 @@ class variables
model_ver_ind=kwargs.get('model_ver_ind') if 'model_ver_ind' in kwargs else
self.modification.model_ver_ind,
epochs=self.modification.epochs,
gnn_name=self.gnn.get_hash()
gnn_name=self.gnn.get_hash(),
mi_defense_hash=self.mi_defense_config.hash_for_config(),
evasion_defense_hash=self.evasion_defense_config.hash_for_config(),
poison_defense_hash=self.poison_defense_config.hash_for_config(),
mi_attack_hash=self.mi_attack_config.hash_for_config(),
evasion_attack_hash=self.evasion_attack_config.hash_for_config(),
poison_attack_hash=self.poison_attack_config.hash_for_config(),
)
path = model_dir_path / 'model'
else:
Expand Down Expand Up @@ -556,6 +563,12 @@ def from_model_path(model_path, dataset_path, **kwargs):
epochs=int(model_path['epochs']) if model_path['epochs'] != 'None' else None,
model_ver_ind=int(model_path['model_ver_ind']),
gnn_name=model_path['gnn'],
poison_attack_hash=model_path['poison_attacker'],
poison_defense_hash=model_path['poison_defender'],
evasion_defense_hash=model_path['evasion_defender'],
mi_defense_hash=model_path['mi_defender'],
evasion_attack_hash=model_path['evasion_attacker'],
mi_attack_hash=model_path['mi_attacker'],
)

gnn_mm_file = files_paths[1]
Expand Down Expand Up @@ -1182,7 +1195,7 @@ class ProtGNNModelManager(FrameworkGNNModelManager):
"_class_import_info": ["torch.optim"],
"_config_kwargs": {},
},
#FUNCTIONS_PARAMETERS_PATH,
# FUNCTIONS_PARAMETERS_PATH,
"loss_function": {
"_config_class": "Config",
"_class_name": "CrossEntropyLoss",
Expand All @@ -1201,7 +1214,7 @@ def __init__(self, gnn=None, dataset_path=None, **kwargs):
_config_obj = getattr(self.manager_config, CONFIG_OBJ)
self.clst = _config_obj.clst
self.sep = _config_obj.sep
#lr = _config_obj.lr
# lr = _config_obj.lr
self.early_stopping_marker = _config_obj.early_stopping
self.proj_epochs = _config_obj.proj_epochs
self.warm_epoch = _config_obj.warm_epoch
Expand Down Expand Up @@ -1318,7 +1331,8 @@ def before_epoch(self, gen_dataset):
train_ind = [n for n, x in enumerate(gen_dataset.train_mask) if x]
# Prototype projection
if cur_step > self.proj_epochs and cur_step % self.proj_epochs == 0:
self.prot_layer.projection(self.gnn, gen_dataset.dataset, train_ind, gen_dataset.dataset.data, thrsh=self.prot_thrsh)
self.prot_layer.projection(self.gnn, gen_dataset.dataset, train_ind, gen_dataset.dataset.data,
thrsh=self.prot_thrsh)
self.gnn.train()
for p in self.gnn.parameters():
p.requires_grad = True
Expand Down Expand Up @@ -1346,7 +1360,6 @@ def after_epoch(self, gen_dataset):
self.early_stop_count = 0
self.gnn.best_prots = self.prot_layer.prototype_graphs


def early_stopping(self, train_loss, gen_dataset, metrics, steps):
step = self.modification.epochs
if self.is_best:
Expand All @@ -1355,4 +1368,4 @@ def early_stopping(self, train_loss, gen_dataset, metrics, steps):
self.early_stop_count += 1
last_projection = (step % self.proj_epochs == 0 and step + self.proj_epochs >= steps)

return self.early_stop_count >= self.early_stopping_marker or last_projection
return self.early_stop_count >= self.early_stopping_marker or last_projection

0 comments on commit a22c1d1

Please sign in to comment.