Skip to content

Commit

Permalink
Merge branch 'main' into feat/21-s3rec
Browse files Browse the repository at this point in the history
  • Loading branch information
GangBean authored Sep 5, 2024
2 parents 322ea17 + 0485a0f commit 1ba8ea1
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 28 deletions.
4 changes: 2 additions & 2 deletions configs/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ epochs: 5
batch_size: 32
lr: 0.0001
optimizer: adam # adamw
loss_name: bce # bpr # pointwise # bce
loss_name: bpr # bpr # pointwise # bce
patience: 5
top_n: 10
weight_decay: 0 #1e-5
Expand All @@ -43,7 +43,7 @@ model:
embed_size: 64
NGCF:
embed_size: 64
num_orders: 3
num_orders: 2
S3Rec:
embed_size: 64
max_seq_len: 50
Expand Down
17 changes: 9 additions & 8 deletions data/datasets/ngcf_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,26 @@ def _set_laplacian_matrix(self, df):
# transform df to user-item interaction (R)
logger.info('transform df to user-item interaction')
user_item_interaction = df.pivot_table(index='user_id', columns=['business_id'], values=['rating'])
user_item_interaction = user_item_interaction.droplevel(0, 1)
user_item_interaction = user_item_interaction.droplevel(0, 1).fillna(0)

# adjacency matrix
logger.info('create adjacency matrix')
adjacency_matrix = np.zeros((self.num_items+self.num_users, self.num_items+self.num_users))
adjacency_matrix = np.zeros((self.num_items+self.num_users, self.num_items+self.num_users), dtype=np.float32)
adjacency_matrix[:self.num_users,self.num_users:] = user_item_interaction
adjacency_matrix[self.num_users:,:self.num_users] = user_item_interaction.T

# diagonal degree matrix (n+m) x (m+n)
logger.info('create diagonal degree matrix')
diagonal_degree_matrix = np.diag(1/np.sqrt(adjacency_matrix.sum(axis=0)))
diagonal_degree_matrix = np.diag(1/np.sqrt(adjacency_matrix.sum(axis=0))).astype(np.float32)

# set laplacian matrix
logger.info('set laplacian matrix')
diagonal_degree_matrix = torch.tensor(diagonal_degree_matrix).float().to('cuda')
adjacency_matrix = torch.tensor(adjacency_matrix).float().to('cuda')
self.laplacian_matrix = torch.matmul(diagonal_degree_matrix, adjacency_matrix)
adjacency_matrix = adjacency_matrix.cpu().detach()
self.laplacian_matrix = torch.matmul(self.laplacian_matrix, diagonal_degree_matrix)
diagonal_degree_matrix = torch.from_numpy(diagonal_degree_matrix).to_sparse().to('cuda')
adjacency_matrix = torch.from_numpy(adjacency_matrix).to_sparse().to('cuda')
self.laplacian_matrix = torch.sparse.mm(diagonal_degree_matrix, adjacency_matrix)
del adjacency_matrix
self.laplacian_matrix = torch.sparse.mm(self.laplacian_matrix, diagonal_degree_matrix)
del diagonal_degree_matrix
logger.info('done...')

def preprocess(self) -> pd.DataFrame:
Expand Down
6 changes: 6 additions & 0 deletions data/datasets/ngcf_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .mf_dataset import MFDataset

class NGCFDataset:
pass

NGCFDataset = MFDataset
52 changes: 37 additions & 15 deletions models/ngcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import torch.nn as nn

from models.base_model import BaseModel
from loguru import logger

class NGCF(BaseModel):

def __init__(self, cfg, num_users, num_items, laplacian_matrix):
def __init__(self, cfg, num_users, num_items): #, laplacian_matrix):
super().__init__()
self.cfg = cfg
self.num_users = num_users
self.laplacian_matrix = laplacian_matrix
self.num_items = num_items
# self.laplacian_matrix = laplacian_matrix
self.embedding = nn.Embedding(
num_users+num_items, cfg.embed_size, dtype=torch.float32)

