Skip to content

Commit

Permalink
Fix the device mis-match bug. (#288) (#290)
Browse files Browse the repository at this point in the history
*Issue #, if available:*
The bug is triggered when edge features are used and mini-batch
inference is enabled.


*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Jun 26, 2023
1 parent 198c141 commit 34dd469
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/graphstorm/model/edge_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def edge_mini_batch_gnn_predict(model, loader, return_proba=True, return_label=F
input_nodes = {g.ntypes[0]: input_nodes}
input_feats = data.get_node_feats(input_nodes, device)
blocks = [block.to(device) for block in blocks]
batch_graph = batch_graph.to(device)
pred = model.predict(blocks, batch_graph, input_feats, None, input_nodes,
return_proba)
preds.append(pred.cpu())
Expand Down
6 changes: 6 additions & 0 deletions tests/end2end-tests/graphstorm-ec/mgpu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,9 @@ python3 -m graphstorm.run.gs_edge_classification --workspace $GS_HOME/training_s

error_and_exit $?
rm -fr /data/gsgnn_ec/*

echo "**************dataset: Generated multilabel MovieLens EC, RGCN layer: 1, node feat: generated feature, inference: minibatch, exclude-training-targets: True, decoder edge feat: label"
python3 -m graphstorm.run.gs_edge_classification --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_label_ec/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec.yaml --exclude-training-targets True --multilabel true --num-classes 6 --node-feat-name feat --use-mini-batch-infer true --topk-model-to-save 1 --save-embed-path /data/gsgnn_ec/emb/ --save-model-path /data/gsgnn_ec/ --save-model-frequency 1000 --decoder-edge-feat user,rating,movie:rate --fanout 'user/rating/movie:4@movie/rating-rev/user:5,user/rating/movie:2@movie/rating-rev/user:2' --num-layers 2 --decoder-type MLPEFeatEdgeDecoder

error_and_exit $?
rm -fr /data/gsgnn_ec/*

0 comments on commit 34dd469

Please sign in to comment.