From 9f6260715e131f3b69315550769bebc8dbd38074 Mon Sep 17 00:00:00 2001 From: misha-isp Date: Wed, 30 Oct 2024 18:40:44 +0300 Subject: [PATCH] fix loading models and explanations --- web_interface/back_front/block.py | 3 ++- web_interface/back_front/dataset_blocks.py | 6 +++--- web_interface/back_front/explainer_blocks.py | 8 +++----- web_interface/back_front/model_blocks.py | 10 ++++------ web_interface/back_front/utils.py | 14 +++++++++++++- 5 files changed, 25 insertions(+), 16 deletions(-) diff --git a/web_interface/back_front/block.py b/web_interface/back_front/block.py index 143f45f..eca8aa3 100644 --- a/web_interface/back_front/block.py +++ b/web_interface/back_front/block.py @@ -78,7 +78,8 @@ def finalize(self): raise RuntimeError(f'Block[{self.name}] failed to finalize') def _finalize(self): - """ Returns True or False + """ Checks whether the config is correct to create the object. + Returns True if OK or False. # TODO can we send to front errors to be fixed? """ raise NotImplementedError diff --git a/web_interface/back_front/dataset_blocks.py b/web_interface/back_front/dataset_blocks.py index 2908f47..16393de 100644 --- a/web_interface/back_front/dataset_blocks.py +++ b/web_interface/back_front/dataset_blocks.py @@ -5,7 +5,7 @@ from base.datasets_processing import DatasetManager, GeneralDataset from aux.configs import DatasetConfig, DatasetVarConfig from web_interface.back_front.block import Block -from web_interface.back_front.utils import json_dumps +from web_interface.back_front.utils import json_dumps, get_config_keys class DatasetBlock(Block): @@ -18,7 +18,7 @@ def _init(self): pass def _finalize(self): - if not (len(self._config.keys()) == 3): # TODO better check + if set(get_config_keys("data_root")) != set(self._config.keys()): return False self.dataset_config = DatasetConfig(**self._config) @@ -68,7 +68,7 @@ def _init(self, dataset: GeneralDataset): return self.gen_dataset.info.to_dict() def _finalize(self): - if not (len(self._config.keys()) == 3): # TODO better check + if set(get_config_keys("data_prepared")) != set(self._config.keys()): return False self.dataset_var_config = DatasetVarConfig(**self._config) diff --git a/web_interface/back_front/explainer_blocks.py b/web_interface/back_front/explainer_blocks.py index 61b64d6..96acee4 100644 --- a/web_interface/back_front/explainer_blocks.py +++ b/web_interface/back_front/explainer_blocks.py @@ -11,7 +11,7 @@ from explainers.explainers_manager import FrameworkExplainersManager from models_builder.gnn_models import GNNModelManager from web_interface.back_front.block import Block, WrapperBlock -from web_interface.back_front.utils import json_loads +from web_interface.back_front.utils import json_loads, get_config_keys class ExplainerWBlock(WrapperBlock): @@ -46,8 +46,8 @@ def _init(self, gen_dataset: GeneralDataset, gmm: GNNModelManager): # return self.get_index() def _finalize(self): - # if 1: # TODO better check - # return False + if set(get_config_keys("explanations")) != set(self._config.keys()): + return False self.explainer_path = self._config return True @@ -57,8 +57,6 @@ def _submit(self): explainer_path=self.explainer_path) modification_config = ExplainerModificationConfig( explainer_ver_ind=self.explainer_path["explainer_ver_ind"], - # FIXME Kirill front attack - explainer_attack_type=self.explainer_path["explainer_attack_type"] ) from explainers.explainers_manager import FrameworkExplainersManager diff --git a/web_interface/back_front/model_blocks.py b/web_interface/back_front/model_blocks.py index bda0c7f..310d2bb 100644 --- a/web_interface/back_front/model_blocks.py +++ b/web_interface/back_front/model_blocks.py @@ -8,13 +8,14 @@ from aux.data_info import UserCodeInfo, DataInfo from aux.declaration import Declare from aux.prefix_storage import PrefixStorage -from aux.utils import import_by_name, model_managers_info_by_names_list, GRAPHS_DIR, TECHNICAL_PARAMETER_KEY, \ +from aux.utils import import_by_name, model_managers_info_by_names_list, GRAPHS_DIR, \ + TECHNICAL_PARAMETER_KEY, \ IMPORT_INFO_KEY from base.datasets_processing import GeneralDataset, VisiblePart from models_builder.gnn_constructor import FrameworkGNNConstructor from models_builder.gnn_models import ModelManagerConfig, GNNModelManager, Metric from web_interface.back_front.block import Block, WrapperBlock -from web_interface.back_front.utils import WebInterfaceError, json_dumps +from web_interface.back_front.utils import WebInterfaceError, json_dumps, get_config_keys class ModelWBlock(WrapperBlock): @@ -43,7 +44,7 @@ def _init(self, gen_dataset: GeneralDataset): return self.get_index() def _finalize(self): - if not (len(self._config.keys()) == 5): # TODO better check + if set(get_config_keys("models")) != set(self._config.keys()): return False self.model_path = self._config @@ -178,9 +179,6 @@ def _init(self, gen_dataset: GeneralDataset, gnn): return mm_info def _finalize(self): - # if 1: # TODO better check - # return False - self.klass = self._config.pop("class") self.model_manager_config = ModelManagerConfig(**self._config) return True diff --git a/web_interface/back_front/utils.py b/web_interface/back_front/utils.py index 5255d52..75d127c 100644 --- a/web_interface/back_front/utils.py +++ b/web_interface/back_front/utils.py @@ -5,6 +5,8 @@ import numpy as np +from aux.utils import SAVE_DIR_STRUCTURE_PATH + class WebInterfaceError(Exception): def __init__(self, *args): @@ -133,4 +135,14 @@ def parser(arg): arg[key] = c[value] return arg - return json.loads(string, object_hook=parser) \ No newline at end of file + return json.loads(string, object_hook=parser) + + +def get_config_keys(object_type): + """ Get a list of keys for a config describing an object of the specified type. + """ + with open(SAVE_DIR_STRUCTURE_PATH) as f: + save_dir_structure = json.loads(f.read())[object_type] + + return [k for k, v in save_dir_structure.items() if v["add_key_name_flag"] is not None] +