diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index c906a32..846b418 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -2,10 +2,14 @@ import json import random from math import ceil +from pathlib import Path from types import FunctionType +from typing import Callable, List, Union, Any, Type, Protocol + import numpy as np import torch import sklearn.metrics +from flask_socketio import SocketIO from torch.nn.utils import clip_grad_norm from torch import tensor import torch.nn.functional as F @@ -19,6 +23,7 @@ hash_data_sha256, \ TECHNICAL_PARAMETER_KEY, IMPORT_INFO_KEY, OPTIMIZERS_PARAMETERS_PATH, FUNCTIONS_PARAMETERS_PATH from aux.declaration import Declare +from base.datasets_processing import DatasetManager from explainers.explainer import ProgressBar from explainers.ProtGNN.MCTS import mcts_args from attacks.evasion_attacks import EvasionAttacker @@ -44,7 +49,10 @@ class Metric: } @staticmethod - def add_custom(name, compute_function): + def add_custom( + name: str, + compute_function: Callable + ) -> None: """ Register a custom metric. Example for accuracy: @@ -60,7 +68,12 @@ def add_custom(name, compute_function): raise NameError(f"Metric '{name}' already registered, use another name") Metric.available_metrics[name] = compute_function - def __init__(self, name, mask, **kwargs): + def __init__( + self, + name: str, + mask: Union[str, List[bool]], + **kwargs + ): """ :param name: name to refer to this metric :param mask: 'train', 'val', 'test', or a bool valued list @@ -70,16 +83,22 @@ def __init__(self, name, mask, **kwargs): self.mask = mask self.kwargs = kwargs - def compute(self, y_true, y_pred): + def compute( + self, + y_true, + y_pred + ): if self.name in Metric.available_metrics: if y_true.device != "cpu": y_true = y_true.cpu() return Metric.available_metrics[self.name](y_true, y_pred, **self.kwargs) - raise NotImplementedError() @staticmethod - def create_mask_by_target_list(y_true, target_list=None): + def create_mask_by_target_list( + y_true, + target_list: List = None + ) -> torch.Tensor: if target_list is None: mask = [True] * len(y_true) else: @@ -88,7 +107,6 @@ def create_mask_by_target_list(y_true, target_list=None): if 0 <= i < len(mask): mask[i] = True return tensor(mask) - # return mask class GNNModelManager: @@ -96,9 +114,11 @@ class GNNModelManager: training, evaluation, save and load principle """ - def __init__(self, - manager_config=None, - modification: ModelModificationConfig = None): + def __init__( + self, + manager_config: ModelManagerConfig = None, + modification: ModelModificationConfig = None + ): """ :param manager_config: socket to use for sending data to frontend :param modification: socket to use for sending data to frontend @@ -118,11 +138,7 @@ def __init__(self, # else: # raise Exception() - # if modification is None: - # modification = ModelModificationConfig() - if modification is None: - # raise RuntimeError("model manager config must be specified") modification = ConfigPattern( _config_class="ModelModificationConfig", _config_kwargs={}, @@ -141,10 +157,6 @@ def __init__(self, # QUE Kirill do we need to store it? maybe pass when need to self.dataset_path = None - - # FIXME Kirill, remove self.gen_dataset - self.gen_dataset = None - self.mi_defender = None self.mi_defense_name = None self.mi_defense_config = None @@ -184,28 +196,48 @@ def __init__(self, self.set_evasion_attacker() self.set_evasion_defender() - def train_model(self, **kwargs): + def train_model( + self, + **kwargs + ): pass - def train_1_step(self, gen_dataset): + def train_1_step( + self, + gen_dataset: DatasetManager + ): """ Perform 1 step of model training. """ # raise NotImplementedError() pass - def train_complete(self, gen_dataset, steps=None, **kwargs): + def train_complete( + self, + gen_dataset: DatasetManager, + steps: int = None, + **kwargs + ) -> None: """ """ # raise NotImplementedError() pass - def train_on_batch(self, batch, **kwargs): + def train_on_batch( + self, + batch, + **kwargs + ): pass - def evaluate_model(self, **kwargs): + def evaluate_model( + self, + **kwargs + ): pass - def get_name(self): + def get_name( + self + ) -> str: manager_name = self.manager_config.to_saveable_dict() # FIXME Kirill, make ModelManagerConfig and remove manager_name[CONFIG_CLASS_NAME] manager_name[CONFIG_CLASS_NAME] = self.__class__.__name__ @@ -215,13 +247,20 @@ def get_name(self): json_str = json.dumps(manager_name, indent=2) return json_str - def load_model(self, path=None, **kwargs): + def load_model( + self, + path: Union[str, Path] = None, + **kwargs + ) -> Type: """ Load model from torch save format """ raise NotImplementedError() - def save_model(self, path=None): + def save_model( + self, + path: Union[str, Path] = None + ) -> None: """ Save the model in torch format @@ -230,11 +269,17 @@ def save_model(self, path=None): """ raise NotImplementedError() - def model_path_info(self): + def model_path_info( + self + ) -> Union[str, Path]: path, _ = Declare.models_path(self) return path - def load_model_executor(self, path=None, **kwargs): + def load_model_executor( + self, + path: Union[str, Path, None] = None, + **kwargs + ) -> Union[str, Path]: """ Load executor. Generates the download model path if no other path is specified. @@ -267,7 +312,9 @@ class variables self.gnn.eval() return model_dir_path - def get_hash(self): + def get_hash( + self + ) -> str: """ calculates the hash on behalf of the model manager required for storage. The sha256 algorithm is used. @@ -277,15 +324,18 @@ def get_hash(self): gnn_MM_name_hash = hash_data_sha256(json_object.encode('utf-8')) return gnn_MM_name_hash - def save_model_executor(self, path=None, files_paths=None): + def save_model_executor( + self, + path: Union[str, Path, None] = None, + files_paths: List[Union[str, Path]] = None + ) -> Path: """ Save executor, generates paths and prepares all information about the model and its parameters for saving - :param gnn_architecture_path: path to save the architecture of the model, - by default it forms the path itself. :param path: path to save the model. By default, the path is compiled based on the global class variables + :param files_paths: """ if path is None: dir_path, files_paths = Declare.models_path(self) @@ -319,7 +369,11 @@ def save_model_executor(self, path=None, files_paths=None): f.write(self.mi_attack_config.json_for_config()) return path.parent - def set_poison_attacker(self, poison_attack_config=None, poison_attack_name: str = None): + def set_poison_attacker( + self, + poison_attack_config: PoisonAttackConfig = None, + poison_attack_name: str = None + ) -> None: if poison_attack_config is None: poison_attack_config = ConfigPattern( _class_name=poison_attack_name or "EmptyPoisonAttacker", @@ -357,7 +411,11 @@ def set_poison_attacker(self, poison_attack_config=None, poison_attack_name: str ) self.poison_attack_flag = True - def set_evasion_attacker(self, evasion_attack_config=None, evasion_attack_name: str = None): + def set_evasion_attacker( + self, + evasion_attack_config: EvasionAttackConfig = None, + evasion_attack_name: str = None + ) -> None: if evasion_attack_config is None: evasion_attack_config = ConfigPattern( _class_name=evasion_attack_name or "EmptyEvasionAttacker", @@ -393,7 +451,11 @@ def set_evasion_attacker(self, evasion_attack_config=None, evasion_attack_name: ) self.evasion_attack_flag = True - def set_mi_attacker(self, mi_attack_config=None, mi_attack_name: str = None): + def set_mi_attacker( + self, + mi_attack_config: MIAttackConfig = None, + mi_attack_name: str = None + ) -> None: if mi_attack_config is None: mi_attack_config = ConfigPattern( _class_name=mi_attack_name or "EmptyMIAttacker", @@ -429,7 +491,11 @@ def set_mi_attacker(self, mi_attack_config=None, mi_attack_name: str = None): ) self.mi_attack_flag = True - def set_poison_defender(self, poison_defense_config=None, poison_defense_name: str = None): + def set_poison_defender( + self, + poison_defense_config: PoisonDefenseConfig = None, + poison_defense_name: str = None + ) -> None: if poison_defense_config is None: poison_defense_config = ConfigPattern( _class_name=poison_defense_name or "EmptyPoisonDefender", @@ -465,7 +531,11 @@ def set_poison_defender(self, poison_defense_config=None, poison_defense_name: s ) self.poison_defense_flag = True - def set_evasion_defender(self, evasion_defense_config=None, evasion_defense_name: str = None): + def set_evasion_defender( + self, + evasion_defense_config: EvasionDefenseConfig = None, + evasion_defense_name: str = None + ) -> None: if evasion_defense_config is None: evasion_defense_config = ConfigPattern( _class_name=evasion_defense_name or "EmptyEvasionDefender", @@ -501,7 +571,11 @@ def set_evasion_defender(self, evasion_defense_config=None, evasion_defense_name ) self.evasion_defense_flag = True - def set_mi_defender(self, mi_defense_config=None, mi_defense_name: str = None): + def set_mi_defender( + self, + mi_defense_config: MIDefenseConfig = None, + mi_defense_name: str = None + ) -> None: """ """ @@ -541,15 +615,21 @@ def set_mi_defender(self, mi_defense_config=None, mi_defense_name: str = None): self.mi_defense_flag = True @staticmethod - def available_attacker(): + def available_attacker( + ): pass @staticmethod - def available_defender(): + def available_defender( + ): pass @staticmethod - def from_model_path(model_path, dataset_path, **kwargs): + def from_model_path( + model_path: dict, + dataset_path: Union[str, Path], + **kwargs + ) -> [Type, Path]: """ Use information about model and model manager take gnn model, create gnn model manager object and load weights to the save model @@ -610,7 +690,9 @@ def from_model_path(model_path, dataset_path, **kwargs): return gnn_model_manager_obj, model_dir_path - def get_full_info(self): + def get_full_info( + self + ) -> dict: """ Get available info about model for frontend """ @@ -623,7 +705,9 @@ def get_full_info(self): result["epochs"] = f"Epochs={self.epochs}" return result - def get_model_data(self): + def get_model_data( + self + ) -> dict: """ :return: dict with the available functions of the model manager by the 'functions' key. """ @@ -638,7 +722,9 @@ def get_own_functions(cls): return model_data @staticmethod - def take_gnn_obj(gnn_file): + def take_gnn_obj( + gnn_file: Union[str, Path] + ) -> Type: with open(gnn_file) as f: params = json.load(f) class_name = params.pop(CONFIG_CLASS_NAME) @@ -667,22 +753,34 @@ def take_gnn_obj(gnn_file): obj_name) return gnn - def before_epoch(self, gen_dataset): + def before_epoch( + self, + gen_dataset: DatasetManager + ): """ This hook is called before training the next training epoch """ pass - def after_epoch(self, gen_dataset): + def after_epoch( + self, + gen_dataset: DatasetManager + ): """ This hook is called after training the next training epoch """ pass - def before_batch(self, batch): + def before_batch( + self, + batch + ): """ This hook is called before training the next training batch """ pass - def after_batch(self, batch): + def after_batch( + self, + batch + ): """ This hook is called after training the next training batch """ pass @@ -725,8 +823,8 @@ class FrameworkGNNModelManager(GNNModelManager): to prevent leakage of the response during training. """ - def __init__(self, gnn=None, - dataset_path=None, + def __init__(self, gnn: Type = None, + dataset_path: Union[str, Path] = None, **kwargs ): """ @@ -770,7 +868,9 @@ def __init__(self, gnn=None, if self.gnn is not None: self.init() - def init(self): + def init( + self + ) -> None: """ Initialize optimizer and loss function. """ @@ -785,7 +885,14 @@ def init(self): if "loss_function" in getattr(self.manager_config, CONFIG_OBJ): self.loss_function = getattr(self.manager_config, CONFIG_OBJ).loss_function.create_obj() - def train_complete(self, gen_dataset, steps=None, pbar=None, metrics=None, **kwargs): + def train_complete( + self, + gen_dataset: DatasetManager, + steps: int = None, + pbar: Protocol = None, + metrics: Union[List[Metric], Metric] = None, + **kwargs + ) -> None: for _ in range(steps): self.before_epoch(gen_dataset) print("epoch", self.modification.epochs) @@ -800,10 +907,19 @@ def train_complete(self, gen_dataset, steps=None, pbar=None, metrics=None, **kwa if early_stopping_flag: break - def early_stopping(self, train_loss, gen_dataset, metrics, steps): + def early_stopping( + self, + train_loss, + gen_dataset: DatasetManager, + metrics: Union[List[Metric], Metric], + steps: int + ) -> bool: return False - def train_1_step(self, gen_dataset): + def train_1_step( + self, + gen_dataset: DatasetManager + ) -> List[Union[float, int]]: task_type = gen_dataset.domain() if task_type == "single-graph": # FIXME Kirill, add data_x_copy mask @@ -828,9 +944,13 @@ def train_1_step(self, gen_dataset): print("loss %.8f" % loss) self.modification.epochs += 1 self.gnn.eval() - return loss.cpu().detach().numpy().tolist() + return loss.detach().numpy().tolist() - def train_on_batch_full(self, batch, task_type=None): + def train_on_batch_full( + self, + batch, + task_type: str = None + ) -> torch.Tensor: if self.mi_defender: self.mi_defender.pre_batch() if self.evasion_defender: @@ -848,12 +968,19 @@ def train_on_batch_full(self, batch, task_type=None): loss = self.optimizer_step(loss=loss) return loss - def optimizer_step(self, loss): + def optimizer_step( + self, + loss: torch.Tensor + ) -> torch.Tensor: loss.backward() self.optimizer.step() return loss - def train_on_batch(self, batch, task_type=None): + def train_on_batch( + self, + batch, + task_type: str = None + ) -> torch.Tensor: loss = None if hasattr(batch, "edge_weight"): weight = batch.edge_weight @@ -893,11 +1020,18 @@ def train_on_batch(self, batch, task_type=None): raise ValueError("Unsupported task type") return loss - def get_name(self, **kwargs): + def get_name( + self, + **kwargs + ) -> str: json_str = super().get_name() return json_str - def load_model(self, path=None, **kwargs): + def load_model( + self, + path: Union[str, Path, None] = None, + **kwargs + ) -> Type: """ Load model from torch save format @@ -914,7 +1048,10 @@ class variables self.init() return self.gnn - def save_model(self, path=None): + def save_model( + self, + path: Union[str, Path] = None + ) -> None: """ Save the model in torch format @@ -923,7 +1060,12 @@ def save_model(self, path=None): """ torch.save(self.gnn.state_dict(), path) - def report_results(self, train_loss, gen_dataset, metrics): + def report_results( + self, + train_loss, + gen_dataset: DatasetManager, + metrics: List[Metric] + ) -> None: metrics_values = self.evaluate_model(gen_dataset=gen_dataset, metrics=metrics) self.compute_stats_data(gen_dataset, predictions=True, logits=True) self.send_epoch_results( @@ -932,8 +1074,15 @@ def report_results(self, train_loss, gen_dataset, metrics): for k, v in self.stats_data.items()}, weights={"weights": self.gnn.get_weights()}, loss=train_loss) - def train_model(self, gen_dataset, save_model_flag=True, mode=None, steps=None, metrics=None, - socket=None): + def train_model( + self, + gen_dataset: DatasetManager, + save_model_flag: bool = True, + mode: Union[str, None] = None, + steps=None, + metrics: List[Metric] = None, + socket: SocketIO = None + ) -> None: """ Convenient train method. @@ -988,7 +1137,12 @@ def train_model(self, gen_dataset, save_model_flag=True, mode=None, steps=None, finally: self.socket = None - def run_model(self, gen_dataset, mask='test', out='answers'): + def run_model( + self, + gen_dataset: DatasetManager, + mask: Union[str, List[bool], torch.Tensor] = 'test', + out: str = 'answers' + ) -> torch.Tensor: """ Run the model on a part of dataset specified with a mask. @@ -1019,7 +1173,7 @@ def run_model(self, gen_dataset, mask='test', out='answers'): dataset = gen_dataset.dataset part_loader = DataLoader( dataset.index_select(mask), batch_size=self.batch, shuffle=False) - full_out = torch.Tensor() + full_out = torch.empty(0) # y_true = torch.Tensor() if hasattr(self, 'optimizer'): self.optimizer.zero_grad() @@ -1071,7 +1225,10 @@ def run_model(self, gen_dataset, mask='test', out='answers'): return full_out - def evaluate_model(self, gen_dataset, metrics): + def evaluate_model( + self, + gen_dataset: DatasetManager, + ) -> dict: """ Compute metrics for a model result on a part of dataset specified by the metric mask. @@ -1111,7 +1268,12 @@ def evaluate_model(self, gen_dataset, metrics): self.mi_attacker.attack() return metrics_values - def compute_stats_data(self, gen_dataset, predictions=False, logits=False): + def compute_stats_data( + self, + gen_dataset: DatasetManager, + predictions: bool = False, + logits: bool = False + ): """ :param gen_dataset: wrapper over the dataset, stores the dataset and all meta-information about the dataset @@ -1132,7 +1294,14 @@ def compute_stats_data(self, gen_dataset, predictions=False, logits=False): logits = self.run_model(gen_dataset, mask='all', out='logits') self.stats_data["embeddings"] = logits.detach().cpu().tolist() - def send_data(self, block, msg, tag='model', obligate=True, socket=None): + def send_data( + self, + block, + msg, + tag='model', + obligate=True, + socket=None + ): """ Send data to the frontend. @@ -1151,8 +1320,15 @@ def send_data(self, block, msg, tag='model', obligate=True, socket=None): socket.send(block=block, msg=msg, tag=tag, obligate=obligate) return True - def send_epoch_results(self, metrics_values=None, stats_data=None, weights=None, loss=None, obligate=False, - socket=None): + def send_epoch_results( + self, + metrics_values=None, + stats_data=None, + weights=None, + loss=None, + obligate=False, + socket=None + ): """ Send updates to the frontend after a training epoch: epoch, metrics, logits, loss. @@ -1174,7 +1350,10 @@ def send_epoch_results(self, metrics_values=None, stats_data=None, weights=None, if stats_data: self.send_data("mt", stats_data, tag='model_stats', obligate=obligate, socket=socket) - def load_train_test_split(self, gen_dataset): + def load_train_test_split( + self, + gen_dataset: DatasetManager + ) -> DatasetManager: path = self.model_path_info() path = path / 'train_test_split' gen_dataset.train_mask, gen_dataset.val_mask, gen_dataset.test_mask, _ = torch.load(path)[:] @@ -1182,10 +1361,6 @@ def load_train_test_split(self, gen_dataset): class ProtGNNModelManager(FrameworkGNNModelManager): - # additional_config = ModelManagerConfig( - # loss_function={CONFIG_CLASS_NAME: "CrossEntropyLoss"}, - # mask_features=[], - # ) additional_config = ConfigPattern( _config_class="ModelManagerConfig", _config_kwargs={ @@ -1208,10 +1383,17 @@ class ProtGNNModelManager(FrameworkGNNModelManager): } ) - def __init__(self, gnn=None, dataset_path=None, **kwargs): + def __init__( + self, + gnn: Type = None, + dataset_path: Union[str, Path] = None, + **kwargs + ): super().__init__(gnn=gnn, dataset_path=dataset_path, **kwargs) # Get prot layer and its params + self.is_best = None + self.cur_acc = None self.prot_layer = getattr(self.gnn, self.gnn.prot_layer_name) _config_obj = getattr(self.manager_config, CONFIG_OBJ) self.clst = _config_obj.clst @@ -1231,7 +1413,10 @@ def __init__(self, gnn=None, dataset_path=None, **kwargs): self.gnn.best_prots = self.prot_layer.prototype_graphs self.best_acc = 0.0 - def save_model(self, path=None): + def save_model( + self, + path: Union[str, Path, None] = None + ) -> None: """ Save the model in torch format @@ -1242,7 +1427,11 @@ def save_model(self, path=None): "best_prots": self.gnn.best_prots, }, path) - def load_model(self, path=None, **kwargs): + def load_model( + self, + path: Union[str, Path, None] = None, + **kwargs + ) -> Type: """ Load model from torch save format @@ -1259,7 +1448,11 @@ class variables self.init() return self.gnn - def train_on_batch(self, batch, task_type=None): + def train_on_batch( + self, + batch, + task_type: str = None + ) -> torch.Tensor: if task_type == "single-graph": self.optimizer.zero_grad() logits = self.gnn(batch.x, batch.edge_index) @@ -1322,13 +1515,19 @@ def train_on_batch(self, batch, task_type=None): raise ValueError("Unsupported task type") return loss - def optimizer_step(self, loss): + def optimizer_step( + self, + loss: torch.Tensor + ) -> torch.Tensor: loss.backward() torch.nn.utils.clip_grad_value_(self.gnn.parameters(), clip_value=2.0) self.optimizer.step() return loss - def before_epoch(self, gen_dataset): + def before_epoch( + self, + gen_dataset: DatasetManager + ): cur_step = self.modification.epochs train_ind = [n for n, x in enumerate(gen_dataset.train_mask) if x] # Prototype projection @@ -1351,9 +1550,12 @@ def after_epoch(self, gen_dataset): # check if best model metrics_values = self.evaluate_model( - gen_dataset, metrics=[Metric("Accuracy", mask='val'), - Metric("Precision", mask='val'), - Metric("Recall", mask='val')]) + gen_dataset, metrics=[ + Metric("Accuracy", mask='val'), + Metric("Precision", mask='val'), + Metric("Recall", mask='val') + ] + ) self.cur_acc = metrics_values['val']["Accuracy"] self.is_best = (self.cur_acc - self.best_acc >= 0.01) @@ -1362,7 +1564,13 @@ def after_epoch(self, gen_dataset): self.early_stop_count = 0 self.gnn.best_prots = self.prot_layer.prototype_graphs - def early_stopping(self, train_loss, gen_dataset, metrics, steps): + def early_stopping( + self, + train_loss, + gen_dataset: DatasetManager, + metrics: Union[List[Metric], Metric], + steps: int + ) -> bool: step = self.modification.epochs if self.is_best: self.early_stop_count = 0 diff --git a/src/models_builder/models_utils.py b/src/models_builder/models_utils.py index d060e45..c1b8a6b 100644 --- a/src/models_builder/models_utils.py +++ b/src/models_builder/models_utils.py @@ -1,8 +1,13 @@ +from typing import Any + import torch from torch_geometric.nn import MessagePassing -def apply_message_gradient_capture(layer, name): +def apply_message_gradient_capture( + layer: Any, + name: str +) -> None: """ # Example how get Tensors # for name, layer in self.gnn.named_children(): @@ -12,23 +17,32 @@ def apply_message_gradient_capture(layer, name): original_message = layer.message layer.message_gradients = {} - def capture_message_gradients(x_j, *args, **kwargs): + def capture_message_gradients( + x_j: torch.Tensor, + *args, + **kwargs + ): x_j = x_j.requires_grad_() if not layer.training: return original_message(x_j=x_j, *args, **kwargs) - def save_message_grad(grad): + def save_message_grad( + grad: torch.Tensor + ) -> None: layer.message_gradients[name] = grad.detach() x_j.register_hook(save_message_grad) return original_message(x_j=x_j, *args, **kwargs) layer.message = capture_message_gradients - def get_message_gradients(): + def get_message_gradients( + ) -> dict: return layer.message_gradients layer.get_message_gradients = get_message_gradients -def apply_decorator_to_graph_layers(model): +def apply_decorator_to_graph_layers( + model: Any +) -> None: # TODO Kirill add more options """ Example how use this def