From 2cc57eab65f6a7439984af35582bb32094fa91c7 Mon Sep 17 00:00:00 2001 From: Tian Zhen <1204216974@qq.com> Date: Wed, 30 Mar 2022 19:34:48 +0800 Subject: [PATCH] FIX: code optimization of SRGNN --- recbole/model/sequential_recommender/srgnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recbole/model/sequential_recommender/srgnn.py b/recbole/model/sequential_recommender/srgnn.py index 0147f1499..5e94459df 100644 --- a/recbole/model/sequential_recommender/srgnn.py +++ b/recbole/model/sequential_recommender/srgnn.py @@ -177,7 +177,7 @@ def _get_slice(self, item_seq): # The relative coordinates of the item node, shape of [batch_size, max_session_len] alias_inputs = torch.LongTensor(alias_inputs).to(self.device) # The connecting matrix, shape of [batch_size, max_session_len, 2 * max_session_len] - A = torch.FloatTensor(A).to(self.device) + A = torch.FloatTensor(np.array(A)).to(self.device) # The unique item nodes, shape of [batch_size, max_session_len] items = torch.LongTensor(items).to(self.device)