From 8951a57c30f881e806b5a15a895333bd35b1df92 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Wed, 30 Oct 2024 14:50:08 +0300 Subject: [PATCH] fix_train_test_split. Add masks in batch --- experiments/attack_defense_test.py | 13 ++++++++----- src/base/datasets_processing.py | 3 +++ src/models_builder/gnn_constructor.py | 2 +- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index 0065a5a..13f1c6e 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -25,7 +25,8 @@ def test_attack_defense(): # full_name = ("multiple-graphs", "TUDataset", 'MUTAG') # full_name = ("single-graph", "custom", 'karate') - full_name = ("single-graph", "Planetoid", 'Cora') + # full_name = ("single-graph", "Planetoid", 'Cora') + full_name = ("single-graph", "Amazon", 'Photo') # full_name = ("single-graph", "Planetoid", 'CiteSeer') # full_name = ("multiple-graphs", "TUDataset", 'PROTEINS') @@ -235,8 +236,8 @@ def test_attack_defense(): # gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config) # gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config) - gnn_model_manager.set_evasion_attacker(evasion_attack_config=netattackgroup_evasion_attack_config) - # gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config) + # gnn_model_manager.set_evasion_attacker(evasion_attack_config=netattackgroup_evasion_attack_config) + gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config) warnings.warn("Start training") dataset.train_test_split() @@ -266,7 +267,8 @@ def test_attack_defense(): def test_meta(): from attacks.metattack import meta_gradient_attack - my_device = device('cpu') + # my_device = device('cpu') + my_device = device('cuda' if torch.cuda.is_available() else 'cpu') full_name = ("single-graph", "Planetoid", 'Cora') dataset, data, results_dataset_path = DatasetManager.get_by_full_name( @@ -705,7 +707,8 @@ def test_adv_training(): from defense.evasion_defense import AdvTraining my_device = device('cpu') - full_name = ("single-graph", "Planetoid", 'Cora') + # full_name = ("single-graph", "Planetoid", 'Cora') + full_name = ("single-graph", "Amazon", 'Photo') dataset, data, results_dataset_path = DatasetManager.get_by_full_name( full_name=full_name, diff --git a/src/base/datasets_processing.py b/src/base/datasets_processing.py index bb2dc56..d696d04 100644 --- a/src/base/datasets_processing.py +++ b/src/base/datasets_processing.py @@ -676,6 +676,9 @@ def train_test_split(self, percent_train_class: float = 0.8, percent_test_class: self.train_mask = train_mask self.test_mask = test_mask self.val_mask = val_mask + self.dataset.data.train_mask = train_mask + self.dataset.data.test_mask = test_mask + self.dataset.data.val_mask = val_mask def save_train_test_mask(self, path): """ Save current train/test mask to a given path (together with the model). """ diff --git a/src/models_builder/gnn_constructor.py b/src/models_builder/gnn_constructor.py index 02ada4d..9c82197 100644 --- a/src/models_builder/gnn_constructor.py +++ b/src/models_builder/gnn_constructor.py @@ -427,7 +427,7 @@ def forward(self, *args, **kwargs): tensor_storage[layer_ind] = torch.clone(x) layer_ind = curr_layer_ind x_copy = torch.clone(x) - connection_tensor = torch.Tensor() + connection_tensor = torch.empty(0, device=x_copy.device) for key, value in self.conn_dict.items(): if key[1] == curr_layer_ind: if key[1] - key[0] == 1: