-
Notifications
You must be signed in to change notification settings - Fork 2
/
training_validation_Davis_KIBA.py
125 lines (108 loc) · 5.07 KB
/
training_validation_Davis_KIBA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# training_validation_Davis_KIBA.py
#
# This file contains training code for Davis and KIBA datasets.
#
# Author: Shugang Zhang
# Created: Wednesday, Aug 4th, 2021
# Last update: Sunday, Dec 26th, 2021
import sys
import torch.nn as nn
from models.gat import GATNet
from models.gat_gcn import GAT_GCN
from models.gcn import GCNNet
from models.ginconv import GINConvNet
from models.sag_global_pooling import SAGNet_GLOBAL
from models.sag_hierarchical_pooling import SAGNet_HIER
from utils import *
from lifelines.utils import concordance_index
# training function at each epoch
def train(model, device, train_loader, optimizer, epoch):
print('Training on {} samples...'.format(len(train_loader.dataset)))
model.train()
for batch_idx, data in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, data.y.view(-1, 1).float().to(device))
loss.backward()
optimizer.step()
if batch_idx % LOG_INTERVAL == 0:
print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,
batch_idx * len(data.x),
len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item()))
def predicting(model, device, loader):
model.eval()
total_preds = torch.Tensor()
total_labels = torch.Tensor()
print('Make prediction for {} samples...'.format(len(loader.dataset)))
with torch.no_grad():
for data in loader:
data = data.to(device)
output = model(data)
total_preds = torch.cat((total_preds, output.cpu()), 0)
total_labels = torch.cat((total_labels, data.y.view(-1, 1).cpu()), 0)
return total_labels.numpy().flatten(), total_preds.numpy().flatten()
# determine the dataset to be trained on in the following line.
datasets = [['davis', 'kiba'][0]]
# determine the network to be trained in the following line.
modeling = [GINConvNet, GATNet, GAT_GCN, GCNNet, SAGNet_GLOBAL, SAGNet_HIER][4]
model_st = modeling.__name__
print("dataset:", datasets)
print("modeling:", modeling)
# determine the device in the following line
cuda_name = "cuda:0"
if len(sys.argv) > 3:
cuda_name = "cuda:" + str(int(sys.argv[3]))
print('cuda_name:', cuda_name)
TRAIN_BATCH_SIZE = 512
TEST_BATCH_SIZE = 512
LR = 0.001 # 0.0005
LOG_INTERVAL = 20
NUM_EPOCHS = 2000
print('Learning rate: ', LR)
print('Epochs: ', NUM_EPOCHS)
# Main program: iterate over different datasets
for dataset in datasets:
print('\nrunning on ', model_st + '_' + dataset)
processed_data_file_train = 'data/processed/' + dataset + '_train.pt'
processed_data_file_test = 'data/processed/' + dataset + '_test.pt'
if ((not os.path.isfile(processed_data_file_train)) or (not os.path.isfile(processed_data_file_test))):
print('please run create_data.py to prepare data in pytorch format!')
else:
train_data = TestbedDataset(root='data', dataset=dataset+'_train')
test_data = TestbedDataset(root='data', dataset=dataset+'_test')
# make data PyTorch mini-batch processing ready
train_loader = DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=TEST_BATCH_SIZE, shuffle=False)
# training the model
device = torch.device(cuda_name if torch.cuda.is_available() else "cpu")
model = modeling().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
best_mse = 1000
best_ci = 0
best_epoch = -1
model_file_name = 'model_' + model_st + '_' + dataset + '.model'
result_file_name = 'result_' + model_st + '_' + dataset + '.csv'
# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
for epoch in range(NUM_EPOCHS):
train(model, device, train_loader, optimizer, epoch+1)
G, P = predicting(model, device, test_loader)
ret = [mse(G, P), concordance_index(G, P)]
if ret[0] < best_mse:
torch.save(model.state_dict(), model_file_name)
with open(result_file_name, 'w') as f:
f.write(','.join(map(str, ret)))
best_epoch = epoch+1
best_mse = ret[0]
best_ci = ret[-1]
print('rmse improved at epoch ', best_epoch, '; best_mse,best_ci:', best_mse, best_ci, model_st, dataset)
else:
print(ret[0], 'No improvement since epoch ', best_epoch, '; best_mse,best_ci:', best_mse, best_ci, model_st, dataset)