Skip to content

Commit

Permalink
Merge pull request #38 from ispras/fix_device_problem
Browse files Browse the repository at this point in the history
fix device problem in def run_model
  • Loading branch information
mishadr authored Dec 2, 2024
2 parents 1426490 + 2299432 commit 31337ba
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def test_attack_defense():

# print(data.train_mask)

gnn = model_configs_zoo(dataset=dataset, model_name='gat_gcn_sage_gcn_gcn')
gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn')
# gnn = model_configs_zoo(dataset=dataset, model_name='gat_gcn_sage_gcn_gcn')
# gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn_lin')
# gnn = model_configs_zoo(dataset=dataset, model_name='test_gnn')
# gnn = model_configs_zoo(dataset=dataset, model_name='gin_gin_gin_lin_lin')
Expand Down Expand Up @@ -263,8 +264,8 @@ def test_attack_defense():

# gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config)
# gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config)
gnn_model_manager.set_evasion_attacker(evasion_attack_config=fgsm_evasion_attack_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=autoencoder_evasion_defense_config)
# gnn_model_manager.set_evasion_attacker(evasion_attack_config=fgsm_evasion_attack_config)
# gnn_model_manager.set_evasion_defender(evasion_defense_config=autoencoder_evasion_defense_config)

warnings.warn("Start training")
dataset.train_test_split()
Expand Down
2 changes: 1 addition & 1 deletion src/models_builder/gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,7 @@ def run_model(
dataset = gen_dataset.dataset
part_loader = DataLoader(
dataset.index_select(mask), batch_size=self.batch, shuffle=False)
full_out = torch.empty(0)
full_out = torch.empty(0, device=dataset.data.x.device)
# y_true = torch.Tensor()
if hasattr(self, 'optimizer'):
self.optimizer.zero_grad()
Expand Down

0 comments on commit 31337ba

Please sign in to comment.