diff --git a/configs/train_config.yaml b/configs/train_config.yaml index b5c3dfb..3c14ea8 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -19,7 +19,7 @@ epochs: 5 batch_size: 32 lr: 0.0001 optimizer: adam # adamw -loss_name: bce # bpr # pointwise # bce +loss_name: bpr # bpr # pointwise # bce patience: 5 top_n: 10 weight_decay: 0 #1e-5 @@ -43,7 +43,7 @@ model: embed_size: 64 NGCF: embed_size: 64 - num_orders: 3 + num_orders: 2 S3Rec: embed_size: 64 max_seq_len: 50 diff --git a/data/datasets/ngcf_data_pipeline.py b/data/datasets/ngcf_data_pipeline.py index 622f1a7..25e6fec 100644 --- a/data/datasets/ngcf_data_pipeline.py +++ b/data/datasets/ngcf_data_pipeline.py @@ -21,25 +21,26 @@ def _set_laplacian_matrix(self, df): # transform df to user-item interaction (R) logger.info('transform df to user-item interaction') user_item_interaction = df.pivot_table(index='user_id', columns=['business_id'], values=['rating']) - user_item_interaction = user_item_interaction.droplevel(0, 1) + user_item_interaction = user_item_interaction.droplevel(0, 1).fillna(0) # adjacency matrix logger.info('create adjacency matrix') - adjacency_matrix = np.zeros((self.num_items+self.num_users, self.num_items+self.num_users)) + adjacency_matrix = np.zeros((self.num_items+self.num_users, self.num_items+self.num_users), dtype=np.float32) adjacency_matrix[:self.num_users,self.num_users:] = user_item_interaction adjacency_matrix[self.num_users:,:self.num_users] = user_item_interaction.T # diagonal degree matrix (n+m) x (m+n) logger.info('create diagonal degree matrix') - diagonal_degree_matrix = np.diag(1/np.sqrt(adjacency_matrix.sum(axis=0))) + diagonal_degree_matrix = np.diag(1/np.sqrt(adjacency_matrix.sum(axis=0))).astype(np.float32) # set laplacian matrix logger.info('set laplacian matrix') - diagonal_degree_matrix = torch.tensor(diagonal_degree_matrix).float().to('cuda') - adjacency_matrix = torch.tensor(adjacency_matrix).float().to('cuda') - self.laplacian_matrix = torch.matmul(diagonal_degree_matrix, adjacency_matrix) - adjacency_matrix = adjacency_matrix.cpu().detach() - self.laplacian_matrix = torch.matmul(self.laplacian_matrix, diagonal_degree_matrix) + diagonal_degree_matrix = torch.from_numpy(diagonal_degree_matrix).to_sparse().to('cuda') + adjacency_matrix = torch.from_numpy(adjacency_matrix).to_sparse().to('cuda') + self.laplacian_matrix = torch.sparse.mm(diagonal_degree_matrix, adjacency_matrix) + del adjacency_matrix + self.laplacian_matrix = torch.sparse.mm(self.laplacian_matrix, diagonal_degree_matrix) + del diagonal_degree_matrix logger.info('done...') def preprocess(self) -> pd.DataFrame: diff --git a/data/datasets/ngcf_dataset.py b/data/datasets/ngcf_dataset.py new file mode 100644 index 0000000..ae6c449 --- /dev/null +++ b/data/datasets/ngcf_dataset.py @@ -0,0 +1,6 @@ +from .mf_dataset import MFDataset + +class NGCFDataset: + pass + +NGCFDataset = MFDataset diff --git a/models/ngcf.py b/models/ngcf.py index 7e9a6f8..818de24 100644 --- a/models/ngcf.py +++ b/models/ngcf.py @@ -2,13 +2,16 @@ import torch.nn as nn from models.base_model import BaseModel +from loguru import logger class NGCF(BaseModel): - def __init__(self, cfg, num_users, num_items, laplacian_matrix): + def __init__(self, cfg, num_users, num_items): #, laplacian_matrix): super().__init__() + self.cfg = cfg self.num_users = num_users - self.laplacian_matrix = laplacian_matrix + self.num_items = num_items + # self.laplacian_matrix = laplacian_matrix self.embedding = nn.Embedding( num_users+num_items, cfg.embed_size, dtype=torch.float32) @@ -22,29 +25,48 @@ def __init__(self, cfg, num_users, num_items, laplacian_matrix): def _init_weights(self): for child in self.children(): if isinstance(child, nn.Embedding): - nn.init.xavier_uniform_(child.weight)k + nn.init.xavier_uniform_(child.weight) - def forward(self, user_id, item_id): + def bpr_forward(self, user_id, pos_item_ids, neg_item_ids, laplacian_matrix): + user_embed_list, pos_item_embed_list, neg_item_embed_list = \ + [self.embedding(user_id),], [self.embedding(self.num_users+pos_item_ids)], [self.embedding(self.num_users+neg_item_ids)] + last_embed = self.embedding.weight + + for w1, w2 in zip(self.W1, self.W2): + last_embed: torch.Tensor = self.embedding_propagation(last_embed, w1, w2, laplacian_matrix) + user_embed_list.append(last_embed[user_id]) + pos_item_embed_list.append(last_embed[self.num_users + pos_item_ids]) + neg_item_embed_list.append(last_embed[self.num_users + neg_item_ids]) + + user_embed = torch.concat(user_embed_list, dim=1) + pos_item_embed = torch.concat(pos_item_embed_list, dim=1) + neg_item_embed = torch.concat(neg_item_embed_list, dim=1) + + return torch.sum(user_embed * pos_item_embed, dim=1), torch.sum(user_embed * neg_item_embed, dim=1) + + def forward(self, user_id, item_id, laplacian_matrix): user_embed_list, item_embed_list = [self.embedding(user_id),], [self.embedding(self.num_users+item_id)] - last_embed = self.embedding + last_embed = self.embedding.weight for w1, w2 in zip(self.W1, self.W2): - last_embed = embedding_propagation(last_embed, w1, w2) - user_embed_list.append(last_embed(user_id)) - item_embed_list.append(last_embed(self.num_users+item_id)) + last_embed: torch.Tensor = self.embedding_propagation(last_embed, w1, w2, laplacian_matrix) + user_embed_list.append(last_embed[user_id]) + item_embed_list.append(last_embed[self.num_users + item_id]) user_embed = torch.concat(user_embed_list, dim=1) item_embed = torch.concat(item_embed_list, dim=1) - return torch.sum(user_emb * item_emb, dim=1) + return torch.sum(user_embed * item_embed, dim=1) + + def embedding_propagation(self, last_embed: torch.Tensor, w1, w2, laplacian_matrix): + identity_matrix = torch.eye(*laplacian_matrix.size(), dtype=torch.float32).to_sparse().to(self.cfg.device) + matrix = laplacian_matrix + identity_matrix - def embedding_propagation(self, last_embed, w1, w2): - identity_matrix = torch.eye(*self.laplacian_matrix.size()) - term1 = torch.matmul(self.laplacian_matrix + identity_matrix, last_embed) + term1 = torch.sparse.mm(matrix, last_embed) term1 = w1(term1) - neighbor_embeddings = torch.matmul(self.laplacian_matrix, last_embed) - term2 = torch.mul(neighbor_embeddings, last_embed) + neighbor_embeddings = torch.sparse.mm(laplacian_matrix, last_embed) + + term2 = torch.mul(last_embed, neighbor_embeddings) term2 = w2(term2) return nn.functional.leaky_relu(term1 + term2) - diff --git a/train.py b/train.py index 5f23ff9..550c29c 100644 --- a/train.py +++ b/train.py @@ -18,10 +18,12 @@ from data.datasets.cdae_dataset import CDAEDataset from data.datasets.mf_dataset import MFDataset from data.datasets.dcn_dataset import DCNDataset +from data.datasets.ngcf_dataset import NGCFDataset from data.datasets.s3rec_dataset import S3RecDataset from trainers.cdae_trainer import CDAETrainer from trainers.dcn_trainer import DCNTrainer from trainers.mf_trainer import MFTrainer +from trainers.ngcf_trainer import NGCFTrainer from trainers.s3rec_trainer import S3RecTrainer, S3RecPreTrainer from utils import set_seed @@ -94,7 +96,7 @@ def train(cfg, args):#train_dataset, valid_dataset, test_dataset, model_info): trainer.load_best_model() trainer.evaluate(args.test_eval_data, 'test') elif cfg.model_name in ('NGCF', ): - trainer = MGCFTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'], + trainer = NGCFTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'], args.data_pipeline.laplacian_matrix) trainer.run(train_dataloader, valid_dataloader, args.valid_eval_data) trainer.load_best_model() @@ -169,8 +171,8 @@ def main(cfg: OmegaConf): model_info['num_items'], model_info['num_users'] = data_pipeline.num_items, data_pipeline.num_users elif cfg.model_name == 'NGCF': train_data, valid_data, valid_eval_data, test_eval_data = data_pipeline.split(df) - train_dataset = MFDataset(train_data, num_items=data_pipeline.num_items) - valid_dataset = MFDataset(valid_data, num_items=data_pipeline.num_items) + train_dataset = NGCFDataset(train_data, num_items=data_pipeline.num_items) + valid_dataset = NGCFDataset(valid_data, num_items=data_pipeline.num_items) args.update({'valid_eval_data': valid_eval_data, 'test_eval_data': test_eval_data}) model_info['num_items'], model_info['num_users'] = data_pipeline.num_items, data_pipeline.num_users elif cfg.model_name == 'S3Rec': diff --git a/trainers/ngcf_trainer.py b/trainers/ngcf_trainer.py new file mode 100644 index 0000000..ecaa690 --- /dev/null +++ b/trainers/ngcf_trainer.py @@ -0,0 +1,182 @@ +import wandb + +import numpy as np +import pandas as pd +from tqdm import tqdm + +import torch +import torch.nn as nn +from torch import Tensor +from torch.utils.data import DataLoader +from torch.optim import Optimizer + +from loguru import logger +from omegaconf.dictconfig import DictConfig +import wandb + +from models.ngcf import NGCF +from .base_trainer import BaseTrainer +from metric import * +from loss import BPRLoss + +class NGCFTrainer(BaseTrainer): + def __init__(self, cfg: DictConfig, num_items: int, num_users: int, laplacian_matrix: torch.Tensor) -> None: + super().__init__(cfg) + logger.info(f'[DEVICE] device = {self.device}') + self.num_items = num_items + self.num_users = num_users + self.model = NGCF(self.cfg, num_users, num_items).to(self.device) + self.optimizer: Optimizer = self._optimizer(self.cfg.optimizer, self.model, self.cfg.lr, self.cfg.weight_decay) + self.loss = self._loss() + self.laplacian_matrix = laplacian_matrix + + def _loss(self): + return BPRLoss() + + def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, valid_eval_data: pd.DataFrame): + logger.info(f"[Trainer] run...") + + best_valid_loss: float = 1e+6 + best_valid_precision_at_k: float = .0 + best_valid_recall_at_k: float = .0 + best_valid_map_at_k: float = .0 + best_valid_ndcg_at_k: float = .0 + best_epoch: int = 0 + endurance: int = 0 + + # train + for epoch in range(self.cfg.epochs): + train_loss: float = self.train(train_dataloader) + valid_loss: float = self.validate(valid_dataloader) + (valid_precision_at_k, + valid_recall_at_k, + valid_map_at_k, + valid_ndcg_at_k) = self.evaluate(valid_eval_data, 'valid') + logger.info(f'''\n[Trainer] epoch: {epoch} > train loss: {train_loss:.4f} / + valid loss: {valid_loss:.4f} / + precision@K : {valid_precision_at_k:.4f} / + Recall@K: {valid_recall_at_k:.4f} / + MAP@K: {valid_map_at_k:.4f} / + NDCG@K: {valid_ndcg_at_k:.4f}''') + + # wandb logging + if self.cfg.wandb: + wandb.log({ + 'train_loss': train_loss, + 'valid_loss': valid_loss, + 'valid_Precision@K': valid_precision_at_k, + 'valid_Recall@K': valid_recall_at_k, + 'valid_MAP@K': valid_map_at_k, + 'valid_NDCG@K': valid_ndcg_at_k, + }) + + # update model + if self._is_surpass_best_metric( + current=(valid_loss, + valid_precision_at_k, + valid_recall_at_k, + valid_map_at_k, + valid_ndcg_at_k), + best=(best_valid_loss, + best_valid_precision_at_k, + best_valid_recall_at_k, + best_valid_map_at_k, + best_valid_ndcg_at_k)): + logger.info(f"[Trainer] update best model...") + best_valid_loss = valid_loss + best_valid_precision_at_k = valid_precision_at_k + best_valid_recall_at_k = valid_recall_at_k + best_valid_ndcg_at_k = valid_ndcg_at_k + best_valid_map_at_k = valid_map_at_k + best_epoch = epoch + endurance = 0 + + torch.save(self.model.state_dict(), f'{self.cfg.model_dir}/best_model.pt') + else: + endurance += 1 + if endurance > self.cfg.patience: + logger.info(f"[Trainer] ealry stopping...") + break + + + def train(self, train_dataloader: DataLoader) -> float: + self.model.train() + train_loss = 0 + for data in tqdm(train_dataloader): + user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \ + data['neg_item'].to(self.device) + pos_pred,neg_pred = self.model.bpr_forward(user_id, pos_item, neg_item, self.laplacian_matrix) + + self.optimizer.zero_grad() + loss = self.loss(pos_pred, neg_pred) + loss.backward() + self.optimizer.step() + + train_loss += loss.item() + + return train_loss + + def validate(self, valid_dataloader: DataLoader) -> tuple[float]: + self.model.eval() + valid_loss = 0 + actual, predicted = [], [] + for data in tqdm(valid_dataloader): + user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \ + data['neg_item'].to(self.device) + pos_pred,neg_pred = self.model.bpr_forward(user_id, pos_item, neg_item, self.laplacian_matrix) + + loss = self.loss(pos_pred, neg_pred) + + valid_loss += loss.item() + + return valid_loss + + def evaluate(self, eval_data: pd.DataFrame, mode='valid') -> tuple: + + self.model.eval() + actual, predicted = [], [] + item_input = torch.tensor([item_id for item_id in range(self.num_items)]).to(self.device) + + for idx in tqdm(np.random.randint(eval_data.shape[0], size=100), total=100): + user_id = eval_data.iloc[[idx], :].index[0] + row = eval_data.iloc[idx, :] + + pred = self.model(torch.tensor([user_id,]*self.num_items).to(self.device), item_input, self.laplacian_matrix) + batch_predicted = \ + self._generate_top_k_recommendation(pred, row['mask_items']) + actual.append(row['pos_items']) + predicted.append(batch_predicted) + + test_precision_at_k = precision_at_k(actual, predicted, self.cfg.top_n) + test_recall_at_k = recall_at_k(actual, predicted, self.cfg.top_n) + test_map_at_k = map_at_k(actual, predicted, self.cfg.top_n) + test_ndcg_at_k = ndcg_at_k(actual, predicted, self.cfg.top_n) + + if mode == 'test': + logger.info(f'''\n[Trainer] Test > + precision@{self.cfg.top_n} : {test_precision_at_k:.4f} / + Recall@{self.cfg.top_n}: {test_recall_at_k:.4f} / + MAP@{self.cfg.top_n}: {test_map_at_k:.4f} / + NDCG@{self.cfg.top_n}: {test_ndcg_at_k:.4f}''') + + return (test_precision_at_k, + test_recall_at_k, + test_map_at_k, + test_ndcg_at_k) + + def _generate_top_k_recommendation(self, pred: Tensor, mask_items) -> tuple[list]: + + # mask to train items + pred = pred.cpu().detach().numpy() + pred[mask_items] = -3.40282e+38 # finfo(float32) + + # find the largest topK item indexes by user + topn_index = np.argpartition(pred, -self.cfg.top_n)[-self.cfg.top_n:] + # take probs from predictions using above indexes + topn_prob = np.take_along_axis(pred, topn_index, axis=0) + # sort topK probs and find their indexes + sorted_indices = np.argsort(-topn_prob) + # apply sorted indexes to item indexes to get sorted topK item indexes by user + topn_index_sorted = np.take_along_axis(topn_index, sorted_indices, axis=0) + + return topn_index_sorted