diff --git a/data/multiple-graphs/custom/small/raw/.info b/data/multiple-graphs/custom/small/raw/.info index 9f7c830..48f95ff 100644 --- a/data/multiple-graphs/custom/small/raw/.info +++ b/data/multiple-graphs/custom/small/raw/.info @@ -1,7 +1,7 @@ { "count": 8, "nodes": [5, 4, 4, 8, 6, 7, 7, 9], - "directed": false, + "directed": true, "node_attributes": { "names": [ "a", "b" diff --git a/experiments/user_dataset.py b/experiments/user_dataset.py index 5ecd7f9..0f15a70 100644 --- a/experiments/user_dataset.py +++ b/experiments/user_dataset.py @@ -122,9 +122,31 @@ def simgnn(): print("len =", len(gen_dataset)) +def nx_to_ptg_converter(): + from aux.utils import GRAPHS_DIR + from base.dataset_converter import networkx_to_ptg + from base.datasets_processing import DatasetManager + import networkx as nx + + nx_path = GRAPHS_DIR / 'networkx-graphs' / 'input' / 'reply_graph.edgelist' + nx_graph = nx.read_edgelist(nx_path) + nx_graph = nx.to_undirected(nx_graph) + ptg_graph = networkx_to_ptg(nx_graph) + if ptg_graph.x is None: + ptg_graph.x = torch.zeros((ptg_graph.num_nodes, 1)) + if ptg_graph.y is None: + ptg_graph.y = torch.zeros(ptg_graph.num_nodes) + ptg_graph.y[0] = 1 + ptg_dataset = UserLocalDataset('test_dataset_single', [ptg_graph]) + gen_dataset = DatasetManager.register_torch_geometric_local(ptg_dataset) + print(len(gen_dataset)) + + if __name__ == '__main__': # local() - converted_local() + # converted_local() # api() # simgnn() + + nx_to_ptg_converter() diff --git a/metainfo/optimizers_parameters.json b/metainfo/optimizers_parameters.json index cb3d372..ad868bf 100644 --- a/metainfo/optimizers_parameters.json +++ b/metainfo/optimizers_parameters.json @@ -1,10 +1,10 @@ {"Adam": { "lr": ["learn rate", "float", 0.001, {"min": 0.0001, "step": 0.001}, "learning rate"], - "beta1": ["beta1", "float", 0.9, {}, "coefficient used for computing running averages of gradient and its square"], - "beta2": ["beta2", "float", 0.999, {}, "coefficient used for computing running averages of gradient and its square"], - "eps": ["Epsilon", "float", 0.00000001, {}, "term added to the denominator to improve numerical stability"], - "weight_decay": ["Weight decay (L2)", "float", 5e-4, {}, "weight decay (L2 penalty)"], + "beta1": ["beta1", "float", 0.9, {"min": 0}, "coefficient used for computing running averages of gradient and its square"], + "beta2": ["beta2", "float", 0.999, {"min": 0}, "coefficient used for computing running averages of gradient and its square"], + "eps": ["Epsilon", "float", 0.00000001, {"min": 0}, "term added to the denominator to improve numerical stability"], + "weight_decay": ["Weight decay (L2)", "float", 5e-4, {"min": 0}, "weight decay (L2 penalty)"], "amsgrad": ["AMSGrad", "bool", false, {}, "whether to use the AMSGrad"], "_technical_parameter": { @@ -17,22 +17,22 @@ "lr": ["learn rate", "float", 1.0, {"min": 0.0001, "step": 1}, "coefficient that scale delta before it is applied to the parameters"], "rho": ["rho", "float", 0.9, {}, "coefficient used for computing a running average of squared gradients"], "eps": ["Epsilon", "float", 1e-6, {}, "term added to the denominator to improve numerical stability"], - "weight_decay": ["Weight decay (L2)", "float", 0, {}, "weight decay (L2 penalty)"] + "weight_decay": ["Weight decay (L2)", "float", 0, {"min": 0}, "weight decay (L2 penalty)"] }, "Adagrad": { "lr": ["learn rate", "float", 0.01, {"min": 0.0001, "step": 0.01}, "learning rate"], - "lr_decay": ["lr decay", "float", 0, {}, "learning rate decay"], - "eps": ["Epsilon", "float", 1e-10, {}, "term added to the denominator to improve numerical stability"], - "weight_decay": ["Weight decay (L2)", "float", 0, {}, "weight decay (L2 penalty)"] + "lr_decay": ["lr decay", "float", 0, {"min": 0}, "learning rate decay"], + "eps": ["Epsilon", "float", 1e-10, {"min": 0}, "term added to the denominator to improve numerical stability"], + "weight_decay": ["Weight decay (L2)", "float", 0, {"min": 0}, "weight decay (L2 penalty)"] }, "AdamW": { "lr": ["learn rate", "float", 0.001, {"min": 0.0001, "step": 0.001}, "learning rate"], - "beta1": ["beta1", "float", 0.9, {}, "coefficient used for computing running averages of gradient and its square"], - "beta2": ["beta2", "float", 0.999, {}, "coefficient used for computing running averages of gradient and its square"], - "eps": ["Epsilon", "float", 1e-8, {}, "term added to the denominator to improve numerical stability"], - "weight_decay": ["Weight decay (L2)", "float", 0.01, {}, "weight decay coefficient"], + "beta1": ["beta1", "float", 0.9, {"min": 0}, "coefficient used for computing running averages of gradient and its square"], + "beta2": ["beta2", "float", 0.999, {"min": 0}, "coefficient used for computing running averages of gradient and its square"], + "eps": ["Epsilon", "float", 1e-8, {"min": 0}, "term added to the denominator to improve numerical stability"], + "weight_decay": ["Weight decay (L2)", "float", 0.01, {"min": 0}, "weight decay coefficient"], "amsgrad": ["AMSGrad", "bool", false, {}, "whether to use the AMSGrad"], "maximize": ["maximize", "bool", false, {}, "maximize the params based on the objective, instead of minimizing"], "_technical_parameter": @@ -44,9 +44,9 @@ "SparseAdam": { "lr": ["learn rate", "float", 0.001, {"min": 0.0001, "step": 0.001}, "learning rate "], - "beta1": ["beta1", "float", 0.9, {}, "coefficients used for computing running averages of gradient and its square"], - "beta2": ["beta2", "float", 0.999, {}, "coefficients used for computing running averages of gradient and its square"], - "eps": ["Epsilon", "float", 1e-8, {}, "term added to the denominator to improve numerical stability"], + "beta1": ["beta1", "float", 0.9, {"min": 0}, "coefficients used for computing running averages of gradient and its square"], + "beta2": ["beta2", "float", 0.999, {"min": 0}, "coefficients used for computing running averages of gradient and its square"], + "eps": ["Epsilon", "float", 1e-8, {"min": 0}, "term added to the denominator to improve numerical stability"], "_technical_parameter": { "parameters_grouping": [[ @@ -56,10 +56,10 @@ "Adamax": { "lr": ["learn rate", "float", 0.002, {"min": 0.0001, "step": 0.001}, "learning rate"], - "beta1": ["beta1", "float", 0.9, {}, "coefficient used for computing running averages of gradient and its square"], - "beta2": ["beta2", "float", 0.999, {}, "coefficient used for computing running averages of gradient and its square"], - "eps": ["Epsilon", "float", 1e-8, {}, "term added to the denominator to improve numerical stability"], - "weight_decay": ["Weight decay (L2)", "float", 0.01, {}, "weight decay (L2 penalty)"], + "beta1": ["beta1", "float", 0.9, {"min": 0}, "coefficient used for computing running averages of gradient and its square"], + "beta2": ["beta2", "float", 0.999, {"min": 0}, "coefficient used for computing running averages of gradient and its square"], + "eps": ["Epsilon", "float", 1e-8, {"min": 0}, "term added to the denominator to improve numerical stability"], + "weight_decay": ["Weight decay (L2)", "float", 0.01, {"min": 0}, "weight decay (L2 penalty)"], "maximize": ["maximize", "bool", false, {}, "maximize the params based on the objective, instead of minimizing"], "_technical_parameter": { @@ -73,7 +73,7 @@ "lambd": ["lambd", "float", 0.0001, {}, "decay term"], "alpha": ["alpha", "float", 0.75, {}, "power for eta update"], "t0": ["t0", "float", 1000000.0, {}, "point at which to start averaging"], - "weight_decay": ["Weight decay (L2)", "float", 0, {}, "weight decay (L2 penalty)"] + "weight_decay": ["Weight decay (L2)", "float", 0, {"min": 0}, "weight decay (L2 penalty)"] }, "LBFGS": { @@ -85,11 +85,11 @@ "NAdam": { "lr": ["learn rate", "float", 0.002, {"min": 0.0001, "step": 0.001}, "learning rate"], - "beta1": ["beta1", "float", 0.9, {}, "coefficient used for computing running averages of gradient and its square"], - "beta2": ["beta2", "float", 0.999, {}, "coefficient used for computing running averages of gradient and its square"], - "eps": ["Epsilon", "float", 1e-8, {}, "term added to the denominator to improve numerical stability"], - "weight_decay": ["Weight decay (L2)", "float", 0, {}, "weight decay (L2 penalty)"], - "momentum_decay": ["Momentum decay", "float", 0.004, {}, "momentum momentum_decay"], + "beta1": ["beta1", "float", 0.9, {"min": 0}, "coefficient used for computing running averages of gradient and its square"], + "beta2": ["beta2", "float", 0.999, {"min": 0}, "coefficient used for computing running averages of gradient and its square"], + "eps": ["Epsilon", "float", 1e-8, {"min": 0}, "term added to the denominator to improve numerical stability"], + "weight_decay": ["Weight decay (L2)", "float", 0, {"min": 0}, "weight decay (L2 penalty)"], + "momentum_decay": ["Momentum decay", "float", 0.004, {"min": 0}, "momentum momentum_decay"], "_technical_parameter": { "parameters_grouping": [[ @@ -99,10 +99,10 @@ "RAdam": { "lr": ["learn rate", "float", 0.001, {"min": 0.0001, "step": 0.001}, "learning rate"], - "beta1": ["beta1", "float", 0.9, {}, "coefficient used for computing running averages of gradient and its square"], - "beta2": ["beta2", "float", 0.999, {}, "coefficient used for computing running averages of gradient and its square"], - "eps": ["Epsilon", "float", 1e-8, {}, "term added to the denominator to improve numerical stability"], - "weight_decay": ["Weight decay (L2)", "float", 0, {}, "weight decay (L2 penalty)"], + "beta1": ["beta1", "float", 0.9, {"min": 0}, "coefficient used for computing running averages of gradient and its square"], + "beta2": ["beta2", "float", 0.999, {"min": 0}, "coefficient used for computing running averages of gradient and its square"], + "eps": ["Epsilon", "float", 1e-8, {"min": 0}, "term added to the denominator to improve numerical stability"], + "weight_decay": ["Weight decay (L2)", "float", 0, {"min": 0}, "weight decay (L2 penalty)"], "_technical_parameter": { "parameters_grouping": [[ @@ -112,9 +112,9 @@ "RMSprop": { "lr": ["learn rate", "float", 0.01, {"min": 0.0001, "step": 0.01}, "learning rate"], - "alpha": ["alpha", "float", 0.99, {}, "smoothing constant"], - "eps": ["Epsilon", "float", 1e-8, {}, "term added to the denominator to improve numerical stability"], - "weight_decay": ["Weight decay (L2)", "float", 0, {}, "weight decay (L2 penalty)"], + "alpha": ["alpha", "float", 0.99, {"min": 0}, "smoothing constant"], + "eps": ["Epsilon", "float", 1e-8, {"min": 0}, "term added to the denominator to improve numerical stability"], + "weight_decay": ["Weight decay (L2)", "float", 0, {"min": 0}, "weight decay (L2 penalty)"], "momentum": ["momentum", "float", 0, {}, "momentum factor"], "centered": ["centered", "bool", false, {}, "if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance"] }, @@ -136,8 +136,8 @@ "SGD": { "lr": ["learn rate", "float", 0.001, {"min": 0.0001, "step": 0.001}, "learning rate"], - "weight_decay": ["Weight decay (L2)", "float", 0, {}, "weight decay (L2 penalty)"], - "momentum": ["momentum", "float", 0, {}, "momentum factor"], + "weight_decay": ["Weight decay (L2)", "float", 0, {"min": 0}, "weight decay (L2 penalty)"], + "momentum": ["momentum", "float", 0, {"min": 0}, "momentum factor"], "dampening": ["dampening", "float", 0, {}, "dampening for momentum"], "nesterov": ["nesterov", "bool", false, {}, "enables Nesterov momentum"], "maximize": ["maximize", "bool", false, {}, "maximize the params based on the objective, instead of minimizing"] diff --git a/requirements2.txt b/requirements2.txt index 17249f5..bdc7583 100644 --- a/requirements2.txt +++ b/requirements2.txt @@ -1,7 +1,7 @@ ## These are reqs needed for documentation numpy==1.26.3 -multiprocess +multiprocess==0.70.16 pydantic tqdm pyparsing @@ -23,4 +23,4 @@ opt-einsum pandas==2.2.0 pylint -torch_geometric==2.3.1 \ No newline at end of file +torch_geometric==2.3.1 diff --git a/src/aux/configs.py b/src/aux/configs.py index 303affa..cf1a2d6 100644 --- a/src/aux/configs.py +++ b/src/aux/configs.py @@ -199,7 +199,7 @@ def set_defaults_config_pattern_info( def to_json( self - ): + ) -> dict: """ Special method which allows to use json.dumps() on Config object """ return self.to_dict() diff --git a/src/base/custom_datasets.py b/src/base/custom_datasets.py index ebe3e3a..0afe0f0 100644 --- a/src/base/custom_datasets.py +++ b/src/base/custom_datasets.py @@ -36,41 +36,41 @@ def __init__( @property def node_attributes_dir( self - ): + ) -> Path: """ Path to dir with node attributes. """ return self.root_dir / 'raw' / (self.name + '.node_attributes') @property def edge_attributes_dir( self - ): + ) -> Path: """ Path to dir with edge attributes. """ return self.root_dir / 'raw' / (self.name + '.edge_attributes') @property def labels_dir( self - ): + ) -> Path: """ Path to dir with labels. """ return self.root_dir / 'raw' / (self.name + '.labels') @property def edges_path( self - ): + ) -> Path: """ Path to file with edge list. """ return self.root_dir / 'raw' / (self.name + '.ij') @property def edge_index_path( self - ): + ) -> Path: """ Path to dir with labels. """ return self.root_dir / 'raw' / (self.name + '.edge_index') def check_validity( self - ): + ) -> None: """ Check that dataset files (graph and attributes) are valid and consistent with .info. """ # Assuming info is OK @@ -159,6 +159,7 @@ def build( return self.dataset_var_data = None + self.stats.update_var_config() self.dataset_var_config = dataset_var_config self.dataset = LocalDataset(self.results_dir, process_func=self._create_ptg) @@ -178,6 +179,7 @@ def _compute_stat( with open(self.node_attributes_dir / a, 'r') as f: attr_node_attrs[a] = json.load(f) + # FIXME misha - for single graph [0] edges = self.edge_index node_map = (lambda i: str(self.node_map[i])) if self.node_map else lambda i: str(i) @@ -223,8 +225,8 @@ def _compute_stat( pearson_corr[i][j] = min(1, max(-1, pc)) return {'attributes': attrs, 'correlations': pearson_corr.tolist()} - else: - return super()._compute_stat(stat) + + raise NotImplementedError() def _compute_dataset_data( self @@ -420,7 +422,7 @@ def _iter_nodes( def _labeling_tensor( self, - g_ix=None + g_ix: int = None ) -> list: """ Returns list of labels (not tensors) """ y = [] @@ -445,7 +447,7 @@ def _labeling_tensor( def _feature_tensor( self, - g_ix=None + g_ix: int = None ) -> list: """ Returns list of features (not tensors) for graph g_ix. """ diff --git a/src/base/dataset_stats.py b/src/base/dataset_stats.py new file mode 100644 index 0000000..801b7f5 --- /dev/null +++ b/src/base/dataset_stats.py @@ -0,0 +1,292 @@ +import os +import json +from collections import Counter +from pathlib import Path +from typing import Union + +import networkx as nx +from networkx import NetworkXError, NetworkXNotImplemented +from torch_geometric.data import Dataset + +from base.datasets_processing import GeneralDataset + + +class DatasetStats: + stats = [ + 'num_nodes', + 'num_edges', + 'avg_degree', + 'degree_distr', + 'degree_assort', + + 'clustering_coeff', + 'num_triangles', + + 'gcc_size', + 'gcc_rel_size', + 'num_cc', + 'cc_distr', + 'gcc_diam', + # 'gcc_diam90', + + 'attr_corr', + ] + var_stats = [ + 'label_distr', + 'label_assort', + 'feature_distr', + 'feature_assort', + ] + all_stats = stats + var_stats + + def __init__( + self, + dataset: GeneralDataset + ): + self.gen_dataset: GeneralDataset = dataset + + self.stats = {} # keep the computed stats + self.nx_graph = None # converted to networkx version + + @property + def dataset( + self + ) -> Dataset: + return self.gen_dataset.dataset # PTG Dataset + + @property + def is_directed( + self + ) -> bool: + return self.gen_dataset.info.directed + + @property + def is_multi( + self + ) -> bool: + return self.gen_dataset.is_multi() + + def _save_path( + self, + stat: str + ) -> Path: + """ Save directory for statistics of variable part. """ + directory = None + if stat in DatasetStats.stats: + directory = self.gen_dataset.root_dir / '.stats' + elif stat in DatasetStats.var_stats: + # We suppose here that dataset_var_config is defined for our gen dataset. + directory = self.gen_dataset.results_dir / '.stats' + else: + raise NotImplementedError + directory.mkdir(exist_ok=True) + return directory / stat + + def get( + self, + stat: str + ) -> Union[int, float, dict, str]: + """ Get the specified statistics. + It will be read from file or computed and saved. + """ + assert stat in DatasetStats.all_stats + if stat in self.stats: + return self.stats[stat] + + # Try to read from file + path = self._save_path(stat) + if path.exists(): + with path.open('r') as f: + value = json.load(f) + self.stats[stat] = value + return value + + # Compute + method = { + False: self._compute, + True: self._compute_multi, + }[self.is_multi] + method(stat) + value = self.stats[stat] + if value is None: + value = f"Statistics '{stat}' is not implemented." + + return value + + def set( + self, + stat: str, + value: Union[int, float, dict, str] + ) -> None: + """ Set statistics to a specified value and save to file. + """ + assert stat in DatasetStats.all_stats + self.stats[stat] = value + path = self._save_path(stat) + with path.open('w') as f: + json.dump(value, f, ensure_ascii=False, indent=1) + + def remove( + self, + stat: str + ) -> None: + """ Remove statistics from dict and file. + """ + if stat in self.stats: + del self.stats[stat] + try: + os.remove(self._save_path(stat)) + except FileNotFoundError: pass + + def clear_all_stats( + self + ) -> None: + """ Remove all stats. E.g. the graph has changed. + """ + for s in DatasetStats.all_stats: + self.remove(s) + + def update_var_config( + self + ) -> None: + """ Remove var stats from dict since dataset config has changed. + """ + for s in DatasetStats.var_stats: + if s in self.stats: + del self.stats[s] + + def _compute( + self, + stat: str + ) -> None: + """ Compute statistics for a single graph. + Result could be: a number, a string, a distribution, a dict of ones. + """ + # assert self.info.count == 1 + # data: Data = self.dataset.get(0) + edges = self.gen_dataset.dataset_data["edges"][0] + num_nodes = self.gen_dataset.info.nodes[0] + + # Simple stats + if stat in ["num_edges", "avg_degree"]: + num_edges = len(edges) + avg_deg = len(edges) / num_nodes * (1 if self.is_directed else 2) + self.set("num_edges", num_edges) + self.set("avg_degree", avg_deg) + return + + # Var stats + if stat == "label_distr": + labels = self.gen_dataset.dataset_var_data["labels"] + self.set("label_distr", list_to_hist(labels)) + return + + # More complex stats - we use networkx + if self.nx_graph is None: + # Converting to networkx + self.nx_graph = nx.DiGraph() if self.is_directed else nx.Graph() + for i, j in edges: + self.nx_graph.add_edge(i, j) + + try: + # TODO misha simplify - some stats can be computed easier + + if stat == "clustering_coeff": + # NOTE this is average local clustering, not global + self.set("clustering_coeff", nx.average_clustering(self.nx_graph)) + + elif stat == "num_triangles": + self.set("num_triangles", int(sum(nx.triangles(self.nx_graph).values()) / 3)) + + elif stat in ['gcc_size', 'gcc_rel_size', 'num_cc', 'cc_distr', 'gcc_diam']: + if self.is_directed: + wcc = list(nx.weakly_connected_components(self.nx_graph)) + scc = list(nx.strongly_connected_components(self.nx_graph)) + self.set("num_cc", {"WCC": len(wcc), "SCC": len(scc)}) + self.set("gcc_size", {"WCC": len(wcc[0]), "SCC": len(scc[0])}) + self.set("gcc_rel_size", {"WCC": len(wcc[0]) / num_nodes, + "SCC": len(scc[0]) / num_nodes}) + self.set("cc_distr", {"WCC": list_to_hist([len(c) for c in wcc]), + "SCC": list_to_hist([len(c) for c in scc])}) + self.set("gcc_diam", nx.diameter(self.nx_graph.subgraph(scc[0]))) + # self.set("gcc_diam", {"WCC": nx.diameter(self.nx_graph.to_undirected().subgraph(wcc[0])), + # "SCC": nx.diameter(self.nx_graph.subgraph(scc[0]))}) + else: + cc = list(nx.connected_components(self.nx_graph)) + self.set("num_cc", len(cc)) + self.set("gcc_size", len(cc[0])) + self.set("gcc_rel_size", len(cc[0]) / num_nodes) + self.set("cc_distr", list_to_hist([len(c) for c in cc])) + self.set("gcc_diam", nx.diameter(self.nx_graph.subgraph(cc[0]))) + + elif stat == "degree_assort": + if self.is_directed: + degree_assort = { + "in-in": nx.degree_assortativity_coefficient(self.nx_graph, "in", "in"), + "in-out": nx.degree_assortativity_coefficient(self.nx_graph, "in", "out"), + "out-in": nx.degree_assortativity_coefficient(self.nx_graph, "out", "in"), + "out-out": nx.degree_assortativity_coefficient(self.nx_graph, "out", "out"), + } + self.set("degree_assort", degree_assort) + else: + self.set("degree_assort", nx.degree_assortativity_coefficient(self.nx_graph)) + + elif stat == "degree_distr": + if self.is_directed: + self.set("degree_distr", { + "in": list_to_hist([d for _, d in self.nx_graph.in_degree()]), + "out": list_to_hist([d for _, d in self.nx_graph.out_degree()]) + }) + else: + self.set("degree_distr", {i: d for i, d in enumerate(nx.degree_histogram(self.nx_graph))}) + + elif stat == "label_assort": + labels = self.gen_dataset.dataset_var_data["labels"] + nx.set_node_attributes(self.nx_graph, dict(list(enumerate(labels))), 'label') + self.set("label_assort", nx.attribute_assortativity_coefficient(self.nx_graph, 'label')) + return + + else: + value = self.gen_dataset._compute_stat(stat) + self.set(stat, value) + except (NetworkXError, NetworkXNotImplemented) as e: + self.set(stat, str(e)) + + def _compute_multi( + self, + stat: str + ) -> None: + """ Compute statistics for a multiple-graphs dataset. + Result could be: a number, a string, a distribution, a dict of ones. + """ + edges = self.gen_dataset.dataset_data["edges"] + + # Var stats + if stat == "label_distr": + labels = self.gen_dataset.dataset_var_data["labels"] + self.set("label_distr", list_to_hist([x for xs in labels for x in xs])) + return + + # Simple stats + if stat in ["num_nodes", "num_edges", "avg_degree"]: + num_nodes = list(self.gen_dataset.info.nodes) + num_edges = [len(e) for e in edges] + avg_degree = [e / n * (1 if self.is_directed else 2) for n, e in zip(num_nodes, num_edges)] + + self.set("num_nodes", list_to_hist(num_nodes)) + self.set("num_edges", list_to_hist(num_edges)) + self.set("avg_degree", list_to_hist(avg_degree)) + return + + else: + value = 'Unknown stats' + # except (NetworkXError, NetworkXNotImplemented) as e: + # value = str(e) + + +def list_to_hist( + a_list: list +) -> dict: + """ Convert a list of integers/floats to a frequency histogram, return it as a dict + """ + return {k: v for k, v in Counter(a_list).most_common()} diff --git a/src/base/datasets_processing.py b/src/base/datasets_processing.py index 3e9819f..8aa12c8 100644 --- a/src/base/datasets_processing.py +++ b/src/base/datasets_processing.py @@ -2,12 +2,12 @@ import shutil import os from pathlib import Path -from typing import Union, Type +from typing import Union import torch import torch_geometric from torch import default_generator, randperm -from torch_geometric.data import Dataset, InMemoryDataset +from torch_geometric.data import Dataset, InMemoryDataset, Data from aux.configs import DatasetConfig, DatasetVarConfig, ConfigPattern from aux.custom_decorators import timing_decorator @@ -123,7 +123,7 @@ def save( @staticmethod def induce( dataset: Dataset - ): + ) -> object: """ Induce metainfo from a given PTG dataset. """ res = DatasetInfo() @@ -144,7 +144,7 @@ def induce( @staticmethod def read( path: Union[str, Path] - ): + ) -> object: """ Read info from a file. """ with path.open('r') as f: a_dict = json.load(f) @@ -326,8 +326,8 @@ def __init__( self.dataset_var_data = None # features data prepared for frontend self.name = self.dataset_config.graph # Last folder name - self.stats_dir.mkdir(exist_ok=True, parents=True) - self.stats = {} # dict of {stat -> value} + from base.dataset_stats import DatasetStats + self.stats = DatasetStats(self) # dict of {stat -> value} self.info: DatasetInfo = None self.dataset: Dataset = None # PTG dataset @@ -346,7 +346,7 @@ def __init__( @property def root_dir( self - ): + ) -> Path: """ Dataset root directory with folders 'raw' and 'prepared'. """ # FIXME Misha, dataset_prepared_dir return path and files_paths not only path return Declare.dataset_root_dir(self.dataset_config)[0] @@ -354,7 +354,7 @@ def root_dir( @property def results_dir( self - ): + ) -> Path: """ Path to 'prepared/../' folder where tensor data is stored. """ # FIXME Misha, dataset_prepared_dir return path and files_paths not only path return Path(Declare.dataset_prepared_dir(self.dataset_config, self.dataset_var_config)[0]) @@ -362,53 +362,46 @@ def results_dir( @property def raw_dir( self - ): + ) -> Path: """ Path to 'raw/' folder where raw data is stored. """ return self.root_dir / 'raw' @property def api_path( self - ): + ) -> Path: """ Path to '.api' file. Could be not present. """ return self.root_dir / '.api' @property def info_path( self - ): + ) -> Path: """ Path to '.info' file. """ return self.root_dir / 'raw' / '.info' - @property - def stats_dir( - self - ): - """ Path to '.stats' directory. """ - return self.root_dir / '.stats' - @property def data( self - ): + ) -> Data: return self.dataset._data @property def num_classes( self - ): + ) -> int: return self.dataset.num_classes @property def num_node_features( self - ): + ) -> int: return self.dataset.num_node_features @property def labels( self - ): + ) -> torch.Tensor: if self._labels is None: # NOTE: this is a copy from torch_geometric.data.dataset v=2.3.1 from torch_geometric.data.dataset import _get_flattened_data_list @@ -435,7 +428,7 @@ def is_multi( def build( self, dataset_var_config: Union[ConfigPattern, DatasetVarConfig] - ): + ) -> None: """ Create node feature tensors from attributes based on dataset_var_config. """ raise NotImplementedError() @@ -480,28 +473,30 @@ def _compute_dataset_data( num = len(self.dataset) data_list = [self.dataset.get(ix) for ix in range(num)] is_directed = self.info.directed - # # FIXME node_attributes must be attributes, features only for ptg dataset! - name_type = self.dataset_var_config.features['attr'] - - if self.is_multi(): - edges_list = [] - self.nodes = [] - for data in data_list: - edges_list.append(data.edge_index.T.tolist()) - self.nodes.append(len(data.x)) - - node_attributes = { - list(name_type.keys())[0]: [data.x.tolist() for data in data_list] - } - - else: - assert len(data_list) == 1 - data = data_list[0] - - self.nodes = [len(data.x)] - node_attributes = { - list(name_type.keys())[0]: [data.x.tolist()] - } + # node_attributes exist for custom datasets. + # We can treat them as PTG features but it's not good. + node_attributes = {} + # name_type = self.dataset_var_config.features['attr'] + # + # if self.is_multi(): + # edges_list = [] + # self.nodes = [] + # for data in data_list: + # edges_list.append(data.edge_index.T.tolist()) + # self.nodes.append(len(data.x)) + # + # node_attributes = { + # list(name_type.keys())[0]: [data.x.tolist() for data in data_list] + # } + # + # else: + # assert len(data_list) == 1 + # data = data_list[0] + # + # self.nodes = [len(data.x)] + # node_attributes = { + # list(name_type.keys())[0]: [data.x.tolist()] + # } edges_list = [] for data in data_list: @@ -553,6 +548,7 @@ def get_dataset_var_data( visible_part = self.visible_part if part is None else VisiblePart(self, **part) for ix in visible_part.ixes(): + # FIXME replace with getting data from tensors instead of keeping the whole data features[ix] = self.dataset_var_data['features'][ix] labels[ix] = self.dataset_var_data['labels'][ix] @@ -592,132 +588,20 @@ def _compute_dataset_var_data( def get_stat( self, - stat - ): + stat: str + ) -> Union[int, float, dict, str]: """ Get statistics. """ - if stat in self.stats: - return self.stats[stat] - - # Try to read from file - path = self.stats_dir / stat - if path.exists(): - with path.open('r') as f: - value = json.load(f) - self.stats[stat] = value - return value - - # Compute - value = self._compute_stat(stat) - if value is None: - value = f"Statistics '{stat}' is not implemented." - - # Save - self.stats[stat] = value - path = self.stats_dir / stat - with path.open('w') as f: - json.dump(value, f, ensure_ascii=False) - return value + return self.stats.get(stat) def _compute_stat( self, - stat - ): - """ Compute statistics. """ - if self.is_multi(): - # try: - if stat == 'num_nodes_distr': - value = {} - for i in self.info.nodes: - if i in value.keys(): - value[i] += 1 - else: - value[i] = 1 - return value - - elif stat == 'avg_degree_distr': - # TODO check for (un)directed - m = self.dataset_data['edges'] - # FIXME misha can't use dataset_data when partial data is sent to front - coeff = 1 if self.info.directed else 2 - avg = [coeff * len(m[i]) / self.info.nodes[i] for i in range(self.info.count)] - value = {} - for i in avg: - if i in value.keys(): - value[i] += 1 - else: - value[i] = 1 - return value - - elif stat == "num_edges": - import numpy as np - m = self.dataset_data['edges'] - # FIXME misha can't use dataset_data when partial data is sent to front - coeff = 1 if self.info.directed else 2 - es = [coeff * len(m[i]) for i in range(self.info.count)] - value = f"{np.min(es)} — {np.max(es)}" - # value = f"{np.mean(es)} ± {np.var(es)**0.5}" - # value = np.mean(es) - - elif stat == "avg_deg": - import numpy as np - m = self.dataset_data['edges'] - # FIXME misha can't use dataset_data when partial data is sent to front - coeff = 1 if self.info.directed else 2 - value = np.mean([coeff * len(m[i]) for i in range(self.info.count)]) - - else: - value = 'Unknown stats' - # except (NetworkXError, NetworkXNotImplemented) as e: - # value = str(e) - - else: - assert self.info.count == 1 - import networkx as nx - from networkx import NetworkXError, NetworkXNotImplemented - # Converting to networkx - g = nx.DiGraph() if self.info.directed else nx.Graph() - for i, j in self.dataset_data["edges"][0]: - g.add_edge(i, j) - try: - # TODO misha simplify - some stats can be computed easier - if stat == "num_edges": - value = g.number_of_edges() - - elif stat == "avg_deg": - value = g.number_of_edges() / g.number_of_nodes() - if not self.info.directed: - value = 2 * value - - elif stat == "CC": - # NOTE this is average local clustering, not global - value = nx.average_clustering(g) - - elif stat == "triangles": - value = int(sum(nx.triangles(g).values()) / 3) - - elif stat == "diameter": - value = nx.diameter(g) - - elif stat == "degree_assortativity": - value = nx.degree_assortativity_coefficient(g) - - elif stat == "cc": - cc = nx.connected_components(g) - value = len(list(cc)) - - elif stat == "lcc": - cc = nx.connected_components(g) - value = max(len(c) for c in cc) - - elif stat == "DD": - value = {i: d for i, d in enumerate(nx.degree_histogram(g))} - - else: - value = None - except (NetworkXError, NetworkXNotImplemented) as e: - value = str(e) - return value + stat: str + ) -> None: + """ Compute a non-standard statistics. + """ + # Should be defined in a subclass + raise NotImplementedError() def is_one_hot_able( self @@ -852,7 +736,7 @@ def get_by_config( gen_dataset = CustomDataset(dataset_config) elif dataset_group in ["vk_samples"]: - # FIXME misha - it is a kind of custom? + # TODO misha - it is a kind of custom? from base.vk_datasets import VKDataset gen_dataset = VKDataset(dataset_config) @@ -1058,7 +942,7 @@ def merge_directories( source_dir: Union[Path, str], destination_dir: Union[Path, str], remove_source: bool = False -): +) -> None: """ Merge source directory into destination directory, replacing existing files. diff --git a/src/base/ptg_datasets.py b/src/base/ptg_datasets.py index f42f058..1cc7ff7 100644 --- a/src/base/ptg_datasets.py +++ b/src/base/ptg_datasets.py @@ -214,12 +214,12 @@ def __init__( @property def processed_file_names( self - ): + ) -> str: return 'data.pt' def process( self - ): + ) -> None: raise RuntimeError("Dataset is supposed to be processed and saved earlier.") # torch.save(self.collate(self.data_list), self.processed_paths[0]) diff --git a/src/base/vk_datasets.py b/src/base/vk_datasets.py index 936ae11..c510a01 100644 --- a/src/base/vk_datasets.py +++ b/src/base/vk_datasets.py @@ -1,7 +1,6 @@ import ast import datetime import json -import os from bisect import bisect_right from numbers import Number from operator import itemgetter @@ -10,7 +9,6 @@ import numpy as np -from aux.data_info import DATASET_KEYS from aux.utils import GRAPHS_DIR from base.custom_datasets import CustomDataset from aux.configs import DatasetConfig @@ -25,7 +23,7 @@ class AttrInfo: @staticmethod def vk_attr( - ): + ) -> dict: vk_dict = { ('age',): list(range(0, len(AGE_GROUPS) + 1)), ('sex',): [1, 2], @@ -177,7 +175,7 @@ def __init__( def _compute_dataset_data( self - ): + ) -> None: """ Get DatasetData for VK graph """ super()._compute_dataset_data() diff --git a/src/explainers/explainers_manager.py b/src/explainers/explainers_manager.py index 8909d8e..cfe1034 100644 --- a/src/explainers/explainers_manager.py +++ b/src/explainers/explainers_manager.py @@ -187,7 +187,8 @@ def conduct_experiment( # TODO what if save_explanation_flag=False? if self.save_explanation_flag: self.save_explanation(run_config) - self.model_manager.save_model_executor() + path = self.model_manager.save_model_executor() + self.gen_dataset.save_train_test_mask(path) except Exception as e: if socket: socket.send("er", {"status": "FAILED"}) diff --git a/src/models_builder/gnn_constructor.py b/src/models_builder/gnn_constructor.py index bf8b325..6d86e81 100644 --- a/src/models_builder/gnn_constructor.py +++ b/src/models_builder/gnn_constructor.py @@ -133,7 +133,8 @@ def get_hash( return gnn_name_hash def get_full_info( - self + self, + tensor_size_limit: int=None ) -> dict: """ Get available info about model for frontend """ @@ -144,7 +145,7 @@ def get_full_info( except (AttributeError, NotImplementedError): pass try: - result["weights"] = self.get_weights() + result["weights"] = self.get_weights(tensor_size_limit=tensor_size_limit) except (AttributeError, NotImplementedError): pass try: @@ -201,7 +202,8 @@ def get_neurons( return neurons def get_weights( - self + self, + tensor_size_limit: str = None ): """ Get model weights calling torch.nn.Module.state_dict() to draw them on the frontend. @@ -222,8 +224,15 @@ def get_weights( k = sub_keys[-1] if type(value) == UninitializedParameter: - continue - part[k] = value.numpy().tolist() + part[k] = '?' + else: + size = 1 + for dim in value.shape: + size *= dim + if tensor_size_limit and size > tensor_size_limit: # Tensor is too big - return just its shape + part[k] = 'x'.join(str(d) for d in value.shape) + else: + part[k] = value.numpy().tolist() return model_data diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index 9674224..8c79fb1 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -120,8 +120,8 @@ def __init__( modification: ModelModificationConfig = None ): """ - :param manager_config: socket to use for sending data to frontend - :param modification: socket to use for sending data to frontend + :param manager_config: ? + :param modification: ? """ if manager_config is None: # raise RuntimeError("model manager config must be specified") @@ -184,9 +184,7 @@ def __init__( self.mi_defense_flag = False self.gnn = None - # We do not want store socket because it is not picklable for a subprocess - self.socket = None - self.stop_signal = False + self.socket = None # Websocket for sending info to frontend, we avoid to store it since it is not pickleable self.stats_data = None # Stores some stats to be sent to frontend self.set_poison_defender() @@ -851,7 +849,6 @@ def __init__(self, gnn: Type = None, # Add fields from additional config self.manager_config = self.manager_config.merge(self.additional_config) - self.stop_signal = False # TODO misha do we need it? self.gnn = gnn if self.modification.epochs is None: diff --git a/web_interface/back_front/block.py b/web_interface/back_front/block.py index 143f45f..aeba9dc 100644 --- a/web_interface/back_front/block.py +++ b/web_interface/back_front/block.py @@ -4,6 +4,52 @@ from web_interface.back_front.utils import SocketConnect +class BlockConfig(dict): + def __init__( + self + ): + super().__init__() + + def init( + self, + *args + ): + pass + + def modify( + self, + **kwargs + ): + for key, value in kwargs.items(): + self[key] = value + + def finalize( + self + ): + """ Check correctness """ + # TODO check correctness + return True + + def toDefaults( + self + ): + """ Set default values """ + self.clear() + + def breik( + self, + arg: str = None + ): + if arg is None: + return + if arg == "full": + self.clear() + elif arg == "default": + self.toDefaults() + else: + raise ValueError(f"Unknown argument for breik(): {arg}") + + class Block: """ A logical block of a dependency diagram. @@ -25,15 +71,16 @@ def __init__(self, name, socket: SocketConnect = None): self._object = None # Result of backend request, will be passed to dependent blocks self._result = None # Info to be send to frontend at submit - # Whether block is defined - def is_set(self): + def is_set( + self + ) -> bool: + """ Whether block is defined """ return self._is_set - # Get the config - def get_config(self): - return self._config.copy() - - def init(self, *args): + def init( + self, + *args + ) -> None: """ Create the default version of config. @@ -49,14 +96,21 @@ def init(self, *args): init_params = self._init(*args) self._send('onInit', init_params) - def _init(self, *args): + def _init( + self, + *args + ) -> None: """ Returns jsonable info to be sent to front with onInit() """ # To be overridden in subclass raise NotImplementedError - # Change some values of the config - def modify(self, **key_values): + def modify( + self, + **key_values + ) -> None: + """ Change some values of the config + """ if self._is_set: raise RuntimeError(f'Block[{self.name}] is set and cannot be modified!') else: @@ -64,8 +118,10 @@ def modify(self, **key_values): self._config.modify(**key_values) self._send('onModify') - # Check config correctness and make block to be defined - def finalize(self): + def finalize( + self + ) -> None: + """ Check config correctness and make block to be defined """ if self._is_set: print(f'Block[{self.name}] already set') return @@ -77,14 +133,19 @@ def finalize(self): else: raise RuntimeError(f'Block[{self.name}] failed to finalize') - def _finalize(self): - """ Returns True or False + def _finalize( + self + ) -> None: + """ Checks whether the config is correct to create the object. + Returns True if OK or False. # TODO can we send to front errors to be fixed? """ raise NotImplementedError - # Run diagram with this block value - def submit(self): + def submit( + self + ) -> None: + """ Run diagram with this block value """ self.finalize() if not self._is_set: @@ -97,17 +158,24 @@ def submit(self): if self.diagram: self.diagram.on_submit(self) - # Perform back request, ect - def _submit(self): + def _submit( + self + ) -> None: + """ Perform back request, ect """ # To be overridden in subclass raise NotImplementedError - def get_object(self): + def get_object( + self + ) -> object: """ Get contained backend object """ return self._object - def unlock(self, toDefault=False): + def unlock( + self, + toDefault: bool = False + ) -> None: """ Make block to be undefined """ if self._is_set: @@ -124,7 +192,10 @@ def unlock(self, toDefault=False): if self.diagram: self.diagram.on_drop(self) - def breik(self, arg=None): + def breik( + self, + arg: str = None + ) -> None: """ Break block logically """ print(f'Block[{self.name}].break()') @@ -132,7 +203,11 @@ def breik(self, arg=None): self._config.breik(arg) self._send('onBreak') - def _send(self, func, kw_params=None): + def _send( + self, + func: str, + kw_params: dict = None + ) -> None: """ Send signal to frontend listeners. """ kw_params_str = str(kw_params) if len(kw_params_str) > 30: @@ -142,39 +217,13 @@ def _send(self, func, kw_params=None): self.socket.send(block=self.name, func=func, msg=kw_params, tag=self.tag) -class BlockConfig(dict): - def __init__(self): - super().__init__() - - def init(self, *args): - pass - - def modify(self, **kwargs): - for key, value in kwargs.items(): - self[key] = value - - # Check correctness - def finalize(self): - # TODO check correctness - return True - - # Set default values - def toDefaults(self): - self.clear() - - def breik(self, arg=None): - if arg is None: - return - if arg == "full": - self.clear() - elif arg == "default": - self.toDefaults() - else: - raise ValueError(f"Unknown argument for breik(): {arg}") - - class WrapperBlock(Block): - def __init__(self, blocks, *args, **kwargs): + def __init__( + self, + blocks: [Block], + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.blocks = blocks # Patch submit and unlock functions @@ -193,22 +242,35 @@ def new_submit(slf): old_unlocks[b] = copy_func(b.unlock) - def new_unlock(slf, *args, **kwargs): + def new_unlock( + slf, + *args, + **kwargs + ): old_unlocks[slf](slf, *args, **kwargs) self.unlock() b.unlock = types.MethodType(new_unlock, b) - def init(self, *args): + def init( + self, + *args + ) -> None: super().init(*args) for b in self.blocks: b.init(*args) - def breik(self, arg=None): + def breik( + self, + arg: str = None + ) -> None: for b in self.blocks: b.breik(arg) super().breik(arg) - def onsubmit(self, block): + def onsubmit( + self, + block + ) -> None: # # Break all but the given # for b in self.blocks: # if b != block: @@ -222,20 +284,29 @@ def onsubmit(self, block): if self.diagram: self.diagram.on_submit(self) - def modify(self, **key_values): + def modify( + self, + **key_values + ) -> None: # Must not be called raise RuntimeError - def _finalize(self): + def _finalize( + self + ) -> None: # Must not be called raise RuntimeError - def _submit(self): + def _submit( + self + ) -> None: # Must not be called raise RuntimeError -def copy_func(f): +def copy_func( + f +): """Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)""" g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, @@ -244,29 +315,3 @@ def copy_func(f): g = functools.update_wrapper(g, f) g.__kwdefaults__ = f.__kwdefaults__ return g - - -# if __name__ == '__main__': -# class A: -# def __init__(self, x): -# self.x = x -# -# def f(self): -# print(self.x) -# -# a_list = [A(1), A(2), A(3)] -# -# def pf(a): -# print('pf', a.x) -# -# # Patching -# for a in a_list: -# old_f = copy_func(a.f) -# -# def new_f(): -# pf(a) -# old_f(a) -# a.f = new_f -# -# for a in a_list: -# a.f() diff --git a/web_interface/back_front/dataset_blocks.py b/web_interface/back_front/dataset_blocks.py index 2908f47..227c960 100644 --- a/web_interface/back_front/dataset_blocks.py +++ b/web_interface/back_front/dataset_blocks.py @@ -2,35 +2,50 @@ from aux.data_info import DataInfo from aux.utils import TORCH_GEOM_GRAPHS_PATH -from base.datasets_processing import DatasetManager, GeneralDataset +from base.datasets_processing import DatasetManager, GeneralDataset, VisiblePart from aux.configs import DatasetConfig, DatasetVarConfig from web_interface.back_front.block import Block -from web_interface.back_front.utils import json_dumps +from web_interface.back_front.utils import json_dumps, get_config_keys class DatasetBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.dataset_config = None - def _init(self): + def _init( + self + ) -> None: pass - def _finalize(self): - if not (len(self._config.keys()) == 3): # TODO better check + def _finalize( + self + ) -> bool: + if set(get_config_keys("data_root")) != set(self._config.keys()): return False self.dataset_config = DatasetConfig(**self._config) return True - def _submit(self): + def _submit( + self + ) -> None: self._object = DatasetManager.get_by_config(self.dataset_config) - def get_stat(self, stat): + def get_stat( + self, + stat + ) -> object: return self._object.get_stat(stat) - def get_index(self): + def get_index( + self + ) -> str: DataInfo.refresh_data_dir_structure() index = DataInfo.data_parse() @@ -47,39 +62,59 @@ def get_index(self): return json_dumps([index.to_json(), json_dumps('')]) - def set_visible_part(self, part=None): + def set_visible_part( + self, + part: dict = None + ) -> str: self._object.set_visible_part(part=part) return '' - def get_dataset_data(self, part=None): + def get_dataset_data( + self, + part: dict = None + ) -> dict: return self._object.get_dataset_data(part=part) class DatasetVarBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.tag = 'dvc' self.gen_dataset: GeneralDataset = None # FIXME duplication!!! self.dataset_var_config = None - def _init(self, dataset: GeneralDataset): + def _init( + self, + dataset: GeneralDataset + ) -> dict: self.gen_dataset = dataset return self.gen_dataset.info.to_dict() - def _finalize(self): - if not (len(self._config.keys()) == 3): # TODO better check + def _finalize( + self + ) -> bool: + if set(get_config_keys("data_prepared")) != set(self._config.keys()): return False self.dataset_var_config = DatasetVarConfig(**self._config) return True - def _submit(self): + def _submit( + self + ) -> None: self.gen_dataset.build(self.dataset_var_config) self._object = self.gen_dataset # NOTE: we need to compute var_data to be able to get is_one_hot_able() self.gen_dataset.get_dataset_var_data() self._result = [self.dataset_var_config.labeling, self.gen_dataset.is_one_hot_able()] - def get_dataset_var_data(self, part=None): + def get_dataset_var_data( + self, + part: dict = None + ) -> dict: return self._object.get_dataset_var_data(part=part) diff --git a/web_interface/back_front/diagram.py b/web_interface/back_front/diagram.py index 8f7d0cf..76a851e 100644 --- a/web_interface/back_front/diagram.py +++ b/web_interface/back_front/diagram.py @@ -1,18 +1,28 @@ -from web_interface.back_front.block import WrapperBlock +from typing import Union + +from web_interface.back_front.block import WrapperBlock, Block class Diagram: - """Diagram of fontend states and transitions between them. + """Diagram of frontend states and transitions between them. """ - def __init__(self): - self.blocks = {} + def __init__( + self + ): + self.blocks = {} # {name -> Block} - def get(self, name): + def get( + self, + name: str + ) -> Block: """ Get block by its name """ return self.blocks[name] - def add_block(self, block): + def add_block( + self, + block: Block + ) -> None: if block.name in self.blocks: return @@ -23,7 +33,12 @@ def add_block(self, block): self.blocks[b.name] = b b.diagram = self - def add_dependency(self, _from, to, condition=all): + def add_dependency( + self, + _from: Union[list, Block], + to: Union[list, Block], + condition=all + ) -> None: # assert condition in [all, any] if not isinstance(_from, list): _from = [_from] @@ -36,7 +51,10 @@ def add_dependency(self, _from, to, condition=all): self.add_block(to) to.condition = condition - def on_submit(self, block): + def on_submit( + self, + block: Block + ) -> None: """ Init all blocks possible after the block submission """ print('Diagram.onSubmit(' + block.name + ')') @@ -46,7 +64,10 @@ def on_submit(self, block): # IMP add params names as blocks names b_after.init(*[x.get_object() for x in b_after.requires if x.is_set()]) - def on_drop(self, block): + def on_drop( + self, + block: Block + ) -> None: """ Recursively break all block that critically depend on the given one """ print('Diagram.onBreak(' + block.name + ')') @@ -54,7 +75,9 @@ def on_drop(self, block): for b in block.influences: b.breik() - def drop(self): + def drop( + self + ) -> None: """ Drop all blocks. """ # FIXME many blocks will break many times for block in self.blocks.values(): diff --git a/web_interface/back_front/explainer_blocks.py b/web_interface/back_front/explainer_blocks.py index 61b64d6..cae8e3d 100644 --- a/web_interface/back_front/explainer_blocks.py +++ b/web_interface/back_front/explainer_blocks.py @@ -1,5 +1,7 @@ import json import os +from pathlib import Path +from typing import Union from aux.configs import ExplainerInitConfig, ExplainerModificationConfig, ExplainerRunConfig, \ ConfigPattern @@ -11,26 +13,44 @@ from explainers.explainers_manager import FrameworkExplainersManager from models_builder.gnn_models import GNNModelManager from web_interface.back_front.block import Block, WrapperBlock -from web_interface.back_front.utils import json_loads +from web_interface.back_front.utils import json_loads, get_config_keys class ExplainerWBlock(WrapperBlock): - def __init__(self, name, blocks, *args, **kwargs): + def __init__( + self, + name: str, + blocks: [Block], + *args, + **kwargs + ): super().__init__(blocks, name, *args, **kwargs) - def _init(self, gen_dataset: GeneralDataset, gmm: GNNModelManager): + def _init( + self, + gen_dataset: GeneralDataset, + gmm: GNNModelManager + ) -> None: self.gen_dataset = gen_dataset self.gmm = gmm - def _finalize(self): + def _finalize( + self + ) -> bool: return True - def _submit(self): + def _submit( + self + ) -> None: pass class ExplainerLoadBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.explainer_path = None @@ -38,27 +58,33 @@ def __init__(self, *args, **kwargs): self.gen_dataset = None self.gmm = None - def _init(self, gen_dataset: GeneralDataset, gmm: GNNModelManager): + def _init( + self, + gen_dataset: GeneralDataset, + gmm: GNNModelManager + ) -> list: # Define options for model manager self.gen_dataset = gen_dataset self.gmm = gmm return [gen_dataset.dataset.num_node_features, gen_dataset.is_multi(), self.get_index()] # return self.get_index() - def _finalize(self): - # if 1: # TODO better check - # return False + def _finalize( + self + ) -> bool: + if set(get_config_keys("explanations")) != set(self._config.keys()): + return False self.explainer_path = self._config return True - def _submit(self): + def _submit( + self + ) -> None: init_config, run_config = self._explainer_kwargs(model_path=self.gmm.model_path_info(), explainer_path=self.explainer_path) modification_config = ExplainerModificationConfig( explainer_ver_ind=self.explainer_path["explainer_ver_ind"], - # FIXME Kirill front attack - explainer_attack_type=self.explainer_path["explainer_attack_type"] ) from explainers.explainers_manager import FrameworkExplainersManager @@ -74,7 +100,9 @@ def _submit(self): "explanation_data": self._object.load_explanation(run_config=run_config) } - def get_index(self): + def get_index( + self + ) -> [str, str]: """ Get all available explanations with respect to current dataset and model """ path = os.path.relpath(self.gmm.model_path_info(), MODELS_DIR) @@ -91,7 +119,11 @@ def get_index(self): # return [ps.to_json(), json_dumps(self.info)] FIXME misha parsing error on front return [ps.to_json(), '{}'] - def _explainer_kwargs(self, model_path, explainer_path): + def _explainer_kwargs( + self, + model_path: Union[str, Path], + explainer_path: Union[str, Path] + ): init_kwargs_file, run_kwargs_file = Declare.explainer_kwargs_path_full( model_path=model_path, explainer_path=explainer_path) with open(init_kwargs_file) as f: @@ -102,31 +134,39 @@ def _explainer_kwargs(self, model_path, explainer_path): class ExplainerInitBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.explainer_init_config = None self.gen_dataset = None self.gmm = None - def _init(self, gen_dataset: GeneralDataset, gmm: GNNModelManager): + def _init( + self, + gen_dataset: GeneralDataset, + gmm: GNNModelManager + ) -> list: # Define options for model manager self.gen_dataset = gen_dataset self.gmm = gmm return FrameworkExplainersManager.available_explainers(self.gen_dataset, self.gmm) - def _finalize(self): - # if 1: # TODO better check - # return False - - # self.explainer_init_config = ExplainerInitConfig(**self._config) + def _finalize( + self + ) -> bool: self.explainer_init_config = ConfigPattern( **self._config, _import_path=EXPLAINERS_INIT_PARAMETERS_PATH, _config_class="ExplainerInitConfig") return True - def _submit(self): + def _submit( + self + ) -> None: # Build an explainer self._object = FrameworkExplainersManager( dataset=self.gen_dataset, gnn_manager=self.gmm, @@ -137,61 +177,31 @@ def _submit(self): class ExplainerRunBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.explainer_run_config = None self.explainer_manager = None - def _init(self, explainer_manager: FrameworkExplainersManager): + def _init( + self, + explainer_manager: FrameworkExplainersManager + ) -> list: self.explainer_manager = explainer_manager return [self.explainer_manager.gen_dataset.dataset.num_node_features, self.explainer_manager.gen_dataset.is_multi(), self.explainer_manager.explainer.name] - def _finalize(self): - # if 1: # TODO better check - # return False - raise NotImplementedError - - # self.explainer_run_config = ExplainerRunConfig(**self._config) # FIXME add class_name - import copy - config = copy.deepcopy(self._config) - config['_config_kwargs']['kwargs']["_import_path"] = EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH - config['_config_kwargs']['kwargs']["_config_class"] = "Config" - self.explainer_run_config = ConfigPattern( - **config, - _config_class="ExplainerRunConfig" - ) - return True - - def _submit(self): - raise NotImplementedError - # # NOTE: multiprocess = multiprocessing + dill instead of pickle, so it can serialize our objects and - # # send them via a Queue - # # NOTE 2: should be imported separately from torch.multiprocessing - # from multiprocess import Process as mpProcess, Queue as mpQueue - - # queue = mpQueue() - # self.explainer_subprocess = mpProcess( - # target=run_function, args=( - # self.explainer_manager, 'conduct_experiment', - # self.explainer_run_config, queue)) - # self.explainer_subprocess.start() - # self.socket.send("explainer", {"status": "STARTED", "mode": self.explainer_run_config["mode"]}) - # self.explainer_subprocess.join() - # - # # Get result if present - otherwise nothing changed - # if not queue.empty(): - # self.explainer_manager = queue.get_nowait() - - self.socket.send(block="er", msg= - {"status": "STARTED", "mode": self.explainer_run_config.mode}) - self.explainer_manager.conduct_experiment(self.explainer_run_config, socket=self.socket) - - def do(self, do, params): + def do( + self, + do: str, + params: dict + ) -> str: if do == "run": - import copy config = json_loads(params.get('explainerRunConfig')) config['_config_kwargs']['kwargs']["_import_path"] =\ EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH \ @@ -207,48 +217,17 @@ def do(self, do, params): self._run_explainer() return '' - elif do == "stop": - # BACK_FRONT.model_manager.stop_signal = True # TODO misha remove stop_signal - self._stop_explainer() - return '' - - elif do == "save": - return self._save_explainer() - - def _run_explainer(self): - # self.explainer_run_config = explainer_run_config - - # # NOTE: multiprocess = multiprocessing + dill instead of pickle, so it can serialize our objects and - # # send them via a Queue - # # NOTE 2: should be imported separately from torch.multiprocessing - # from multiprocess import Process as mpProcess, Queue as mpQueue - # - # # queue = mpQueue() - # # self.explainer_subprocess = mpProcess( - # # target=run_function, args=( - # # self.explainer_manager, 'conduct_experiment', - # # self.explainer_run_config, queue)) - # # self.explainer_subprocess.start() - # # self.socket.send("explainer", {"status": "STARTED", "mode": self.explainer_run_config["mode"]}) - # # self.explainer_subprocess.join() - # # - # # # Get result if present - otherwise nothing changed - # # if not queue.empty(): - # # self.explainer_manager = queue.get_nowait() + # elif do == "save": + # return self._save_explainer() + def _run_explainer( + self + ) -> None: self.socket.send("explainer", { "status": "STARTED", "mode": self.explainer_run_config.mode}) + # Saves explanation by default, save_explanation_flag=True self.explainer_manager.conduct_experiment(self.explainer_run_config, socket=self.socket) - def _stop_explainer(self): - raise NotImplementedError - # FIXME not implemented - print('stop explainer') - if self.explainer_subprocess and self.explainer_subprocess.is_alive(): - self.explainer_subprocess.terminate() - self.socket.send("er", {"status": "INTERRUPTED", "mode": mode}) - self.explanation_data = None - - def _save_explainer(self): - # self.explainer.save_explanation() TODO is it necessary? - return str(self.explainer_manager.explainer_result_file_path) + # def _save_explainer(self): + # # self.explainer_manager.save_explanation() + # return str(self.explainer_manager.explainer_result_file_path) diff --git a/web_interface/back_front/frontend_client.py b/web_interface/back_front/frontend_client.py index 32f3ce3..2c61107 100644 --- a/web_interface/back_front/frontend_client.py +++ b/web_interface/back_front/frontend_client.py @@ -1,4 +1,5 @@ import json +from typing import Union from aux.utils import FUNCTIONS_PARAMETERS_PATH, FRAMEWORK_PARAMETERS_PATH, MODULES_PARAMETERS_PATH, \ EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, \ @@ -14,19 +15,45 @@ class FrontendClient: """ - Keeps data currently loaded at frontend, s.t. dataset, model, explainer to avoid loading them at - each call. - + Frontend client. + Keeps data currently loaded at frontend for the client: dataset, model, explainer. """ - def __init__(self, socketio: SocketConnect): - self.socket = SocketConnect(socket=socketio) + # Global values. + # TODO this should be updated regularly or by some event + storage_index = { # type -> PrefixStorage + 'D': None, 'DV': None, 'M': None, 'CM': None, 'E': None} + parameters = { # type -> Parameters dict + 'F': None, 'FW': None, 'M': None, 'EI': None, 'ER': None, 'O': None} + + @staticmethod + def get_parameters( + type: str + ) -> Union[dict, None]: + """ + """ + if type not in FrontendClient.parameters: + WebInterfaceError(f"Unknown 'ask' argument 'type'={type}") + + with open({ + 'F': FUNCTIONS_PARAMETERS_PATH, + 'FW': FRAMEWORK_PARAMETERS_PATH, + 'M': MODULES_PARAMETERS_PATH, + 'EI': EXPLAINERS_INIT_PARAMETERS_PATH, + 'ELR': EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, + 'EGR': EXPLAINERS_GLOBAL_RUN_PARAMETERS_PATH, + 'O': OPTIMIZERS_PARAMETERS_PATH, + }[type], 'r') as f: + FrontendClient.parameters[type] = json.load(f) + + return FrontendClient.parameters[type] - # TODO this should be updated regularly or by some event - self.storage_index = { # type -> PrefixStorage - 'D': None, 'DV': None, 'M': None, 'CM': None, 'E': None} - self.parameters = { # type -> Parameters dict - 'F': None, 'FW': None, 'M': None, 'EI': None, 'ER': None, 'O': None} + def __init__( + self, + sid: str + ): + self.sid = sid # socket ID + self.socket = SocketConnect(sid=sid) # Build the diagram self.diagram = Diagram() @@ -65,32 +92,17 @@ def __init__(self, socketio: SocketConnect): self.erBlock = ExplainerRunBlock("er", socket=self.socket) self.diagram.add_dependency(self.eiBlock, self.erBlock) - def drop(self): - """ Drop all current data """ - self.diagram.drop() - # self.storage_index = {'D': None, 'DV': None, 'M': None, 'CM': None, 'E': None} - self.parameters = {'F': None, 'FW': None, 'M': None, 'EI': None, 'ER': None, 'O': None} - - def get_parameters(self, type): - """ - """ - if type not in self.parameters: - WebInterfaceError(f"Unknown 'ask' argument 'type'={type}") - - with open({ - 'F': FUNCTIONS_PARAMETERS_PATH, - 'FW': FRAMEWORK_PARAMETERS_PATH, - 'M': MODULES_PARAMETERS_PATH, - 'EI': EXPLAINERS_INIT_PARAMETERS_PATH, - 'ELR': EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, - 'EGR': EXPLAINERS_GLOBAL_RUN_PARAMETERS_PATH, - 'O': OPTIMIZERS_PARAMETERS_PATH, - }[type], 'r') as f: - self.parameters[type] = json.load(f) - - return self.parameters[type] - - def request_block(self, block, func, params: dict = None): + # def drop(self): + # """ Drop all current data + # """ + # self.diagram.drop() + + def request_block( + self, + block: str, + func: str, + params: dict = None + ) -> object: """ :param block: name of block :param func: block function to call @@ -101,4 +113,4 @@ def request_block(self, block, func, params: dict = None): block = self.diagram.get(block) func = getattr(block, func) res = func(**params or {}) - return res \ No newline at end of file + return res diff --git a/web_interface/back_front/model_blocks.py b/web_interface/back_front/model_blocks.py index 727dd76..053fa86 100644 --- a/web_interface/back_front/model_blocks.py +++ b/web_interface/back_front/model_blocks.py @@ -1,5 +1,7 @@ import json import os +from pathlib import Path +from typing import Union import torch from torch_geometric.data import Dataset @@ -8,48 +10,76 @@ from aux.data_info import UserCodeInfo, DataInfo from aux.declaration import Declare from aux.prefix_storage import PrefixStorage -from aux.utils import import_by_name, model_managers_info_by_names_list, GRAPHS_DIR, TECHNICAL_PARAMETER_KEY, \ +from aux.utils import import_by_name, model_managers_info_by_names_list, GRAPHS_DIR, \ + TECHNICAL_PARAMETER_KEY, \ IMPORT_INFO_KEY from base.datasets_processing import GeneralDataset, VisiblePart -from models_builder.gnn_constructor import FrameworkGNNConstructor +from models_builder.gnn_constructor import FrameworkGNNConstructor, GNNConstructor from models_builder.gnn_models import ModelManagerConfig, GNNModelManager, Metric from web_interface.back_front.block import Block, WrapperBlock -from web_interface.back_front.utils import WebInterfaceError, json_dumps +from web_interface.back_front.utils import WebInterfaceError, json_dumps, get_config_keys, \ + SocketConnect + +TENSOR_SIZE_LIMIT = 1024 # Max size of weights tensor we sent to frontend class ModelWBlock(WrapperBlock): - def __init__(self, name, blocks, *args, **kwargs): + def __init__( + self, + name: str, + blocks: [Block], + *args, + **kwargs + ): super().__init__(blocks, name, *args, **kwargs) - def _init(self, ptg_dataset: Dataset): + def _init( + self, + ptg_dataset: Dataset + ) -> list[int]: return [ptg_dataset.num_node_features, ptg_dataset.num_classes] - def _finalize(self): + def _finalize( + self + ) -> bool: return True - def _submit(self): + def _submit( + self + ) -> None: pass class ModelLoadBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.model_path = None self.gen_dataset = None - def _init(self, gen_dataset: GeneralDataset): + def _init( + self, + gen_dataset: GeneralDataset + ) -> list[str]: self.gen_dataset = gen_dataset return self.get_index() - def _finalize(self): - if not (len(self._config.keys()) == 5): # TODO better check + def _finalize( + self + ) -> bool: + if set(get_config_keys("models")) != set(self._config.keys()): return False self.model_path = self._config return True - def _submit(self): + def _submit( + self + ) -> None: from models_builder.gnn_models import GNNModelManager self.model_manager, train_test_split_path = GNNModelManager.from_model_path( model_path=self.model_path, dataset_path=self.gen_dataset.results_dir) @@ -57,9 +87,11 @@ def _submit(self): self._object = self.model_manager self._result = self._object.get_full_info() - self._result.update(self._object.gnn.get_full_info()) + self._result.update(self._object.gnn.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT)) - def get_index(self): + def get_index( + self + ) -> list[str]: """ Get all available models with respect to current dataset """ DataInfo.refresh_models_dir_structure() @@ -70,13 +102,15 @@ def get_index(self): keys_list, full_keys_list, dir_structure, _ = DataInfo.take_keys_etc_by_prefix( prefix=("data_root", "data_prepared") ) - values_info = DataInfo.values_list_by_path_and_keys(path=path, - full_keys_list=full_keys_list, - dir_structure=dir_structure) + values_info = DataInfo.values_list_by_path_and_keys( + path=path, full_keys_list=full_keys_list, dir_structure=dir_structure) ps = index.filter(dict(zip(keys_list, values_info))) return [ps.to_json(), json_dumps(info)] - def _load_train_test_mask(self, path): + def _load_train_test_mask( + self, + path: Union[Path, str] + ) -> None: """ Load train/test mask associated to the model and send to frontend """ # FIXME self.manager_config.train_test_split self.gen_dataset.train_mask, self.gen_dataset.val_mask, \ @@ -85,16 +119,25 @@ def _load_train_test_mask(self, path): class ModelConstructorBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.model_config = None - def _init(self, gen_dataset: GeneralDataset): + def _init( + self, + gen_dataset: GeneralDataset + ) -> list: ptg_dataset = gen_dataset.dataset return [ptg_dataset.num_node_features, ptg_dataset.num_classes, gen_dataset.is_multi()] - def _finalize(self): + def _finalize( + self + ) -> bool: # TODO better check if not ('layers' in self._config and isinstance(self._config['layers'], list)): return False @@ -102,30 +145,43 @@ def _finalize(self): self.model_config = ModelConfig(structure=ModelStructureConfig(**self._config)) return True - def _submit(self): + def _submit( + self + ) -> None: self._object = FrameworkGNNConstructor(self.model_config) - self._result = self._object.get_full_info() + self._result = self._object.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT) class ModelCustomBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.gen_dataset = None self.model_name: dict = None - def _init(self, gen_dataset: GeneralDataset): + def _init( + self, + gen_dataset: GeneralDataset + ) -> list[str]: self.gen_dataset = gen_dataset return self.get_index() - def _finalize(self): + def _finalize( + self + ) -> bool: if not (len(self._config.keys()) == 2): # TODO better check return False self.model_name = self._config return True - def _submit(self): + def _submit( + self + ) -> None: # FIXME misha this is bad way user_models_obj_dict_info = UserCodeInfo.user_models_list_ref() cm_path = None @@ -139,9 +195,11 @@ def _submit(self): assert cm_path self._object = UserCodeInfo.take_user_model_obj(cm_path, self.model_name["model"]) - self._result = self._object.get_full_info() + self._result = self._object.get_full_info(tensor_size_limit=TENSOR_SIZE_LIMIT) - def get_index(self): + def get_index( + self + ) -> list[str]: """ Get all available models with respect to current dataset """ user_models_obj_dict_info = UserCodeInfo.user_models_list_ref() @@ -159,13 +217,21 @@ def get_index(self): class ModelManagerBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.model_manager_config = None self.klass = None - def _init(self, gen_dataset: GeneralDataset, gnn): + def _init( + self, + gen_dataset: GeneralDataset, + gnn: GNNConstructor + ) -> dict: # Define options for model manager self.gen_dataset = gen_dataset self.gnn = gnn @@ -177,21 +243,22 @@ def _init(self, gen_dataset: GeneralDataset, gnn): mm_info = model_managers_info_by_names_list(mm_set) return mm_info - def _finalize(self): - # if 1: # TODO better check - # return False - + def _finalize( + self + ) -> bool: self.klass = self._config.pop("class") self.model_manager_config = ModelManagerConfig(**self._config) return True - def _submit(self): + def _submit( + self + ) -> None: create_train_test_mask = True assert self.gnn is not None # Import correct class - from web_interface.main import FRONTEND_CLIENT - if self.klass in FRONTEND_CLIENT.get_parameters("FW"): + from web_interface.main_multi import FrontendClient + if self.klass in FrontendClient.get_parameters("FW"): mm_class = import_by_name(self.klass, ["models_builder.gnn_models"]) else: # Custom MM @@ -218,7 +285,10 @@ def _submit(self): self.gen_dataset.train_test_split(*self.model_manager_config.train_test_split) send_train_test_mask(self.gen_dataset, self.socket) - def get_satellites(self, part=None): + def get_satellites( + self, + part: dict = None + ) -> dict: """ Resend model dependent satellites data: train-test mask, embeds, preds """ visible_part = self.gen_dataset.visible_part if part is None else\ @@ -234,7 +304,11 @@ def get_satellites(self, part=None): return res -def send_train_test_mask(gen_dataset, socket, visible_part=None): +def send_train_test_mask( + gen_dataset, + socket: SocketConnect = None, + visible_part: VisiblePart = None +) -> Union[None, dict]: """ Compute train/test mask for the dataset and send to frontend. """ if visible_part is None: @@ -256,27 +330,43 @@ def send_train_test_mask(gen_dataset, socket, visible_part=None): class ModelTrainerBlock(Block): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super().__init__(*args, **kwargs) self.gen_dataset = None self.model_manager = None - def _init(self, gen_dataset: GeneralDataset, gmm: GNNModelManager): + def _init( + self, + gen_dataset: GeneralDataset, + gmm: GNNModelManager + ) -> dict: self.gen_dataset = gen_dataset self.model_manager = gmm return self.model_manager.get_model_data() - def _finalize(self): + def _finalize( + self + ) -> bool: # TODO for ProtGNN model must be trained return True - def _submit(self): + def _submit( + self + ) -> None: self._object = self.model_manager - def do(self, do, params): + def do( + self, + do, + params + ) -> str: if do == "run": metrics = [Metric(**m) for m in json.loads(params.get('metrics'))] self._run_model(metrics) @@ -293,11 +383,6 @@ def do(self, do, params): self._train_model(mode=mode, steps=steps, metrics=metrics) return '' - elif do == "stop": - # BACK_FRONT.model_manager.stop_signal = True # TODO remove stop_signal - self.stop_model() - return '' - elif do == "save": return self._save_model() @@ -314,14 +399,19 @@ def do(self, do, params): else: raise WebInterfaceError(f"Unknown 'do' command {do} for model") - def _reset_model(self): + def _reset_model( + self + ) -> None: self.model_manager.gnn.reset_parameters() self.model_manager.modification.epochs = 0 self.gen_dataset.train_test_split(*self.model_manager.manager_config.train_test_split) send_train_test_mask(self.gen_dataset, self.socket) self._run_model([Metric("Accuracy", mask='train'), Metric("Accuracy", mask='test')]) - def _run_model(self, metrics): + def _run_model( + self, + metrics + ) -> None: """ Runs model to compute predictions and logits """ # TODO add set of nodes assert self.model_manager @@ -336,42 +426,30 @@ def _run_model(self, metrics): self.model_manager.send_epoch_results( metrics_values=metrics_values, stats_data=stats_data, socket=self.socket) - def _train_model(self, mode, steps, metrics): - # # Remove unpickable socket - # self.model_manager.gnn_mm.socket = None - # - # assert self.model_training_subprocess is None or not self.model_training_subprocess.is_alive() - # - # queue = tQueue() - # self.model_training_subprocess = tProcess( - # target=run_function, args=( - # self.model_manager, 'train_model', { - # "gen_dataset": self.gen_dataset, "save_model_flag": False, "mode": mode, - # "steps": steps, "metrics": metrics}, - # queue)) - # - # self.model_training_subprocess.start() - # self.socket.send("model", {"status": "STARTED"}) - # self.model_training_subprocess.join() - # - # # Get result if present - otherwise nothing changed - # self.model_manager = queue.get_nowait() - # # Put unpickable socket back - # self.model_manager.gnn_mm.socket = self.socket - + def _train_model( + self, + mode: Union[str, None], + steps: Union[int, None], + metrics: list[Metric] + ) -> None: self._check_metrics(metrics) self.model_manager.train_model( gen_dataset=self.gen_dataset, save_model_flag=False, mode=mode, steps=steps, metrics=metrics, socket=self.socket) - def _save_model(self): + def _save_model( + self + ) -> str: path = self.model_manager.save_model_executor() self.gen_dataset.save_train_test_mask(path) DataInfo.refresh_models_dir_structure() # TODO send dir_structure info to front return str(path) - def _check_metrics(self, metrics): + def _check_metrics( + self, + metrics: list[Metric] + ) -> None: """ Adjust metrics parameters if dataset has many classes, e.g. binary -> macro averaging """ classes = self.gen_dataset.num_classes diff --git a/web_interface/back_front/utils.py b/web_interface/back_front/utils.py index 2e51e5e..207f5e0 100644 --- a/web_interface/back_front/utils.py +++ b/web_interface/back_front/utils.py @@ -1,3 +1,6 @@ +from typing import Any + +from flask_socketio import SocketIO import json from collections import deque from threading import Thread @@ -5,12 +8,19 @@ import numpy as np +from aux.utils import SAVE_DIR_STRUCTURE_PATH + class WebInterfaceError(Exception): - def __init__(self, *args): + def __init__( + self, + *args + ): self.message = args[0] if args else None - def __str__(self): + def __str__( + self + ): if self.message: return f"WebInterfaceError: {self.message}" else: @@ -18,18 +28,29 @@ def __str__(self): class Queue(deque): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs + ): super(Queue, self).__init__(*args, **kwargs) self.last_obl = True - def push(self, obj, id, obligate): + def push( + self, + obj: object, + id: int, + obligate: bool + ) -> None: # If last is not obligate - replace it if len(self) > 0 and self.last_obl is False: self.pop() super(Queue, self).append((obj, id)) self.last_obl = obligate - def get_first_id(self): + def get_first_id( + self + ) -> int: if len(self) > 0: obj, id = self.popleft() self.appendleft((obj, id)) @@ -39,26 +60,37 @@ def get_first_id(self): class SocketConnect: - """ Sends messages to JS socket from python process + """ Sends messages to JS socket from a python process """ # max_packet_size = 1024**2 # 1MB limit by default - def __init__(self, socket=None): + def __init__( + self, + socket: SocketIO = None, + sid: str = None + ): if socket is None: - from flask_socketio import SocketIO self.socket = SocketIO(message_queue='redis://') else: self.socket = socket + self.sid = sid self.queue = deque() # general queue self.tag_queue = {} # {tag -> Queue} self.obj_id = 0 # Messages ids counter self.sleep_time = 0.5 self.active = False # True when sending cycle is running - def send(self, block, msg, func=None, tag='all', obligate=True): + def send( + self, + block: str, + msg: dict, + func: str = None, + tag: str = 'all', + obligate: bool = True + ): """ Send info message to frontend. - :param dst: destination, e.g. "" (to console), "model", "explainer" + :param block: destination block, e.g. "" (to console), "model", "explainer" :param msg: dict :param tag: keep messages in a separate queue with this tag, all but last unobligate messages will be squashed @@ -79,7 +111,9 @@ def send(self, block, msg, func=None, tag='all', obligate=True): if not self.active: Thread(target=self._cycle, args=()).start() - def _send(self): + def _send( + self + ) -> None: """ Send leftmost actual data element from the queue. """ data = None # Find actual data elem @@ -92,14 +126,16 @@ def _send(self): if data is None: return - self.socket.send(data) + self.socket.send(data, to=self.sid) size = len(json_dumps(data)) if size > 25e6: raise RuntimeError(f"Too big package size: {size} bytes") self.sleep_time = 0.5 * size / 25e6 * 10 print('sent data', id, tag, 'of len=', size, 'sleep', self.sleep_time) - def _cycle(self): + def _cycle( + self + ) -> None: """ Send messages from the queue until it is empty. """ self.active = True while True: @@ -110,7 +146,9 @@ def _cycle(self): sleep(self.sleep_time) -def json_dumps(object): +def json_dumps( + object +) -> str: """ Dump an object to JSON properly handling values "-Infinity", "Infinity", and "NaN" """ string = json.dumps(object, ensure_ascii=False) @@ -120,16 +158,32 @@ def json_dumps(object): .replace('Infinity', '"Infinity"') -def json_loads(string): +def json_loads( + string: str +) -> Any: """ Parse JSON string properly handling values "-Infinity", "Infinity", and "NaN" """ c = {"-Infinity": -np.inf, "Infinity": np.inf, "NaN": np.nan} - def parser(arg): + def parser( + arg + ): if isinstance(arg, dict): for key, value in arg.items(): if isinstance(value, str) and value in c: arg[key] = c[value] return arg - return json.loads(string, object_hook=parser) \ No newline at end of file + return json.loads(string, object_hook=parser) + + +def get_config_keys( + object_type: str +) -> list: + """ Get a list of keys for a config describing an object of the specified type. + """ + with open(SAVE_DIR_STRUCTURE_PATH) as f: + save_dir_structure = json.loads(f.read())[object_type] + + return [k for k, v in save_dir_structure.items() if v["add_key_name_flag"] is not None] + diff --git a/web_interface/main.py b/web_interface/main.py deleted file mode 100644 index a4b3c82..0000000 --- a/web_interface/main.py +++ /dev/null @@ -1,164 +0,0 @@ -import json -import logging -from time import sleep - -from flask import Flask, render_template, request -from flask_socketio import SocketIO - -from aux.data_info import DataInfo -from web_interface.back_front.frontend_client import FrontendClient -from web_interface.back_front.utils import WebInterfaceError, json_dumps, json_loads - -app = Flask(__name__) -app.config['SECRET_KEY'] = '57916211bb0b13ce0c676dfde280ba245' -socketio = SocketIO(app, async_mode='threading', message_queue='redis://') -FRONTEND_CLIENT = FrontendClient(socketio) - - -@app.route("/") -@app.route("/home", methods=['GET', 'POST']) -def home(): - # Drop all data at page reload - FRONTEND_CLIENT.drop() - DataInfo.refresh_all_data_info() - return render_template('home.html') - - -@app.route("/ask", methods=['GET', 'POST']) -def storage(): - if request.method == 'POST': - ask = request.form.get('ask') - - if ask == "parameters": - type = request.form.get('type') - return json_dumps(FRONTEND_CLIENT.get_parameters(type)) - - else: - raise WebInterfaceError(f"Unknown 'ask' command {ask}") - - -@app.route("/block", methods=['GET', 'POST']) -def block(): - if request.method == 'POST': - block = request.form.get('block') - func = request.form.get('func') - params = request.form.get('params') - if params: - params = json_loads(params) - print(f"request_block: block={block}, func={func}, params={params}") - FRONTEND_CLIENT.request_block(block, func, params) - return '{}' - - -@app.route("/dataset", methods=['GET', 'POST']) -def dataset(): - if request.method == 'POST': - get = request.form.get('get') - set = request.form.get('set') - part = request.form.get('part') - if part: - part = json_loads(part) - - # # FIXME tmp - # - # from web_interface.back_front.communication import SocketConnect, WebInterfaceError - # socket = SocketConnect(socket=socketio) - # for i in range(300): - # print('sending', i, 'big') - # socket.send(i, 1000000 * "x") - # # print('sending', i, 'small') - # # socket.send(i, "small") - # sleep(0.5/25) - - if set == "visible_part": - return FRONTEND_CLIENT.dcBlock.set_visible_part(part=part) - - if get == "data": - dataset_data = FRONTEND_CLIENT.dcBlock.get_dataset_data(part=part) - data = json.dumps(dataset_data) - logging.info(f"Length of dataset_data: {len(data)}") - return data - - elif get == "var_data": - if not FRONTEND_CLIENT.dvcBlock.is_set(): - return '' - dataset_var_data = FRONTEND_CLIENT.dvcBlock.get_dataset_var_data(part=part) - data = json.dumps(dataset_var_data) - logging.info(f"Length of dataset_var_data: {len(data)}") - return data - - elif get == "stat": - stat = request.form.get('stat') - return json_dumps(FRONTEND_CLIENT.dcBlock.get_stat(stat)) - - elif get == "index": - return FRONTEND_CLIENT.dcBlock.get_index() - - else: - raise WebInterfaceError(f"Unknown 'part' command {get} for dataset") - - -@app.route("/model", methods=['GET', 'POST']) -def model(): - if request.method == 'POST': - do = request.form.get('do') - get = request.form.get('get') - - if do: - print(f"model.do: do={do}, params={request.form}") - if do == 'index': - type = request.form.get('type') - if type == "saved": - return FRONTEND_CLIENT.mloadBlock.get_index() - if type == "custom": - return FRONTEND_CLIENT.mcustomBlock.get_index() - else: - return FRONTEND_CLIENT.mtBlock.do(do, request.form) - - if get: - if get == "satellites": - if FRONTEND_CLIENT.mmcBlock.is_set(): - part = request.form.get('part') - if part: - part = json_loads(part) - return FRONTEND_CLIENT.mmcBlock.get_satellites(part=part) - else: - return '' - - -@app.route("/explainer", methods=['GET', 'POST']) -def explainer(): - # session.clear() - if request.method == 'POST': - do = request.form.get('do') - - print(f"explainer.do: do={do}, params={request.form}") - - if do in ["run", "stop"]: - return FRONTEND_CLIENT.erBlock.do(do, request.form) - - elif do == 'index': - return FRONTEND_CLIENT.elBlock.get_index() - - # elif do == "save": - # return FRONTEND_CLIENT.save_explanation() - - else: - raise WebInterfaceError(f"Unknown 'do' command {do} for explainer") - - -if __name__ == '__main__': - logging.getLogger().setLevel(logging.INFO) - - # TODO think about multiple instances of a client - - # TODO switch to 'run' in production - # In production mode the eventlet web server is used if available, - # else the gevent web server is used. If eventlet and gevent are not installed, - # the Werkzeug development web server is used. - # app.run(debug=True, port=4568) - - # TODO Flask development web server is used, use eventlet or gevent, - # see https://flask-socketio.readthedocs.io/en/latest/deployment.html - socketio.run(app, host='0.0.0.0', debug=True, port=4567, - allow_unsafe_werkzeug=True) diff --git a/web_interface/main_multi.py b/web_interface/main_multi.py new file mode 100644 index 0000000..e73cdb0 --- /dev/null +++ b/web_interface/main_multi.py @@ -0,0 +1,313 @@ +import json +import logging +import multiprocessing +from multiprocessing import Pipe +from multiprocessing.connection import Connection + +from flask import Flask, render_template, request +from flask_socketio import SocketIO, emit +import uuid + +from aux.data_info import DataInfo +from web_interface.back_front.frontend_client import FrontendClient +from web_interface.back_front.utils import WebInterfaceError, json_dumps, json_loads + +app = Flask(__name__) +app.config['SECRET_KEY'] = '57916211bb0b13ce0c676dfde280ba245' +## Need to run redis server: sudo apt install redis-server +socketio = SocketIO(app, async_mode='threading', message_queue='redis://', cors_allowed_origins="*") + +# Store active sessions +active_sessions = {} # {session Id -> sid, process, conn} + + +def worker_process( + process_id: str, + conn: Connection, + sid: str +) -> None: + print(f"Process {process_id} started") + # TODO problem is each process sends data to main process then to frontend. + # Easier to send it directly to url + + client = FrontendClient(sid) + # client.socket.socket.send('hello from subprocess') + + # import time + # from threading import Thread + # + # def report(process_id): + # while True: + # print(f"Process {process_id} is working...") + # time.sleep(1) + # + # Thread(target=report, args=(process_id,)).start() + + while True: + command = conn.recv() # This blocks until a command is received + type = command.get('type') + args = command.get('args') + print(f"Received command: {type} with args: {args}") + + if type == "dataset": + get = args.get('get') + set = args.get('set') + part = args.get('part') + if part: + part = json_loads(part) + + # # FIXME tmp + # + # from web_interface.back_front.communication import SocketConnect, WebInterfaceError + # socket = SocketConnect(socket=socketio) + # for i in range(300): + # print('sending', i, 'big') + # socket.send(i, 1000000 * "x") + # # print('sending', i, 'small') + # # socket.send(i, "small") + # sleep(0.5/25) + + if set == "visible_part": + result = client.dcBlock.set_visible_part(part=part) + + elif get == "data": + dataset_data = client.dcBlock.get_dataset_data(part=part) + data = json.dumps(dataset_data) + logging.info(f"Length of dataset_data: {len(data)}") + result = data + + elif get == "var_data": + if not client.dvcBlock.is_set(): + result = '' + else: + dataset_var_data = client.dvcBlock.get_dataset_var_data(part=part) + data = json.dumps(dataset_var_data) + logging.info(f"Length of dataset_var_data: {len(data)}") + result = data + + elif get == "stat": + stat = args.get('stat') + result = json_dumps(client.dcBlock.get_stat(stat)) + + elif get == "index": + result = client.dcBlock.get_index() + + else: + raise WebInterfaceError(f"Unknown 'part' command {get} for dataset") + + conn.send(result) + + elif type == "block": + block = args.get('block') + func = args.get('func') + params = args.get('params') + if params: + params = json_loads(params) + print(f"request_block: block={block}, func={func}, params={params}") + # TODO what if raise exception? process will stop + client.request_block(block, func, params) + # conn.send('{}') + + elif type == "model": + do = args.get('do') + get = args.get('get') + + if do: + print(f"model.do: do={do}, params={args}") + if do == 'index': + type = args.get('type') + if type == "saved": + result = client.mloadBlock.get_index() + elif type == "custom": + result = client.mcustomBlock.get_index() + else: + result = client.mtBlock.do(do, args) + + if get: + if get == "satellites": + if client.mmcBlock.is_set(): + part = args.get('part') + if part: + part = json_loads(part) + result = client.mmcBlock.get_satellites(part=part) + else: + result = '' + + assert result is not None + conn.send(result) + + elif type == "explainer": + do = args.get('do') + + print(f"explainer.do: do={do}, params={args}") + + if do in ["run", "stop"]: + result = client.erBlock.do(do, args) + + elif do == 'index': + result = client.elBlock.get_index() + + # elif do == "save": + # return client.save_explanation() + + else: + raise WebInterfaceError(f"Unknown 'do' command {do} for explainer") + + conn.send(result) + + elif type == "EXIT": + break + + print(f"Process {process_id} received STOP command") + # client.drop() + + +@socketio.on('connect') +def handle_connect( +) -> None: + # FIXME create process not when socket connects but when a new tab is open + session_id = str(uuid.uuid4()) + print('handle_connect', session_id, request.sid) + emit('session_id', {'session_id': session_id}) + + # Create a couple of connections + parent_conn, child_conn = Pipe() + + # Start the worker process + process = multiprocessing.Process(target=worker_process, + args=(session_id, child_conn, request.sid)) + active_sessions[session_id] = request.sid, process, parent_conn + process.start() + + +@socketio.on('disconnect') +def handle_disconnect( +) -> None: + print('handle_disconnect from some websocket') + + for session_id, (sid, process, parent_conn) in active_sessions.items(): + if sid == request.sid: + # print(f"Disconnected: {session_id}") + stop_session(session_id) + break + + +@app.route('/') +def home( +) -> str: + # FIXME ? + DataInfo.refresh_all_data_info() + return render_template('analysis.html') + + +@app.route('/analysis') +def analysis( +) -> str: + return render_template('analysis.html') + + +@app.route('/interpretation') +def interpretation( +) -> str: + return render_template('interpretation.html') + + +@app.route('/defense') +def defense( +) -> str: + return render_template('defense.html') + + +@app.route("/drop", methods=['GET', 'POST']) +def drop( +) -> str: + if request.method == 'POST': + session_id = json.loads(request.data)['sessionId'] + if session_id in active_sessions: + # raise WebInterfaceError(f"Session {session_id} is not active") + stop_session(session_id) + return '' + + +def stop_session( + session_id: str +): + _, process, conn = active_sessions[session_id] + + # Stop corresponding process + try: + # Send stop command + conn.send({'type': "STOP"}) + except Exception as e: + print('exception:', e) + + # Wait for the process to terminate + process.join(timeout=1) + + # If the process is still alive, terminate it + if process.is_alive(): + print(f"Forcefully terminating process {session_id}") + process.terminate() + process.join(timeout=1) + + del active_sessions[session_id] + + +@app.route("/ask", methods=['GET', 'POST']) +def storage( +) -> str: + if request.method == 'POST': + session_id = request.form.get('sessionId') + assert session_id in active_sessions + print('ask request from', session_id) + ask = request.form.get('ask') + + if ask == "parameters": + type = request.form.get('type') + return json_dumps(FrontendClient.get_parameters(type)) + + else: + raise WebInterfaceError(f"Unknown 'ask' command {ask}") + + +@app.route("/block", methods=['GET', 'POST']) +def block( +) -> str: + if request.method == 'POST': + session_id = request.form.get('sessionId') + assert session_id in active_sessions + print('block request from', session_id) + _, process, conn = active_sessions[session_id] + + conn.send({'type': 'block', 'args': request.form}) + return '{}' + + +@app.route("/", methods=['GET', 'POST']) +def url( + url: str +) -> str: + assert url in ['dataset', 'model', 'explainer'] + if request.method == 'POST': + sid = request.form.get('sessionId') + _, process, conn = active_sessions[sid] + print(url, 'request from', sid) + + conn.send({'type': url, 'args': request.form}) + return conn.recv() + + +if __name__ == '__main__': + # print(f"Async mode is: {socketio.async_mode}") + socketio.run(app, debug=True, allow_unsafe_werkzeug=True) + + # TODO switch to 'run' in production + # In production mode the eventlet web server is used if available, + # else the gevent web server is used. If eventlet and gevent are not installed, + # the Werkzeug development web server is used. + # app.run(debug=True, port=4568) + + # TODO Flask development web server is used, use eventlet or gevent, + # see https://flask-socketio.readthedocs.io/en/latest/deployment.html + # socketio.run(app, host='0.0.0.0', debug=True, port=4567, + # allow_unsafe_werkzeug=True) diff --git a/web_interface/static/css/dropdown-menu.css b/web_interface/static/css/dropdown-menu.css new file mode 100644 index 0000000..b7842c6 --- /dev/null +++ b/web_interface/static/css/dropdown-menu.css @@ -0,0 +1,58 @@ +.dropdown-menu-button { + padding: 10px; + background-color: #f0f0f0; + border: none; + cursor: pointer; + display: flex; + align-items: center; +} +.dropdown-menu-button svg { + width: 20px; + height: 20px; +} +.dropdown-menu { + display: none; + position: absolute; + background-color: #f9f9f9; + min-width: 160px; + box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); + z-index: 1; +} +.dropdownmenuitem { + color: black; + padding: 12px 16px; + text-decoration: none; + user-select: none; + display: block; +} +.dropdownmenuitem:hover { + background-color: #f1f1f1; +} +.dropdownmenuitem.has-submenu { + position: relative; + padding-right: 24px; /* Make room for the caret */ +} +.dropdownmenuitem.has-submenu::after { + content: ''; + position: absolute; + right: 10px; + top: 50%; + transform: translateY(-50%); + width: 0; + height: 0; + border-left: 5px solid black; + border-top: 5px solid transparent; + border-bottom: 5px solid transparent; +} +.dropdownmenuitem.has-submenu:hover::after { + border-left-color: #555; /* Change color on hover if desired */ +} +.dropdown-submenu { + display: none; + position: absolute; + left: 100%; + top: 0; + background-color: #f9f9f9; + min-width: 160px; + box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); +} diff --git a/web_interface/static/js/controllers/controller.js b/web_interface/static/js/controllers/controller.js index c212b9c..4bdbb2d 100644 --- a/web_interface/static/js/controllers/controller.js +++ b/web_interface/static/js/controllers/controller.js @@ -1,10 +1,32 @@ class Controller { + static sessionId // ID of session + static isActive = false // this controller was started + constructor() { this.presenter = new Presenter() // Setup socket connection this.socket = io() - this.socket.on('connect', () => console.log('socket connected')) + this.socket.on('connect', () => { + console.log('socket connected') + if (Controller.isActive) { + // Means re-connection to server. Need to reload the page + alert("This session is outdated. Press OK to reload the page.") + Controller.isActive = false + window.location.reload(true) + } + }) + this.socket.on('session_id', (data) => { + Controller.sessionId = data["session_id"] + Controller.isActive = true + console.log('session_id', Controller.sessionId) + this.run() + }) + this.socket.on('disconnect', () => { + // Controller.isActive = false + console.log('Disconnected from server'); + }); + this.socket.on('message', async (data) => { // Message to block listeners let msg = JSON_parse(data["msg"]) @@ -72,14 +94,14 @@ class Controller { static async ajaxRequest(url, data) { let result = null + console.assert(!('sessionId' in data)) + data['sessionId'] = Controller.sessionId + // console.log('ajaxRequest', data) await $.ajax({ type: 'POST', url: url, data: data, success: (res, status, jqXHR) => { - // console.log('Result: ' + res) - // console.log('status: ' + status) - // console.log('jqXHR: ' + jqXHR) result = res } }) diff --git a/web_interface/static/js/paramsBuilder.js b/web_interface/static/js/paramsBuilder.js index 1ccb1c7..7a15f12 100644 --- a/web_interface/static/js/paramsBuilder.js +++ b/web_interface/static/js/paramsBuilder.js @@ -11,18 +11,11 @@ class ParamsBuilder { if (type in ParamsBuilder.cachedParams) return ParamsBuilder.cachedParams[type] - let params = null - await $.ajax({ - type: 'POST', - url: '/ask', - data: { + let params = await Controller.ajaxRequest('/ask', { ask: "parameters", type: type, - }, - success: (parameters) => { - params = ParamsBuilder.cachedParams[type] = JSON_parse(parameters) - } - }) + }) + ParamsBuilder.cachedParams[type] = params return params } @@ -159,6 +152,7 @@ class ParamsBuilder { $input.css("min-width", "60px") delete possible.special } + if (type === "int") { $input.attr("step", 1) $input.attr("pattern", "\d+") @@ -173,6 +167,9 @@ class ParamsBuilder { } for (const [key, value] of Object.entries(possible)) $input.attr(key, value) + + // Check input value when user unfocus it or change it + addValueChecker($input, type, def, possible["min"], possible["max"], "change") } else if (type === "string") { diff --git a/web_interface/static/js/presentation/left_menu/model/builder/layer.js b/web_interface/static/js/presentation/left_menu/model/builder/layer.js index eb5a23c..9e8855e 100644 --- a/web_interface/static/js/presentation/left_menu/model/builder/layer.js +++ b/web_interface/static/js/presentation/left_menu/model/builder/layer.js @@ -237,7 +237,9 @@ class LayerBlock { id = "menu-model-constructor-dropout-probability-" this.$dropoutProbInput = $("").attr("id", id).attr("type", "number") - .attr("min", 0).attr("max", 1).attr("step", 0.01).val(0.5) + .attr("min", 0).attr("max", 0.999).attr("step", 0.01).val(0.5) + addValueChecker(this.$dropoutProbInput, "float", 0.5, 0, 0.999, "change") + $dropoutParamsDiv.append($("").text("probability").attr("for", id)) $dropoutParamsDiv.append(this.$dropoutProbInput) @@ -338,6 +340,7 @@ class LayerBlock { .css("background-color", LayerBlock.linearColor) this.$linearOutputSizeInput = $("").attr("id", id).attr("type", "number") .attr("min", 1).attr("step", 1).val(LINEAR_LAYER_OUTPUT_SIZE) + addValueChecker(this.$linearOutputSizeInput, "int", 1, 1) $cb.append(this.$linearOutputSizeInput) } else if (type === "gin") { @@ -364,6 +367,7 @@ class LayerBlock { .css("background-color", LayerBlock.linearColor) this.$linearOutputSizeInput = $("").attr("id", id).attr("type", "number") .attr("min", 1).attr("step", 1).val(LINEAR_LAYER_OUTPUT_SIZE) + addValueChecker(this.$linearOutputSizeInput, "int", LINEAR_LAYER_OUTPUT_SIZE, 1) $cb.append(this.$linearOutputSizeInput) // Not showing - because prot layer is the last and outSize = num_classes $cb.hide() diff --git a/web_interface/static/js/presentation/left_menu/model/builder/sequential.js b/web_interface/static/js/presentation/left_menu/model/builder/sequential.js index eda6933..d70c333 100644 --- a/web_interface/static/js/presentation/left_menu/model/builder/sequential.js +++ b/web_interface/static/js/presentation/left_menu/model/builder/sequential.js @@ -42,6 +42,7 @@ class SequentialLayer { .css("background-color", LayerBlock.linearColor) let $linearOutputSizeInput = $("").attr("id", id).attr("type", "number") .attr("min", 1).attr("step", 1).val(LINEAR_LAYER_OUTPUT_SIZE) + addValueChecker($linearOutputSizeInput, "int", LINEAR_LAYER_OUTPUT_SIZE, 1) $cb.append($linearOutputSizeInput) this.linearInputs.push($linearOutputSizeInput) diff --git a/web_interface/static/js/presentation/left_menu/model/menuModelManagerView.js b/web_interface/static/js/presentation/left_menu/model/menuModelManagerView.js index 0373be9..1ecc1e4 100644 --- a/web_interface/static/js/presentation/left_menu/model/menuModelManagerView.js +++ b/web_interface/static/js/presentation/left_menu/model/menuModelManagerView.js @@ -84,20 +84,28 @@ class MenuModelManagerView extends MenuView { $cb.append($("").text("Train/validation/test ratio")) let $div3 = $("
").css("display", "flex") $cb.append($div3) - this.$trainRatioInput = $("").attr("type", "number").attr("min", "0") - .attr("max", "1").attr("step", "0.01").attr("value", "0.6") + this.$trainRatioInput = $("").attr("type", "number") + .attr("min", "0").attr("max", "1") + .attr("step", "0.01").attr("value", "0.6") + addValueChecker(this.$trainRatioInput, "float", 0.6, 0, 0.9999, "change") $div3.append(this.$trainRatioInput) - this.$valRatioInput = $("").attr("type", "number").attr("min", "0") - .attr("max", "0.99").attr("step", "0.01").attr("value", "0") + this.$valRatioInput = $("").attr("type", "number") + .attr("min", "0").attr("max", "0.99") + .attr("step", "0.01").attr("value", "0") + addValueChecker(this.$valRatioInput, "float", 0, 0, 0.9999, "change") $div3.append(this.$valRatioInput) - this.$testRatioInput = $("").attr("type", "number").attr("min", "0") - .attr("max", "1").attr("step", "0.01").attr("value", "0.4") + this.$testRatioInput = $("").attr("type", "number") + .attr("min", "0").attr("max", "1") + .attr("step", "0.01").attr("value", "0.4") + addValueChecker(this.$testRatioInput, "float", 0.4, 0, 0.9999, "change") $div3.append(this.$testRatioInput) // Balance between 3 fields this.$trainRatioInput.change((e) => { // val, test + if (isNaN(e.target.valueAsNumber)) // incorrect value, wait for checker + return let val = 1 - e.target.valueAsNumber - this.$testRatioInput.val() if (val > 0) this.$valRatioInput.val(Math.round(val*1e4)/1e4) @@ -108,6 +116,8 @@ class MenuModelManagerView extends MenuView { } }) this.$valRatioInput.change((e) => { // train, test + if (isNaN(e.target.valueAsNumber)) // incorrect value, wait for checker + return let val = 1 - e.target.valueAsNumber - this.$testRatioInput.val() if (val > 0) this.$trainRatioInput.val(Math.round(val*1e4)/1e4) @@ -118,6 +128,8 @@ class MenuModelManagerView extends MenuView { } }) this.$testRatioInput.change((e) => { // val, train + if (isNaN(e.target.valueAsNumber)) // incorrect value, wait for checker + return let val = 1 - e.target.valueAsNumber - this.$trainRatioInput.val() if (val > 0) this.$valRatioInput.val(Math.round(val*1e4)/1e4) diff --git a/web_interface/static/js/presentation/left_menu/model/menuModelTrainerView.js b/web_interface/static/js/presentation/left_menu/model/menuModelTrainerView.js index ed67b9d..5aa65a1 100644 --- a/web_interface/static/js/presentation/left_menu/model/menuModelTrainerView.js +++ b/web_interface/static/js/presentation/left_menu/model/menuModelTrainerView.js @@ -142,8 +142,11 @@ class MenuModelTrainerView extends MenuView { $cb = $("
").attr("class", "control-block") this.$mainDiv.append($cb) $cb.append($("").text("Epochs to train")) // TODO assoc to select ? - this.$epochsInput = $("").attr("type", "number").attr("min", "1") - .attr("max", "3000").attr("step", "1").attr("value", "10") + this.$epochsInput = $("").attr("type", "number") + // .attr("min", "1").attr("max", "3000") + .attr("step", "1").attr("value", "10") + addValueChecker(this.$epochsInput, "natural", 10, 1, 3000, "change") + $cb.append(this.$epochsInput) // for (const metric of ["Accuracy", "BalancedAccuracy", "Precision", "Recall", "F1", "Jaccard"]) { diff --git a/web_interface/static/js/presentation/left_menu/visualsView.js b/web_interface/static/js/presentation/left_menu/visualsView.js index 81f7474..614ada8 100644 --- a/web_interface/static/js/presentation/left_menu/visualsView.js +++ b/web_interface/static/js/presentation/left_menu/visualsView.js @@ -98,6 +98,7 @@ class VisualsView extends View { let $graphInput = $("").attr("type", "number") .attr("min", "0").attr("step", "1") .attr("id", this.idPrefix + '-' + this.multiGraphId) + addValueChecker($graphInput, "int", 0, 0, null, "change") $cb.append($graphInput) $graphInput.change((e) => this._update( this.multiGraphId, $graphInput.val(), true, $graphInput)) @@ -116,6 +117,7 @@ class VisualsView extends View { this.setEnabled(this.multiArrangeId, val > 0) this._update(this.multiDepthId, val, true, $depthInput) }) + addValueChecker($depthInput, "int", 0, 0, null, "change") $cb = $("
").attr("class", "control-block") this.$multiDiv.append($cb) @@ -179,18 +181,20 @@ class VisualsView extends View { $cb.append($nodeInput) $nodeInput.change(async (e) => await this._update( this.singleNeighNodeId, $nodeInput.val(), true, $nodeInput)) + addValueChecker($nodeInput, "int", 0, 0, null, "change") $cb = $("
").attr("class", "control-block") this.$neighborhoodDiv.append($cb) $label = $("").text("Neighborhood depth") $cb.append($label) - let $depthSelect = $("").attr("type", "number") + let $depthInput = $("").attr("type", "number") .attr("min", "0").attr("max", Neighborhood.MAX_DEPTH) .attr("step", "1") .attr("id", this.idPrefix + '-' + this.singleNeighDepthId) - $cb.append($depthSelect) - $depthSelect.change((e) => this._update( - this.singleNeighDepthId, $depthSelect.val(), true, $depthSelect)) + $cb.append($depthInput) + $depthInput.change((e) => this._update( + this.singleNeighDepthId, $depthInput.val(), true, $depthInput)) + addValueChecker($depthInput, "int", 0, 0, null, "change") $cb = $("
").attr("class", "control-block") this.$neighborhoodDiv.append($cb) @@ -432,12 +436,12 @@ class VisualsView extends View { // [this.showModeId, 'whole-graph'], [this.showModeId, 'neighborhood'], [this.singleGraphLayoutId, 'random'], - [this.singleNeighLayoutId, 'force'], + [this.singleNeighLayoutId, 'random'], [this.singleNeighNodeId, 0], - [this.singleNeighDepthId, 2], + [this.singleNeighDepthId, 1], [this.singleClassAsColorId, true], [this.multiNodeTypeAsColorId, true], - [this.multiLayoutId, 'force'], + [this.multiLayoutId, 'random'], [this.multiGraphId, 0], [this.multiDepthId, 0], [this.multiCountId, "several"], diff --git a/web_interface/static/js/presentation/right_panel/model/panelModelArchView.js b/web_interface/static/js/presentation/right_panel/model/panelModelArchView.js index 687eea2..ae50c50 100644 --- a/web_interface/static/js/presentation/right_panel/model/panelModelArchView.js +++ b/web_interface/static/js/presentation/right_panel/model/panelModelArchView.js @@ -101,9 +101,9 @@ class PanelModelArchView extends PanelView { // Init all SVG primitives this.primitives = {} - // Function drawing an parameters data object let marginText = 5 let marginBlocks = 140 // TODO make it = rightmost bound of all texts + // Function drawing an parameters data object let draw = (kv, primitives, offsets, keysList=[], depth=0) => { for (let [key, value] of Object.entries(kv)) { let text @@ -173,6 +173,15 @@ class PanelModelArchView extends PanelView { primitives[key] = text offsets[1] += 25 } + else if (typeof(value) === 'string') { // String, indicates tensor size is over limit + this.svgPanel.$svg.append(Svg.text( + `[Tensor ${value}]`, + marginBlocks + 12*depth, offsets[1], + 'middle', '20px', + 'normal', "#000000" + )) + offsets[1] += this.size + 20 + } else if (value.constructor === Array) { let arrayPrimitives = [] primitives[key] = arrayPrimitives @@ -248,7 +257,7 @@ class PanelModelArchView extends PanelView { } } } - else console.error("Unknown type") + else console.error("Model data contains unknown data type.") } } diff --git a/web_interface/static/js/presentation/right_panel/panelDatasetView.js b/web_interface/static/js/presentation/right_panel/panelDatasetView.js index cea837d..920df09 100644 --- a/web_interface/static/js/presentation/right_panel/panelDatasetView.js +++ b/web_interface/static/js/presentation/right_panel/panelDatasetView.js @@ -1,11 +1,15 @@ class PanelDatasetView extends PanelView { constructor($div, requestBlock, listenBlocks) { super($div, requestBlock, listenBlocks) + this.$infoDiv = null + this.$statsDiv = null + this.$varStatsDiv = null this.init() // Variables this.datasetInfo = null + this.labeling = null } init() { @@ -21,108 +25,124 @@ class PanelDatasetView extends PanelView { } } + onUnlock(block) { + super.onUnlock(block) + if (block === "dvc") { + this.labeling = null + this.$varStatsDiv.empty() + } + } + onSubmit(block, data) { - // Do nothing + // No super call + if (block === "dvc") { + this.labeling = data[0] + this.updateVar() + } } - addNumericStat(name, stat, fracFlag) { + /// Called at each submit + addNumericStat($whereDiv, name, stat, fracFlag) { let $div = $("
") - this.$body.append($div) + $whereDiv.append($div) let $button = $("").text("get") $div.append(name + ': ') $div.append($button) $button.click(async () => { $button.prop("disabled", true) let res = await Controller.ajaxRequest('/dataset', {get: "stat", stat: stat}) + // console.log(res) $div.empty() - $div.append(name + ': ' + (fracFlag ? parseFloat(res).toFixed(4) : res)) + if (res.constructor === Object) { // dict + let str = "" + for (let [k, v] of Object.entries(res)) { + str += '
' + k + ': ' + (fracFlag ? parseFloat(v).toFixed(4) : v) + '
' + } + $div.append(name + ': ' + str) + } + else + $div.append(name + ': ' + (fracFlag ? parseFloat(res).toFixed(4) : res)) }) } - plotDistribution(name, st, txt, lbl, oX, oY) { + plotDistribution($div, name, st, lbl, oX, oY, dictFlag) { let $ddDiv = $("
") - this.$body.append($ddDiv) + $div.append($ddDiv) let $button = $("").text("get") $ddDiv.append(name + ': ') $ddDiv.append($button) $button.click(async () => { $button.prop("disabled", true) - await $.ajax({ - type: 'POST', - url: '/dataset', - data: { - get: "stat", - stat: st, - }, - success: function (data) { - data = JSON_parse(data) - // console.log(data) - $ddDiv.empty() - let scale = 'linear' - let type = 'bar' - if (Object.keys(data).length > 20) { - scale = 'logarithmic' - type = 'scatter' - delete data[0] - } - let $canvas = $("").css("height", "300px") - $ddDiv.append($canvas) - const ctx = $canvas[0].getContext('2d') - new Chart(ctx, { - type: type, - data: { - datasets: [{ - label: lbl, - data: data, - backgroundColor: 'rgb(52, 132, 246, 0.6)', - // borderColor: borderColor, - borderWidth: 1, - barPercentage: 1, - categoryPercentage: 1, - borderRadius: 0, - }] - }, - options: { - // responsive: false, - // maintainAspectRatio: true, - // aspectRatio: 3, - scales: { - x: { - type: scale, - beginAtZero: false, - // offset: false, - // grid: { - // offset: false - // }, - ticks: {stepSize: 1}, - title: { - display: true, - text: oX, - font: {size: 14} - } - }, - y: { - type: scale, - suggestedMin: 1, - title: { - display: true, - text: oY, - font: {size: 14} - } + let data = await Controller.ajaxRequest('/dataset', {get: "stat",stat: st}) + // console.log(data) + $ddDiv.empty() + if (!dictFlag) + data = {"": data} + for (let [k, v] of Object.entries(data)) { + let scale = 'linear' + let type = 'bar' + if (Object.keys(v).length > 20) { + scale = 'logarithmic' + type = 'scatter' + delete v[0] + } + let $canvas = $("").css("height", "300px") + $ddDiv.append($canvas) + const ctx = $canvas[0].getContext('2d') + new Chart(ctx, { + type: type, + data: { + datasets: [{ + label: lbl, + data: v, + backgroundColor: 'rgb(52, 132, 246, 0.6)', + // borderColor: borderColor, + borderWidth: 1, + barPercentage: 1, + categoryPercentage: 1, + borderRadius: 0, + }] + }, + options: { + // responsive: false, + // maintainAspectRatio: true, + // aspectRatio: 3, + scales: { + x: { + type: scale, + beginAtZero: false, + // offset: false, + // grid: { + // offset: false + // }, + ticks: {stepSize: 1}, + title: { + display: true, + text: oX, + font: {size: 14} } }, - plugins: { + y: { + type: scale, + suggestedMin: 1, title: { display: true, - text: name, - font: {size: 16} - }, - legend: {display: false}, + text: oY, + font: {size: 14} + } } + }, + plugins: { + title: { + display: true, + text: k + ' ' + name, + font: {size: 16} + }, + legend: {display: false}, } - }) - } - }) + } + }) + } }) } @@ -161,131 +181,134 @@ class PanelDatasetView extends PanelView { // Update a dataset info panel update() { this._collapse(false) - // this.updateArgs = arguments - // if (this.collapsed) { - // return - // } - // if (dataset === this.dataset) return - // this.dataset = dataset + this.$infoDiv = $("
") + this.$statsDiv = $("
") + this.$varStatsDiv = $("
") this.$body.empty() + this.$body.append(this.$infoDiv) + this.$body.append(this.$statsDiv) + this.$body.append(this.$varStatsDiv) if (this.datasetInfo == null) { - this.$body.append('No dataset specified') + this.$infoDiv.append('No dataset specified') return } // Info let html = 'Info' html += '
' + this.getInfo() - this.$body.append(html) + this.$infoDiv.append(html) // Stats let multi = this.datasetInfo.count > 1 - this.$body.append('Statistics
') + this.$statsDiv.append('Degree statistics
') if (multi) { - this.$body.append('Graphs: ' + this.datasetInfo.count + '
') - this.$body.append('Nodes: ' + Math.min(...this.datasetInfo.nodes) + this.$statsDiv.append('Graphs: ' + this.datasetInfo.count + '
') + this.$statsDiv.append('Nodes: ' + Math.min(...this.datasetInfo.nodes) + ' — ' + Math.max(...this.datasetInfo.nodes) + '
') } else { - this.$body.append('Nodes: ' + this.datasetInfo.nodes[0] + '
') + this.$statsDiv.append('Nodes: ' + this.datasetInfo.nodes[0] + '
') + this.addNumericStat(this.$statsDiv, "Edges", "num_edges", false) + this.addNumericStat(this.$statsDiv, "Average degree", "avg_degree", true) } - this.addNumericStat("Edges", "num_edges", false) - this.addNumericStat("Average degree", "avg_deg", true) if (!multi) { - this.addNumericStat("Clustering", "CC", true) - this.addNumericStat("Triangles", "triangles", false) - this.addNumericStat("Diameter", "diameter", false) - this.addNumericStat("Number of connected components", "cc", false) - this.addNumericStat("Largest connected component size", "lcc", false) - this.addNumericStat("Degree assortativity", "degree_assortativity", true) + this.addNumericStat(this.$statsDiv, "Clustering", "clustering_coeff", true) + this.addNumericStat(this.$statsDiv, "Triangles", "num_triangles", false) + this.addNumericStat(this.$statsDiv, "Number of connected components", "num_cc", false) + this.addNumericStat(this.$statsDiv, "Giant connected component (GCC) size", "gcc_size", false) + this.addNumericStat(this.$statsDiv, "GCC relative size (relative size)", "gcc_rel_size", true) + this.addNumericStat(this.$statsDiv, "GCC diameter", "gcc_diam", false) + // this.addNumericStat(this.$statsDiv, "GCC 90% effective diameter", "gcc_diam90", true) + this.addNumericStat(this.$statsDiv, "Degree assortativity", "degree_assort", true) } if (multi) { - this.plotDistribution('Distribution of number of nodes', 'num_nodes_distr', 'Maximum nodes: ', 'Number of graphs', 'Nodes', - 'Number of graphs', multi, false) - this.plotDistribution('Distribution of average degree', 'avg_degree_distr', 'Highest average: ', 'Number of graphs', - 'Average degree', 'Number of graphs', multi, false) + this.plotDistribution(this.$statsDiv, 'Distribution of number of nodes', 'num_nodes', 'Number of graphs', 'Nodes', 'Number of graphs', false) + this.plotDistribution(this.$statsDiv, 'Distribution of number of edges', 'num_edges', 'Number of graphs', 'Edges', 'Number of graphs', false) + // this.plotDistribution(this.$statsDiv, 'Distribution of average degree', 'avg_degree', 'Highest average: ', 'Number of graphs', 'Average degree', 'Number of graphs', false) } else { - this.plotDistribution( - 'Degree distribution', 'DD', 'Maximum degree: ', 'Degree', - 'Nodes', 'Degree', multi, true) + this.plotDistribution(this.$statsDiv, 'Degree distribution', 'degree_distr', 'Degree', 'Degree', 'Nodes', this.datasetInfo.directed) let name1 = 'Attributes assortativity' let $acDiv = $("
") - this.$body.append($acDiv) + this.$statsDiv.append($acDiv) let $button1 = $("").text("get") $acDiv.append(name1 + ': ') $acDiv.append($button1) $button1.click(async () => { $button1.prop("disabled", true) - await $.ajax({ - type: 'POST', - url: '/dataset', - data: { - get: "stat", - stat: "attr_corr", - }, - success: function (data) { - data = JSON_parse(data) - let attrs = data['attributes'] - let correlations = data['correlations'] - $acDiv.empty() - $acDiv.append(name1 + ':
') + let data = Controller.ajaxRequest('/dataset', {get: "stat", stat: "attr_corr"}) - // Adds mouse listener for all elements which shows a tip with given text - let $tip = $("").addClass("tooltip-text") - $acDiv.append($tip) - let _addTip = (element, text) => { - element.onmousemove = (e) => { - $tip.show() - $tip.css("left", e.pageX + 10) - $tip.css("top", e.pageY + 15) - $tip.html(text) - } - element.onmouseout = (e) => { - $tip.hide() - } - } + let attrs = data['attributes'] + let correlations = data['correlations'] + $acDiv.empty() + $acDiv.append(name1 + ':
') - // SVG with table - let count = attrs.length - let size = Math.min(30, Math.floor(300 / count)) - let svg = document.createElementNS("http://www.w3.org/2000/svg", "svg"); - let $svg = $(svg) - .css("background-color", "#e7e7e7") - // .css("flex-shrink", "0") - .css("margin", "5px") - .css("width", (count * size) + "px") - .css("height", (count * size) + "px") - $acDiv.append($svg) - for (let j = 0; j < count; j++) { - for (let i = 0; i < count; i++) { - let rect = document.createElementNS("http://www.w3.org/2000/svg", "rect") - rect.setAttribute('x', size * i) - rect.setAttribute('y', size * j) - rect.setAttribute('width', size) - rect.setAttribute('height', size) - let color = valueToColor(correlations[i][j], CORRELATION_COLORMAP, -1, 1) - _addTip(rect, `Corr[${attrs[i]}][${attrs[j]}]=` + correlations[i][j]) - rect.setAttribute('fill', color) - rect.setAttribute('stroke', '#e7e7e7') - rect.setAttribute('stroke-width', 1) - $svg.append(rect) - } - } + // Adds mouse listener for all elements which shows a tip with given text + let $tip = $("").addClass("tooltip-text") + $acDiv.append($tip) + let _addTip = (element, text) => { + element.onmousemove = (e) => { + $tip.show() + $tip.css("left", e.pageX + 10) + $tip.css("top", e.pageY + 15) + $tip.html(text) + } + element.onmouseout = (e) => { + $tip.hide() + } + } + // SVG with table + let count = attrs.length + let size = Math.min(30, Math.floor(300 / count)) + let svg = document.createElementNS("http://www.w3.org/2000/svg", "svg"); + let $svg = $(svg) + .css("background-color", "#e7e7e7") + // .css("flex-shrink", "0") + .css("margin", "5px") + .css("width", (count * size) + "px") + .css("height", (count * size) + "px") + $acDiv.append($svg) + for (let j = 0; j < count; j++) { + for (let i = 0; i < count; i++) { + let rect = document.createElementNS("http://www.w3.org/2000/svg", "rect") + rect.setAttribute('x', size * i) + rect.setAttribute('y', size * j) + rect.setAttribute('width', size) + rect.setAttribute('height', size) + let color = valueToColor(correlations[i][j], CORRELATION_COLORMAP, -1, 1) + _addTip(rect, `Corr[${attrs[i]}][${attrs[j]}]=` + correlations[i][j]) + rect.setAttribute('fill', color) + rect.setAttribute('stroke', '#e7e7e7') + rect.setAttribute('stroke-width', 1) + $svg.append(rect) } - }) + } }) } } + // Update a dataset info panel when Var data is known + updateVar() { + this._collapse(false) + this.$varStatsDiv.empty() + this.$varStatsDiv.append('Variable data statistics') + let multi = this.datasetInfo.count > 1 + + this.plotDistribution(this.$varStatsDiv, 'Labels distribution', 'label_distr', 'Items', 'Class', 'Count', false) + + if (!multi) + this.addNumericStat(this.$varStatsDiv,'Labels assortativity', 'label_assort', true) + } + break() { super.break() + this.datasetInfo = null this.$body.append("No Dataset selected") } } \ No newline at end of file diff --git a/web_interface/static/js/presentation/right_panel/panelView.js b/web_interface/static/js/presentation/right_panel/panelView.js index f1c32b2..a5aa254 100644 --- a/web_interface/static/js/presentation/right_panel/panelView.js +++ b/web_interface/static/js/presentation/right_panel/panelView.js @@ -64,7 +64,6 @@ class PanelView extends View { break() { if (this.$body) this.$body.empty() - this.datasetInfo = null this._collapse(true) } diff --git a/web_interface/static/js/utils.js b/web_interface/static/js/utils.js index d3d49c2..a2947b2 100644 --- a/web_interface/static/js/utils.js +++ b/web_interface/static/js/utils.js @@ -487,3 +487,29 @@ async function addOptionsWithParams(id, label, options, paramsType, paramsColor) return [$cb, $optionSelect, $paramsDiv, paramsBuilder] } +/// Add checker to the value user entered in an input form +function addValueChecker($elem, type, defaultValue, min=null, max=null, on="change") { + $elem[0].addEventListener(on, function (e) { + let typeCheck = false + if (this.value === "") + typeCheck = false + else if (type === "int") + typeCheck = /^-?\d+$/.test(this.value) + else if (type === "natural") + typeCheck = /^\d+$/.test(this.value) + else if (type === "float") + typeCheck = !isNaN(this.value) + else + console.error('Unknown type to be checked:', type) + + if (!typeCheck) { + this.value = defaultValue + return + } + + if (min !== null && this.value < min) + this.value = min + if (max !== null && this.value > max) + this.value = max + }, true) // true enables capture phase to ensure this event handler runs before others +} diff --git a/web_interface/templates/analysis.html b/web_interface/templates/analysis.html new file mode 100644 index 0000000..0d08f94 --- /dev/null +++ b/web_interface/templates/analysis.html @@ -0,0 +1,18 @@ +{% extends "base.html" %} + +{% block title %} + GNN-AID - Analysis +{% endblock %} + +{% block status %} + Analysis +{% endblock %} + +{% block menu_content %} + + +{% endblock %} + +{% block panel_content %} + +{% endblock %} diff --git a/web_interface/templates/home.html b/web_interface/templates/base.html similarity index 66% rename from web_interface/templates/home.html rename to web_interface/templates/base.html index 098dd12..5990072 100644 --- a/web_interface/templates/home.html +++ b/web_interface/templates/base.html @@ -3,10 +3,11 @@ - GNN visualizer + {% block title %}{% endblock %} + @@ -16,18 +17,38 @@
-
- - - + + + + + + + + +
- - - - - - - - - - - + {% block menu_content %}{% endblock %}
-
- Explainer panel - -
- - - - - + {% block panel_content %}{% endblock %} +