-
Notifications
You must be signed in to change notification settings - Fork 6
/
SVDppRecommender.py
48 lines (41 loc) · 1.63 KB
/
SVDppRecommender.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
'''
@Author: Yu Di
@Date: 2019-08-08 14:43:33
@LastEditors: Yudi
@LastEditTime: 2019-08-13 16:10:21
@Company: Cardinal Operation
@Email: [email protected]
@Description:
'''
import torch
class SVDpp(torch.nn.Module):
def __init__(self, params):
super(SVDpp, self).__init__()
self.num_users = params['num_users']
self.num_items = params['num_items']
self.latent_dim = params['latent_dim']
self.mu = params['global_mean']
self.user_embedding = torch.nn.Embedding(self.num_users, self.latent_dim)
self.item_embedding = torch.nn.Embedding(self.num_items, self.latent_dim)
self.user_bias = torch.nn.Embedding(self.num_users, 1)
self.user_bias.weight.data = torch.zeros(self.num_users, 1).float()
self.item_bias = torch.nn.Embedding(self.num_items, 1)
self.item_bias.weight.data = torch.zeros(self.num_items, 1).float()
self.yj = torch.nn.Embedding(self.num_items, self.latent_dim)
def forward(self, user_idx, item_idx, Iu):
'''
Parameters
----------
Iu: item set that user u interacted before
'''
user_vec = self.user_embedding(user_idx)
u_impl_fdb = torch.zeros(user_idx.size(0), self.latent_dim)
for j in Iu:
j = torch.LongTensor([j])
u_impl_fdb += self.yj(j)
u_impl_fdb /= torch.FloatTensor([len(Iu)]).sqrt()
user_vec += u_impl_fdb
item_vec = self.item_embedding(item_idx)
dot = torch.mul(user_vec, item_vec).sum(dim=1)
rating = dot + self.mu + self.user_bias(user_idx).view(-1) + self.item_bias(item_idx).view(-1)
return rating