Skip to content

Commit

Permalink
ProtGNN fix for new train. WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeratt committed Sep 5, 2024
1 parent 8692b7e commit 31a290f
Showing 1 changed file with 40 additions and 10 deletions.
50 changes: 40 additions & 10 deletions src/models_builder/gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,14 +1152,14 @@ class ProtGNNModelManager(FrameworkGNNModelManager):
_config_class="ModelManagerConfig",
_config_kwargs={
"mask_features": [],
# "optimizer": {
# "_config_class": "Config",
# "_class_name": "Adam",
# "_import_path": OPTIMIZERS_PARAMETERS_PATH,
# "_class_import_info": ["torch.optim"],
# "_config_kwargs": {},
# },
# FUNCTIONS_PARAMETERS_PATH,
"optimizer": {
"_config_class": "Config",
"_class_name": "Adam",
"_import_path": OPTIMIZERS_PARAMETERS_PATH,
"_class_import_info": ["torch.optim"],
"_config_kwargs": {},
},
#FUNCTIONS_PARAMETERS_PATH,
"loss_function": {
"_config_class": "Config",
"_class_name": "CrossEntropyLoss",
Expand Down Expand Up @@ -1234,7 +1234,7 @@ def evaluate_model(self, gen_dataset, metrics):

return metrics_values

def train_full(self, gen_dataset, steps=None, metrics=None):
def train_full(self, gen_dataset, steps=None, metrics=None, pbar=None):
"""
Train ProtGNN model for Graph classification
"""
Expand Down Expand Up @@ -1294,6 +1294,9 @@ def train_full(self, gen_dataset, steps=None, metrics=None):
best_prots = prot_layer.prototype_graphs
# data_indices = train_loader.dataset.indices
for step in range(steps):
self.before_epoch(gen_dataset)
print("epoch", self.modification.epochs)

acc = []
precision = []
recall = []
Expand All @@ -1315,8 +1318,8 @@ def train_full(self, gen_dataset, steps=None, metrics=None):
p.requires_grad = True

for batch in train_loader:
min_distances = self.gnn.min_distances
logits = self.gnn(batch.x, batch.edge_index, batch.batch)
min_distances = self.gnn.min_distances
loss = self.loss_function(logits, batch.y)
# cluster loss
prot_layer.prototype_class_identity = prot_layer.prototype_class_identity
Expand Down Expand Up @@ -1421,6 +1424,16 @@ def train_full(self, gen_dataset, steps=None, metrics=None):
"""
self.modification.epochs = step + 1

self.after_epoch(gen_dataset)
early_stopping_flag = self.early_stopping(train_loss=np.average(loss_list), gen_dataset=gen_dataset,
metrics=metrics)
if self.socket:
self.report_results(train_loss=np.average(loss_list), gen_dataset=gen_dataset,
metrics=metrics)
pbar.update(1)
if early_stopping_flag:
break

print(f"The best validation accuracy is {best_acc}.")
# report test msg
# checkpoint = torch.load(os.path.join(ckpt_dir, f'{model_args.model_name}_best.pth'))
Expand All @@ -1432,6 +1445,23 @@ def train_full(self, gen_dataset, steps=None, metrics=None):

return best_acc

def train_complete(self, gen_dataset, steps=None, pbar=None, metrics=None, **kwargs):
print("TEST TEST TEST")
self.train_full(gen_dataset=gen_dataset, steps=steps, pbar=pbar, metrics=metrics)
# for _ in range(steps):
# self.before_epoch(gen_dataset)
# print("epoch", self.modification.epochs)
# train_loss = self.train_1_step(gen_dataset)
# self.after_epoch(gen_dataset)
# early_stopping_flag = self.early_stopping(train_loss=train_loss, gen_dataset=gen_dataset,
# metrics=metrics)
# if self.socket:
# self.report_results(train_loss=train_loss, gen_dataset=gen_dataset,
# metrics=metrics)
# pbar.update(1)
# if early_stopping_flag:
# break

def run_model(self, gen_dataset, mask='test', out='answers'):
"""
Run the model on a part of dataset specified with a mask.
Expand Down

0 comments on commit 31a290f

Please sign in to comment.