diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index 3f3316b..a93aad7 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -765,7 +765,7 @@ def train_complete(self, gen_dataset, steps=None, pbar=None, metrics=None, **kwa train_loss = self.train_1_step(gen_dataset) self.after_epoch(gen_dataset) early_stopping_flag = self.early_stopping(train_loss=train_loss, gen_dataset=gen_dataset, - metrics=metrics) + metrics=metrics, steps=steps) if self.socket: self.report_results(train_loss=train_loss, gen_dataset=gen_dataset, metrics=metrics) @@ -773,7 +773,7 @@ def train_complete(self, gen_dataset, steps=None, pbar=None, metrics=None, **kwa if early_stopping_flag: break - def early_stopping(self, train_loss, gen_dataset, metrics): + def early_stopping(self, train_loss, gen_dataset, metrics, steps): return False def train_1_step(self, gen_dataset): @@ -921,8 +921,9 @@ def train_model(self, gen_dataset, save_model_flag=True, mode=None, steps=None, assert issubclass(type(self), GNNModelManager) assert mode in ['1_step', 'full', None] - has_complete = self.train_complete != super(type(self), self).train_complete - assert has_complete + # TODO Kirill what is this? Outdated? + # has_complete = self.train_complete != super(type(self), self).train_complete + # assert has_complete do_1_step = True try: @@ -1140,12 +1141,7 @@ def load_train_test_split(self, gen_dataset): gen_dataset.train_mask, gen_dataset.val_mask, gen_dataset.test_mask, _ = torch.load(path)[:] return gen_dataset - -# FIXME George class ProtGNNModelManager(FrameworkGNNModelManager): - """ - Prot layer needs a special training procedure. - """ # additional_config = ModelManagerConfig( # loss_function={CONFIG_CLASS_NAME: "CrossEntropyLoss"}, # mask_features=[], @@ -1154,14 +1150,14 @@ class ProtGNNModelManager(FrameworkGNNModelManager): _config_class="ModelManagerConfig", _config_kwargs={ "mask_features": [], - # "optimizer": { - # "_config_class": "Config", - # "_class_name": "Adam", - # "_import_path": OPTIMIZERS_PARAMETERS_PATH, - # "_class_import_info": ["torch.optim"], - # "_config_kwargs": {}, - # }, - # FUNCTIONS_PARAMETERS_PATH, + "optimizer": { + "_config_class": "Config", + "_class_name": "Adam", + "_import_path": OPTIMIZERS_PARAMETERS_PATH, + "_class_import_info": ["torch.optim"], + "_config_kwargs": {}, + }, + #FUNCTIONS_PARAMETERS_PATH, "loss_function": { "_config_class": "Config", "_class_name": "CrossEntropyLoss", @@ -1172,6 +1168,29 @@ class ProtGNNModelManager(FrameworkGNNModelManager): } ) + def __init__(self, gnn=None, dataset_path=None, **kwargs): + super().__init__(gnn=gnn, dataset_path=dataset_path, **kwargs) + + # Get prot layer and its params + self.prot_layer = getattr(self.gnn, self.gnn.prot_layer_name) + _config_obj = getattr(self.manager_config, CONFIG_OBJ) + self.clst = _config_obj.clst + self.sep = _config_obj.sep + #lr = _config_obj.lr + self.early_stopping_marker = _config_obj.early_stopping + self.proj_epochs = _config_obj.proj_epochs + self.warm_epoch = _config_obj.warm_epoch + self.save_epoch = _config_obj.save_epoch + self.save_thrsh = _config_obj.save_thrsh + # TODO implement other MCTS args too + # TODO MCTS args via static ? + mcts_args.min_atoms = _config_obj.mcts_min_atoms + mcts_args.max_atoms = _config_obj.mcts_max_atoms + self.prot_thrsh = _config_obj.prot_thrsh + self.early_stop_count = 0 + self.gnn.best_prots = self.prot_layer.prototype_graphs + self.best_acc = 0.0 + def save_model(self, path=None): """ Save the model in torch format @@ -1200,317 +1219,133 @@ class variables self.init() return self.gnn - def evaluate_model(self, gen_dataset, metrics): - """ - Compute metrics for a model result on a part of dataset specified by the metric mask. - - :param gen_dataset: wrapper over the dataset, stores the dataset and all meta-information about the dataset - :param metrics: list of metrics ot compute - :return: dict {metric -> value} - """ - mask_metrics = {} - for metric in metrics: - mask = metric.mask - if mask not in mask_metrics: - mask_metrics[mask] = [] - mask_metrics[mask].append(metric) - - metrics_values = {} - for mask, ms in mask_metrics.items(): - metrics_values[mask] = {} - y_pred = self.run_model(gen_dataset, mask=mask) - try: - mask_tensor = { - 'train': gen_dataset.train_mask.tolist(), - 'val': gen_dataset.val_mask.tolist(), - 'test': gen_dataset.test_mask.tolist(), - 'all': [True] * len(gen_dataset.labels), - }[mask] - except KeyError: - assert isinstance(mask, list) - y_true = torch.tensor([y for m, y in zip(mask_tensor, gen_dataset.labels) if m]) - - for metric in ms: - metrics_values[mask][metric.name] = metric.compute(y_pred, y_true) - # metrics_values[mask][metric.name] = MetricManager.compute(metric, y_pred, y_true) - - return metrics_values - - def train_full(self, gen_dataset, steps=None, metrics=None): - """ - Train ProtGNN model for Graph classification - """ - metrics = metrics or [] - # TODO Misha can we split into 1-step funcs? - # do we need steps here? + # TODO Kirill train_on_batch to be divided into two parts + def train_on_batch(self, batch, task_type=None): + if self.mi_defender: + self.mi_defender.pre_batch() + if self.evasion_defender: + self.evasion_defender.pre_batch(model_manager=self, batch=batch) + loss = None + if task_type == "single-graph": + # TODO Prot with single-graph + self.optimizer.zero_grad() + logits = self.gnn(batch.x, batch.edge_index) + min_distances = self.gnn.min_distances + + # cluster loss + self.prot_layer.prototype_class_identity = self.prot_layer.prototype_class_identity + prototypes_of_correct_class = torch.t( + self.prot_layer.prototype_class_identity[:, batch.y].bool()) + cluster_cost = torch.mean( + torch.min(min_distances[prototypes_of_correct_class] + .reshape(-1, self.prot_layer.num_prototypes_per_class), dim=1)[0]) + + # seperation loss + separation_cost = -torch.mean( + torch.min(min_distances[~prototypes_of_correct_class].reshape(-1, ( + self.prot_layer.output_dim - 1) * self.prot_layer.num_prototypes_per_class), + dim=1)[0]) + + # sparsity loss + l1_mask = 1 - torch.t(self.prot_layer.prototype_class_identity) + l1 = (self.prot_layer.last_layer.weight * l1_mask).norm(p=1) + + # diversity loss + ld = 0 + # TODO expreriments required. With zero coeff - meaningless + # for k in range(prot_layer.output_dim): + # p = prot_layer.prototype_vectors[ + # k * prot_layer.num_prototypes_per_class: + # (k + 1) * prot_layer.num_prototypes_per_class] + # p = F.normalize(p, p=2, dim=1) + # matrix1 = torch.mm(p, torch.t(p)) - torch.eye(p.shape[0]) - 0.3 + # matrix2 = torch.zeros(matrix1.shape) + # ld += torch.sum(torch.where(matrix1 > 0, matrix1, matrix2)) - # Get prot layer and its params - prot_layer = getattr(self.gnn, self.gnn.prot_layer_name) - _config_obj = getattr(self.manager_config, CONFIG_OBJ) - clst = _config_obj.clst - sep = _config_obj.sep - lr = _config_obj.lr - early_stopping = _config_obj.early_stopping - proj_epochs = _config_obj.proj_epochs - warm_epoch = _config_obj.warm_epoch - save_epoch = _config_obj.save_epoch - save_thrsh = _config_obj.save_thrsh - # TODO implement other MCTS args too - mcts_args.min_atoms = _config_obj.mcts_min_atoms - mcts_args.max_atoms = _config_obj.mcts_max_atoms - prot_thrsh = _config_obj.prot_thrsh + loss = self.loss_function(logits, batch.y) + loss += self.clst * cluster_cost + self.sep * separation_cost + 5e-4 * l1 + 0.00 * ld + if self.clip is not None: + clip_grad_norm(self.gnn.parameters(), self.clip) + self.optimizer.zero_grad() + # loss.backward() + # self.optimizer.step() + elif task_type == "multiple-graphs": + self.optimizer.zero_grad() + logits = self.gnn(batch.x, batch.edge_index, batch.batch) + loss = self.loss_function(logits, batch.y) + # loss.backward() + # self.optimizer.step() + # TODO Kirill, remove False when release edge recommendation task + elif task_type == "edge" and False: + self.optimizer.zero_grad() + edge_index = batch.edge_index + pos_edge_index = edge_index[:, batch.y == 1] + neg_edge_index = edge_index[:, batch.y == 0] - print(f"cluster loss cost: {clst}", f"separation loss cost: {sep}", sep='\n') + pos_out = self.gnn(batch.x, pos_edge_index) + neg_out = self.gnn(batch.x, neg_edge_index) - # TODO Misha use save_model_flag and other params - # TODO Misha add checkpoint + pos_loss = self.loss_function(pos_out, torch.ones_like(pos_out)) + neg_loss = self.loss_function(neg_out, torch.zeros_like(neg_out)) - # criterion = torch.nn.CrossEntropyLoss() - # FIXME use optimizer from manager_config and its LR - self.optimizer = torch.optim.Adam(self.gnn.parameters(), lr=lr) + loss = pos_loss + neg_loss + # loss.backward() + else: + raise ValueError("Unsupported task type") + if self.mi_defender: + self.mi_defender.post_batch() + evasion_defender_dict = None + if self.evasion_defender: + evasion_defender_dict = self.evasion_defender.post_batch( + model_manager=self, batch=batch, loss=loss, + ) + if evasion_defender_dict and "loss" in evasion_defender_dict: + loss = evasion_defender_dict["loss"] + loss.backward() + torch.nn.utils.clip_grad_value_(self.gnn.parameters(), clip_value=2.0) + self.optimizer.step() + return loss - dataset = gen_dataset.dataset - data = gen_dataset.data - train_dataset = dataset.index_select(gen_dataset.train_mask) - train_loader = DataLoader(train_dataset, batch_size=self.batch, shuffle=False) + def before_epoch(self, gen_dataset): + cur_step = self.modification.epochs train_ind = [n for n, x in enumerate(gen_dataset.train_mask) if x] - # val_ind = [n for n, x in enumerate(gen_dataset.val_mask) if x] - - # data.x = data.x.float() - - avg_nodes = 0.0 - avg_edge_index = 0.0 - for i in range(len(dataset)): - avg_nodes += dataset[i].x.shape[0] - avg_edge_index += dataset[i].edge_index.shape[1] - avg_nodes /= len(dataset) - avg_edge_index /= len(dataset) - print( - f"graphs {len(dataset)}, avg_nodes{avg_nodes :.4f}, avg_edge_index_{avg_edge_index / 2 :.4f}") - - best_acc = 0.0 - data_size = len(dataset) - print(f'The total num of dataset is {data_size}') - - early_stop_count = 0 - best_prots = prot_layer.prototype_graphs - # data_indices = train_loader.dataset.indices - for step in range(steps): - acc = [] - precision = [] - recall = [] - loss_list = [] - ld_loss_list = [] - - # Prototype projection - if step > proj_epochs and step % proj_epochs == 0: - prot_layer.projection(self.gnn, dataset, train_ind, data, thrsh=prot_thrsh) - self.gnn.train() - for p in self.gnn.parameters(): + # Prototype projection + if cur_step > self.proj_epochs and cur_step % self.proj_epochs == 0: + self.prot_layer.projection(self.gnn, gen_dataset.dataset, train_ind, gen_dataset.dataset.data, thrsh=self.prot_thrsh) + self.gnn.train() + for p in self.gnn.parameters(): + p.requires_grad = True + self.prot_layer.prototype_vectors.requires_grad = True + if cur_step < self.warm_epoch: + for p in self.prot_layer.last_layer.parameters(): + p.requires_grad = False + else: + for p in self.prot_layer.last_layer.parameters(): p.requires_grad = True - prot_layer.prototype_vectors.requires_grad = True - if step < warm_epoch: - for p in prot_layer.last_layer.parameters(): - p.requires_grad = False - else: - for p in prot_layer.last_layer.parameters(): - p.requires_grad = True - - for batch in train_loader: - min_distances = self.gnn.min_distances - logits = self.gnn(batch.x, batch.edge_index, batch.batch) - loss = self.loss_function(logits, batch.y) - # cluster loss - prot_layer.prototype_class_identity = prot_layer.prototype_class_identity - prototypes_of_correct_class = torch.t( - prot_layer.prototype_class_identity[:, batch.y].bool()) - cluster_cost = torch.mean( - torch.min(min_distances[prototypes_of_correct_class] - .reshape(-1, prot_layer.num_prototypes_per_class), dim=1)[0]) - - # seperation loss - separation_cost = -torch.mean( - torch.min(min_distances[~prototypes_of_correct_class].reshape(-1, ( - prot_layer.output_dim - 1) * prot_layer.num_prototypes_per_class), - dim=1)[0]) - - # sparsity loss - l1_mask = 1 - torch.t(prot_layer.prototype_class_identity) - l1 = (prot_layer.last_layer.weight * l1_mask).norm(p=1) - - # diversity loss - ld = 0 - for k in range(prot_layer.output_dim): - p = prot_layer.prototype_vectors[ - k * prot_layer.num_prototypes_per_class: - (k + 1) * prot_layer.num_prototypes_per_class] - p = F.normalize(p, p=2, dim=1) - matrix1 = torch.mm(p, torch.t(p)) - torch.eye(p.shape[0]) - 0.3 - matrix2 = torch.zeros(matrix1.shape) - ld += torch.sum(torch.where(matrix1 > 0, matrix1, matrix2)) - - loss = loss + clst * cluster_cost + sep * separation_cost + 5e-4 * l1 + 0.00 * ld - - # optimization - self.optimizer.zero_grad() - loss.backward() - torch.nn.utils.clip_grad_value_(self.gnn.parameters(), clip_value=2.0) - self.optimizer.step() - - # record - _, prediction = torch.max(logits, -1) - loss_list.append(loss.item()) - ld_loss_list.append(ld.item()) - acc.append(prediction.eq(batch.y).cpu().numpy()) - precision.append( - prediction[prediction == 1].eq(batch.y[prediction == 1]).cpu().numpy()) - recall.append(prediction[batch.y == 1].eq(batch.y[batch.y == 1]).cpu().numpy()) - - # if best_prots == 0: - # best_prots = model.prototype_graphs - - # report train msg - print( - f"Train Epoch:{step} |Loss: {np.average(loss_list):.3f} | Ld: {np.average(ld_loss_list):.3f} | " - f"Acc: {np.concatenate(acc, axis=0).mean():.3f} | " - f"Precision: {np.concatenate(precision, axis=0).mean():.3f} | " - f"Recall: {np.concatenate(recall, axis=0).mean():.3f}") - - # report eval msg - # eval_state = self.evaluate(val_loader) - - metrics_values = self.evaluate_model( - gen_dataset, metrics=[Metric("Accuracy", mask='val'), - Metric("Precision", mask='val'), - Metric("Recall", mask='val')] + metrics) - # model_data = self.get_stats_data(gen_dataset, predictions=True, logits=True) - # self.send_epoch_results(metrics_values, model_data, loss=np.average(loss_list)) - acc = metrics_values['val']["Accuracy"] - print( - f"Eval Epoch: {step} | Acc: {acc:.3f} | " - f"Precision: {metrics_values['val']['Precision']:.3f} | " - f"Recall: {metrics_values['val']['Recall']:.3f}") - - # only save the best model - is_best = (acc - best_acc >= 0.01) - - if is_best: - early_stop_count = 0 - else: - early_stop_count += 1 - - if early_stop_count > early_stopping: - break - - if is_best: - best_acc = acc - early_stop_count = 0 - best_prots = prot_layer.prototype_graphs - if is_best or step % save_epoch == 0: - # save_best(ckpt_dir, epoch, gnnNets, model_args.model_name, acc, is_best) - pass - - if acc > save_thrsh: - if best_prots == 0: - prot_layer.projection(dataset, train_ind, data, thrsh=prot_thrsh) - best_prots = prot_layer.prototype_graphs - best_acc = acc - break - - """ - if best_acc >= save_thrsh and projected: - break - """ - self.modification.epochs = step + 1 - - print(f"The best validation accuracy is {best_acc}.") - # report test msg - # checkpoint = torch.load(os.path.join(ckpt_dir, f'{model_args.model_name}_best.pth')) - # gnnNets.update_state_dict(checkpoint['net']) - # test_state, _, _ = test_GC(dataloader['test'], model, criterion) - # print(f"Test: | Loss: {test_state['loss']:.3f} | Acc: {test_state['acc']:.3f}") - print("End(for breakpoint)") - self.gnn.best_prots = best_prots - - return best_acc - - def run_model(self, gen_dataset, mask='test', out='answers'): - """ - Run the model on a part of dataset specified with a mask. - - :param gen_dataset: wrapper over the dataset, stores the dataset and all meta-information about the dataset - :param mask: 'train', 'val', 'test', or a bool valued list - :param out: if 'answers' return answers, otherwise predictions - :return: y_pred, y_true - """ - try: - mask = { - 'train': gen_dataset.train_mask, - 'val': gen_dataset.val_mask, - 'test': gen_dataset.test_mask, - 'all': tensor([True] * len(gen_dataset.labels)), - }[mask] - except KeyError: - assert isinstance(mask, torch.Tensor) - - run_func = { - 'answers': self.gnn.get_answer, - 'predictions': self.gnn.get_predictions, - 'logits': self.gnn.__call__, - }[out] - self.gnn.eval() - with torch.no_grad(): # Turn off gradients computation - if gen_dataset.is_multi(): - dataset = gen_dataset.dataset - part_loader = DataLoader( - dataset.index_select(mask), batch_size=self.batch, shuffle=False) - full_out = torch.Tensor() - # y_true = torch.Tensor() - if hasattr(self, 'optimizer'): - self.optimizer.zero_grad() - for data in part_loader: - # logits_batch = self.gnn(data.x, data.edge_index, data.batch) - # pred_batch = logits_batch.argmax(dim=1) - out = run_func(data.x, data.edge_index, data.batch) - full_out = torch.cat((full_out, out)) - # y_true = torch.cat((y_true, data.y)) - else: # single-graph - data = gen_dataset.dataset.data - ver_ind = [n for n, x in enumerate(gen_dataset.train_mask) if x] - mask_size = len(ver_ind) - random.shuffle(ver_ind) - - number_of_batches = ceil(mask_size / self.batch) - # data_x_elem_len = data.x.size()[1] - full_out = torch.Tensor() - - for batch_ind in range(number_of_batches): - data_x_copy = torch.clone(data.x) - mask_copy = [False] * data.x.size()[0] - - # features_mask_tensor_copy = torch.clone(features_mask_tensor) - - train_batch = ver_ind[batch_ind * self.batch: (batch_ind + 1) * self.batch] - for elem_ind in train_batch: - for feature in self.mask_features: - # features_mask_tensor_copy[elem_ind][gen_dataset.info.node_attr_slices[feature][0]: - # gen_dataset.info.node_attr_slices[feature][1]] = False - data_x_copy[elem_ind][gen_dataset.info.node_attr_slices[feature][0]: - gen_dataset.info.node_attr_slices[feature][1]] = 0 - # if self.gnn_mm.train_mask_flag: - # data_x_copy[elem_ind] = torch.zeros(data_x_elem_len) - # y_true = torch.masked.masked_tensor(data.y, mask_tensor) - mask_copy[elem_ind] = True - # mask_x_tensor = torch.masked.masked_tensor(data.x, features_mask_tensor_copy) - - # FIXME Kirill what to do if no optimizer, train_mask_flag, batch? - if hasattr(self, 'optimizer'): - self.optimizer.zero_grad() - # logits_batch = self.gnn(data_x_copy, data.edge_index) - # pred_batch = logits_batch.argmax(dim=1) - out = run_func(data_x_copy, data.edge_index, data.batch) - full_out = torch.cat((full_out, out[mask_copy])) - # y_true = torch.cat((y_true, data.y[mask_copy])) + def after_epoch(self, gen_dataset): + # TODO compare is_best with different metrics to be implemented + + # check if best model + metrics_values = self.evaluate_model( + gen_dataset, metrics=[Metric("Accuracy", mask='val'), + Metric("Precision", mask='val'), + Metric("Recall", mask='val')]) + self.cur_acc = metrics_values['val']["Accuracy"] + self.is_best = (self.cur_acc - self.best_acc >= 0.01) + + if self.is_best: + self.best_acc = self.cur_acc + self.early_stop_count = 0 + self.gnn.best_prots = self.prot_layer.prototype_graphs + + + def early_stopping(self, train_loss, gen_dataset, metrics, steps): + step = self.modification.epochs + if self.is_best: + self.early_stop_count = 0 + else: + self.early_stop_count += 1 + last_projection = (step % self.proj_epochs == 0 and step + self.proj_epochs >= steps) - return full_out + return self.early_stop_count >= self.early_stopping_marker or last_projection \ No newline at end of file diff --git a/tests/explainers_test.py b/tests/explainers_test.py index 8574392..ae3fe4b 100644 --- a/tests/explainers_test.py +++ b/tests/explainers_test.py @@ -126,8 +126,8 @@ def setUp(self) -> None: metrics=[Metric("F1", mask='test')]) # TODO Kirill, tmp comment work and tests with Prot - # gin3_lin2_prot_mg_small = model_configs_zoo( - # dataset=dataset_mg_small, model_name='gin_gin_gin_lin_lin_prot') + gin3_lin2_prot_mg_small = model_configs_zoo( + dataset=dataset_mg_small, model_name='gin_gin_gin_lin_lin_prot') gin3_lin1_mg_mutag = model_configs_zoo( dataset=dataset_mg_mutag, model_name='gin_gin_gin_lin') @@ -156,14 +156,14 @@ def setUp(self) -> None: } ) - # self.prot_gnn_mm_mg_small = ProtGNNModelManager( - # gnn=gin3_lin2_prot_mg_small, dataset_path=results_dataset_path_mg_small, - # # manager_config=gin3_lin2_mg_small_manager_config, - # ) + self.prot_gnn_mm_mg_small = ProtGNNModelManager( + gnn=gin3_lin2_prot_mg_small, dataset_path=results_dataset_path_mg_small, + # manager_config=gin3_lin2_mg_small_manager_config, + ) # TODO Misha use as training params: clst=clst, sep=sep, save_thrsh=save_thrsh, lr=lr - # best_acc = self.prot_gnn_mm_mg_small.train_model( - # gen_dataset=gen_dataset_mg_small, steps=100, metrics=[]) + best_acc = self.prot_gnn_mm_mg_small.train_model( + gen_dataset=gen_dataset_mg_small, steps=100, metrics=[]) gin3_lin2_mg_small = model_configs_zoo( dataset=gen_dataset_mg_small, model_name='gin_gin_gin_lin_lin') @@ -327,36 +327,36 @@ def test_Zorro(self): ) explainer_Zorro.conduct_experiment(explainer_run_config) - # def test_ProtGNN(self): - # warnings.warn("Start ProtGNN") - # explainer_init_config = ConfigPattern( - # _class_name="ProtGNN", - # _import_path=EXPLAINERS_INIT_PARAMETERS_PATH, - # _config_class="ExplainerInitConfig", - # _config_kwargs={ - # } - # ) - # explainer_run_config = ConfigPattern( - # _config_class="ExplainerRunConfig", - # _config_kwargs={ - # "mode": "global", - # "kwargs": { - # "_class_name": "ProtGNN", - # "_import_path": EXPLAINERS_GLOBAL_RUN_PARAMETERS_PATH, - # "_config_class": "Config", - # "_config_kwargs": { - # - # }, - # } - # } - # ) - # explainer_Prot = FrameworkExplainersManager( - # init_config=explainer_init_config, - # dataset=self.dataset_mg_small, gnn_manager=self.prot_gnn_mm_mg_small, - # explainer_name='ProtGNN', - # ) - # - # explainer_Prot.conduct_experiment(explainer_run_config) + def test_ProtGNN(self): + warnings.warn("Start ProtGNN") + explainer_init_config = ConfigPattern( + _class_name="ProtGNN", + _import_path=EXPLAINERS_INIT_PARAMETERS_PATH, + _config_class="ExplainerInitConfig", + _config_kwargs={ + } + ) + explainer_run_config = ConfigPattern( + _config_class="ExplainerRunConfig", + _config_kwargs={ + "mode": "global", + "kwargs": { + "_class_name": "ProtGNN", + "_import_path": EXPLAINERS_GLOBAL_RUN_PARAMETERS_PATH, + "_config_class": "Config", + "_config_kwargs": { + + }, + } + } + ) + explainer_Prot = FrameworkExplainersManager( + init_config=explainer_init_config, + dataset=self.dataset_mg_small, gnn_manager=self.prot_gnn_mm_mg_small, + explainer_name='ProtGNN', + ) + + explainer_Prot.conduct_experiment(explainer_run_config) def test_GNNExpl_PYG_SG(self): warnings.warn("Start GNNExplainer_PYG") diff --git a/tests/models_test.py b/tests/models_test.py index f123d06..8217376 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -72,7 +72,7 @@ def setUp(self) -> None: labeling='binary', dataset_ver_ind=0) ) - self.gen_dataset_mg_small.train_test_split(percent_train_class=0.6, percent_test_class=0.4) + self.gen_dataset_mg_small.train_test_split(percent_train_class=0.6, percent_test_class=0.2) self.results_dataset_path_mg_small = self.gen_dataset_mg_small.results_dir self.default_config = ModelModificationConfig( model_ver_ind=0, @@ -124,7 +124,8 @@ def test_model_on_multiple_graph(self): ) gnn_mm_mg_small.train_model(gen_dataset=self.gen_dataset_mg_small, steps=100, - metrics=[Metric("F1", mask='test')]) + metrics=[Metric("F1", mask='val'), + Metric("F1", mask='test')]) metric_loc = gnn_mm_mg_small.evaluate_model( gen_dataset=self.gen_dataset_mg_small, metrics=[Metric("F1", mask='test', average='macro')]) print(metric_loc)