From c73763a9e939bff6caf043f57723fd2e06c94db4 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Thu, 21 Nov 2024 17:14:43 +0300 Subject: [PATCH] make better files in base --- src/base/custom_datasets.py | 75 ++++++++--- src/base/datasets_processing.py | 220 ++++++++++++++++++++++++-------- src/base/ptg_datasets.py | 67 +++++++--- src/base/vk_datasets.py | 51 ++++++-- 4 files changed, 319 insertions(+), 94 deletions(-) diff --git a/src/base/custom_datasets.py b/src/base/custom_datasets.py index 912835c..2be1bf5 100644 --- a/src/base/custom_datasets.py +++ b/src/base/custom_datasets.py @@ -1,20 +1,27 @@ import json import os from pathlib import Path +from typing import Union + import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset from aux.declaration import Declare from base.datasets_processing import GeneralDataset, DatasetInfo -from aux.configs import DatasetConfig, DatasetVarConfig +from aux.configs import DatasetConfig, DatasetVarConfig, ConfigPattern from base.ptg_datasets import LocalDataset -class CustomDataset(GeneralDataset): +class CustomDataset( + GeneralDataset +): """ User-defined dataset in 'ij' format. """ - def __init__(self, dataset_config: DatasetConfig): + def __init__( + self, + dataset_config: Union[ConfigPattern, DatasetConfig] + ): """ Args: dataset_config: DatasetConfig dict from frontend @@ -27,31 +34,44 @@ def __init__(self, dataset_config: DatasetConfig): self.edge_index = None @property - def node_attributes_dir(self): + def node_attributes_dir( + self + ): """ Path to dir with node attributes. """ return self.root_dir / 'raw' / (self.name + '.node_attributes') @property - def edge_attributes_dir(self): + def edge_attributes_dir( + self + ): """ Path to dir with edge attributes. """ return self.root_dir / 'raw' / (self.name + '.edge_attributes') @property - def labels_dir(self): + def labels_dir( + self + ): """ Path to dir with labels. """ return self.root_dir / 'raw' / (self.name + '.labels') @property - def edges_path(self): + def edges_path( + self + ): """ Path to file with edge list. """ return self.root_dir / 'raw' / (self.name + '.ij') @property - def edge_index_path(self): + def edge_index_path( + self + ): """ Path to dir with labels. """ return self.root_dir / 'raw' / (self.name + '.edge_index') - def build(self, dataset_var_config: DatasetVarConfig): + def build( + self, + dataset_var_config: Union[ConfigPattern, DatasetVarConfig] + ) -> None: """ Build ptg dataset based on dataset_var_config and create DatasetVarData. """ if dataset_var_config == self.dataset_var_config: @@ -62,7 +82,10 @@ def build(self, dataset_var_config: DatasetVarConfig): self.dataset_var_config = dataset_var_config self.dataset = LocalDataset(self.results_dir, process_func=self._create_ptg) - def _compute_stat(self, stat): + def _compute_stat( + self, + stat: str + ) -> dict: """ Compute some additional stats """ if stat == "attr_corr": @@ -123,7 +146,9 @@ def _compute_stat(self, stat): else: return super()._compute_stat(stat) - def _compute_dataset_data(self): + def _compute_dataset_data( + self + ) -> None: """ Get DatasetData for debug graph Structure according to https://docs.google.com/spreadsheets/d/1fNI3sneeGoOFyIZP_spEjjD-7JX2jNl_P8CQrA4HZiI/edit#gid=1096434224 """ @@ -272,7 +297,9 @@ def _compute_dataset_data(self): # if self.info.name == "": # self.dataset_data['info']['name'] = '/'.join(self.dataset_config.full_name()) - def _create_ptg(self): + def _create_ptg( + self + ) -> None: """ Create PTG Dataset and save tensors """ if self.edge_index is None: @@ -295,7 +322,10 @@ def _create_ptg(self): self.results_dir.mkdir(exist_ok=True, parents=True) torch.save(InMemoryDataset.collate(data_list), self.results_dir / 'data.pt') - def _iter_nodes(self, graph: int = None): + def _iter_nodes( + self, + graph: int = None + ) -> None: """ Iterate over nodes according to mapping. Yields pairs of (node_index, original_id) """ # offset = sum(self.info.nodes[:graph]) if self.is_multi() else 0 @@ -308,7 +338,10 @@ def _iter_nodes(self, graph: int = None): for n in range(self.info.nodes[graph or 0]): yield offset+n, str(n) - def _labeling_tensor(self, g_ix=None) -> list: + def _labeling_tensor( + self, + g_ix=None + ) -> list: """ Returns list of labels (not tensors) """ y = [] # Read labels @@ -330,21 +363,29 @@ def _labeling_tensor(self, g_ix=None) -> list: return y - def _feature_tensor(self, g_ix=None) -> list: + def _feature_tensor( + self, + g_ix=None + ) -> list: """ Returns list of features (not tensors) for graph g_ix. """ features = self.dataset_var_config.features # dict about attributes construction nodes_onehot = "str_g" in features and features["str_g"] == "one_hot" # Read attributes - def one_hot(x, values): + def one_hot( + x: int, + values: list + ) -> list: res = [0] * len(values) for ix, v in enumerate(values): if x == v: res[ix] = 1 return res - def as_is(x): + def as_is( + x + ) -> list: return x if isinstance(x, list) else [x] # TODO other encoding types from Kirill diff --git a/src/base/datasets_processing.py b/src/base/datasets_processing.py index d696d04..1fce165 100644 --- a/src/base/datasets_processing.py +++ b/src/base/datasets_processing.py @@ -1,11 +1,14 @@ import json import os from pathlib import Path +from typing import Union, Type + import torch +import torch_geometric from torch import default_generator, randperm from torch_geometric.data import Dataset, InMemoryDataset -from aux.configs import DatasetConfig, DatasetVarConfig +from aux.configs import DatasetConfig, DatasetVarConfig, ConfigPattern from aux.custom_decorators import timing_decorator from aux.declaration import Declare from aux.utils import TORCH_GEOM_GRAPHS_PATH @@ -17,7 +20,9 @@ class DatasetInfo: Some fields are obligate, others are not. """ - def __init__(self): + def __init__( + self + ): self.name: str = "" self.count: int = None self.directed: bool = None @@ -35,7 +40,9 @@ def __init__(self): self.edge_info: dict = {} self.graph_info: dict = {} - def check_validity(self): + def check_validity( + self + ) -> None: """ Check existing fields have allowed values. """ assert self.count > 0 assert len(self.node_attributes) > 0 @@ -56,7 +63,9 @@ def check_validity(self): assert isinstance(k, str) assert isinstance(v, int) and v > 1 - def check_consistency(self): + def check_consistency( + self + ) -> None: """ Check existing fields are consistent. """ assert self.count == len(self.nodes) assert len(self.node_attributes["names"]) == len(self.node_attributes["types"]) == len( @@ -64,13 +73,18 @@ def check_consistency(self): assert len(self.edge_attributes["names"]) == len(self.edge_attributes["types"]) == len( self.edge_attributes["values"]) - def check_sufficiency(self): + def check_sufficiency( + self + ) -> None: """ Check all obligates fields are defined. """ for attr in self.__dict__.keys(): if attr is None: raise ValueError(f"Attribute '{attr}' of metainfo should be defined.") - def check_consistency_with_dataset(self, dataset: Dataset): + def check_consistency_with_dataset( + self, + dataset: Dataset + ) -> None: """ Check if metainfo fields are consistent with dataset. """ assert self.count == len(dataset) from base.ptg_datasets import is_graph_directed @@ -80,17 +94,24 @@ def check_consistency_with_dataset(self, dataset: Dataset): assert self.node_attributes["types"][0] == "other" # TODO check features values range - def check(self): + def check( + self + ) -> None: """ Check metainfo is sufficient, consistent, and valid. """ self.check_sufficiency() self.check_consistency() self.check_validity() - def to_dict(self): + def to_dict( + self + ) -> dict: """ Return info as a dictionary. """ return dict(self.__dict__) - def save(self, path: Path): + def save( + self, + path: Union[str, Path] + ) -> None: """ Save into file non-null info. """ not_nones = {k: v for k, v in self.__dict__.items() if v is not None} path.parent.mkdir(exist_ok=True, parents=True) @@ -98,7 +119,9 @@ def save(self, path: Path): json.dump(not_nones, f, indent=1) @staticmethod - def induce(dataset: Dataset): + def induce( + dataset: Dataset + ): """ Induce metainfo from a given PTG dataset. """ res = DatasetInfo() res.count = len(dataset) @@ -116,7 +139,9 @@ def induce(dataset: Dataset): return res @staticmethod - def read(path: Path): + def read( + path: Union[str, Path] + ): """ Read info from a file. """ with path.open('r') as f: a_dict = json.load(f) @@ -128,7 +153,9 @@ def read(path: Path): return res @staticmethod - def node_attributes_to_node_attr_slices(node_attributes): + def node_attributes_to_node_attr_slices( + node_attributes: dict + ) -> dict: node_attr_slices = {} start_attr_index = 0 for i in range(len(node_attributes['names'])): @@ -145,7 +172,12 @@ def node_attributes_to_node_attr_slices(node_attributes): class VisiblePart: - def __init__(self, gen_dataset, center: [int, list] = None, depth: [int] = None): + def __init__( + self, + gen_dataset, + center: [int, list] = None, + depth: [int] = None + ): """ Compute a part of dataset specified by a center node/graph and a depth :param gen_dataset: @@ -226,10 +258,14 @@ def __init__(self, gen_dataset, center: [int, list] = None, depth: [int] = None) self.nodes = gen_dataset.info.nodes[0] self._ixes = list(range(self.nodes)) - def ixes(self): + def ixes( + self + ) -> list: return self._ixes - def as_dict(self): + def as_dict( + self + ) -> dict: res = {} if self.nodes: res['nodes'] = self.nodes @@ -239,7 +275,10 @@ def as_dict(self): res['graphs'] = self.graphs return res - def filter(self, array): + def filter( + self, + array + ) -> dict: """ Suppose ixes = [2,4]: [a, b, c, d] -> {2: b, 4: d} """ return {ix: array[ix] for ix in self._ixes} @@ -249,7 +288,10 @@ class GeneralDataset: """ Generalisation of PTG and user-defined datasets: custom, VK, etc. """ - def __init__(self, dataset_config: DatasetConfig): + def __init__( + self, + dataset_config: Union[DatasetConfig, ConfigPattern] + ): """ Args: dataset_config: DatasetConfig dict from frontend @@ -281,51 +323,71 @@ def __init__(self, dataset_config: DatasetConfig): self._labels = None @property - def root_dir(self): + def root_dir( + self + ): """ Dataset root directory with folders 'raw' and 'prepared'. """ # FIXME Misha, dataset_prepared_dir return path and files_paths not only path return Declare.dataset_root_dir(self.dataset_config)[0] @property - def results_dir(self): + def results_dir( + self + ): """ Path to 'prepared/../' folder where tensor data is stored. """ # FIXME Misha, dataset_prepared_dir return path and files_paths not only path return Path(Declare.dataset_prepared_dir(self.dataset_config, self.dataset_var_config)[0]) @property - def raw_dir(self): + def raw_dir( + self + ): """ Path to 'raw/' folder where raw data is stored. """ return self.root_dir / 'raw' @property - def api_path(self): + def api_path( + self + ): """ Path to '.api' file. Could be not present. """ return self.root_dir / '.api' @property - def info_path(self): + def info_path( + self + ): """ Path to '.info' file. """ return self.root_dir / 'raw' / '.info' @property - def stats_dir(self): + def stats_dir( + self + ): """ Path to '.stats' directory. """ return self.root_dir / '.stats' @property - def data(self): + def data( + self + ): return self.dataset._data @property - def num_classes(self): + def num_classes( + self + ): return self.dataset.num_classes @property - def num_node_features(self): + def num_node_features( + self + ): return self.dataset.num_node_features @property - def labels(self): + def labels( + self + ): if self._labels is None: # NOTE: this is a copy from torch_geometric.data.dataset v=2.3.1 from torch_geometric.data.dataset import _get_flattened_data_list @@ -333,22 +395,34 @@ def labels(self): self._labels = torch.cat([data.y for data in data_list if 'y' in data], dim=0) return self._labels - def __len__(self): + def __len__( + self + ) -> int: return self.info.count - def domain(self): + def domain( + self + ) -> str: return self.dataset_config.domain - def is_multi(self): + def is_multi( + self + ) -> bool: """ Return whether this dataset is multiple-graphs or single-graph. """ return self.info.count > 1 - def build(self, dataset_var_config: DatasetVarConfig): + def build( + self, + dataset_var_config: Union[ConfigPattern, DatasetVarConfig] + ): """ Create node feature tensors from attributes based on dataset_var_config. """ raise NotImplementedError() - def get_dataset_data(self, part=None): + def get_dataset_data( + self, + part: Union[dict, None] = None + ) -> dict: """ Get DatasetData for specified graphs or nodes """ edges_list = [] @@ -379,7 +453,9 @@ def get_dataset_data(self, part=None): return res - def _compute_dataset_data(self): + def _compute_dataset_data( + self + ) -> None: num = len(self.dataset) data_list = [self.dataset.get(ix) for ix in range(num)] is_directed = self.info.directed @@ -427,13 +503,19 @@ def _compute_dataset_data(self): # if self.info.name == "": # self.dataset_data['info']['name'] = '/'.join(self.dataset_config.full_name()) - def set_visible_part(self, part: dict): + def set_visible_part( + self, + part: dict + ) -> None: if self.dataset_data is None: self._compute_dataset_data() self.visible_part = VisiblePart(self, **part) - def get_dataset_var_data(self, part=None): + def get_dataset_var_data( + self, + part: Union[dict, None] = None + ) -> dict: """ Get DatasetVarData for specified graphs or nodes """ if self.dataset_var_data is None: @@ -455,7 +537,9 @@ def get_dataset_var_data(self, part=None): return dataset_var_data - def _compute_dataset_var_data(self): + def _compute_dataset_var_data( + self + ) -> None: """ Prepare dataset_var_data for frontend on demand. """ # FIXME version fail in torch-geom 2.3.1 @@ -485,7 +569,10 @@ def _compute_dataset_var_data(self): "labels": labels if self.is_multi() else labels[0], } - def get_stat(self, stat): + def get_stat( + self, + stat + ): """ Get statistics. """ if stat in self.stats: @@ -511,7 +598,10 @@ def get_stat(self, stat): json.dump(value, f, ensure_ascii=False) return value - def _compute_stat(self, stat): + def _compute_stat( + self, + stat + ): """ Compute statistics. """ if self.is_multi(): # try: @@ -608,7 +698,9 @@ def _compute_stat(self, stat): value = str(e) return value - def is_one_hot_able(self): + def is_one_hot_able( + self + ) -> bool: """ Return whether features are 1-hot encodings. If yes, nodes can be colored. """ assert self.dataset_var_config @@ -633,12 +725,16 @@ def is_one_hot_able(self): elif features['attr'][attr] == 'other': # Check honestly each feature vector feats = self.dataset_var_data['features'] - res = all(all(all(x == 1 or x == 0 for x in f) for f in feat) for feat in feats) and\ + res = all(all(all(x == 1 or x == 0 for x in f) for f in feat) for feat in feats) and \ all(all(sum(f) == 1 for f in feat) for feat in feats) return res - def train_test_split(self, percent_train_class: float = 0.8, percent_test_class: float = 0.2): + def train_test_split( + self, + percent_train_class: float = 0.8, + percent_test_class: float = 0.2 + ) -> None: """ Compute train-validation-test split of graphs/nodes. """ self.percent_train_class = percent_train_class self.percent_test_class = percent_test_class @@ -680,7 +776,10 @@ def train_test_split(self, percent_train_class: float = 0.8, percent_test_class: self.dataset.data.test_mask = test_mask self.dataset.data.val_mask = val_mask - def save_train_test_mask(self, path): + def save_train_test_mask( + self, + path: Union[str, Path] + ) -> None: """ Save current train/test mask to a given path (together with the model). """ if path is not None: path /= 'train_test_split' @@ -700,7 +799,9 @@ class DatasetManager: @staticmethod def register_torch_geometric_local( - dataset: InMemoryDataset, name: str = None) -> GeneralDataset: + dataset: InMemoryDataset, + name: str = None + ) -> GeneralDataset: """ Save a given PTG dataset locally. Dataset is then always available for use by its config. @@ -717,8 +818,10 @@ def register_torch_geometric_local( # QUE Misha, Kirill - can we use get_by_config always instead of it? @staticmethod @timing_decorator - def get_by_config(dataset_config: DatasetConfig, - dataset_var_config: DatasetVarConfig = None) -> GeneralDataset: + def get_by_config( + dataset_config: DatasetConfig, + dataset_var_config: DatasetVarConfig = None + ) -> GeneralDataset: """ Get GeneralDataset by dataset config. Used from the frontend. """ dataset_group = dataset_config.group @@ -743,7 +846,10 @@ def get_by_config(dataset_config: DatasetConfig, @staticmethod @timing_decorator - def get_by_full_name(full_name=None, **kwargs): + def get_by_full_name( + full_name=None, + **kwargs + ) -> [GeneralDataset, torch_geometric.data.Data, Path]: """ Get PTG dataset by its full name tuple. Starts the creation of an object from raw data or takes already saved datasets in prepared @@ -770,8 +876,10 @@ def get_by_full_name(full_name=None, **kwargs): @staticmethod def register_torch_geometric_api( - dataset: Dataset, name: str = None, - obj_name: str = 'DATASET_TO_EXPORT') -> GeneralDataset: + dataset: Dataset, + name: str = None, + obj_name: str = 'DATASET_TO_EXPORT' + ) -> GeneralDataset: """ Register a user defined code implementing a PTG dataset. This function should be called at each framework run to make the dataset available for use. @@ -797,7 +905,9 @@ def register_torch_geometric_api( return gen_dataset @staticmethod - def register_custom_ij(path: Path) -> GeneralDataset: + def register_custom_ij( + path: Path + ) -> GeneralDataset: """ :return: GeneralDataset """ @@ -805,8 +915,12 @@ def register_custom_ij(path: Path) -> GeneralDataset: @staticmethod def _register_torch_geometric( - dataset: Dataset, name=None, group=None, - exists_ok=False, copy_data=False) -> GeneralDataset: + dataset: Dataset, + name: Union[str, None] = None, + group: str = None, + exists_ok: bool = False, + copy_data: bool = False + ) -> GeneralDataset: """ Create GeneralDataset from an externally specified torch geometric dataset. @@ -874,7 +988,9 @@ def _register_torch_geometric( return gen_dataset -def is_in_torch_geometric_datasets(full_name=None): +def is_in_torch_geometric_datasets( + full_name: tuple = None +) -> bool: from aux.prefix_storage import PrefixStorage with open(TORCH_GEOM_GRAPHS_PATH, 'r') as f: return PrefixStorage.from_json(f.read()).check(full_name) diff --git a/src/base/ptg_datasets.py b/src/base/ptg_datasets.py index 5cd555f..f42f058 100644 --- a/src/base/ptg_datasets.py +++ b/src/base/ptg_datasets.py @@ -3,16 +3,20 @@ import os import shutil from pathlib import Path +from typing import Union, Callable + import torch from torch_geometric.data import InMemoryDataset, Data, Dataset from torch_geometric.data.data import BaseData from aux.utils import import_by_name, root_dir, root_dir_len from base.datasets_processing import GeneralDataset, is_in_torch_geometric_datasets, DatasetInfo -from aux.configs import DatasetConfig, DatasetVarConfig +from aux.configs import DatasetConfig, DatasetVarConfig, ConfigPattern -class PTGDataset(GeneralDataset): +class PTGDataset( + GeneralDataset +): """ Contains a PTG dataset. """ attr_name = 'unknown' @@ -22,7 +26,11 @@ class PTGDataset(GeneralDataset): dataset_ver_ind=0 ) - def __init__(self, dataset_config: DatasetConfig, **kwargs): + def __init__( + self, + dataset_config: Union[ConfigPattern, DatasetConfig], + **kwargs + ): """ :param dataset_config: dataset config dictionary :param kwargs: additional args to init torch dataset class @@ -127,14 +135,20 @@ def __init__(self, dataset_config: DatasetConfig, **kwargs): # raise FileNotFoundError( # f"No data found for dataset '{self.dataset_config.full_name()}'") - def move_processed(self, processed: (str, Path)): + def move_processed( + self, + processed: Union[str, Path] + ) -> None: if not self.results_dir.exists(): self.results_dir.mkdir(parents=True) os.rename(processed, self.results_dir) else: shutil.rmtree(processed) - def move_raw(self, raw: (str, Path)): + def move_raw( + self, + raw: Union[str, Path] + ) -> None: if Path(raw) == self.raw_dir: return if not self.raw_dir.exists(): @@ -143,25 +157,39 @@ def move_raw(self, raw: (str, Path)): else: raise RuntimeError(f"raw_dir '{self.raw_dir}' already exists") - def _compute_dataset_data(self, center=None, depth=None): + def _compute_dataset_data( + self, + center=None, + depth: Union[int, None] = None + ) -> None: # assert len(name_type) == 1 # FIXME dataset_data = super()._compute_dataset_data() # FIXME add features return dataset_data - def build(self, dataset_var_config: dict=None): + def build( + self, + dataset_var_config: dict = None + ) -> None: """ PTG dataset is already built """ # Use cached ptg dataset. Only default dataset_var_config is allowed. assert self.dataset_var_config == dataset_var_config -class LocalDataset(InMemoryDataset): +class LocalDataset( + InMemoryDataset +): """ Locally saved PTG Dataset. """ - def __init__(self, results_dir, process_func=None, **kwargs): + def __init__( + self, + results_dir: Union[str, Path], + process_func: Union[Callable, None] = None, + **kwargs + ): """ :param results_dir: @@ -172,7 +200,7 @@ def __init__(self, results_dir, process_func=None, **kwargs): if process_func: self.process = process_func # Init and process if needed - super().__init__(None, **kwargs) + super().__init__(None, **kwargs) # Load self.data, *rest_data = torch.load(self.processed_paths[0]) @@ -180,22 +208,31 @@ def __init__(self, results_dir, process_func=None, **kwargs): try: self.slices = rest_data[0] # TODO can use rest_data[1] ? - except IndexError: pass + except IndexError: + pass @property - def processed_file_names(self): + def processed_file_names( + self + ): return 'data.pt' - def process(self): + def process( + self + ): raise RuntimeError("Dataset is supposed to be processed and saved earlier.") # torch.save(self.collate(self.data_list), self.processed_paths[0]) @property - def processed_dir(self) -> str: + def processed_dir( + self + ) -> str: return self.results_dir -def is_graph_directed(data: (Data, BaseData)) -> bool: +def is_graph_directed( + data: Union[Data, BaseData] +) -> bool: """ Detect whether graph is directed or not (for each edge i->j, exists j->i). """ # Note: this does not work correctly. E.g. for TUDataset/MUTAG it incorrectly says directed. diff --git a/src/base/vk_datasets.py b/src/base/vk_datasets.py index 4d48349..936ae11 100644 --- a/src/base/vk_datasets.py +++ b/src/base/vk_datasets.py @@ -6,6 +6,7 @@ from numbers import Number from operator import itemgetter from pathlib import Path +from typing import Union import numpy as np @@ -23,7 +24,8 @@ class AttrInfo: _attribute_vals_cache = {} # (full_name, attribute) -> attribute_vals @staticmethod - def vk_attr(): + def vk_attr( + ): vk_dict = { ('age',): list(range(0, len(AGE_GROUPS) + 1)), ('sex',): [1, 2], @@ -39,7 +41,10 @@ def vk_attr(): return vk_dict @staticmethod - def attribute_vals(full_name, attribute: [str, tuple, list]) -> list: + def attribute_vals( + full_name: tuple, + attribute: [str, tuple, list] + ) -> list: """ Get a set of possible attribute values or None for textual or continuous attributes. """ # Convert to tuple @@ -90,10 +95,14 @@ def attribute_vals(full_name, attribute: [str, tuple, list]) -> list: return res @staticmethod - def one_hot(full_name, attribute: [str, tuple, list], value, add_none=False): + def one_hot( + full_name: tuple, + attribute: [str, tuple, list], + value: int, + add_none: bool = False + ) -> Union[np.ndarray, list]: """ 1-hot encoding feature. If no such value, return all zeros or with 1 it in last element. :param full_name: - :param graph: MyGraph :param attribute: attribute name, e.g. 'sex', ('personal', 'smoking'). :param value: value of this attribute. If a list, a multiple-hot vector will be constructed. :param add_none: if True, last element of returned vector encodes undefined or @@ -142,11 +151,21 @@ def one_hot(full_name, attribute: [str, tuple, list], value, add_none=False): return res # all zeros here -class VKDataset(CustomDataset): +class ConfigPatter: + pass + + +class VKDataset( + CustomDataset +): """ Custom dataset of VK samples with specific attributes processing and features creation. """ - def __init__(self, dataset_config: DatasetConfig, add_none=False): + def __init__( + self, + dataset_config: Union[ConfigPatter, DatasetConfig], + add_none: bool = False + ): """ Args: dataset_config: DatasetConfig dict from frontend @@ -156,7 +175,9 @@ def __init__(self, dataset_config: DatasetConfig, add_none=False): super().__init__(dataset_config) self.add_none = add_none - def _compute_dataset_data(self): + def _compute_dataset_data( + self + ): """ Get DatasetData for VK graph """ super()._compute_dataset_data() @@ -174,7 +195,10 @@ def _compute_dataset_data(self): # labelings[filename] = max([-1 if x is None else x for x in d.values()]) + 1 # self.dataset_data["info"]["labelings"] = labelings - def _feature_tensor(self, g_ix=None) -> list: + def _feature_tensor( + self, + g_ix=None + ) -> list: # FIXME Misha self.node_map[graph] ... x = [[] for _ in range(len(self.node_map))] features = self.dataset_var_config.features @@ -204,7 +228,10 @@ def _feature_tensor(self, g_ix=None) -> list: return x @staticmethod - def bdate_to_age(attr_dir_path: str, node_map: list): + def bdate_to_age( + attr_dir_path: str, + node_map: list + ) -> None: with open(attr_dir_path / Path('bdate'), 'r') as f: age_dict = json.load(f) node_age = {} @@ -223,7 +250,11 @@ def bdate_to_age(attr_dir_path: str, node_map: list): json.dump(node_age, f1) -def make_vk_labeling(attr_path: str, labeling_path: str, attr_val: int = 1): +def make_vk_labeling( + attr_path: str, + labeling_path: str, + attr_val: int = 1 +) -> None: """ Creates a markup file where the attribute's target value is set to 1 and the rest to 0 Args: