diff --git a/recbole/model/general_recommender/diffrec.py b/recbole/model/general_recommender/diffrec.py index 82e33a75c..84716ac76 100644 --- a/recbole/model/general_recommender/diffrec.py +++ b/recbole/model/general_recommender/diffrec.py @@ -328,7 +328,7 @@ def full_sort_predict(self, interaction): def predict(self, interaction): item = interaction[self.ITEM_ID] x_t = self.full_sort_predict(interaction) - scores = x_t[:, item] + scores = x_t[torch.arange(len(item)).to(self.device), item] return scores def calculate_loss(self, interaction): diff --git a/recbole/model/general_recommender/ldiffrec.py b/recbole/model/general_recommender/ldiffrec.py index 7d8364f0b..a537db7d2 100644 --- a/recbole/model/general_recommender/ldiffrec.py +++ b/recbole/model/general_recommender/ldiffrec.py @@ -335,7 +335,7 @@ def full_sort_predict(self, interaction): def predict(self, interaction): item = interaction[self.ITEM_ID] x_t = self.full_sort_predict(interaction) - scores = x_t[:, item] + scores = x_t[torch.arange(len(item)).to(self.device), item] return scores