Expand All @@ -22,29 +25,48 @@ def __init__(self, cfg, num_users, num_items, laplacian_matrix):
def _init_weights(self):
for child in self.children():
if isinstance(child, nn.Embedding):
nn.init.xavier_uniform_(child.weight)k
nn.init.xavier_uniform_(child.weight)

def forward(self, user_id, item_id):
def bpr_forward(self, user_id, pos_item_ids, neg_item_ids, laplacian_matrix):
user_embed_list, pos_item_embed_list, neg_item_embed_list = \
[self.embedding(user_id),], [self.embedding(self.num_users+pos_item_ids)], [self.embedding(self.num_users+neg_item_ids)]
last_embed = self.embedding.weight

for w1, w2 in zip(self.W1, self.W2):
last_embed: torch.Tensor = self.embedding_propagation(last_embed, w1, w2, laplacian_matrix)
user_embed_list.append(last_embed[user_id])
pos_item_embed_list.append(last_embed[self.num_users + pos_item_ids])
neg_item_embed_list.append(last_embed[self.num_users + neg_item_ids])

user_embed = torch.concat(user_embed_list, dim=1)
pos_item_embed = torch.concat(pos_item_embed_list, dim=1)
neg_item_embed = torch.concat(neg_item_embed_list, dim=1)

return torch.sum(user_embed * pos_item_embed, dim=1), torch.sum(user_embed * neg_item_embed, dim=1)

def forward(self, user_id, item_id, laplacian_matrix):
user_embed_list, item_embed_list = [self.embedding(user_id),], [self.embedding(self.num_users+item_id)]
last_embed = self.embedding
last_embed = self.embedding.weight
for w1, w2 in zip(self.W1, self.W2):
last_embed = embedding_propagation(last_embed, w1, w2)
user_embed_list.append(last_embed(user_id))
item_embed_list.append(last_embed(self.num_users+item_id))
last_embed: torch.Tensor = self.embedding_propagation(last_embed, w1, w2, laplacian_matrix)
user_embed_list.append(last_embed[user_id])
item_embed_list.append(last_embed[self.num_users + item_id])

user_embed = torch.concat(user_embed_list, dim=1)
item_embed = torch.concat(item_embed_list, dim=1)

return torch.sum(user_emb * item_emb, dim=1)
return torch.sum(user_embed * item_embed, dim=1)

def embedding_propagation(self, last_embed: torch.Tensor, w1, w2, laplacian_matrix):
identity_matrix = torch.eye(*laplacian_matrix.size(), dtype=torch.float32).to_sparse().to(self.cfg.device)
matrix = laplacian_matrix + identity_matrix

def embedding_propagation(self, last_embed, w1, w2):
identity_matrix = torch.eye(*self.laplacian_matrix.size())
term1 = torch.matmul(self.laplacian_matrix + identity_matrix, last_embed)
term1 = torch.sparse.mm(matrix, last_embed)
term1 = w1(term1)

neighbor_embeddings = torch.matmul(self.laplacian_matrix, last_embed)
term2 = torch.mul(neighbor_embeddings, last_embed)
neighbor_embeddings = torch.sparse.mm(laplacian_matrix, last_embed)

term2 = torch.mul(last_embed, neighbor_embeddings)
term2 = w2(term2)

return nn.functional.leaky_relu(term1 + term2)

8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from data.datasets.cdae_dataset import CDAEDataset
from data.datasets.mf_dataset import MFDataset
from data.datasets.dcn_dataset import DCNDataset
from data.datasets.ngcf_dataset import NGCFDataset
from data.datasets.s3rec_dataset import S3RecDataset
from trainers.cdae_trainer import CDAETrainer
from trainers.dcn_trainer import DCNTrainer
from trainers.mf_trainer import MFTrainer
from trainers.ngcf_trainer import NGCFTrainer
from trainers.s3rec_trainer import S3RecTrainer, S3RecPreTrainer
from utils import set_seed

