Skip to content

Commit

Permalink
Merge pull request #33 from ispras/fix_split
Browse files Browse the repository at this point in the history
fix_train_test_split. Add masks in batch
  • Loading branch information
LukyanovKirillML authored Oct 30, 2024
2 parents 4f09088 + c3b7399 commit c2a7a42
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
13 changes: 8 additions & 5 deletions experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_attack_defense():

# full_name = ("multiple-graphs", "TUDataset", 'MUTAG')
# full_name = ("single-graph", "custom", 'karate')
full_name = ("single-graph", "Planetoid", 'Cora')
# full_name = ("single-graph", "Planetoid", 'Cora')
full_name = ("single-graph", "Amazon", 'Photo')
# full_name = ("single-graph", "Planetoid", 'CiteSeer')
# full_name = ("multiple-graphs", "TUDataset", 'PROTEINS')

Expand Down Expand Up @@ -235,8 +236,8 @@ def test_attack_defense():

# gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config)
# gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config)
gnn_model_manager.set_evasion_attacker(evasion_attack_config=netattackgroup_evasion_attack_config)
# gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config)
# gnn_model_manager.set_evasion_attacker(evasion_attack_config=netattackgroup_evasion_attack_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config)

warnings.warn("Start training")
dataset.train_test_split()
Expand Down Expand Up @@ -266,7 +267,8 @@ def test_attack_defense():

def test_meta():
from attacks.metattack import meta_gradient_attack
my_device = device('cpu')
# my_device = device('cpu')
my_device = device('cuda' if torch.cuda.is_available() else 'cpu')
full_name = ("single-graph", "Planetoid", 'Cora')

dataset, data, results_dataset_path = DatasetManager.get_by_full_name(
Expand Down Expand Up @@ -705,7 +707,8 @@ def test_adv_training():
from defense.evasion_defense import AdvTraining

my_device = device('cpu')
full_name = ("single-graph", "Planetoid", 'Cora')
# full_name = ("single-graph", "Planetoid", 'Cora')
full_name = ("single-graph", "Amazon", 'Photo')

dataset, data, results_dataset_path = DatasetManager.get_by_full_name(
full_name=full_name,
Expand Down
3 changes: 3 additions & 0 deletions src/base/datasets_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,9 @@ def train_test_split(self, percent_train_class: float = 0.8, percent_test_class:
self.train_mask = train_mask
self.test_mask = test_mask
self.val_mask = val_mask
self.dataset.data.train_mask = train_mask
self.dataset.data.test_mask = test_mask
self.dataset.data.val_mask = val_mask

def save_train_test_mask(self, path):
""" Save current train/test mask to a given path (together with the model). """
Expand Down
2 changes: 1 addition & 1 deletion src/models_builder/gnn_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def forward(self, *args, **kwargs):
tensor_storage[layer_ind] = torch.clone(x)
layer_ind = curr_layer_ind
x_copy = torch.clone(x)
connection_tensor = torch.Tensor()
connection_tensor = torch.empty(0, device=x_copy.device)
for key, value in self.conn_dict.items():
if key[1] == curr_layer_ind:
if key[1] - key[0] == 1:
Expand Down

0 comments on commit c2a7a42

Please sign in to comment.