-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feat/9-mlflow
- Loading branch information
Showing
13 changed files
with
501 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,59 @@ | ||
import numpy as np | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
from loguru import logger | ||
|
||
class CDAEDataset(Dataset): | ||
|
||
def __init__(self, data, mode='train'): | ||
def __init__(self, data, mode='train', neg_times: int=5): | ||
super().__init__() | ||
self.data = data | ||
self.mode = mode | ||
if self.mode != 'test': | ||
self.neg_times = neg_times | ||
|
||
def __len__(self): | ||
return len(self.data.keys()) | ||
|
||
def _negative_sampling(self, input_mask): | ||
# Calculate the number of positive samples. | ||
num_pos = int(input_mask.sum()) | ||
# Flip zeros and ones to generate candidates for negative sampling. | ||
flipped_mask = 1-input_mask | ||
# Retrieve indexes of the negative candidates. | ||
negative_indexes = flipped_mask.nonzero()[0] | ||
# Sample from negative indexes, selecting multiple times the number of positive samples. | ||
negative_samples = np.random.choice(negative_indexes, num_pos*self.neg_times, replace=False) | ||
# Create a negative mask of the same shape as input_mask | ||
negative_mask = np.zeros_like(input_mask) | ||
# Set sampled indexes to 1 in the negative mask | ||
# Only the masked indexes need to be computed for the loss | ||
negative_mask[negative_samples] = 1. | ||
return negative_mask | ||
|
||
def __getitem__(self, user_id): | ||
input_mask = self.data[user_id]['input_mask'].astype('float32') | ||
if self.mode == 'train': | ||
return { | ||
'user_id': user_id, | ||
'input_mask': self.data[user_id]['input_mask'].astype('float32'), | ||
'input_mask': input_mask, | ||
'negative_mask': self._negative_sampling(input_mask) | ||
} | ||
elif self.mode == 'valid': | ||
valid_mask = self.data[user_id]['valid_mask'].astype('float32') | ||
return { | ||
'user_id': user_id, | ||
'input_mask': self.data[user_id]['input_mask'].astype('float32'), | ||
'valid_mask': self.data[user_id]['valid_mask'].astype('float32'), | ||
'input_mask': input_mask, | ||
'valid_mask': valid_mask, | ||
'negative_mask': self._negative_sampling(input_mask + valid_mask) | ||
} | ||
else: | ||
test_mask = self.data[user_id]['test_mask'].astype('float32') | ||
return { | ||
'user_id': user_id, | ||
'input_mask': self.data[user_id]['input_mask'].astype('float32'), | ||
'test_mask': self.data[user_id]['test_mask'].astype('float32') | ||
'input_mask': input_mask, | ||
'test_mask': test_mask, | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import os | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from tqdm import tqdm | ||
from sklearn.model_selection import train_test_split | ||
from loguru import logger | ||
|
||
from .data_pipeline import DataPipeline | ||
|
||
class MFDataPipeline(DataPipeline): | ||
|
||
def __init__(self, cfg): | ||
super().__init__(cfg) | ||
self.num_items = None | ||
self.num_users = None | ||
|
||
def split(self, df): | ||
''' | ||
data: ((user_id, item_id, rating), ...) | ||
''' | ||
logger.info(f'start random user split...') | ||
train_df, valid_df, test_df = [], [], [] | ||
|
||
for _, user_df in df.groupby('user_id'): | ||
if self.cfg.loss_name == 'pointwise': | ||
user_train_df, user_test_df = train_test_split(user_df, test_size=.2, stratify=user_df['rating']) | ||
user_train_df, user_valid_df = train_test_split(user_train_df, test_size=.25, stratify=user_train_df['rating']) | ||
else: | ||
user_train_df, user_test_df = train_test_split(user_df, test_size=.2) | ||
user_train_df, user_valid_df = train_test_split(user_train_df, test_size=.25) | ||
train_df.append(user_train_df) | ||
valid_df.append(user_valid_df) | ||
test_df.append(user_test_df) | ||
|
||
train_df = pd.concat(train_df).reset_index() | ||
valid_df = pd.concat(valid_df).reset_index() | ||
test_df = pd.concat(test_df).reset_index() | ||
|
||
train_pos_df = train_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) | ||
valid_pos_df = valid_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) | ||
train_valid_pos_df = pd.concat([train_df, valid_df], axis=0).groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) | ||
test_pos_df = test_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) | ||
|
||
train_data = pd.merge(train_df, train_pos_df, left_on='user_id', right_on='user_id', how='left') | ||
valid_data = pd.merge(valid_df, train_valid_pos_df, left_on='user_id', right_on='user_id', how='left') | ||
valid_eval_data = pd.merge(valid_pos_df, train_pos_df.rename(columns={'pos_items': 'mask_items'}), left_on='user_id', right_on='user_id', how='left') | ||
test_eval_data = pd.merge(test_pos_df, train_valid_pos_df.rename(columns={'pos_items': 'mask_items'}), left_on='user_id', right_on='user_id', how='left') | ||
|
||
return train_data, valid_data, valid_eval_data, test_eval_data | ||
|
||
def preprocess(self) -> pd.DataFrame: | ||
''' | ||
output: pivot table (row: user, col: user-specific vector + item set, values: binary preference) | ||
''' | ||
logger.info("start preprocessing...") | ||
# load df | ||
df = self._load_df() | ||
# set num items and num users | ||
self._set_num_items_and_num_users(df) | ||
# negative sampling | ||
if self.cfg.loss_name == 'pointwise': | ||
df = self._negative_sampling(df, self.cfg.neg_times) | ||
logger.info("done") | ||
return df | ||
|
||
def _load_df(self): | ||
logger.info("load df...") | ||
return pd.read_csv(os.path.join(self.cfg.data_dir, 'yelp_interactions.tsv'), sep='\t', index_col=False) | ||
|
||
def _set_num_items_and_num_users(self, df): | ||
self.num_items = df.business_id.nunique() | ||
self.num_users = df.user_id.nunique() | ||
|
||
def _negative_sampling(self, df: pd.DataFrame, neg_times: 5) -> pd.DataFrame: | ||
logger.info(f"negative sampling...") | ||
logger.info(f"before neg sampling: {df.shape}") | ||
all_items = df.business_id.unique() | ||
|
||
df['rating'] = 1 | ||
neg_data = [] | ||
for _, user_df in df.groupby('user_id'): | ||
user_id = user_df.user_id.values[0] | ||
pos_items = user_df.business_id.unique() | ||
neg_items = [] | ||
while len(neg_items) < len(pos_items)*neg_times: | ||
neg_item = np.random.choice(all_items) | ||
if (neg_item in pos_items) or (neg_item in neg_items): continue | ||
neg_items.append(neg_item) | ||
neg_data.extend([[user_id, neg_item, 0] for neg_item in neg_items]) | ||
|
||
df = pd.concat([df, pd.DataFrame(neg_data, columns=df.columns)], axis=0) | ||
df = df.sample(frac=1).reset_index(drop=True) | ||
logger.info(f"after neg sampling: {df.shape}") | ||
logger.info(f"done...") | ||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import numpy as np | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
from loguru import logger | ||
|
||
class MFDataset(Dataset): | ||
|
||
def __init__(self, data, num_items=None): | ||
super().__init__() | ||
self.data = data | ||
self.num_items = num_items | ||
|
||
def __len__(self): | ||
return self.data.shape[0] | ||
|
||
def _negative_sampling(self, user_positives): | ||
neg_item = np.random.randint(self.num_items) | ||
while neg_item in user_positives: | ||
neg_item = np.random.randint(self.num_items) | ||
return neg_item | ||
|
||
def __getitem__(self, index): | ||
data = self.data.iloc[index,:] | ||
pos_item = data['business_id'].astype('int64') | ||
user_pos_items = data['pos_items'] | ||
return { | ||
'user_id': data['user_id'].astype('int64'), | ||
'pos_item': pos_item, | ||
'neg_item': self._negative_sampling(user_pos_items) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch import Tensor | ||
|
||
class NSBCELoss(nn.BCELoss): | ||
|
||
def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None: | ||
super().__init__(weight, size_average, reduce, reduction) | ||
|
||
def forward(self, input: Tensor, target: Tensor, negative_mask: Tensor) -> Tensor: | ||
# make loss masking adding negative_mask to target and find nonzero indices | ||
loss_targets = (target.add(negative_mask)).nonzero(as_tuple=True) | ||
# compute loss only for nonzero indices | ||
return nn.functional.binary_cross_entropy(input[loss_targets], target[loss_targets], weight=self.weight, reduction=self.reduction) | ||
|
||
|
||
class BPRLoss(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.logsigmoid = nn.LogSigmoid() | ||
|
||
def forward(self, positive_preds, negative_preds): | ||
difference = positive_preds - negative_preds | ||
return torch.mean(-self.logsigmoid(difference)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from models.base_model import BaseModel | ||
|
||
from loguru import logger | ||
class MatrixFactorization(BaseModel): | ||
|
||
def __init__(self, cfg, num_users, num_items): | ||
super().__init__() | ||
self.user_embedding = nn.Embedding(num_users, cfg.embed_size, dtype=torch.float32) | ||
self.item_embedding = nn.Embedding(num_items, cfg.embed_size, dtype=torch.float32) | ||
self._init_weights() | ||
|
||
def _init_weights(self): | ||
for child in self.children(): | ||
if isinstance(child, nn.Embedding): | ||
nn.init.xavier_uniform_(child.weight) | ||
|
||
def forward(self, user_id, item_id): | ||
user_emb = self.user_embedding(user_id) | ||
item_emb = self.item_embedding(item_id) | ||
return torch.sum(user_emb * item_emb, dim=1) |
Oops, something went wrong.