Expand Down Expand Up @@ -94,7 +96,7 @@ def train(cfg, args):#train_dataset, valid_dataset, test_dataset, model_info):
trainer.load_best_model()
trainer.evaluate(args.test_eval_data, 'test')
elif cfg.model_name in ('NGCF', ):
trainer = MGCFTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'],
trainer = NGCFTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'],
args.data_pipeline.laplacian_matrix)
trainer.run(train_dataloader, valid_dataloader, args.valid_eval_data)
trainer.load_best_model()
Expand Down Expand Up @@ -169,8 +171,8 @@ def main(cfg: OmegaConf):
model_info['num_items'], model_info['num_users'] = data_pipeline.num_items, data_pipeline.num_users
elif cfg.model_name == 'NGCF':
train_data, valid_data, valid_eval_data, test_eval_data = data_pipeline.split(df)
train_dataset = MFDataset(train_data, num_items=data_pipeline.num_items)
valid_dataset = MFDataset(valid_data, num_items=data_pipeline.num_items)
train_dataset = NGCFDataset(train_data, num_items=data_pipeline.num_items)
valid_dataset = NGCFDataset(valid_data, num_items=data_pipeline.num_items)
args.update({'valid_eval_data': valid_eval_data, 'test_eval_data': test_eval_data})
model_info['num_items'], model_info['num_users'] = data_pipeline.num_items, data_pipeline.num_users
elif cfg.model_name == 'S3Rec':
Expand Down
182 changes: 182 additions & 0 deletions trainers/ngcf_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import wandb

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader
from torch.optim import Optimizer

from loguru import logger
from omegaconf.dictconfig import DictConfig
import wandb

from models.ngcf import NGCF
from .base_trainer import BaseTrainer
from metric import *
from loss import BPRLoss

class NGCFTrainer(BaseTrainer):
def __init__(self, cfg: DictConfig, num_items: int, num_users: int, laplacian_matrix: torch.Tensor) -> None:
super().__init__(cfg)
logger.info(f'[DEVICE] device = {self.device}')
self.num_items = num_items
self.num_users = num_users
self.model = NGCF(self.cfg, num_users, num_items).to(self.device)
self.optimizer: Optimizer = self._optimizer(self.cfg.optimizer, self.model, self.cfg.lr, self.cfg.weight_decay)
self.loss = self._loss()
self.laplacian_matrix = laplacian_matrix

def _loss(self):
return BPRLoss()

def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, valid_eval_data: pd.DataFrame):
logger.info(f"[Trainer] run...")

best_valid_loss: float = 1e+6
best_valid_precision_at_k: float = .0
best_valid_recall_at_k: float = .0
best_valid_map_at_k: float = .0
best_valid_ndcg_at_k: float = .0
best_epoch: int = 0
endurance: int = 0

# train
for epoch in range(self.cfg.epochs):
train_loss: float = self.train(train_dataloader)
valid_loss: float = self.validate(valid_dataloader)
(valid_precision_at_k,
valid_recall_at_k,
valid_map_at_k,
valid_ndcg_at_k) = self.evaluate(valid_eval_data, 'valid')
logger.info(f'''\n[Trainer] epoch: {epoch} > train loss: {train_loss:.4f} /
valid loss: {valid_loss:.4f} /
precision@K : {valid_precision_at_k:.4f} /
Recall@K: {valid_recall_at_k:.4f} /
MAP@K: {valid_map_at_k:.4f} /
NDCG@K: {valid_ndcg_at_k:.4f}''')

# wandb logging
if self.cfg.wandb:
wandb.log({
'train_loss': train_loss,
'valid_loss': valid_loss,
'valid_Precision@K': valid_precision_at_k,
'valid_Recall@K': valid_recall_at_k,
'valid_MAP@K': valid_map_at_k,
'valid_NDCG@K': valid_ndcg_at_k,
})

