Skip to content

Commit

Permalink
[Bug fix] Fix using weight in link prediction. (#278)
Browse files Browse the repository at this point in the history
*Issue #, if available:*
GSgnnLinkPredictionDataLoader does not return edge weight in graph.

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song committed Jun 21, 2023
1 parent cd0bc20 commit f7393ea
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions python/graphstorm/dataloading/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ def _prepare_dataloader(self, g, target_idxs, fanout,
return loader

def __iter__(self):
return self.dataloader.__iter__()
self.dataloader.__iter__()
return self

def __next__(self):
input_nodes, pos_graph, neg_graph, blocks = self.dataloader.__next__()
Expand All @@ -258,7 +259,7 @@ def __next__(self):
for etype in pos_graph.canonical_etypes}
edge_weight_feats = self._data.get_edge_feats(input_edges,
self._lp_edge_weight_for_loss,
self._device)
pos_graph.device)
# store edge feature into graph
for etype, feat in edge_weight_feats.items():
pos_graph.edges[etype].data[LP_DECODER_EDGE_WEIGHT] = feat
Expand Down Expand Up @@ -504,7 +505,8 @@ def _prepare_dataloader(self, g, target_idxs, fanout, num_negative_edges,
return loader

def __iter__(self):
return self.dataloader.__iter__()
self.dataloader.__iter__()
return self

def __next__(self):
input_nodes, pos_graph, neg_graph, blocks = self.dataloader.__next__()
Expand All @@ -513,7 +515,7 @@ def __next__(self):
for etype in pos_graph.canonical_etypes}
edge_weight_feats = self._data.get_edge_feats(input_edges,
self._lp_edge_weight_for_loss,
self._device)
pos_graph.device)
# store edge feature into graph
for etype, feat in edge_weight_feats.items():
pos_graph.edges[etype].data[LP_DECODER_EDGE_WEIGHT] = feat
Expand Down

0 comments on commit f7393ea

Please sign in to comment.