diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index 4f2cd57ca5..aed46cb31e 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -175,6 +175,7 @@ def edge_mini_batch_gnn_predict(model, loader, return_proba=True, return_label=F input_nodes = {g.ntypes[0]: input_nodes} input_feats = data.get_node_feats(input_nodes, device) blocks = [block.to(device) for block in blocks] + batch_graph = batch_graph.to(device) pred = model.predict(blocks, batch_graph, input_feats, None, input_nodes, return_proba) preds.append(pred.cpu()) diff --git a/tests/end2end-tests/graphstorm-ec/mgpu_test.sh b/tests/end2end-tests/graphstorm-ec/mgpu_test.sh index f69e769e5a..162b7d85b3 100644 --- a/tests/end2end-tests/graphstorm-ec/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-ec/mgpu_test.sh @@ -202,3 +202,9 @@ python3 -m graphstorm.run.gs_edge_classification --workspace $GS_HOME/training_s error_and_exit $? rm -fr /data/gsgnn_ec/* + +echo "**************dataset: Generated multilabel MovieLens EC, RGCN layer: 1, node feat: generated feature, inference: minibatch, exclude-training-targets: True, decoder edge feat: label" +python3 -m graphstorm.run.gs_edge_classification --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_label_ec/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec.yaml --exclude-training-targets True --multilabel true --num-classes 6 --node-feat-name feat --use-mini-batch-infer true --topk-model-to-save 1 --save-embed-path /data/gsgnn_ec/emb/ --save-model-path /data/gsgnn_ec/ --save-model-frequency 1000 --decoder-edge-feat user,rating,movie:rate --fanout 'user/rating/movie:4@movie/rating-rev/user:5,user/rating/movie:2@movie/rating-rev/user:2' --num-layers 2 --decoder-type MLPEFeatEdgeDecoder + +error_and_exit $? +rm -fr /data/gsgnn_ec/*