Skip to content

Commit

Permalink
feat: implements S3Rec preprocess method #21
Browse files Browse the repository at this point in the history
  • Loading branch information
GangBean committed Jul 15, 2024
1 parent f701e91 commit f3f13ee
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
4 changes: 3 additions & 1 deletion configs/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ weight_decay: 0 #1e-5
best_metric: loss # loss, precision, recall, map, ndcg

# model config
model_name: NGCF
model_name: S3Rec
model:
CDAE:
negative_sampling: True # False
Expand All @@ -44,3 +44,5 @@ model:
NGCF:
embed_size: 64
num_orders: 3
S3Rec:
max_seq_len: 50
40 changes: 40 additions & 0 deletions data/datasets/s3rec_data_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
from loguru import logger
import pandas as pd
from .data_pipeline import DataPipeline

class S3RecDataPipeline(DataPipeline):
def __init__(self, cfg):
super().__init__(cfg)
self.num_users = None
self.num_items = None

def split(self):
return None

def preprocess(self) -> pd.DataFrame:
'''
output: pivot table (row: user, col: user-specific vector + item set, values: binary preference)
'''
logger.info("start preprocessing...")
# load df
df = self._load_df()
# set num items and num users
self._set_num_items_and_num_users(df)

# user 별로 item sequence 뽑아야돼
# train_pos_df = train_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1)
df = df.groupby(['user_id']).agg({'business_id': [('behaviors', list)]}).droplevel(0, 1)
# logger.info(f"after groupby: {df.head()}")

logger.info("done")
return df

def _load_df(self):
logger.info("load df...")
return pd.read_csv(os.path.join(self.cfg.data_dir, 'yelp_interactions.tsv'), sep='\t', index_col=False)


def _set_num_items_and_num_users(self, df):
self.num_items = df.business_id.nunique()
self.num_users = df.user_id.nunique()
3 changes: 3 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from data.datasets.ngcf_data_pipeline import NGCFDataPipeline
from data.datasets.cdae_dataset import CDAEDataset
from data.datasets.mf_dataset import MFDataset
from data.datasets.s3rec_data_pipeline import S3RecDataPipeline
from trainers.cdae_trainer import CDAETrainer
from trainers.dcn_trainer import DCNTrainer
from trainers.mf_trainer import MFTrainer
Expand Down Expand Up @@ -122,6 +123,8 @@ def main(cfg: OmegaConf):
data_pipeline = DCNDatapipeline(cfg)
elif cfg.model_name == 'NGCF':
data_pipeline = NGCFDataPipeline(cfg)
elif cfg.model_name == 'S3Rec':
data_pipeline = S3RecDataPipeline(cfg)
else:
raise ValueError()

Expand Down

0 comments on commit f3f13ee

Please sign in to comment.