From 380a7523f3bca9f10c4fd16af8c915f02e6f0bbf Mon Sep 17 00:00:00 2001 From: Misha D Date: Fri, 6 Dec 2024 15:14:34 +0300 Subject: [PATCH] style fixed --- src/aux/configs.py | 2 +- src/base/custom_datasets.py | 12 +- src/base/dataset_stats.py | 14 +- src/base/datasets_processing.py | 36 +-- src/base/ptg_datasets.py | 4 +- src/base/vk_datasets.py | 4 +- web_interface/back_front/block.py | 218 +++++++++++-------- web_interface/back_front/dataset_blocks.py | 63 ++++-- web_interface/back_front/diagram.py | 41 +++- web_interface/back_front/explainer_blocks.py | 95 ++++++-- web_interface/back_front/frontend_client.py | 17 +- web_interface/back_front/model_blocks.py | 180 +++++++++++---- web_interface/back_front/utils.py | 71 ++++-- web_interface/main_multi.py | 42 +++- 14 files changed, 568 insertions(+), 231 deletions(-) diff --git a/src/aux/configs.py b/src/aux/configs.py index 303affa..cf1a2d6 100644 --- a/src/aux/configs.py +++ b/src/aux/configs.py @@ -199,7 +199,7 @@ def set_defaults_config_pattern_info( def to_json( self - ): + ) -> dict: """ Special method which allows to use json.dumps() on Config object """ return self.to_dict() diff --git a/src/base/custom_datasets.py b/src/base/custom_datasets.py index 2055c15..0afe0f0 100644 --- a/src/base/custom_datasets.py +++ b/src/base/custom_datasets.py @@ -36,41 +36,41 @@ def __init__( @property def node_attributes_dir( self - ): + ) -> Path: """ Path to dir with node attributes. """ return self.root_dir / 'raw' / (self.name + '.node_attributes') @property def edge_attributes_dir( self - ): + ) -> Path: """ Path to dir with edge attributes. """ return self.root_dir / 'raw' / (self.name + '.edge_attributes') @property def labels_dir( self - ): + ) -> Path: """ Path to dir with labels. """ return self.root_dir / 'raw' / (self.name + '.labels') @property def edges_path( self - ): + ) -> Path: """ Path to file with edge list. """ return self.root_dir / 'raw' / (self.name + '.ij') @property def edge_index_path( self - ): + ) -> Path: """ Path to dir with labels. """ return self.root_dir / 'raw' / (self.name + '.edge_index') def check_validity( self - ): + ) -> None: """ Check that dataset files (graph and attributes) are valid and consistent with .info. """ # Assuming info is OK diff --git a/src/base/dataset_stats.py b/src/base/dataset_stats.py index 0fd2baf..801b7f5 100644 --- a/src/base/dataset_stats.py +++ b/src/base/dataset_stats.py @@ -117,7 +117,7 @@ def set( self, stat: str, value: Union[int, float, dict, str] - ): + ) -> None: """ Set statistics to a specified value and save to file. """ assert stat in DatasetStats.all_stats @@ -129,7 +129,7 @@ def set( def remove( self, stat: str - ): + ) -> None: """ Remove statistics from dict and file. """ if stat in self.stats: @@ -140,7 +140,7 @@ def remove( def clear_all_stats( self - ): + ) -> None: """ Remove all stats. E.g. the graph has changed. """ for s in DatasetStats.all_stats: @@ -148,7 +148,7 @@ def clear_all_stats( def update_var_config( self - ): + ) -> None: """ Remove var stats from dict since dataset config has changed. """ for s in DatasetStats.var_stats: @@ -158,7 +158,7 @@ def update_var_config( def _compute( self, stat: str - ): + ) -> None: """ Compute statistics for a single graph. Result could be: a number, a string, a distribution, a dict of ones. """ @@ -255,7 +255,7 @@ def _compute( def _compute_multi( self, stat: str - ): + ) -> None: """ Compute statistics for a multiple-graphs dataset. Result could be: a number, a string, a distribution, a dict of ones. """ @@ -286,7 +286,7 @@ def _compute_multi( def list_to_hist( a_list: list -): +) -> dict: """ Convert a list of integers/floats to a frequency histogram, return it as a dict """ return {k: v for k, v in Counter(a_list).most_common()} diff --git a/src/base/datasets_processing.py b/src/base/datasets_processing.py index 053f378..8aa12c8 100644 --- a/src/base/datasets_processing.py +++ b/src/base/datasets_processing.py @@ -7,7 +7,7 @@ import torch import torch_geometric from torch import default_generator, randperm -from torch_geometric.data import Dataset, InMemoryDataset +from torch_geometric.data import Dataset, InMemoryDataset, Data from aux.configs import DatasetConfig, DatasetVarConfig, ConfigPattern from aux.custom_decorators import timing_decorator @@ -123,7 +123,7 @@ def save( @staticmethod def induce( dataset: Dataset - ): + ) -> object: """ Induce metainfo from a given PTG dataset. """ res = DatasetInfo() @@ -144,7 +144,7 @@ def induce( @staticmethod def read( path: Union[str, Path] - ): + ) -> object: """ Read info from a file. """ with path.open('r') as f: a_dict = json.load(f) @@ -346,7 +346,7 @@ def __init__( @property def root_dir( self - ): + ) -> Path: """ Dataset root directory with folders 'raw' and 'prepared'. """ # FIXME Misha, dataset_prepared_dir return path and files_paths not only path return Declare.dataset_root_dir(self.dataset_config)[0] @@ -354,7 +354,7 @@ def root_dir( @property def results_dir( self - ): + ) -> Path: """ Path to 'prepared/../' folder where tensor data is stored. """ # FIXME Misha, dataset_prepared_dir return path and files_paths not only path return Path(Declare.dataset_prepared_dir(self.dataset_config, self.dataset_var_config)[0]) @@ -362,46 +362,46 @@ def results_dir( @property def raw_dir( self - ): + ) -> Path: """ Path to 'raw/' folder where raw data is stored. """ return self.root_dir / 'raw' @property def api_path( self - ): + ) -> Path: """ Path to '.api' file. Could be not present. """ return self.root_dir / '.api' @property def info_path( self - ): + ) -> Path: """ Path to '.info' file. """ return self.root_dir / 'raw' / '.info' @property def data( self - ): + ) -> Data: return self.dataset._data @property def num_classes( self - ): + ) -> int: return self.dataset.num_classes @property def num_node_features( self - ): + ) -> int: return self.dataset.num_node_features @property def labels( self - ): + ) -> torch.Tensor: if self._labels is None: # NOTE: this is a copy from torch_geometric.data.dataset v=2.3.1 from torch_geometric.data.dataset import _get_flattened_data_list @@ -428,7 +428,7 @@ def is_multi( def build( self, dataset_var_config: Union[ConfigPattern, DatasetVarConfig] - ): + ) -> None: """ Create node feature tensors from attributes based on dataset_var_config. """ raise NotImplementedError() @@ -589,7 +589,7 @@ def _compute_dataset_var_data( def get_stat( self, stat: str - ): + ) -> Union[int, float, dict, str]: """ Get statistics. """ return self.stats.get(stat) @@ -597,15 +597,15 @@ def get_stat( def _compute_stat( self, stat: str - ): + ) -> None: """ Compute a non-standard statistics. """ - # Should bw defined in a subclass + # Should be defined in a subclass raise NotImplementedError() def is_one_hot_able( self - ): + ) -> bool: """ Return whether features are 1-hot encodings. If yes, nodes can be colored. """ assert self.dataset_var_config @@ -942,7 +942,7 @@ def merge_directories( source_dir: Union[Path, str], destination_dir: Union[Path, str], remove_source: bool = False -): +) -> None: """ Merge source directory into destination directory, replacing existing files. diff --git a/src/base/ptg_datasets.py b/src/base/ptg_datasets.py index f42f058..1cc7ff7 100644 --- a/src/base/ptg_datasets.py +++ b/src/base/ptg_datasets.py @@ -214,12 +214,12 @@ def __init__( @property def processed_file_names( self - ): + ) -> str: return 'data.pt' def process( self - ): + ) -> None: raise RuntimeError("Dataset is supposed to be processed and saved earlier.") # torch.save(self.collate(self.data_list), self.processed_paths[0]) diff --git a/src/base/vk_datasets.py b/src/base/vk_datasets.py index 4ef3a07..c510a01 100644 --- a/src/base/vk_datasets.py +++ b/src/base/vk_datasets.py @@ -23,7 +23,7 @@ class AttrInfo: @staticmethod def vk_attr( - ): + ) -> dict: vk_dict = { ('age',): list(range(0, len(AGE_GROUPS) + 1)), ('sex',): [1, 2], @@ -175,7 +175,7 @@ def __init__( def _compute_dataset_data( self - ): + ) -> None: """ Get DatasetData for VK graph """ super()._compute_dataset_data() diff --git a/web_interface/back_front/block.py b/web_interface/back_front/block.py index eca8aa3..aeba9dc 100644 --- a/web_interface/back_front/block.py +++ b/web_interface/back_front/block.py @@ -4,6 +4,52 @@ from web_interface.back_front.utils import SocketConnect +class BlockConfig(dict): + def __init__( + self + ): + super().__init__() + + def init( + self, + *args + ): + pass + + def modify( + self, + **kwargs + ): + for key, value in kwargs.items(): + self[key] = value + + def finalize( + self + ): + """ Check correctness """ + # TODO check correctness + return True + + def toDefaults( + self + ): + """ Set default values """ + self.clear() + + def breik( + self, + arg: str = None + ): + if arg is None: + return + if arg == "full": + self.clear() + elif arg == "default": + self.toDefaults() + else: + raise ValueError(f"Unknown argument for breik(): {arg}") + + class Block: """ A logical block of a dependency diagram. @@ -25,15 +71,16 @@ def __init__(self, name, socket: SocketConnect = None): self._object = None # Result of backend request, will be passed to dependent blocks self._result = None # Info to be send to frontend at submit - # Whether block is defined - def is_set(self): + def is_set( + self + ) -> bool: + """ Whether block is defined """ return self._is_set - # Get the config - def get_config(self): - return self._config.copy() - - def init(self, *args): + def init( + self, + *args + ) -> None: """ Create the default version of config. @@ -49,14 +96,21 @@ def init(self, *args): init_params = self._init(*args) self._send('onInit', init_params) - def _init(self, *args): + def _init( + self, + *args + ) -> None: """ Returns jsonable info to be sent to front with onInit() """ # To be overridden in subclass raise NotImplementedError - # Change some values of the config - def modify(self, **key_values): + def modify( + self, + **key_values + ) -> None: + """ Change some values of the config + """ if self._is_set: raise RuntimeError(f'Block[{self.name}] is set and cannot be modified!') else: @@ -64,8 +118,10 @@ def modify(self, **key_values): self._config.modify(**key_values) self._send('onModify') - # Check config correctness and make block to be defined - def finalize(self): + def finalize( + self + ) -> None: + """ Check config correctness and make block to be defined """ if self._is_set: print(f'Block[{self.name}] already set') return @@ -77,15 +133,19 @@ def finalize(self): else: raise RuntimeError(f'Block[{self.name}] failed to finalize') - def _finalize(self): + def _finalize( + self + ) -> None: """ Checks whether the config is correct to create the object. Returns True if OK or False. # TODO can we send to front errors to be fixed? """ raise NotImplementedError - # Run diagram with this block value - def submit(self): + def submit( + self + ) -> None: + """ Run diagram with this block value """ self.finalize() if not self._is_set: @@ -98,17 +158,24 @@ def submit(self): if self.diagram: self.diagram.on_submit(self) - # Perform back request, ect - def _submit(self): + def _submit( + self + ) -> None: + """ Perform back request, ect """ # To be overridden in subclass raise NotImplementedError - def get_object(self): + def get_object( + self + ) -> object: """ Get contained backend object """ return self._object - def unlock(self, toDefault=False): + def unlock( + self, + toDefault: bool = False + ) -> None: """ Make block to be undefined """ if self._is_set: @@ -125,7 +192,10 @@ def unlock(self, toDefault=False): if self.diagram: self.diagram.on_drop(self) - def breik(self, arg=None): + def breik( + self, + arg: str = None + ) -> None: """ Break block logically """ print(f'Block[{self.name}].break()') @@ -133,7 +203,11 @@ def breik(self, arg=None): self._config.breik(arg) self._send('onBreak') - def _send(self, func, kw_params=None): + def _send( + self, + func: str, + kw_params: dict = None + ) -> None: """ Send signal to frontend listeners. """ kw_params_str = str(kw_params) if len(kw_params_str) > 30: @@ -143,39 +217,13 @@ def _send(self, func, kw_params=None): self.socket.send(block=self.name, func=func, msg=kw_params, tag=self.tag) -class BlockConfig(dict): - def __init__(self): - super().__init__() - - def init(self, *args): - pass - - def modify(self, **kwargs): - for key, value in kwargs.items(): - self[key] = value - - # Check correctness - def finalize(self): - # TODO check correctness - return True - - # Set default values - def toDefaults(self): - self.clear() - - def breik(self, arg=None): - if arg is None: - return - if arg == "full": - self.clear() - elif arg == "default": - self.toDefaults() - else: - raise ValueError(f"Unknown argument for breik(): {arg}") - - class WrapperBlock(Block): - def __init__(self, blocks, *args, **kwargs): + def __init__( + self, + blocks: [Block], + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.blocks = blocks # Patch submit and unlock functions @@ -194,22 +242,35 @@ def new_submit(slf): old_unlocks[b] = copy_func(b.unlock) - def new_unlock(slf, *args, **kwargs): + def new_unlock( + slf, + *args, + **kwargs + ): old_unlocks[slf](slf, *args, **kwargs) self.unlock() b.unlock = types.MethodType(new_unlock, b) - def init(self, *args): + def init( + self, + *args + ) -> None: super().init(*args) for b in self.blocks: b.init(*args) - def breik(self, arg=None): + def breik( + self, + arg: str = None + ) -> None: for b in self.blocks: b.breik(arg) super().breik(arg) - def onsubmit(self, block): + def onsubmit( + self, + block + ) -> None: # # Break all but the given # for b in self.blocks: # if b != block: @@ -223,20 +284,29 @@ def onsubmit(self, block): if self.diagram: self.diagram.on_submit(self) - def modify(self, **key_values): + def modify( + self, + **key_values + ) -> None: # Must not be called raise RuntimeError - def _finalize(self): + def _finalize( + self + ) -> None: # Must not be called raise RuntimeError - def _submit(self): + def _submit( + self + ) -> None: # Must not be called raise RuntimeError -def copy_func(f): +def copy_func( + f +): """Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)""" g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, @@ -245,29 +315,3 @@ def copy_func(f): g = functools.update_wrapper(g, f) g.__kwdefaults__ = f.__kwdefaults__ return g - - -# if __name__ == '__main__': -# class A: -# def __init__(self, x): -# self.x = x -# -# def f(self): -# print(self.x) -# -# a_list = [A(1), A(2), A(3)] -# -# def pf(a): -# print('pf', a.x) -# -# # Patching -# for a in a_list: -# old_f = copy_func(a.f) -# -# def new_f(): -# pf(a) -# old_f(a) -# a.f = new_f -# -# for a in a_list: -# a.f() diff --git a/web_interface/back_front/dataset_blocks.py b/web_interface/back_front/dataset_blocks.py index 16393de..227c960 100644 --- a/web_interface/back_front/dataset_blocks.py +++ b/web_interface/back_front/dataset_blocks.py @@ -2,35 +2,50 @@ from aux.data_info import DataInfo from aux.utils import TORCH_GEOM_GRAPHS_PATH -from base.datasets_processing import DatasetManager, GeneralDataset +from base.datasets_processing import DatasetManager, GeneralDataset, VisiblePart from aux.configs import DatasetConfig, DatasetVarConfig from web_interface.back_front.block import Block from web_interface.back_front.utils import json_dumps, get_config_keys class DatasetBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.dataset_config = None - def _init(self): + def _init( + self + ) -> None: pass - def _finalize(self): + def _finalize( + self + ) -> bool: if set(get_config_keys("data_root")) != set(self._config.keys()): return False self.dataset_config = DatasetConfig(**self._config) return True - def _submit(self): + def _submit( + self + ) -> None: self._object = DatasetManager.get_by_config(self.dataset_config) - def get_stat(self, stat): + def get_stat( + self, + stat + ) -> object: return self._object.get_stat(stat) - def get_index(self): + def get_index( + self + ) -> str: DataInfo.refresh_data_dir_structure() index = DataInfo.data_parse() @@ -47,39 +62,59 @@ def get_index(self): return json_dumps([index.to_json(), json_dumps('')]) - def set_visible_part(self, part=None): + def set_visible_part( + self, + part: dict = None + ) -> str: self._object.set_visible_part(part=part) return '' - def get_dataset_data(self, part=None): + def get_dataset_data( + self, + part: dict = None + ) -> dict: return self._object.get_dataset_data(part=part) class DatasetVarBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.tag = 'dvc' self.gen_dataset: GeneralDataset = None # FIXME duplication!!! self.dataset_var_config = None - def _init(self, dataset: GeneralDataset): + def _init( + self, + dataset: GeneralDataset + ) -> dict: self.gen_dataset = dataset return self.gen_dataset.info.to_dict() - def _finalize(self): + def _finalize( + self + ) -> bool: if set(get_config_keys("data_prepared")) != set(self._config.keys()): return False self.dataset_var_config = DatasetVarConfig(**self._config) return True - def _submit(self): + def _submit( + self + ) -> None: self.gen_dataset.build(self.dataset_var_config) self._object = self.gen_dataset # NOTE: we need to compute var_data to be able to get is_one_hot_able() self.gen_dataset.get_dataset_var_data() self._result = [self.dataset_var_config.labeling, self.gen_dataset.is_one_hot_able()] - def get_dataset_var_data(self, part=None): + def get_dataset_var_data( + self, + part: dict = None + ) -> dict: return self._object.get_dataset_var_data(part=part) diff --git a/web_interface/back_front/diagram.py b/web_interface/back_front/diagram.py index f7cbabb..76a851e 100644 --- a/web_interface/back_front/diagram.py +++ b/web_interface/back_front/diagram.py @@ -1,18 +1,28 @@ -from web_interface.back_front.block import WrapperBlock +from typing import Union + +from web_interface.back_front.block import WrapperBlock, Block class Diagram: """Diagram of frontend states and transitions between them. """ - def __init__(self): - self.blocks = {} + def __init__( + self + ): + self.blocks = {} # {name -> Block} - def get(self, name): + def get( + self, + name: str + ) -> Block: """ Get block by its name """ return self.blocks[name] - def add_block(self, block): + def add_block( + self, + block: Block + ) -> None: if block.name in self.blocks: return @@ -23,7 +33,12 @@ def add_block(self, block): self.blocks[b.name] = b b.diagram = self - def add_dependency(self, _from, to, condition=all): + def add_dependency( + self, + _from: Union[list, Block], + to: Union[list, Block], + condition=all + ) -> None: # assert condition in [all, any] if not isinstance(_from, list): _from = [_from] @@ -36,7 +51,10 @@ def add_dependency(self, _from, to, condition=all): self.add_block(to) to.condition = condition - def on_submit(self, block): + def on_submit( + self, + block: Block + ) -> None: """ Init all blocks possible after the block submission """ print('Diagram.onSubmit(' + block.name + ')') @@ -46,7 +64,10 @@ def on_submit(self, block): # IMP add params names as blocks names b_after.init(*[x.get_object() for x in b_after.requires if x.is_set()]) - def on_drop(self, block): + def on_drop( + self, + block: Block + ) -> None: """ Recursively break all block that critically depend on the given one """ print('Diagram.onBreak(' + block.name + ')') @@ -54,7 +75,9 @@ def on_drop(self, block): for b in block.influences: b.breik() - def drop(self): + def drop( + self + ) -> None: """ Drop all blocks. """ # FIXME many blocks will break many times for block in self.blocks.values(): diff --git a/web_interface/back_front/explainer_blocks.py b/web_interface/back_front/explainer_blocks.py index 2214936..cae8e3d 100644 --- a/web_interface/back_front/explainer_blocks.py +++ b/web_interface/back_front/explainer_blocks.py @@ -1,5 +1,7 @@ import json import os +from pathlib import Path +from typing import Union from aux.configs import ExplainerInitConfig, ExplainerModificationConfig, ExplainerRunConfig, \ ConfigPattern @@ -15,22 +17,40 @@ class ExplainerWBlock(WrapperBlock): - def __init__(self, name, blocks, *args, **kwargs): + def __init__( + self, + name: str, + blocks: [Block], + *args, + **kwargs + ): super().__init__(blocks, name, *args, **kwargs) - def _init(self, gen_dataset: GeneralDataset, gmm: GNNModelManager): + def _init( + self, + gen_dataset: GeneralDataset, + gmm: GNNModelManager + ) -> None: self.gen_dataset = gen_dataset self.gmm = gmm - def _finalize(self): + def _finalize( + self + ) -> bool: return True - def _submit(self): + def _submit( + self + ) -> None: pass class ExplainerLoadBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.explainer_path = None @@ -38,21 +58,29 @@ def __init__(self, *args, **kwargs): self.gen_dataset = None self.gmm = None - def _init(self, gen_dataset: GeneralDataset, gmm: GNNModelManager): + def _init( + self, + gen_dataset: GeneralDataset, + gmm: GNNModelManager + ) -> list: # Define options for model manager self.gen_dataset = gen_dataset self.gmm = gmm return [gen_dataset.dataset.num_node_features, gen_dataset.is_multi(), self.get_index()] # return self.get_index() - def _finalize(self): + def _finalize( + self + ) -> bool: if set(get_config_keys("explanations")) != set(self._config.keys()): return False self.explainer_path = self._config return True - def _submit(self): + def _submit( + self + ) -> None: init_config, run_config = self._explainer_kwargs(model_path=self.gmm.model_path_info(), explainer_path=self.explainer_path) modification_config = ExplainerModificationConfig( @@ -72,7 +100,9 @@ def _submit(self): "explanation_data": self._object.load_explanation(run_config=run_config) } - def get_index(self): + def get_index( + self + ) -> [str, str]: """ Get all available explanations with respect to current dataset and model """ path = os.path.relpath(self.gmm.model_path_info(), MODELS_DIR) @@ -89,7 +119,11 @@ def get_index(self): # return [ps.to_json(), json_dumps(self.info)] FIXME misha parsing error on front return [ps.to_json(), '{}'] - def _explainer_kwargs(self, model_path, explainer_path): + def _explainer_kwargs( + self, + model_path: Union[str, Path], + explainer_path: Union[str, Path] + ): init_kwargs_file, run_kwargs_file = Declare.explainer_kwargs_path_full( model_path=model_path, explainer_path=explainer_path) with open(init_kwargs_file) as f: @@ -100,27 +134,39 @@ def _explainer_kwargs(self, model_path, explainer_path): class ExplainerInitBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.explainer_init_config = None self.gen_dataset = None self.gmm = None - def _init(self, gen_dataset: GeneralDataset, gmm: GNNModelManager): + def _init( + self, + gen_dataset: GeneralDataset, + gmm: GNNModelManager + ) -> list: # Define options for model manager self.gen_dataset = gen_dataset self.gmm = gmm return FrameworkExplainersManager.available_explainers(self.gen_dataset, self.gmm) - def _finalize(self): + def _finalize( + self + ) -> bool: self.explainer_init_config = ConfigPattern( **self._config, _import_path=EXPLAINERS_INIT_PARAMETERS_PATH, _config_class="ExplainerInitConfig") return True - def _submit(self): + def _submit( + self + ) -> None: # Build an explainer self._object = FrameworkExplainersManager( dataset=self.gen_dataset, gnn_manager=self.gmm, @@ -131,19 +177,30 @@ def _submit(self): class ExplainerRunBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.explainer_run_config = None self.explainer_manager = None - def _init(self, explainer_manager: FrameworkExplainersManager): + def _init( + self, + explainer_manager: FrameworkExplainersManager + ) -> list: self.explainer_manager = explainer_manager return [self.explainer_manager.gen_dataset.dataset.num_node_features, self.explainer_manager.gen_dataset.is_multi(), self.explainer_manager.explainer.name] - def do(self, do, params): + def do( + self, + do: str, + params: dict + ) -> str: if do == "run": config = json_loads(params.get('explainerRunConfig')) config['_config_kwargs']['kwargs']["_import_path"] =\ @@ -163,7 +220,9 @@ def do(self, do, params): # elif do == "save": # return self._save_explainer() - def _run_explainer(self): + def _run_explainer( + self + ) -> None: self.socket.send("explainer", { "status": "STARTED", "mode": self.explainer_run_config.mode}) # Saves explanation by default, save_explanation_flag=True diff --git a/web_interface/back_front/frontend_client.py b/web_interface/back_front/frontend_client.py index 6b1bd02..2c61107 100644 --- a/web_interface/back_front/frontend_client.py +++ b/web_interface/back_front/frontend_client.py @@ -1,4 +1,5 @@ import json +from typing import Union from aux.utils import FUNCTIONS_PARAMETERS_PATH, FRAMEWORK_PARAMETERS_PATH, MODULES_PARAMETERS_PATH, \ EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, \ @@ -26,7 +27,9 @@ class FrontendClient: 'F': None, 'FW': None, 'M': None, 'EI': None, 'ER': None, 'O': None} @staticmethod - def get_parameters(type): + def get_parameters( + type: str + ) -> Union[dict, None]: """ """ if type not in FrontendClient.parameters: @@ -45,7 +48,10 @@ def get_parameters(type): return FrontendClient.parameters[type] - def __init__(self, sid): + def __init__( + self, + sid: str + ): self.sid = sid # socket ID self.socket = SocketConnect(sid=sid) @@ -91,7 +97,12 @@ def __init__(self, sid): # """ # self.diagram.drop() - def request_block(self, block, func, params: dict = None): + def request_block( + self, + block: str, + func: str, + params: dict = None + ) -> object: """ :param block: name of block :param func: block function to call diff --git a/web_interface/back_front/model_blocks.py b/web_interface/back_front/model_blocks.py index b9c11b1..053fa86 100644 --- a/web_interface/back_front/model_blocks.py +++ b/web_interface/back_front/model_blocks.py @@ -1,5 +1,7 @@ import json import os +from pathlib import Path +from typing import Union import torch from torch_geometric.data import Dataset @@ -12,47 +14,72 @@ TECHNICAL_PARAMETER_KEY, \ IMPORT_INFO_KEY from base.datasets_processing import GeneralDataset, VisiblePart -from models_builder.gnn_constructor import FrameworkGNNConstructor +from models_builder.gnn_constructor import FrameworkGNNConstructor, GNNConstructor from models_builder.gnn_models import ModelManagerConfig, GNNModelManager, Metric from web_interface.back_front.block import Block, WrapperBlock -from web_interface.back_front.utils import WebInterfaceError, json_dumps, get_config_keys +from web_interface.back_front.utils import WebInterfaceError, json_dumps, get_config_keys, \ + SocketConnect TENSOR_SIZE_LIMIT = 1024 # Max size of weights tensor we sent to frontend class ModelWBlock(WrapperBlock): - def __init__(self, name, blocks, *args, **kwargs): + def __init__( + self, + name: str, + blocks: [Block], + *args, + **kwargs + ): super().__init__(blocks, name, *args, **kwargs) - def _init(self, ptg_dataset: Dataset): + def _init( + self, + ptg_dataset: Dataset + ) -> list[int]: return [ptg_dataset.num_node_features, ptg_dataset.num_classes] - def _finalize(self): + def _finalize( + self + ) -> bool: return True - def _submit(self): + def _submit( + self + ) -> None: pass class ModelLoadBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.model_path = None self.gen_dataset = None - def _init(self, gen_dataset: GeneralDataset): + def _init( + self, + gen_dataset: GeneralDataset + ) -> list[str]: self.gen_dataset = gen_dataset return self.get_index() - def _finalize(self): + def _finalize( + self + ) -> bool: if set(get_config_keys("models")) != set(self._config.keys()): return False self.model_path = self._config return True - def _submit(self): + def _submit( + self + ) -> None: from models_builder.gnn_models import GNNModelManager self.model_manager, train_test_split_path = GNNModelManager.from_model_path( model_path=self.model_path, dataset_path=self.gen_dataset.results_dir) @@ -62,7 +89,9 @@ def _submit(self): self._result = self._object.get_full_info() self._result.update(self._object.gnn.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT)) - def get_index(self): + def get_index( + self + ) -> list[str]: """ Get all available models with respect to current dataset """ DataInfo.refresh_models_dir_structure() @@ -78,7 +107,10 @@ def get_index(self): ps = index.filter(dict(zip(keys_list, values_info))) return [ps.to_json(), json_dumps(info)] - def _load_train_test_mask(self, path): + def _load_train_test_mask( + self, + path: Union[Path, str] + ) -> None: """ Load train/test mask associated to the model and send to frontend """ # FIXME self.manager_config.train_test_split self.gen_dataset.train_mask, self.gen_dataset.val_mask, \ @@ -87,16 +119,25 @@ def _load_train_test_mask(self, path): class ModelConstructorBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.model_config = None - def _init(self, gen_dataset: GeneralDataset): + def _init( + self, + gen_dataset: GeneralDataset + ) -> list: ptg_dataset = gen_dataset.dataset return [ptg_dataset.num_node_features, ptg_dataset.num_classes, gen_dataset.is_multi()] - def _finalize(self): + def _finalize( + self + ) -> bool: # TODO better check if not ('layers' in self._config and isinstance(self._config['layers'], list)): return False @@ -104,30 +145,43 @@ def _finalize(self): self.model_config = ModelConfig(structure=ModelStructureConfig(**self._config)) return True - def _submit(self): + def _submit( + self + ) -> None: self._object = FrameworkGNNConstructor(self.model_config) self._result = self._object.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT) class ModelCustomBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.gen_dataset = None self.model_name: dict = None - def _init(self, gen_dataset: GeneralDataset): + def _init( + self, + gen_dataset: GeneralDataset + ) -> list[str]: self.gen_dataset = gen_dataset return self.get_index() - def _finalize(self): + def _finalize( + self + ) -> bool: if not (len(self._config.keys()) == 2): # TODO better check return False self.model_name = self._config return True - def _submit(self): + def _submit( + self + ) -> None: # FIXME misha this is bad way user_models_obj_dict_info = UserCodeInfo.user_models_list_ref() cm_path = None @@ -143,7 +197,9 @@ def _submit(self): self._object = UserCodeInfo.take_user_model_obj(cm_path, self.model_name["model"]) self._result = self._object.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT) - def get_index(self): + def get_index( + self + ) -> list[str]: """ Get all available models with respect to current dataset """ user_models_obj_dict_info = UserCodeInfo.user_models_list_ref() @@ -161,13 +217,21 @@ def get_index(self): class ModelManagerBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.model_manager_config = None self.klass = None - def _init(self, gen_dataset: GeneralDataset, gnn): + def _init( + self, + gen_dataset: GeneralDataset, + gnn: GNNConstructor + ) -> dict: # Define options for model manager self.gen_dataset = gen_dataset self.gnn = gnn @@ -179,12 +243,16 @@ def _init(self, gen_dataset: GeneralDataset, gnn): mm_info = model_managers_info_by_names_list(mm_set) return mm_info - def _finalize(self): + def _finalize( + self + ) -> bool: self.klass = self._config.pop("class") self.model_manager_config = ModelManagerConfig(**self._config) return True - def _submit(self): + def _submit( + self + ) -> None: create_train_test_mask = True assert self.gnn is not None @@ -217,7 +285,10 @@ def _submit(self): self.gen_dataset.train_test_split(*self.model_manager_config.train_test_split) send_train_test_mask(self.gen_dataset, self.socket) - def get_satellites(self, part=None): + def get_satellites( + self, + part: dict = None + ) -> dict: """ Resend model dependent satellites data: train-test mask, embeds, preds """ visible_part = self.gen_dataset.visible_part if part is None else\ @@ -233,7 +304,11 @@ def get_satellites(self, part=None): return res -def send_train_test_mask(gen_dataset, socket, visible_part=None): +def send_train_test_mask( + gen_dataset, + socket: SocketConnect = None, + visible_part: VisiblePart = None +) -> Union[None, dict]: """ Compute train/test mask for the dataset and send to frontend. """ if visible_part is None: @@ -255,27 +330,43 @@ def send_train_test_mask(gen_dataset, socket, visible_part=None): class ModelTrainerBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.gen_dataset = None self.model_manager = None - def _init(self, gen_dataset: GeneralDataset, gmm: GNNModelManager): + def _init( + self, + gen_dataset: GeneralDataset, + gmm: GNNModelManager + ) -> dict: self.gen_dataset = gen_dataset self.model_manager = gmm return self.model_manager.get_model_data() - def _finalize(self): + def _finalize( + self + ) -> bool: # TODO for ProtGNN model must be trained return True - def _submit(self): + def _submit( + self + ) -> None: self._object = self.model_manager - def do(self, do, params): + def do( + self, + do, + params + ) -> str: if do == "run": metrics = [Metric(**m) for m in json.loads(params.get('metrics'))] self._run_model(metrics) @@ -308,14 +399,19 @@ def do(self, do, params): else: raise WebInterfaceError(f"Unknown 'do' command {do} for model") - def _reset_model(self): + def _reset_model( + self + ) -> None: self.model_manager.gnn.reset_parameters() self.model_manager.modification.epochs = 0 self.gen_dataset.train_test_split(*self.model_manager.manager_config.train_test_split) send_train_test_mask(self.gen_dataset, self.socket) self._run_model([Metric("Accuracy", mask='train'), Metric("Accuracy", mask='test')]) - def _run_model(self, metrics): + def _run_model( + self, + metrics + ) -> None: """ Runs model to compute predictions and logits """ # TODO add set of nodes assert self.model_manager @@ -330,20 +426,30 @@ def _run_model(self, metrics): self.model_manager.send_epoch_results( metrics_values=metrics_values, stats_data=stats_data, socket=self.socket) - def _train_model(self, mode, steps, metrics): + def _train_model( + self, + mode: Union[str, None], + steps: Union[int, None], + metrics: list[Metric] + ) -> None: self._check_metrics(metrics) self.model_manager.train_model( gen_dataset=self.gen_dataset, save_model_flag=False, mode=mode, steps=steps, metrics=metrics, socket=self.socket) - def _save_model(self): + def _save_model( + self + ) -> str: path = self.model_manager.save_model_executor() self.gen_dataset.save_train_test_mask(path) DataInfo.refresh_models_dir_structure() # TODO send dir_structure info to front return str(path) - def _check_metrics(self, metrics): + def _check_metrics( + self, + metrics: list[Metric] + ) -> None: """ Adjust metrics parameters if dataset has many classes, e.g. binary -> macro averaging """ classes = self.gen_dataset.num_classes diff --git a/web_interface/back_front/utils.py b/web_interface/back_front/utils.py index 75d127c..207f5e0 100644 --- a/web_interface/back_front/utils.py +++ b/web_interface/back_front/utils.py @@ -1,3 +1,6 @@ +from typing import Any + +from flask_socketio import SocketIO import json from collections import deque from threading import Thread @@ -9,10 +12,15 @@ class WebInterfaceError(Exception): - def __init__(self, *args): + def __init__( + self, + *args + ): self.message = args[0] if args else None - def __str__(self): + def __str__( + self + ): if self.message: return f"WebInterfaceError: {self.message}" else: @@ -20,18 +28,29 @@ def __str__(self): class Queue(deque): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super(Queue, self).__init__(*args, **kwargs) self.last_obl = True - def push(self, obj, id, obligate): + def push( + self, + obj: object, + id: int, + obligate: bool + ) -> None: # If last is not obligate - replace it if len(self) > 0 and self.last_obl is False: self.pop() super(Queue, self).append((obj, id)) self.last_obl = obligate - def get_first_id(self): + def get_first_id( + self + ) -> int: if len(self) > 0: obj, id = self.popleft() self.appendleft((obj, id)) @@ -46,9 +65,12 @@ class SocketConnect: # max_packet_size = 1024**2 # 1MB limit by default - def __init__(self, socket=None, sid=None): + def __init__( + self, + socket: SocketIO = None, + sid: str = None + ): if socket is None: - from flask_socketio import SocketIO self.socket = SocketIO(message_queue='redis://') else: self.socket = socket @@ -59,9 +81,16 @@ def __init__(self, socket=None, sid=None): self.sleep_time = 0.5 self.active = False # True when sending cycle is running - def send(self, block, msg, func=None, tag='all', obligate=True): + def send( + self, + block: str, + msg: dict, + func: str = None, + tag: str = 'all', + obligate: bool = True + ): """ Send info message to frontend. - :param dst: destination, e.g. "" (to console), "model", "explainer" + :param block: destination block, e.g. "" (to console), "model", "explainer" :param msg: dict :param tag: keep messages in a separate queue with this tag, all but last unobligate messages will be squashed @@ -82,7 +111,9 @@ def send(self, block, msg, func=None, tag='all', obligate=True): if not self.active: Thread(target=self._cycle, args=()).start() - def _send(self): + def _send( + self + ) -> None: """ Send leftmost actual data element from the queue. """ data = None # Find actual data elem @@ -102,7 +133,9 @@ def _send(self): self.sleep_time = 0.5 * size / 25e6 * 10 print('sent data', id, tag, 'of len=', size, 'sleep', self.sleep_time) - def _cycle(self): + def _cycle( + self + ) -> None: """ Send messages from the queue until it is empty. """ self.active = True while True: @@ -113,7 +146,9 @@ def _cycle(self): sleep(self.sleep_time) -def json_dumps(object): +def json_dumps( + object +) -> str: """ Dump an object to JSON properly handling values "-Infinity", "Infinity", and "NaN" """ string = json.dumps(object, ensure_ascii=False) @@ -123,12 +158,16 @@ def json_dumps(object): .replace('Infinity', '"Infinity"') -def json_loads(string): +def json_loads( + string: str +) -> Any: """ Parse JSON string properly handling values "-Infinity", "Infinity", and "NaN" """ c = {"-Infinity": -np.inf, "Infinity": np.inf, "NaN": np.nan} - def parser(arg): + def parser( + arg + ): if isinstance(arg, dict): for key, value in arg.items(): if isinstance(value, str) and value in c: @@ -138,7 +177,9 @@ def parser(arg): return json.loads(string, object_hook=parser) -def get_config_keys(object_type): +def get_config_keys( + object_type: str +) -> list: """ Get a list of keys for a config describing an object of the specified type. """ with open(SAVE_DIR_STRUCTURE_PATH) as f: diff --git a/web_interface/main_multi.py b/web_interface/main_multi.py index 883964b..e73cdb0 100644 --- a/web_interface/main_multi.py +++ b/web_interface/main_multi.py @@ -2,6 +2,7 @@ import logging import multiprocessing from multiprocessing import Pipe +from multiprocessing.connection import Connection from flask import Flask, render_template, request from flask_socketio import SocketIO, emit @@ -20,7 +21,11 @@ active_sessions = {} # {session Id -> sid, process, conn} -def worker_process(process_id, conn, sid): +def worker_process( + process_id: str, + conn: Connection, + sid: str +) -> None: print(f"Process {process_id} started") # TODO problem is each process sends data to main process then to frontend. # Easier to send it directly to url @@ -158,7 +163,8 @@ def worker_process(process_id, conn, sid): @socketio.on('connect') -def handle_connect(): +def handle_connect( +) -> None: # FIXME create process not when socket connects but when a new tab is open session_id = str(uuid.uuid4()) print('handle_connect', session_id, request.sid) @@ -175,7 +181,8 @@ def handle_connect(): @socketio.on('disconnect') -def handle_disconnect(): +def handle_disconnect( +) -> None: print('handle_disconnect from some websocket') for session_id, (sid, process, parent_conn) in active_sessions.items(): @@ -186,29 +193,34 @@ def handle_disconnect(): @app.route('/') -def home(): +def home( +) -> str: # FIXME ? DataInfo.refresh_all_data_info() return render_template('analysis.html') @app.route('/analysis') -def analysis(): +def analysis( +) -> str: return render_template('analysis.html') @app.route('/interpretation') -def interpretation(): +def interpretation( +) -> str: return render_template('interpretation.html') @app.route('/defense') -def defense(): +def defense( +) -> str: return render_template('defense.html') @app.route("/drop", methods=['GET', 'POST']) -def drop(): +def drop( +) -> str: if request.method == 'POST': session_id = json.loads(request.data)['sessionId'] if session_id in active_sessions: @@ -217,7 +229,9 @@ def drop(): return '' -def stop_session(session_id): +def stop_session( + session_id: str +): _, process, conn = active_sessions[session_id] # Stop corresponding process @@ -240,7 +254,8 @@ def stop_session(session_id): @app.route("/ask", methods=['GET', 'POST']) -def storage(): +def storage( +) -> str: if request.method == 'POST': session_id = request.form.get('sessionId') assert session_id in active_sessions @@ -256,7 +271,8 @@ def storage(): @app.route("/block", methods=['GET', 'POST']) -def block(): +def block( +) -> str: if request.method == 'POST': session_id = request.form.get('sessionId') assert session_id in active_sessions @@ -268,7 +284,9 @@ def block(): @app.route("/", methods=['GET', 'POST']) -def url(url): +def url( + url: str +) -> str: assert url in ['dataset', 'model', 'explainer'] if request.method == 'POST': sid = request.form.get('sessionId')