Skip to content

Commit

Permalink
Glem full graph inference (#498)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:* Enabled in full-graph inference for GLEM
model. Added integration test to ensure the inference results are
consistent between mini-batch and full-graph inference.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Da Zheng <[email protected]>
  • Loading branch information
2 people authored and Xiang Song committed Sep 29, 2023
1 parent 708a759 commit 925c145
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 14 deletions.
3 changes: 2 additions & 1 deletion python/graphstorm/model/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,8 @@ def do_full_graph_inference(model, data, batch_size=1024, fanout=None, edge_mask
-------
dict of th.Tensor : node embeddings.
"""
assert isinstance(model, GSgnnModel), "Only GSgnnModel supports full-graph inference."
assert isinstance(model, GSgnnModel) or type(model).__name__ == 'GLEM',\
"Only GSgnnModel and GLEM support full-graph inference."
t1 = time.time() # pylint: disable=invalid-name
# full graph evaluation
barrier()
Expand Down
45 changes: 38 additions & 7 deletions python/graphstorm/model/node_glem.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,25 @@ def __init__(self,
self.num_pretrain_epochs = num_pretrain_epochs
self.lm = GSgnnNodeModel(alpha_l2norm)
self.gnn = GSgnnNodeModel(alpha_l2norm)
self.training_lm = not em_order_gnn_first
# `training_lm` has three states, controled by `.toggle()`:
# None: model is loaded for inference, inference logic is decided by `inference_using_gnn`
# True: lm is being trained
# False: gnn is being trained
self.training_lm = None

@property
def inference_route_is_gnn(self):
"""This flag decides which inference route to perform: gnn (True) or lm (False).
This is decided based on the values of `training_lm` and `inference_using_gnn`.
There are two inference routes for GLEM:
- False (lm): lm.node_input_encoder->lm.decoder
- True (gnn): lm.node_input_encoder->gnn.gnn_encoder->gnn.decoder
"""
if self.training_lm is None:
# GLEM is loaded for inference only, decide based `inference_using_gnn`
return self.inference_using_gnn
# GLEM is being trained: use gnn route if not training lm
return not self.training_lm

def init_optimizer(self, lr, sparse_optimizer_lr, weight_decay, lm_lr=None):
"""Initialize optimzer, which will be stored in self.lm._optimizer, self.gnn._optimizer
Expand Down Expand Up @@ -125,8 +143,21 @@ def set_gnn_encoder(self, gnn_encoder):

@property
def gnn_encoder(self):
"""Alias for accessing the gnn_encoder"""
return self.gnn.gnn_encoder
"""Alias for accessing the gnn_encoder. Hide gnn_encoder if the inference route is lm.
This property is only used for model inference and evaluation."""
return self.gnn.gnn_encoder if self.inference_route_is_gnn else None

@property
def node_input_encoder(self):
"""Alias for accessing the node_input_encoder.
This property is only used for model inference and evaluation."""
return self.lm.node_input_encoder

@property
def decoder(self):
"""Alias for accessing the decoder.
This property is only used for model inference and evaluation."""
return self.gnn.decoder if self.inference_route_is_gnn else self.lm.decoder

def set_decoder(self, decoder):
"""Set the same decoder for both, since lm needs to be able to
Expand Down Expand Up @@ -320,8 +351,8 @@ def forward_lm_semisup(self, blocks, node_feats, edge_feats, labels, input_nodes

def predict(self, blocks, node_feats, edge_feats, input_nodes, return_proba):
"""Make prediction on the nodes with the LM or GNN.
The model's `inference_using_gnn` flag determines how inference is performed.
If inference_using_gnn is True, message-passing GNN is used on the LM features,
The model's `inference_route_is_gnn` flag determines how inference is performed.
If inference_route_is_gnn, message-passing GNN is used on the LM features,
Otherwise, LM's decoder is used for inference, no message-passing involved.
Parameters
Expand All @@ -344,8 +375,8 @@ def predict(self, blocks, node_feats, edge_feats, input_nodes, return_proba):
Tensor : the GNN embeddings.
"""
emb_lm, emb_gnn = self._embed_nodes(blocks, node_feats, edge_feats, input_nodes,
do_gnn_encode=self.inference_using_gnn)
if self.inference_using_gnn:
do_gnn_encode=self.inference_route_is_gnn)
if self.inference_route_is_gnn:
decoder = self.gnn.decoder
emb = emb_gnn
else:
Expand Down
4 changes: 0 additions & 4 deletions python/graphstorm/trainer/glem_np_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from ..model.node_gnn import GSgnnNodeModelInterface
from ..model.node_glem import GLEM
from ..model.gnn import GSgnnModel
from .np_trainer import GSgnnNodePredictionTrainer

from ..utils import sys_tracker, rt_profiler, print_mem
Expand Down Expand Up @@ -134,9 +133,6 @@ def fit(self, train_loader, num_epochs,
if self.evaluator is not None:
assert val_loader is not None, \
"The evaluator is provided but validation set is not provided."
if not use_mini_batch_infer:
assert isinstance(self._model, GSgnnModel), \
"Only GSgnnModel supports full-graph inference."

# computation graph will be changed during training.
on_cpu = self.device == th.device('cpu')
Expand Down
20 changes: 18 additions & 2 deletions tests/end2end-tests/graphstorm-nc/mgpu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ python3 -m graphstorm.run.gs_node_classification --workspace $GS_HOME/training_s
error_and_exit $?
rm -fr /data/gsgnn_nc_ml_text/*

echo "**************dataset: MovieLens classification, GLEM co-training, RGCN layer: 1, node feat: BERT nodes: movie, user inference: mini-batch save model save emb node"
echo "**************dataset: MovieLens classification, GLEM co-training, RGCN layer: 1, node feat: BERT nodes: movie, user inference: full-graph save model save emb node"
python3 -m graphstorm.run.gs_node_classification --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lm_encoder_train_val_1p_4t/movie-lens-100k-text.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_utext_glem.yml --save-model-path /data/gsgnn_nc_ml_text/ --topk-model-to-save 1 --num-epochs 3

error_and_exit $?
Expand Down Expand Up @@ -299,11 +299,27 @@ fi

best_epoch=$(ls /data/gsgnn_nc_ml_text/ | grep epoch)
echo "**************dataset: MovieLens node classification, GLEM loads GLEM trained checkpoints, RGCN layer: 1, node feat: BERT nodes: movie, user inference: mini-batch save model"
python3 -m graphstorm.run.gs_node_classification --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lm_encoder_train_val_1p_4t/movie-lens-100k-text.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_utext_glem.yml --restore-model-path /data/gsgnn_nc_ml_text/$best_epoch --restore-model-layers embed,decoder --inference
python3 -m graphstorm.run.gs_node_classification --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lm_encoder_train_val_1p_4t/movie-lens-100k-text.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_utext_glem.yml --restore-model-path /data/gsgnn_nc_ml_text/$best_epoch --restore-model-layers embed,decoder --inference --use-mini-batch-infer false --logging-file /tmp/full_graph_inf.txt

error_and_exit $?

python3 -m graphstorm.run.gs_node_classification --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lm_encoder_train_val_1p_4t/movie-lens-100k-text.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_utext_glem.yml --restore-model-path /data/gsgnn_nc_ml_text/$best_epoch --restore-model-layers embed,decoder --inference --use-mini-batch-infer true --logging-file /tmp/mini_batch_inf.txt

error_and_exit $?

# assert same performance results from mini-batch and full-graph inference
perf_full=$(grep "Best Test accuracy" /tmp/full_graph_inf.txt | sed 's/^.*: //')
perf_mini=$(grep "Best Test accuracy" /tmp/mini_batch_inf.txt | sed 's/^.*: //')
if $perf_full != $perf_mini
then
echo "The performance is different from full-graph and mini-batch inference!"
exit -1
fi


rm -fr /data/gsgnn_nc_ml_text/*
rm /tmp/full_graph_inf.txt
rm /tmp/mini_batch_inf.txt

echo "**************dataset: MovieLens classification, RGCN layer: 1, node feat: fixed HF BERT, BERT nodes: movie, inference: mini-batch save model save emb node, Backend nccl"
python3 -m graphstorm.run.gs_node_classification --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc.yaml --save-model-path /data/gsgnn_nc_ml/ --num-epochs 1 --backend nccl --node-feat-name movie:title user:feat
Expand Down

0 comments on commit 925c145

Please sign in to comment.