diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index f4584b1bf1..ae2092d921 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -249,7 +249,8 @@ def _prepare_dataloader(self, g, target_idxs, fanout, return loader def __iter__(self): - return self.dataloader.__iter__() + self.dataloader.__iter__() + return self def __next__(self): input_nodes, pos_graph, neg_graph, blocks = self.dataloader.__next__() @@ -258,7 +259,7 @@ def __next__(self): for etype in pos_graph.canonical_etypes} edge_weight_feats = self._data.get_edge_feats(input_edges, self._lp_edge_weight_for_loss, - self._device) + pos_graph.device) # store edge feature into graph for etype, feat in edge_weight_feats.items(): pos_graph.edges[etype].data[LP_DECODER_EDGE_WEIGHT] = feat @@ -504,7 +505,8 @@ def _prepare_dataloader(self, g, target_idxs, fanout, num_negative_edges, return loader def __iter__(self): - return self.dataloader.__iter__() + self.dataloader.__iter__() + return self def __next__(self): input_nodes, pos_graph, neg_graph, blocks = self.dataloader.__next__() @@ -513,7 +515,7 @@ def __next__(self): for etype in pos_graph.canonical_etypes} edge_weight_feats = self._data.get_edge_feats(input_edges, self._lp_edge_weight_for_loss, - self._device) + pos_graph.device) # store edge feature into graph for etype, feat in edge_weight_feats.items(): pos_graph.edges[etype].data[LP_DECODER_EDGE_WEIGHT] = feat