Skip to content

Commit

Permalink
fix loading models and explanations
Browse files Browse the repository at this point in the history
  • Loading branch information
mishadr committed Oct 30, 2024
1 parent d81b188 commit 9f62607
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 16 deletions.
3 changes: 2 additions & 1 deletion web_interface/back_front/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def finalize(self):
raise RuntimeError(f'Block[{self.name}] failed to finalize')

def _finalize(self):
""" Returns True or False
""" Checks whether the config is correct to create the object.
Returns True if OK or False.
# TODO can we send to front errors to be fixed?
"""
raise NotImplementedError
Expand Down
6 changes: 3 additions & 3 deletions web_interface/back_front/dataset_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from base.datasets_processing import DatasetManager, GeneralDataset
from aux.configs import DatasetConfig, DatasetVarConfig
from web_interface.back_front.block import Block
from web_interface.back_front.utils import json_dumps
from web_interface.back_front.utils import json_dumps, get_config_keys


class DatasetBlock(Block):
Expand All @@ -18,7 +18,7 @@ def _init(self):
pass

def _finalize(self):
if not (len(self._config.keys()) == 3): # TODO better check
if set(get_config_keys("data_root")) != set(self._config.keys()):
return False

self.dataset_config = DatasetConfig(**self._config)
Expand Down Expand Up @@ -68,7 +68,7 @@ def _init(self, dataset: GeneralDataset):
return self.gen_dataset.info.to_dict()

def _finalize(self):
if not (len(self._config.keys()) == 3): # TODO better check
if set(get_config_keys("data_prepared")) != set(self._config.keys()):
return False

self.dataset_var_config = DatasetVarConfig(**self._config)
Expand Down
8 changes: 3 additions & 5 deletions web_interface/back_front/explainer_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from explainers.explainers_manager import FrameworkExplainersManager
from models_builder.gnn_models import GNNModelManager
from web_interface.back_front.block import Block, WrapperBlock
from web_interface.back_front.utils import json_loads
from web_interface.back_front.utils import json_loads, get_config_keys


class ExplainerWBlock(WrapperBlock):
Expand Down Expand Up @@ -46,8 +46,8 @@ def _init(self, gen_dataset: GeneralDataset, gmm: GNNModelManager):
# return self.get_index()

def _finalize(self):
# if 1: # TODO better check
# return False
if set(get_config_keys("explanations")) != set(self._config.keys()):
return False

self.explainer_path = self._config
return True
Expand All @@ -57,8 +57,6 @@ def _submit(self):
explainer_path=self.explainer_path)
modification_config = ExplainerModificationConfig(
explainer_ver_ind=self.explainer_path["explainer_ver_ind"],
# FIXME Kirill front attack
explainer_attack_type=self.explainer_path["explainer_attack_type"]
)

from explainers.explainers_manager import FrameworkExplainersManager
Expand Down
10 changes: 4 additions & 6 deletions web_interface/back_front/model_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from aux.data_info import UserCodeInfo, DataInfo
from aux.declaration import Declare
from aux.prefix_storage import PrefixStorage
from aux.utils import import_by_name, model_managers_info_by_names_list, GRAPHS_DIR, TECHNICAL_PARAMETER_KEY, \
from aux.utils import import_by_name, model_managers_info_by_names_list, GRAPHS_DIR, \
TECHNICAL_PARAMETER_KEY, \
IMPORT_INFO_KEY
from base.datasets_processing import GeneralDataset, VisiblePart
from models_builder.gnn_constructor import FrameworkGNNConstructor
from models_builder.gnn_models import ModelManagerConfig, GNNModelManager, Metric
from web_interface.back_front.block import Block, WrapperBlock
from web_interface.back_front.utils import WebInterfaceError, json_dumps
from web_interface.back_front.utils import WebInterfaceError, json_dumps, get_config_keys


class ModelWBlock(WrapperBlock):
Expand Down Expand Up @@ -43,7 +44,7 @@ def _init(self, gen_dataset: GeneralDataset):
return self.get_index()

def _finalize(self):
if not (len(self._config.keys()) == 5): # TODO better check
if set(get_config_keys("models")) != set(self._config.keys()):
return False

self.model_path = self._config
Expand Down Expand Up @@ -178,9 +179,6 @@ def _init(self, gen_dataset: GeneralDataset, gnn):
return mm_info

def _finalize(self):
# if 1: # TODO better check
# return False

self.klass = self._config.pop("class")
self.model_manager_config = ModelManagerConfig(**self._config)
return True
Expand Down
14 changes: 13 additions & 1 deletion web_interface/back_front/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import numpy as np

from aux.utils import SAVE_DIR_STRUCTURE_PATH


class WebInterfaceError(Exception):
def __init__(self, *args):
Expand Down Expand Up @@ -133,4 +135,14 @@ def parser(arg):
arg[key] = c[value]
return arg

return json.loads(string, object_hook=parser)
return json.loads(string, object_hook=parser)


def get_config_keys(object_type):
""" Get a list of keys for a config describing an object of the specified type.
"""
with open(SAVE_DIR_STRUCTURE_PATH) as f:
save_dir_structure = json.loads(f.read())[object_type]

return [k for k, v in save_dir_structure.items() if v["add_key_name_flag"] is not None]

0 comments on commit 9f62607

Please sign in to comment.