From 87854b303ccebb506accfe3f2d0bae73216224b7 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Thu, 21 Nov 2024 15:38:09 +0300 Subject: [PATCH] make better files in aux, models_zoo.py + some small fix --- src/aux/custom_decorators.py | 36 +++++++++--- src/aux/data_info.py | 85 ++++++++++++++++++++------- src/aux/declaration.py | 48 +++++++++++---- src/aux/prefix_storage.py | 69 +++++++++++++++++----- src/aux/utils.py | 25 ++++++-- src/models_builder/gnn_constructor.py | 3 +- src/models_builder/gnn_models.py | 1 + src/models_builder/models_zoo.py | 6 +- 8 files changed, 210 insertions(+), 63 deletions(-) diff --git a/src/aux/custom_decorators.py b/src/aux/custom_decorators.py index f6b9cb2..7f290bd 100644 --- a/src/aux/custom_decorators.py +++ b/src/aux/custom_decorators.py @@ -2,19 +2,28 @@ from functools import wraps import logging import functools +from typing import Callable logging.basicConfig(level=logging.INFO) -def retry(max_tries=3, delay_seconds=1): +def retry( + max_tries: int = 3, + delay_seconds: int = 1 +): """ Allows you to re-execute the program after the Nth amount of time :param max_tries: number of restart attempts :param delay_seconds: time interval between attempts """ - def decorator_retry(func): + def decorator_retry( + func: Callable + ) -> Callable: @wraps(func) - def wrapper_retry(*args, **kwargs): + def wrapper_retry( + *args, + **kwargs + ): tries = 0 while tries < max_tries: try: @@ -30,13 +39,17 @@ def wrapper_retry(*args, **kwargs): return decorator_retry -def memoize(func): +def memoize( + func: Callable +) -> Callable: """ Caching function """ cache = {} - def wrapper(*args): + def wrapper( + *args + ): if args in cache: return cache[args] else: @@ -47,11 +60,16 @@ def wrapper(*args): return wrapper -def timing_decorator(func): +def timing_decorator( + func: Callable +) -> Callable: """ Timing functions """ - def wrapper(*args, **kwargs): + def wrapper( + *args, + **kwargs + ): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() @@ -61,7 +79,9 @@ def wrapper(*args, **kwargs): return wrapper -def log_execution(func): +def log_execution( + func: Callable +) -> Callable: """ Function call logging """ diff --git a/src/aux/data_info.py b/src/aux/data_info.py index 390823a..ae82230 100644 --- a/src/aux/data_info.py +++ b/src/aux/data_info.py @@ -2,6 +2,8 @@ import importlib.util import json import logging +from typing import List, Tuple, Union + # from pydantic.utils import deep_update # from pydantic.v1.utils import deep_update @@ -24,7 +26,8 @@ class DataInfo: """ @staticmethod - def refresh_all_data_info(): + def refresh_all_data_info( + ) -> None: """ Calling all files to update with information about saved objects """ @@ -35,7 +38,8 @@ def refresh_all_data_info(): DataInfo.refresh_explanations_dir_structure() @staticmethod - def refresh_data_dir_structure(): + def refresh_data_dir_structure( + ) -> None: """ Calling a file update with information about saved raw datasets """ @@ -51,7 +55,8 @@ def refresh_data_dir_structure(): prev_path = path @staticmethod - def refresh_models_dir_structure(): + def refresh_models_dir_structure( + ) -> None: """ Calling a file update with information about saved models """ @@ -62,7 +67,8 @@ def refresh_models_dir_structure(): f.write(str(Path(*path)) + '\n') @staticmethod - def refresh_explanations_dir_structure(): + def refresh_explanations_dir_structure( + ) -> None: """ Calling a file update with information about saved explanations """ @@ -73,7 +79,8 @@ def refresh_explanations_dir_structure(): f.write(str(Path(*path)) + '\n') @staticmethod - def refresh_data_var_dir_structure(): + def refresh_data_var_dir_structure( + ) -> None: """ Calling a file update with information about saved prepared datasets """ @@ -84,7 +91,9 @@ def refresh_data_var_dir_structure(): f.write(str(Path(*path)) + '\n') @staticmethod - def take_keys_etc_by_prefix(prefix): + def take_keys_etc_by_prefix( + prefix: Tuple + ) -> [List, List, dict, int]: """ :param prefix: what data and in what order were used to form the path when saving the object @@ -114,7 +123,11 @@ def take_keys_etc_by_prefix(prefix): return keys_list, full_keys_list, dir_structure, empty_dir_shift @staticmethod - def values_list_by_path_and_keys(path, full_keys_list, dir_structure): + def values_list_by_path_and_keys( + path: Union[str, Path], + full_keys_list: List, + dir_structure: dict + ) -> List: """ :param path: path of the saved object @@ -133,7 +146,10 @@ def values_list_by_path_and_keys(path, full_keys_list, dir_structure): return parts_val @staticmethod - def values_list_and_technical_files_by_path_and_prefix(path, prefix): + def values_list_and_technical_files_by_path_and_prefix( + path: Union[str, Path], + prefix: Tuple + ) -> [List, dict]: """ :param path: path of the saved object @@ -169,12 +185,16 @@ def values_list_and_technical_files_by_path_and_prefix(path, prefix): else: file_name = file_info_dict["file_name"] file_name += file_info_dict["format"] - description_info.update({key: {parts_val[-1]: os.path.join(os.path.join(*path[:parts_parse]), file_name)}}) + description_info.update( + {key: {parts_val[-1]: os.path.join(os.path.join(*path[:parts_parse]), file_name)}}) parts_parse += 1 return parts_val, description_info @staticmethod - def fill_prefix_storage(prefix, file_with_paths): + def fill_prefix_storage( + prefix: Tuple, + file_with_paths: Union[str, Path] + ) -> [PrefixStorage, dict]: """ Fill prefix storage by file with paths @@ -183,7 +203,7 @@ def fill_prefix_storage(prefix, file_with_paths): :param file_with_paths: file with paths of saved objects :return: fill prefix storage and dict with description_info about objects use hash """ - keys_list, full_keys_list, dir_structure, empty_dir_shift =\ + keys_list, full_keys_list, dir_structure, empty_dir_shift = \ DataInfo.take_keys_etc_by_prefix(prefix=prefix) ps = PrefixStorage(keys_list) with open(file_with_paths, 'r', encoding='utf-8') as f: @@ -202,7 +222,10 @@ def fill_prefix_storage(prefix, file_with_paths): return ps, description_info @staticmethod - def deep_update(d, u): + def deep_update( + d: dict, + u: dict + ) -> dict: for k, v in u.items(): if isinstance(v, collections.abc.Mapping): d[k] = DataInfo.deep_update(d.get(k, {}), v) @@ -211,14 +234,19 @@ def deep_update(d, u): return d @staticmethod - def description_info_with_paths_to_description_info_with_files_values(description_info, root_path): + def description_info_with_paths_to_description_info_with_files_values( + description_info: dict, + root_path: Union[str, Path] + ) -> dict: for description_info_key, description_info_val in description_info.items(): for obj_name, obj_file_path in description_info_val.items(): with open(os.path.join(root_path, obj_file_path)) as f: description_info[description_info_key][obj_name] = f.read() return description_info + @staticmethod - def explainers_parse(): + def explainers_parse( + ) -> [PrefixStorage, dict]: """ Parses the path to explainers from a technical file with the paths of all saved explainers. """ @@ -232,7 +260,8 @@ def explainers_parse(): return ps, description_info @staticmethod - def models_parse(): + def models_parse( + ) -> [PrefixStorage, dict]: """ Parses the path to models from a technical file with the paths of all saved models. """ @@ -246,7 +275,8 @@ def models_parse(): return ps, description_info @staticmethod - def data_parse(): + def data_parse( + ) -> [PrefixStorage, dict]: """ Parses the path to raw datasets from a technical file with the paths of all saved raw datasets. """ @@ -257,7 +287,8 @@ def data_parse(): return ps @staticmethod - def data_var_parse(): + def data_var_parse( + ) -> [PrefixStorage, dict]: """ Parses the path to prepared datasets from a technical file with the paths of all saved prepared datasets. """ @@ -268,7 +299,9 @@ def data_var_parse(): return ps @staticmethod - def clean_prepared_data(dry_run=False): + def clean_prepared_data( + dry_run: bool = False + ) -> None: """ Remove all prepared data for all datasets. """ @@ -279,7 +312,9 @@ def clean_prepared_data(dry_run=False): shutil.rmtree(path) @staticmethod - def all_obj_ver_by_obj_path(obj_dir_path): + def all_obj_ver_by_obj_path( + obj_dir_path: Union[str, Path] + ) -> set: """ :param obj_dir_path: path to the saved object @@ -294,7 +329,9 @@ def all_obj_ver_by_obj_path(obj_dir_path): return set(vers_ind) @staticmethod - def del_all_empty_folders(dir_path): + def del_all_empty_folders( + dir_path: Union[str, Path] + ) -> None: """ Deletes all empty folders and files with meta information in the selected directory @@ -316,7 +353,8 @@ def del_all_empty_folders(dir_path): class UserCodeInfo: @staticmethod - def user_models_list_ref(): + def user_models_list_ref( + ) -> dict: """ :return: dict with information about user models objects in directory /user_model_list Contains information about objects class name, objects names and import paths @@ -371,7 +409,10 @@ def models_init(): return user_models_obj_dict_info @staticmethod - def take_user_model_obj(user_file_path, obj_name: str): + def take_user_model_obj( + user_file_path: Union[str, Path], + obj_name: str + ) -> object: """ :param user_file_path: path to the user file with user model :param obj_name: user object name diff --git a/src/aux/declaration.py b/src/aux/declaration.py index 7988911..ee3eefc 100644 --- a/src/aux/declaration.py +++ b/src/aux/declaration.py @@ -1,5 +1,7 @@ import json +from typing import Union, Type +from aux.configs import DatasetConfig, DatasetVarConfig from aux.utils import MODELS_DIR, GRAPHS_DIR, EXPLANATIONS_DIR, hash_data_sha256, \ SAVE_DIR_STRUCTURE_PATH import os @@ -12,7 +14,11 @@ class Declare: """ @staticmethod - def obj_info_to_path(what_save=None, previous_path=None, obj_info=None): + def obj_info_to_path( + what_save: str = None, + previous_path: Union[str, Path] = None, + obj_info: Union[None, list, tuple, dict] = None + ) -> [Path, list]: """ :param what_save: the path for which object is being built. Now support: data_root, data_prepared, models, explanations @@ -87,7 +93,9 @@ def obj_info_to_path(what_save=None, previous_path=None, obj_info=None): return path, files_paths @staticmethod - def dataset_root_dir(dataset_config): + def dataset_root_dir( + dataset_config: DatasetConfig + ) -> [Path, list]: """ :param dataset_config: DatasetConfig :return: forms the path to the data folder and adds to it the path to a specific dataset @@ -99,7 +107,10 @@ def dataset_root_dir(dataset_config): return path, files_paths @staticmethod - def dataset_prepared_dir(dataset_config, dataset_var_config): + def dataset_prepared_dir( + dataset_config: DatasetConfig, + dataset_var_config: DatasetVarConfig + ) -> [Path, list]: """ :param dataset_config: DatasetConfig :param dataset_var_config: DatasetVarConfig @@ -128,7 +139,9 @@ def dataset_prepared_dir(dataset_config, dataset_var_config): return path, files_paths @staticmethod - def models_path(class_obj): + def models_path( + class_obj: Type + ) -> [Path, list]: """ :param class_obj: class base on GNNModelManager :return: The path where the model will be saved @@ -189,8 +202,8 @@ def declare_model_by_config( mi_attack_hash: str, evasion_attack_hash: str, poison_attack_hash: str, - epochs=None, - ): + epochs: Union[int, str] = None, + ) -> [Path, list]: """ Formation of the way to save the path of the model in the root of the project according to its hyperparameters and features @@ -199,6 +212,12 @@ def declare_model_by_config( :param model_ver_ind: index of explain version :param gnn_name: gnn hash :param epochs: number of epochs during which the model was trained + :param mi_defense_hash: + :param evasion_defense_hash: + :param poison_defense_hash: + :param mi_attack_hash: + :param evasion_attack_hash: + :param poison_attack_hash: :return: the path where the model is saved use information from ModelConfig """ if not isinstance(model_ver_ind, int) or model_ver_ind < 0: @@ -223,16 +242,19 @@ def declare_model_by_config( return path, files_paths @staticmethod - def explanation_file_path(models_path: str, explainer_name: str, - explainer_ver_ind: int = None, - explainer_run_kwargs=None, explainer_init_kwargs=None): + def explanation_file_path( + models_path: str, + explainer_name: str, + explainer_ver_ind: int = None, + explainer_run_kwargs: dict = None, + explainer_init_kwargs: dict = None + ) -> [Path, list]: """ :param explainer_init_kwargs: dict with kwargs for explainer class :param explainer_run_kwargs:dict with kwargs for run explanation :param models_path: model path :param explainer_name: explainer name. Example: Zorro :param explainer_ver_ind: index of explain version - :param explainer_attack_type: type of attack on explainer. Now support: original :return: path for explanations result file and list with technical files """ explainer_init_kwargs = explainer_init_kwargs.copy() @@ -279,7 +301,10 @@ def explanation_file_path(models_path: str, explainer_name: str, return path, files_paths @staticmethod - def explainer_kwargs_path_full(model_path, explainer_path): + def explainer_kwargs_path_full( + model_path: Union[str, Path], + explainer_path: Union[str, Path] + ) -> list: """ :param model_path: model path :param explainer_path: explanation path @@ -287,6 +312,7 @@ def explainer_kwargs_path_full(model_path, explainer_path): """ path = Path(str(model_path).replace(str(MODELS_DIR), str(EXPLANATIONS_DIR))) what_save = "explanations" + # BUG Misha, check is correct next line, because in def obj_info_to_path can't be Path or str obj_info = explainer_path _, files_paths = Declare.obj_info_to_path(what_save=what_save, previous_path=path, diff --git a/src/aux/prefix_storage.py b/src/aux/prefix_storage.py index 376a780..8b3c844 100644 --- a/src/aux/prefix_storage.py +++ b/src/aux/prefix_storage.py @@ -1,6 +1,7 @@ import json from json.encoder import JSONEncoder from pathlib import Path +from typing import Union class PrefixStorage: @@ -10,26 +11,38 @@ class PrefixStorage: * adding, removing, filtering, iterating elements; * gathering contents from file structure. """ - def __init__(self, keys: (tuple, list)): + def __init__( + self, + keys: Union[tuple, list] + ): assert isinstance(keys, (tuple, list)) assert len(keys) >= 1 self._keys = keys self.content = {} if len(keys) > 1 else set() @property - def depth(self): + def depth( + self + ) -> int: return len(self._keys) @property - def keys(self): + def keys( + self + ) -> tuple: return tuple(self._keys) - def size(self): + def size( + self + ) -> int: def count(obj): return sum(count(_) for _ in obj.values()) if isinstance(obj, dict) else len(obj) return count(self.content) - def add(self, values: (dict, tuple, list)): + def add( + self, + values: Union[dict, tuple, list] + ) -> None: """ Add one list of values. """ @@ -56,7 +69,11 @@ def add(obj, depth): else: raise TypeError("dict, tuple, or list were expected") - def merge(self, ps, ignore_conflicts=False): + def merge( + self, + ps, + ignore_conflicts: bool = False + ) -> None: """ Extend this with another PrefixStorage with same keys. if ignore_conflicts=True, do not raise Exception when values sets intersect. @@ -80,7 +97,10 @@ def merge(content1, content2): merge(self.content, ps.content) - def remove(self, values: (dict, tuple, list)): + def remove( + self, + values: Union[dict, tuple, list] + ) -> None: """ Remove one tuple of values if it is present. """ @@ -97,7 +117,10 @@ def rm(obj, depth): rm(self.content, 0) - def filter(self, key_values: dict): + def filter( + self, + key_values: dict + ): """ Find all items satisfying specified key values. Returns a new PrefixStorage. """ @@ -132,7 +155,10 @@ def filter(obj, depth): ps.content = filter(self.content, 0) return ps - def check(self, values: (dict, tuple, list)): + def check( + self, + values: Union[dict, tuple, list] + ) -> bool: """ Check if a tuple of values is present. """ @@ -151,7 +177,9 @@ def check(self, values: (dict, tuple, list)): else: return False - def __iter__(self): + def __iter__( + self + ): def enum(obj, elems): if isinstance(obj, (set, list)): for e in obj: @@ -165,7 +193,9 @@ def enum(obj, elems): yield _ @staticmethod - def from_json(string): + def from_json( + string: str + ): """ Construct PrefixStorage object from a json string. """ @@ -174,7 +204,10 @@ def from_json(string): ps.content = data["content"] return ps - def to_json(self, **dump_args): + def to_json( + self, + **dump_args + ) -> str: """ Return json string. """ class Encoder(JSONEncoder): def default(self, obj): @@ -183,7 +216,11 @@ def default(self, obj): return json.JSONEncoder.default(self, obj) return json.dumps({"keys": self.keys, "content": self.content}, cls=Encoder, **dump_args) - def fill_from_folder(self, path: Path, file_pattern=r".*"): + def fill_from_folder( + self, + path: Path, + file_pattern: str = r".*" + ) -> None: """ Recursively walk over the given folder and repeat its structure. The content will be replaced. @@ -208,7 +245,11 @@ def walk(p, elems): self.add(e) print(f"Added {self.size()} items of {len(res)} files found.") - def remap(self, mapping, only_values=False): + def remap( + self, + mapping, + only_values: bool = False + ): """ Change keys order and combination. """ diff --git a/src/aux/utils.py b/src/aux/utils.py index 2acd503..017951b 100644 --- a/src/aux/utils.py +++ b/src/aux/utils.py @@ -3,6 +3,8 @@ import warnings from pathlib import Path from pydoc import locate +from typing import Union, Type, Any + import numpy as np root_dir = Path(__file__).parent.parent.parent.resolve() # directory of source root @@ -40,11 +42,16 @@ TECHNICAL_PARAMETER_KEY = "_technical_parameter" -def hash_data_sha256(data): +def hash_data_sha256( + data +) -> str: return hashlib.sha256(data).hexdigest() -def import_by_name(name: str, packs: list = None): +def import_by_name( + name: str, + packs: list = None +) -> None: """ Import name from packages, return class :param name: class name, full or relative @@ -63,7 +70,9 @@ def import_by_name(name: str, packs: list = None): raise ImportError(f"Unknown {packs} model '{name}', couldn't import.") -def model_managers_info_by_names_list(model_managers_names: set): +def model_managers_info_by_names_list( + model_managers_names: set +) -> dict: """ :param model_managers_names: set with model managers class names (user and framework) :return: dict with info about model managers @@ -86,7 +95,11 @@ def model_managers_info_by_names_list(model_managers_names: set): return model_managers_info -def setting_class_default_parameters(class_name: str, class_kwargs: dict, default_parameters_file_path): +def setting_class_default_parameters( + class_name: str, + class_kwargs: dict, + default_parameters_file_path: Union[str, Path] +) -> [dict, dict]: """ :param class_name: class name, should be same in default_parameters_file :param class_kwargs: dict with parameters, which needs to be supplemented with default parameters @@ -143,6 +156,8 @@ def setting_class_default_parameters(class_name: str, class_kwargs: dict, defaul return class_kwargs_for_save, class_kwargs_for_init -def all_subclasses(cls): +def all_subclasses( + cls: Type[Any] +) -> set: return set(cls.__subclasses__()).union( [s for c in cls.__subclasses__() for s in all_subclasses(c)]) diff --git a/src/models_builder/gnn_constructor.py b/src/models_builder/gnn_constructor.py index e09e8a5..3df22e8 100644 --- a/src/models_builder/gnn_constructor.py +++ b/src/models_builder/gnn_constructor.py @@ -6,7 +6,6 @@ import torch from torch.nn.parameter import UninitializedParameter from torch.utils import hooks -from ..parameter import Parameter from torch.utils.hooks import RemovableHandle from torch_geometric.nn import MessagePassing @@ -577,7 +576,7 @@ def get_predictions( def get_parameters( self - ) -> Iterator[Parameter]: + ) -> Iterator: return self.parameters() def get_answer( diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index 846b418..bd547b6 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -1228,6 +1228,7 @@ def run_model( def evaluate_model( self, gen_dataset: DatasetManager, + metrics: Union[List[Metric], Metric] ) -> dict: """ Compute metrics for a model result on a part of dataset specified by the metric mask. diff --git a/src/models_builder/models_zoo.py b/src/models_builder/models_zoo.py index 72003bd..a8a9b18 100644 --- a/src/models_builder/models_zoo.py +++ b/src/models_builder/models_zoo.py @@ -1,8 +1,12 @@ +from base.datasets_processing import DatasetManager from models_builder.gnn_constructor import FrameworkGNNConstructor from aux.configs import ModelConfig, ModelStructureConfig -def model_configs_zoo(dataset, model_name): +def model_configs_zoo( + dataset: DatasetManager, + model_name: str +): gat_gin_lin = FrameworkGNNConstructor( model_config=ModelConfig( structure=ModelStructureConfig(