Skip to content

Commit

Permalink
split train_on_batch into def train_on_batch, def optimizer_step and …
Browse files Browse the repository at this point in the history
…def train_on_batch_full. Fix ProtGNN
  • Loading branch information
LukyanovKirillML committed Oct 21, 2024
1 parent 04535db commit 5cfd429
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 38 deletions.
11 changes: 6 additions & 5 deletions src/defense/evasion_defense.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import copy


class EvasionDefender(Defender):
def __init__(self, **kwargs):
super().__init__()
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(self, batch) -> None:
self.data = batch
self.dataset = self


class AdvTraining(EvasionDefender):
name = "AdvTraining"

Expand Down Expand Up @@ -104,7 +106,7 @@ def __init__(self, attack_name=None, attack_config=None, attack_type=None, devic
self.prob_cross = self.attack_config._config_kwargs["prob_cross"]
self.prob_mutate = self.attack_config._config_kwargs["prob_mutate"]
# set attacker
self.attacker = qattack.QAttacker(self.population_size, self.individual_size,
self.attacker = qattack.QAttacker(self.population_size, self.individual_size,
self.generations, self.prob_cross,
self.prob_mutate)
elif self.attack_config._class_name == "MetaAttackFull":
Expand All @@ -123,11 +125,10 @@ def pre_batch(self, model_manager, batch):
self.perturbed_gen_dataset.dataset = self.perturbed_gen_dataset.data
self.perturbed_gen_dataset.dataset.data = self.perturbed_gen_dataset.data
if self.attack_type == "EVASION":
self.perturbed_gen_dataset = self.attacker.attack(model_manager=model_manager,
gen_dataset=self.perturbed_gen_dataset,
mask_tensor=self.perturbed_gen_dataset.data.train_mask)
self.perturbed_gen_dataset = self.attacker.attack(model_manager=model_manager,
gen_dataset=self.perturbed_gen_dataset,
mask_tensor=self.perturbed_gen_dataset.data.train_mask)


def post_batch(self, model_manager, batch, loss) -> dict:
super().post_batch(model_manager=model_manager, batch=batch, loss=loss)
# Output on perturbed data
Expand Down
57 changes: 24 additions & 33 deletions src/models_builder/gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,11 +815,30 @@ def train_1_step(self, gen_dataset):
self.gnn.eval()
return loss.cpu().detach().numpy().tolist()

def train_on_batch(self, batch, task_type=None):
def train_on_batch_full(self, batch, task_type=None):
if self.mi_defender:
self.mi_defender.pre_batch()
if self.evasion_defender:
self.evasion_defender.pre_batch(model_manager=self, batch=batch)
loss = self.train_on_batch(batch=batch, task_type=task_type)
if self.mi_defender:
self.mi_defender.post_batch()
evasion_defender_dict = None
if self.evasion_defender:
evasion_defender_dict = self.evasion_defender.post_batch(
model_manager=self, batch=batch, loss=loss,
)
if evasion_defender_dict and "loss" in evasion_defender_dict:
loss = evasion_defender_dict["loss"]
loss = self.optimizer_step(loss=loss)
return loss

def optimizer_step(self, loss):
loss.backward()
self.optimizer.step()
return loss

def train_on_batch(self, batch, task_type=None):
loss = None
if hasattr(batch, "edge_weight"):
weight = batch.edge_weight
Expand Down Expand Up @@ -857,17 +876,6 @@ def train_on_batch(self, batch, task_type=None):
# loss.backward()
else:
raise ValueError("Unsupported task type")
if self.mi_defender:
self.mi_defender.post_batch()
evasion_defender_dict = None
if self.evasion_defender:
evasion_defender_dict = self.evasion_defender.post_batch(
model_manager=self, batch=batch, loss=loss,
)
if evasion_defender_dict and "loss" in evasion_defender_dict:
loss = evasion_defender_dict["loss"]
loss.backward()
self.optimizer.step()
return loss

def get_name(self, **kwargs):
Expand Down Expand Up @@ -1157,6 +1165,7 @@ def load_train_test_split(self, gen_dataset):
gen_dataset.train_mask, gen_dataset.val_mask, gen_dataset.test_mask, _ = torch.load(path)[:]
return gen_dataset


class ProtGNNModelManager(FrameworkGNNModelManager):
# additional_config = ModelManagerConfig(
# loss_function={CONFIG_CLASS_NAME: "CrossEntropyLoss"},
Expand Down Expand Up @@ -1235,15 +1244,8 @@ class variables
self.init()
return self.gnn

# TODO Kirill train_on_batch to be divided into two parts
def train_on_batch(self, batch, task_type=None):
if self.mi_defender:
self.mi_defender.pre_batch()
if self.evasion_defender:
self.evasion_defender.pre_batch(model_manager=self, batch=batch)
loss = None
if task_type == "single-graph":
# TODO Prot with single-graph
self.optimizer.zero_grad()
logits = self.gnn(batch.x, batch.edge_index)
min_distances = self.gnn.min_distances
Expand Down Expand Up @@ -1283,14 +1285,10 @@ def train_on_batch(self, batch, task_type=None):
if self.clip is not None:
clip_grad_norm(self.gnn.parameters(), self.clip)
self.optimizer.zero_grad()
# loss.backward()
# self.optimizer.step()
elif task_type == "multiple-graphs":
self.optimizer.zero_grad()
logits = self.gnn(batch.x, batch.edge_index, batch.batch)
loss = self.loss_function(logits, batch.y)
# loss.backward()
# self.optimizer.step()
# TODO Kirill, remove False when release edge recommendation task
elif task_type == "edge" and False:
self.optimizer.zero_grad()
Expand All @@ -1305,18 +1303,11 @@ def train_on_batch(self, batch, task_type=None):
neg_loss = self.loss_function(neg_out, torch.zeros_like(neg_out))

loss = pos_loss + neg_loss
# loss.backward()
else:
raise ValueError("Unsupported task type")
if self.mi_defender:
self.mi_defender.post_batch()
evasion_defender_dict = None
if self.evasion_defender:
evasion_defender_dict = self.evasion_defender.post_batch(
model_manager=self, batch=batch, loss=loss,
)
if evasion_defender_dict and "loss" in evasion_defender_dict:
loss = evasion_defender_dict["loss"]
return loss

def optimizer_step(self, loss):
loss.backward()
torch.nn.utils.clip_grad_value_(self.gnn.parameters(), clip_value=2.0)
self.optimizer.step()
Expand Down

0 comments on commit 5cfd429

Please sign in to comment.