-
Notifications
You must be signed in to change notification settings - Fork 1
/
index.html
292 lines (280 loc) · 17.7 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
<!DOCTYPE html>
<html>
<head>
<title>FlyKD</title>
<!-- <meta name="viewport" content="width=device-width, initial-scale=1.0"> -->
<link rel="stylesheet" type="text/css" href="style.css">
<link rel="icon" type="image/x-icon" href="/Images/pill.svg">
<script type="text/javascript" src="app.js"></script>
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto">
<!-- <script src="https://cdn.jsdelivr.net/npm/chart.js"></script> -->
<script src="https://cdn.plot.ly/plotly-2.30.0.min.js"></script>
</head>
<body>
<button class="back-button" onclick="slideDown()"> < Back </button>
<div id="overviewText" class="overlay-text">
<div id="overviewBackground"></div>
<h2>Overview:</h2>
<p> We propose FlyKD, a novel, scalable way to compress large Graph Neural Networks
(GNN) to lighter, deployable GNNs. Specifically, FlyKD is a variation of Knowledge
Distillation (KD) method where a larger, more capable teacher model generates
pseudo labels for the student model to learn from.
</p>
<p>
FlyKD has two novel components in addition to the original KD to address two problems
in Knowledge Distillation in Graphs. The first problem FlyKD addresses is memory isssue of
GNNs. When you generate many pseudo labels using GNNs, users will often run into cuda
Out of Memory (OOM) issue as the student model backpropagates. However,
by generating unseen pseudo labels on the fly with randomness at every epoch, one can generate
virtually infinite amount of pseudo labels for the student model to learn from.
</p>
<p>
But this poses another problem: pseudo labels are inherently noisy and difficult to optimize.
Generating immense amount of pseudo labels worsens this problem. In order to alleviate this,
FlyKD incorporates a form of Curriculum Learning. Inspired by how humans learn, Curriculum
Learning helps the optimization process of the model by introducing data at an increasing
complexity. To incoporate Curriculum Learning, we use our prior knowledge of noisiness
of each type of labels and gradually introduce them in the order of noisiness.
</p>
<h2>PrimeKG:</h2>
<p>
Our dataset is composed of the latest biomedical Knowledge Graph called Precision Medicine
Knowledge Graph (Chandak, Huang and Zitnik (2023)), which combines 20+ credible databases such
as DrugBank, DrugCentral, DisGeNet, etc. PrimeKG is specially designed for therepeutic
repurposing of drugs and can be used to find new treatments to diseases that currently
have no treatment.
</p>
<p>
Furthermore, PrimeKG has over 4 million relations and the scheme of the Knowledge Graph can be illustrated
below (adapted from Chandak, Huang and Zitnik (2023)):
<div style="text-align: center;">
<img src="Images/schematic-nobg.png" height: auto;">
</div>
</p>
<br><br>
<div class="nav-button-container ">
<button class="section-nav-button" id="toMethodsButton" onclick="switchSection('methodsText', 'method-text')">Methods ></button>
</div>
</div>
<div id="methodsText" class="method-text">
<div id="methodsBackground"></div>
<h2>Methods:</h2>
<p> FlyKD first trains a teacher TxGNN, zero-shot Relational GCN (R-GCN) developed by (Chandak, Huang and Zitnik
(2023)),
to predict the existence of indication and contraindication relations between drugs and diseases. With
pre-trained teacher TxGNN, we obtain the final node embeddings of the original graph then use the final node embeddings
with a scoring function called DistMult to generate three types of labels: Original Label, Pseudo label on Training graph,
and Pseudo label on Degree-aware Random Graph. These three types of labels are illustrated below and ordered by the
level of noisiness. </p>
<img src="Images/Labels.png" alt="Model Architecture" height: auto;">
<p>
The recipe for generating Degree-Aware Random Graph is depicted below:
</p>
<img src="Images/Algo.png" height: auto;">
<p>
In plain words, GenerateRandomGraph selects two nodes where the probability of a
node being selected is proportionate to the degree of a node in the original graph (with
respect to the relation). By doing so, we increase the quality of the pseudo labels on the
random graph by utilizing the prior information that nodes that have seen more labels
during training are more likely to have a higher embedding quality. Once two nodes are
selected, a link is formed between them. This process of generating pseudo links is repeated
k times.
</p>
<p>
The above generation of random graph per epoch conceptually enables the teacher model to generate
unlimited number of pseudo labels that can be backpropgated by the student model.
</p>
<p>
As mentioned earlier, pseudo labels are inherently noisy and pseudo labels on random graphs are
even more noisy. Thus, we incorporate a form of Curriculum Learning by utilizing a loss scheduler,
which gradually shifts the emphasis from easy original label to medium pseudo label on train graph and
ultimately hard pseudo label on the random graph. The loss scheduler is depicted below:
</p>
<img src="Images/scheduler.png" height: auto;">
<br>
<br><br>
<div class="nav-button-container ">
<button class="section-nav-button" id="toOverviewButton" onclick="switchSection('overviewText', 'overlay-text')">< Overview</button>
<button class="section-nav-button" id="toMethodsButton" onclick="switchSection('resultsText', 'method-text')">Results ></button>
</div>
</div>
<div id="resultsText" class="method-text">
<div id="methodsBackground"></div>
<h2>Results:</h2>
<p>
Here are the baseline models' performances where TxGNN 130 represents the more capable teacher model and
TxGNN 80 represents the lighter student model.
</p>
<table style="width:75%; margin: auto; border-collapse: collapse; border: 1px solid black; margin-bottom: 20px;">
<caption style="text-align: center; font-weight: bold; margin-bottom: 0.5em;">
Baseline AUPRC (%):
</caption>
<tr>
<th style="border: 1px solid black;">Model</th>
<th style="border: 1px solid black;">Num. Params</th>
<th style="border: 1px solid black;">Seed 45</th>
<th style="border: 1px solid black;">Seed 46</th>
<th style="border: 1px solid black;">Seed 47</th>
<th style="border: 1px solid black;">Seed 48</th>
<th style="border: 1px solid black;">Seed 49</th>
<th style="border: 1px solid black;">Mean ± std</th>
</tr>
<tr>
<td style="border: 1px solid black;">Baseline 130</td>
<td style="border: 1px solid black;">(1.7M)</td>
<td style="border: 1px solid black;">80.33</td>
<td style="border: 1px solid black;">73.66</td>
<td style="border: 1px solid black;">76.29</td>
<td style="border: 1px solid black;">84.19</td>
<td style="border: 1px solid black;">79.91</td>
<td style="border: 1px solid black;">78.87 ± 4.04</td>
</tr>
<tr>
<td style="border: 1px solid black;">Baseline 80</td>
<td style="border: 1px solid black;">(650k)</td>
<td style="border: 1px solid black;">78.64</td>
<td style="border: 1px solid black;">71.97</td>
<td style="border: 1px solid black;">74.44</td>
<td style="border: 1px solid black;">82.74</td>
<td style="border: 1px solid black;">77.87</td>
<td style="border: 1px solid black;">77.13 ± 4.13</td>
</tr>
</table>
<p>
There is a high variance of performance due to zero-shot
evaluation setting. Thus, we focus on the relative performance gains at each seed and average across seeds.
This way, we can more accurately attribute changes in performance to the methods themselves,
rather than to fluctuations in task difficulty associated with different seeds.
</p>
<p>
In our study, we delve into three distinct methods of Knowledge Distillation (KD) for Graph Neural Networks (GNN):
Basic Knowledge Distillation (BKD) as introduced by Hinton, Vinyals, and Dean in 2015,
Local Structure Preserving GCN (LSPGCN, also known as DistillGCN) developed by Yang et al. in 2020,
and our innovative approach, FlyKD.
</p>
<table style="width:75%; margin: auto; border-collapse: collapse; border: 1px solid black; margin-bottom: 20px;">
<caption style="text-align: center; font-weight: bold; margin-bottom: 0.5em;">
Knowledge Distillation Methods (Relative gains from Baseline80)
</caption>
<tr>
<th style="border: 1px solid black;">Model</th>
<th style="border: 1px solid black;">Time</th>
<th style="border: 1px solid black;">Curriculum Learning</th>
<th style="border: 1px solid black;">Mean±std</th>
</tr>
<tr>
<td style="border: 1px solid black;">Basic KD</td>
<td style="border: 1px solid black;">1600</td>
<td style="border: 1px solid black;">No</td>
<td style="border: 1px solid black;">-0.62±0.59</td>
</tr>
<tr>
<td style="border: 1px solid black;">LSP 1 layer (RBF)</td>
<td style="border: 1px solid black;">20000</td>
<td style="border: 1px solid black;">No</td>
<td style="border: 1px solid black;">-1.09±0.23</td>
</tr>
<tr>
<td style="border: 1px solid black;">LSP 2 layers (RBF)</td>
<td style="border: 1px solid black;">40000</td>
<td style="border: 1px solid black;">No</td>
<td style="border: 1px solid black;">-1.41±0.82</td>
</tr>
<tr>
<td style="border: 1px solid black;">FlyKD</td>
<td style="border: 1px solid black;">2000</td>
<td style="border: 1px solid black;">Yes</td>
<td style="border: 1px solid black;">1.16±0.36</td>
</tr>
</table>
<p>
Our experiments reveal that while BKD and LSPGCN result in negative KD effects,
FlyKD stands out by achieving positive relative gains. We find out the reason for such gap in the following
ablation study.
</p>
<!-- <div id='myDiv2'></div> -->
<table style="width:75%; margin-left: auto; margin-right: auto; border-collapse: collapse; border: 1px solid black; margin-bottom: 20px;">
<caption style="text-align: center; font-weight: bold; margin-bottom: 0.5em;">
Ablation study (Relative gains from Baseline80)
</caption>
<tr>
<th style="border: 1px solid black;">Model</th>
<th style="border: 1px solid black;">Configuration</th>
<th style="border: 1px solid black;">Mean±std</th>
</tr>
<tr>
<td style="border: 1px solid black;">Basic KD</td>
<td style="border: 1px solid black;">Employ Curriculum Learning</td>
<td style="border: 1px solid black;">0.93±0.45</td>
</tr>
<tr>
<td style="border: 1px solid black;">FlyKD</td>
<td style="border: 1px solid black;">Fix Random Graph</td>
<td style="border: 1px solid black;">1.14±0.39</td>
</tr>
<tr>
<td style="border: 1px solid black;">FlyKD</td>
<td style="border: 1px solid black;">No Curriculum Learning</td>
<td style="border: 1px solid black;">0.19±0.42</td>
</tr>
<tr>
<td style="border: 1px solid black;">FlyKD</td>
<td style="border: 1px solid black;">Take Out Pseudo Labels on Train Dataset</td>
<td style="border: 1px solid black;">-0.68±0.63</td>
</tr>
<tr>
<td style="border: 1px solid black;">FlyKD</td>
<td style="border: 1px solid black;">stepwise function for Curriculum Learning</td>
<td style="border: 1px solid black;">-1.436±0.86</td>
</tr>
</table>
<p>
Our ablation study shows that noise in the teacher model's pseudo labels causes the performance gap between FlyKD and other KD methods.
Adding Curriculum Learning to BKD improves its performance noticeably, giving a +1.55% boost over the standard approach.
</p>
<br><br>
<div class="nav-button-container ">
<button class="section-nav-button" id="toOverviewButton" onclick="switchSection('methodsText', 'method-text')">< Methods</button>
<button class="section-nav-button" id="toMethodsButton" onclick="switchSection('findingsText', 'method-text')">Findings ></button>
</div>
</div>
<div id="findingsText" class="method-text">
<div id="methodsBackground"></div>
<h2>Findings</h2>
<ul>
<li>Optimization process of the student model in KD is inherently noisy, and Curriculum Learning can greatly alleviate this issue.</li>
<li>Curriculum Learning for KD requires a gradual change in difficulty, corroborated by catastrophic degradation in performance when a step-wise loss scheduler is employed instead.</li>
<li>LSPGCN is not scalable for Knowledge Graphs, as demonstrated by a 10-20x increase in time as shown in Table 2. This holds true for most Knowledge Distillation methods tailored for graphs with similar mechanisms.</li>
<li>FlyKD’s pseudo labels on degree-aware random graphs seem to help but do not provide any additional gains beyond the generation of one random graph and maintaining it, instead of generating a new random graph at every epoch throughout the training. Further investigation is needed to understand this phenomenon.</li>
</ul>
<p>
Last Remark:
We believe our experimental results suggest a new research direction of how to improve the
optimization process of the student model rather than the common what pseudo labels to
generate (What to distill).
</p>
<br><br>
<div class="nav-button-container ">
<button class="section-nav-button" id="toOverviewButton" onclick="switchSection('resultsText', 'method-text')">< Results</button>
</div>
</div>
<div class="main-container bg-image">
<div class="centered-container">
<div id="title-div">
<h1 id="title">FlyKD</h1> <h1>Graph Knowledge Distillation on the Fly with Curriculum Learning</h1>
</div>
<div class="button-container">
<button class="text-button" onclick="slideUp('overviewText', 'overlay-text')">Overview</button>
<button class="text-button" onclick="slideUp('methodsText', 'method-text')">Methods</button>
<button class="text-button" onclick="slideUp('resultsText', 'method-text')">Results</button>
<button class="text-button" onclick="slideUp('findingsText', 'method-text')">Findings</button>
<button href="https://github.com/Drug-Repurposing-GNN/SSL-DiseaseDrug-Prediction" class="text-button" onclick="window.open('https://github.com/Drug-Repurposing-GNN/SSL-DiseaseDrug-Prediction')">Code</button>
<button href="https://drive.google.com/file/d/148UQyEkZ6p4ZGRXUR3hh5w5m7xPlSrUP/view?usp=sharing" class="text-button" onclick="window.open('https://drive.google.com/file/d/148UQyEkZ6p4ZGRXUR3hh5w5m7xPlSrUP/view?usp=sharing')">Paper</button>
</div>
<div class="title-div">
<p> Dear incoming capstone students, I have created a <a href="https://docs.google.com/document/d/1lfRQBqIBnwr-BRNaXZfbS_bbZerJPE4mibgd3uF5rgk/edit#heading=h.2fagso2oojx"> short guide </a> on how to prepare better for Yusu and Gal's capstone. I hope you will take a look at it and find them useful. Note that these are purely my opinions and have nothing to do with Yusu and Gal's. Best of luck and feel free to reach out if you have any other questions: [email protected]. </p>
</div>
</div>
</div>
</body>
</html>