-
Notifications
You must be signed in to change notification settings - Fork 223
/
main.py
156 lines (122 loc) · 6.61 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from util import load_data, separate_data
from models.graphcnn import GraphCNN
criterion = nn.CrossEntropyLoss()
def train(args, model, device, train_graphs, optimizer, epoch):
model.train()
total_iters = args.iters_per_epoch
pbar = tqdm(range(total_iters), unit='batch')
loss_accum = 0
for pos in pbar:
selected_idx = np.random.permutation(len(train_graphs))[:args.batch_size]
batch_graph = [train_graphs[idx] for idx in selected_idx]
output = model(batch_graph)
labels = torch.LongTensor([graph.label for graph in batch_graph]).to(device)
#compute loss
loss = criterion(output, labels)
#backprop
if optimizer is not None:
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss = loss.detach().cpu().numpy()
loss_accum += loss
#report
pbar.set_description('epoch: %d' % (epoch))
average_loss = loss_accum/total_iters
print("loss training: %f" % (average_loss))
return average_loss
###pass data to model with minibatch during testing to avoid memory overflow (does not perform backpropagation)
def pass_data_iteratively(model, graphs, minibatch_size = 64):
model.eval()
output = []
idx = np.arange(len(graphs))
for i in range(0, len(graphs), minibatch_size):
sampled_idx = idx[i:i+minibatch_size]
if len(sampled_idx) == 0:
continue
output.append(model([graphs[j] for j in sampled_idx]).detach())
return torch.cat(output, 0)
def test(args, model, device, train_graphs, test_graphs, epoch):
model.eval()
output = pass_data_iteratively(model, train_graphs)
pred = output.max(1, keepdim=True)[1]
labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device)
correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
acc_train = correct / float(len(train_graphs))
output = pass_data_iteratively(model, test_graphs)
pred = output.max(1, keepdim=True)[1]
labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device)
correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
acc_test = correct / float(len(test_graphs))
print("accuracy train: %f test: %f" % (acc_train, acc_test))
return acc_train, acc_test
def main():
# Training settings
# Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper.
parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification')
parser.add_argument('--dataset', type=str, default="MUTAG",
help='name of dataset (default: MUTAG)')
parser.add_argument('--device', type=int, default=0,
help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=32,
help='input batch size for training (default: 32)')
parser.add_argument('--iters_per_epoch', type=int, default=50,
help='number of iterations per each epoch (default: 50)')
parser.add_argument('--epochs', type=int, default=350,
help='number of epochs to train (default: 350)')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument('--seed', type=int, default=0,
help='random seed for splitting the dataset into 10 (default: 0)')
parser.add_argument('--fold_idx', type=int, default=0,
help='the index of fold in 10-fold validation. Should be less then 10.')
parser.add_argument('--num_layers', type=int, default=5,
help='number of layers INCLUDING the input one (default: 5)')
parser.add_argument('--num_mlp_layers', type=int, default=2,
help='number of layers for MLP EXCLUDING the input one (default: 2). 1 means linear model.')
parser.add_argument('--hidden_dim', type=int, default=64,
help='number of hidden units (default: 64)')
parser.add_argument('--final_dropout', type=float, default=0.5,
help='final layer dropout (default: 0.5)')
parser.add_argument('--graph_pooling_type', type=str, default="sum", choices=["sum", "average"],
help='Pooling for over nodes in a graph: sum or average')
parser.add_argument('--neighbor_pooling_type', type=str, default="sum", choices=["sum", "average", "max"],
help='Pooling for over neighboring nodes: sum, average or max')
parser.add_argument('--learn_eps', action="store_true",
help='Whether to learn the epsilon weighting for the center nodes. Does not affect training accuracy though.')
parser.add_argument('--degree_as_tag', action="store_true",
help='let the input node features be the degree of nodes (heuristics for unlabeled graph)')
parser.add_argument('--filename', type = str, default = "",
help='output file')
args = parser.parse_args()
#set up seeds and gpu device
torch.manual_seed(0)
np.random.seed(0)
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
graphs, num_classes = load_data(args.dataset, args.degree_as_tag)
##10-fold cross validation. Conduct an experiment on the fold specified by args.fold_idx.
train_graphs, test_graphs = separate_data(graphs, args.seed, args.fold_idx)
model = GraphCNN(args.num_layers, args.num_mlp_layers, train_graphs[0].node_features.shape[1], args.hidden_dim, num_classes, args.final_dropout, args.learn_eps, args.graph_pooling_type, args.neighbor_pooling_type, device).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
for epoch in range(1, args.epochs + 1):
scheduler.step()
avg_loss = train(args, model, device, train_graphs, optimizer, epoch)
acc_train, acc_test = test(args, model, device, train_graphs, test_graphs, epoch)
if not args.filename == "":
with open(args.filename, 'w') as f:
f.write("%f %f %f" % (avg_loss, acc_train, acc_test))
f.write("\n")
print("")
print(model.eps)
if __name__ == '__main__':
main()