-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from twndus/feat/12-mf-bpr
Feat/12 mf-bpr
- Loading branch information
Showing
10 changed files
with
399 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import os | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from tqdm import tqdm | ||
from sklearn.model_selection import train_test_split | ||
from loguru import logger | ||
|
||
from .data_pipeline import DataPipeline | ||
|
||
class MFDataPipeline(DataPipeline): | ||
|
||
def __init__(self, cfg): | ||
super().__init__(cfg) | ||
self.num_items = None | ||
self.num_users = None | ||
|
||
def split(self, df): | ||
''' | ||
data: ((user_id, item_id, rating), ...) | ||
''' | ||
logger.info(f'start random user split...') | ||
train_df, valid_df, test_df = [], [], [] | ||
|
||
for _, user_df in df.groupby('user_id'): | ||
if self.cfg.loss_name == 'pointwise': | ||
user_train_df, user_test_df = train_test_split(user_df, test_size=.2, stratify=user_df['rating']) | ||
user_train_df, user_valid_df = train_test_split(user_train_df, test_size=.25, stratify=user_train_df['rating']) | ||
else: | ||
user_train_df, user_test_df = train_test_split(user_df, test_size=.2) | ||
user_train_df, user_valid_df = train_test_split(user_train_df, test_size=.25) | ||
train_df.append(user_train_df) | ||
valid_df.append(user_valid_df) | ||
test_df.append(user_test_df) | ||
|
||
train_df = pd.concat(train_df).reset_index() | ||
valid_df = pd.concat(valid_df).reset_index() | ||
test_df = pd.concat(test_df).reset_index() | ||
|
||
train_pos_df = train_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) | ||
valid_pos_df = valid_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) | ||
train_valid_pos_df = pd.concat([train_df, valid_df], axis=0).groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) | ||
test_pos_df = test_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) | ||
|
||
train_data = pd.merge(train_df, train_pos_df, left_on='user_id', right_on='user_id', how='left') | ||
valid_data = pd.merge(valid_df, train_valid_pos_df, left_on='user_id', right_on='user_id', how='left') | ||
valid_eval_data = pd.merge(valid_pos_df, train_pos_df.rename(columns={'pos_items': 'mask_items'}), left_on='user_id', right_on='user_id', how='left') | ||
test_eval_data = pd.merge(test_pos_df, train_valid_pos_df.rename(columns={'pos_items': 'mask_items'}), left_on='user_id', right_on='user_id', how='left') | ||
|
||
return train_data, valid_data, valid_eval_data, test_eval_data | ||
|
||
def preprocess(self) -> pd.DataFrame: | ||
''' | ||
output: pivot table (row: user, col: user-specific vector + item set, values: binary preference) | ||
''' | ||
logger.info("start preprocessing...") | ||
# load df | ||
df = self._load_df() | ||
# set num items and num users | ||
self._set_num_items_and_num_users(df) | ||
# negative sampling | ||
if self.cfg.loss_name == 'pointwise': | ||
df = self._negative_sampling(df, self.cfg.neg_times) | ||
logger.info("done") | ||
return df | ||
|
||
def _load_df(self): | ||
logger.info("load df...") | ||
return pd.read_csv(os.path.join(self.cfg.data_dir, 'yelp_interactions.tsv'), sep='\t', index_col=False) | ||
|
||
def _set_num_items_and_num_users(self, df): | ||
self.num_items = df.business_id.nunique() | ||
self.num_users = df.user_id.nunique() | ||
|
||
def _negative_sampling(self, df: pd.DataFrame, neg_times: 5) -> pd.DataFrame: | ||
logger.info(f"negative sampling...") | ||
logger.info(f"before neg sampling: {df.shape}") | ||
all_items = df.business_id.unique() | ||
|
||
df['rating'] = 1 | ||
neg_data = [] | ||
for _, user_df in df.groupby('user_id'): | ||
user_id = user_df.user_id.values[0] | ||
pos_items = user_df.business_id.unique() | ||
neg_items = [] | ||
while len(neg_items) < len(pos_items)*neg_times: | ||
neg_item = np.random.choice(all_items) | ||
if (neg_item in pos_items) or (neg_item in neg_items): continue | ||
neg_items.append(neg_item) | ||
neg_data.extend([[user_id, neg_item, 0] for neg_item in neg_items]) | ||
|
||
df = pd.concat([df, pd.DataFrame(neg_data, columns=df.columns)], axis=0) | ||
df = df.sample(frac=1).reset_index(drop=True) | ||
logger.info(f"after neg sampling: {df.shape}") | ||
logger.info(f"done...") | ||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import numpy as np | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
from loguru import logger | ||
|
||
class MFDataset(Dataset): | ||
|
||
def __init__(self, data, num_items=None): | ||
super().__init__() | ||
self.data = data | ||
self.num_items = num_items | ||
|
||
def __len__(self): | ||
return self.data.shape[0] | ||
|
||
def _negative_sampling(self, user_positives): | ||
neg_item = np.random.randint(self.num_items) | ||
while neg_item in user_positives: | ||
neg_item = np.random.randint(self.num_items) | ||
return neg_item | ||
|
||
def __getitem__(self, index): | ||
data = self.data.iloc[index,:] | ||
pos_item = data['business_id'].astype('int64') | ||
user_pos_items = data['pos_items'] | ||
return { | ||
'user_id': data['user_id'].astype('int64'), | ||
'pos_item': pos_item, | ||
'neg_item': self._negative_sampling(user_pos_items) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from models.base_model import BaseModel | ||
|
||
from loguru import logger | ||
class MatrixFactorization(BaseModel): | ||
|
||
def __init__(self, cfg, num_users, num_items): | ||
super().__init__() | ||
self.user_embedding = nn.Embedding(num_users, cfg.embed_size, dtype=torch.float32) | ||
self.item_embedding = nn.Embedding(num_items, cfg.embed_size, dtype=torch.float32) | ||
self._init_weights() | ||
|
||
def _init_weights(self): | ||
for child in self.children(): | ||
if isinstance(child, nn.Embedding): | ||
nn.init.xavier_uniform_(child.weight) | ||
|
||
def forward(self, user_id, item_id): | ||
user_emb = self.user_embedding(user_id) | ||
item_emb = self.item_embedding(item_id) | ||
return torch.sum(user_emb * item_emb, dim=1) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.