forked from zwt233/GAMLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_dataset.py
224 lines (200 loc) · 8.08 KB
/
load_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import dgl
import dgl.function as fn
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from heteo_data import load_data, read_relation_subsets, gen_rel_subset_feature, preprocess_features
import torch.nn.functional as F
import gc
def prepare_label_emb(args, g, labels, n_classes, train_idx, valid_idx, test_idx, label_teacher_emb=None):
if args.dataset == 'ogbn-mag':
target_type_id = g.get_ntype_id("paper")
homo_g = dgl.to_homogeneous(g, ndata=["feat"])
homo_g = dgl.add_reverse_edges(homo_g, copy_ndata=True)
homo_g.ndata["target_mask"] = homo_g.ndata[dgl.NTYPE] == target_type_id
feat = g.ndata['feat']['paper']
print(n_classes)
print(labels.shape[0])
if label_teacher_emb == None:
y = np.zeros(shape=(labels.shape[0], int(n_classes)))
y[train_idx] = F.one_hot(labels[train_idx].to(
torch.long), num_classes=n_classes).float().squeeze(1)
y = torch.Tensor(y)
else:
print("use teacher label")
y = np.zeros(shape=(labels.shape[0], int(n_classes)))
y[valid_idx] = label_teacher_emb[len(
train_idx):len(train_idx)+len(valid_idx)]
y[test_idx] = label_teacher_emb[len(
train_idx)+len(valid_idx):len(train_idx)+len(valid_idx)+len(test_idx)]
y[train_idx] = F.one_hot(labels[train_idx].to(
torch.long), num_classes=n_classes).float().squeeze(1)
y = torch.Tensor(y)
if args.dataset == 'ogbn-mag':
target_mask = homo_g.ndata["target_mask"]
target_ids = homo_g.ndata[dgl.NID][target_mask]
num_target = target_mask.sum().item()
new_label_emb = torch.zeros((len(homo_g.ndata["feat"]),) + y.shape[1:],
dtype=y.dtype, device=y.device)
new_label_emb[target_mask] = y[target_ids]
y = new_label_emb
g = homo_g
del labels
gc.collect()
for hop in range(args.label_num_hops):
y = neighbor_average_labels(g, y.to(torch.float), args)
gc.collect()
if args.dataset == "ogbn-mag":
target_mask = g.ndata['target_mask']
target_ids = g.ndata[dgl.NID][target_mask]
num_target = target_mask.sum().item()
new_res = torch.zeros((num_target,) + y.shape[1:],
dtype=y.dtype, device=y.device)
new_res[target_ids] = y[target_mask]
y = new_res
res = y
return torch.cat([res[train_idx], res[valid_idx], res[test_idx]], dim=0)
def neighbor_average_labels(g, feat, args):
"""
Compute multi-hop neighbor-averaged node features
"""
print("Compute neighbor-averaged labels")
g.ndata["f"] = feat
g.update_all(fn.copy_u("f", "msg"),
fn.mean("msg", "f"))
feat = g.ndata.pop('f')
'''if args.dataset == "ogbn-mag":
# For MAG dataset, only return features for target node types (i.e.
# paper nodes)
# error
target_mask = g.ndata['target_mask']
target_ids = g.ndata[dgl.NID][target_mask]
num_target = target_mask.sum().item()
new_res = []
for x in res:
feat = torch.zeros((num_target,) + x.shape[1:],
dtype=x.dtype, device=x.device)
feat[target_ids] = x[target_mask]
new_res.append(feat)
res = new_res'''
return feat
def neighbor_average_features(g, args):
"""
Compute multi-hop neighbor-averaged node features
"""
print("Compute neighbor-averaged feats")
g.ndata["feat_0"] = g.ndata["feat"]
for hop in range(1, args.num_hops + 1):
g.update_all(fn.copy_u(f"feat_{hop-1}", "msg"),
fn.mean("msg", f"feat_{hop}"))
res = []
for hop in range(args.num_hops + 1):
res.append(g.ndata.pop(f"feat_{hop}"))
return res
def batched_acc(labels,pred):
# testing accuracy for single label multi-class prediction
return (torch.argmax(pred, dim=1) == labels,)
def get_evaluator(dataset):
dataset = dataset.lower()
if dataset.startswith("oag"):
return batched_ndcg_mrr
else:
return batched_acc
def get_ogb_evaluator(dataset):
"""
Get evaluator from Open Graph Benchmark based on dataset
"""
# if dataset=='ogbn-mag':
# return batched_acc
# else:
evaluator = Evaluator(name=dataset)
return lambda preds, labels: evaluator.eval({
"y_true": labels.view(-1, 1),
"y_pred": preds.view(-1, 1),
})["acc"]
def load_dataset(name, device, args):
"""
Load dataset and move graph and features to device
"""
'''if name not in ["ogbn-products", "ogbn-arxiv","ogbn-mag"]:
raise RuntimeError("Dataset {} is not supported".format(name))'''
if name not in ["ogbn-products", "ogbn-mag","ogbn-papers100M"]:
raise RuntimeError("Dataset {} is not supported".format(name))
dataset = DglNodePropPredDataset(name=name, root=args.root)
splitted_idx = dataset.get_idx_split()
if name == "ogbn-products":
train_nid = splitted_idx["train"]
val_nid = splitted_idx["valid"]
test_nid = splitted_idx["test"]
g, labels = dataset[0]
g.ndata["labels"] = labels
g.ndata['feat'] = g.ndata['feat'].float()
n_classes = dataset.num_classes
labels = labels.squeeze()
evaluator = get_ogb_evaluator(name)
elif name == "ogbn-mag":
data = load_data(device, args)
g, labels, n_classes, train_nid, val_nid, test_nid = data
evaluator = get_ogb_evaluator(name)
elif name=="ogbn-papers100M":
train_nid = splitted_idx["train"]
val_nid = splitted_idx["valid"]
test_nid = splitted_idx["test"]
g, labels = dataset[0]
n_classes = dataset.num_classes
labels = labels.squeeze()
evaluator = get_ogb_evaluator(name)
print(f"# Nodes: {g.number_of_nodes()}\n"
f"# Edges: {g.number_of_edges()}\n"
f"# Train: {len(train_nid)}\n"
f"# Val: {len(val_nid)}\n"
f"# Test: {len(test_nid)}\n"
f"# Classes: {n_classes}\n")
return g, labels, n_classes, train_nid, val_nid, test_nid, evaluator
def prepare_data(device, args, teacher_probs):
"""
Load dataset and compute neighbor-averaged node features used by SIGN model
"""
data = load_dataset(args.dataset, device, args)
g, labels, n_classes, train_nid, val_nid, test_nid, evaluator = data
if args.dataset == 'ogbn-products':
feats = neighbor_average_features(g, args)
in_feats = feats[0].shape[1]
elif args.dataset == 'ogbn-mag':
rel_subsets = read_relation_subsets(args.use_relation_subsets)
with torch.no_grad():
feats = preprocess_features(g, rel_subsets, args, device)
print("Done preprocessing")
_, num_feats, in_feats = feats[0].shape
elif args.dataset == 'ogbn-papers100M':
g = dgl.add_reverse_edges(g, copy_ndata=True)
feat=g.ndata.pop('feat')
gc.collect()
label_emb = None
if args.use_rlu:
label_emb = prepare_label_emb(args, g, labels, n_classes, train_nid, val_nid, test_nid, teacher_probs)
# move to device
if args.dataset=='ogbn-papers100M':
feats=[]
for i in range(args.num_hops+1):
feats.append(torch.load(f"/data2/zwt/ogbn_papers100M/feat/papers100m_feat_{i}.pt"))
in_feats=feats[0].shape[1]
'''
g.ndata['feat']=feat
feats=neighbor_average_features(g,args)
in_feats=feats[0].shape[1]
for i, x in enumerate(feats):
feats[i] = torch.cat((x[train_nid], x[val_nid], x[test_nid]), dim=0)
'''
else:
for i, x in enumerate(feats):
feats[i] = torch.cat((x[train_nid], x[val_nid], x[test_nid]), dim=0)
train_nid = train_nid.to(device)
val_nid = val_nid.to(device)
test_nid = test_nid.to(device)
labels = labels.to(device).to(torch.long)
return feats, torch.cat([labels[train_nid], labels[val_nid], labels[test_nid]]), in_feats, n_classes, \
train_nid, val_nid, test_nid, evaluator, label_emb