# update model
if self._is_surpass_best_metric(
current=(valid_loss,
valid_precision_at_k,
valid_recall_at_k,
valid_map_at_k,
valid_ndcg_at_k),
best=(best_valid_loss,
best_valid_precision_at_k,
best_valid_recall_at_k,
best_valid_map_at_k,
best_valid_ndcg_at_k)):
logger.info(f"[Trainer] update best model...")
best_valid_loss = valid_loss
best_valid_precision_at_k = valid_precision_at_k
best_valid_recall_at_k = valid_recall_at_k
best_valid_ndcg_at_k = valid_ndcg_at_k
best_valid_map_at_k = valid_map_at_k
best_epoch = epoch
endurance = 0

torch.save(self.model.state_dict(), f'{self.cfg.model_dir}/best_model.pt')
else:
endurance += 1
if endurance > self.cfg.patience:
logger.info(f"[Trainer] ealry stopping...")
break


def train(self, train_dataloader: DataLoader) -> float:
self.model.train()
train_loss = 0
for data in tqdm(train_dataloader):
user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \
data['neg_item'].to(self.device)
pos_pred,neg_pred = self.model.bpr_forward(user_id, pos_item, neg_item, self.laplacian_matrix)

self.optimizer.zero_grad()
loss = self.loss(pos_pred, neg_pred)
loss.backward()
self.optimizer.step()

train_loss += loss.item()

return train_loss

def validate(self, valid_dataloader: DataLoader) -> tuple[float]:
self.model.eval()
valid_loss = 0
actual, predicted = [], []
for data in tqdm(valid_dataloader):
user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \
data['neg_item'].to(self.device)
pos_pred,neg_pred = self.model.bpr_forward(user_id, pos_item, neg_item, self.laplacian_matrix)

loss = self.loss(pos_pred, neg_pred)

valid_loss += loss.item()

return valid_loss

def evaluate(self, eval_data: pd.DataFrame, mode='valid') -> tuple:

self.model.eval()
actual, predicted = [], []
item_input = torch.tensor([item_id for item_id in range(self.num_items)]).to(self.device)

for idx in tqdm(np.random.randint(eval_data.shape[0], size=100), total=100):
user_id = eval_data.iloc[[idx], :].index[0]
row = eval_data.iloc[idx, :]

pred = self.model(torch.tensor([user_id,]*self.num_items).to(self.device), item_input, self.laplacian_matrix)
batch_predicted = \
self._generate_top_k_recommendation(pred, row['mask_items'])
actual.append(row['pos_items'])
predicted.append(batch_predicted)

test_precision_at_k = precision_at_k(actual, predicted, self.cfg.top_n)
test_recall_at_k = recall_at_k(actual, predicted, self.cfg.top_n)
test_map_at_k = map_at_k(actual, predicted, self.cfg.top_n)
test_ndcg_at_k = ndcg_at_k(actual, predicted, self.cfg.top_n)

if mode == 'test':
logger.info(f'''\n[Trainer] Test >
precision@{self.cfg.top_n} : {test_precision_at_k:.4f} /
Recall@{self.cfg.top_n}: {test_recall_at_k:.4f} /
MAP@{self.cfg.top_n}: {test_map_at_k:.4f} /
NDCG@{self.cfg.top_n}: {test_ndcg_at_k:.4f}''')

return (test_precision_at_k,
test_recall_at_k,
test_map_at_k,
test_ndcg_at_k)

def _generate_top_k_recommendation(self, pred: Tensor, mask_items) -> tuple[list]:

# mask to train items
pred = pred.cpu().detach().numpy()
pred[mask_items] = -3.40282e+38 # finfo(float32)

# find the largest topK item indexes by user
topn_index = np.argpartition(pred, -self.cfg.top_n)[-self.cfg.top_n:]
# take probs from predictions using above indexes
topn_prob = np.take_along_axis(pred, topn_index, axis=0)
# sort topK probs and find their indexes
sorted_indices = np.argsort(-topn_prob)
# apply sorted indexes to item indexes to get sorted topK item indexes by user
topn_index_sorted = np.take_along_axis(topn_index, sorted_indices, axis=0)

return topn_index_sorted

0 comments on commit 1ba8ea1

Please sign in to comment.