diff --git a/recbole/model/sequential_recommender/srgnn.py b/recbole/model/sequential_recommender/srgnn.py index 0147f1499..5e94459df 100644 --- a/recbole/model/sequential_recommender/srgnn.py +++ b/recbole/model/sequential_recommender/srgnn.py @@ -177,7 +177,7 @@ def _get_slice(self, item_seq): # The relative coordinates of the item node, shape of [batch_size, max_session_len] alias_inputs = torch.LongTensor(alias_inputs).to(self.device) # The connecting matrix, shape of [batch_size, max_session_len, 2 * max_session_len] - A = torch.FloatTensor(A).to(self.device) + A = torch.FloatTensor(np.array(A)).to(self.device) # The unique item nodes, shape of [batch_size, max_session_len] items = torch.LongTensor(items).to(self.device)