Skip to content

Commit

Permalink
feat: implements s3rec model #21
Browse files Browse the repository at this point in the history
  • Loading branch information
GangBean committed Jul 19, 2024
1 parent a4b2721 commit 5c222dc
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 8 deletions.
5 changes: 4 additions & 1 deletion configs/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ notes: "..."
tags: [sweep, yelp, cdae, hyper-parameter, model-structure]

# train config
device: cuda # cpu
device: cpu
epochs: 100
batch_size: 32
lr: 0.0001
Expand Down Expand Up @@ -45,4 +45,7 @@ model:
embed_size: 64
num_orders: 3
S3Rec:
embed_size: 64
max_seq_len: 50
num_heads: 2
num_blocks: 2
4 changes: 2 additions & 2 deletions data/datasets/s3rec_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def _load_df(self):
def _load_attributes(self):
logger.info("load item2attributes...")
df = pd.read_json(os.path.join(self.cfg.data_dir, 'yelp_item2attributes.json')).transpose()
self.attributes_count = [df.categories.explode().nunique(), df.statecity.nunique()]
self.attributes_count = df.categories.explode().nunique()

return df.transpose().to_dict()
return df.drop(columns=['statecity']).transpose().to_dict()


def _set_num_items_and_num_users(self, df):
Expand Down
4 changes: 2 additions & 2 deletions data/datasets/s3rec_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def __getitem__(self, user_id):
if self.train:
return {
'user_id': user_id,
'X': data['X'],
'X': np.array(data['X'], dtype='int64'),
'pos_item': pos_item,
'neg_item': self._negative_sampling(data['behaviors'])[0]
}
else:
return {
'user_id': user_id,
'X': data['X'],
'X': np.array(data['X'], dtype='int64'),
'pos_item': pos_item,
'neg_items': self._negative_sampling(data['behaviors'])
}
7 changes: 7 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from trainers.cdae_trainer import CDAETrainer
from trainers.dcn_trainer import DCNTrainer
from trainers.mf_trainer import MFTrainer
from trainers.s3rec_trainer import S3RecTrainer
from utils import set_seed


Expand Down Expand Up @@ -98,6 +99,12 @@ def train(cfg, args):#train_dataset, valid_dataset, test_dataset, model_info):
trainer.run(train_dataloader, valid_dataloader, args.valid_eval_data)
trainer.load_best_model()
trainer.evaluate(args.test_eval_data, 'test')
elif cfg.model_name in ('S3Rec',):
trainer = S3RecTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'],
args.data_pipeline.item2attributes, args.data_pipeline.attributes_count)
trainer.run(train_dataloader, valid_dataloader)
trainer.load_best_model()
trainer.evaluate(test_dataloader)

def unpack_model(cfg: OmegaConf) -> OmegaConf:
if cfg.model_name not in cfg.model:
Expand Down
19 changes: 16 additions & 3 deletions trainers/s3rec_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,33 @@
import wandb

from models.cdae import CDAE
from models.s3rec import S3Rec
from utils import log_metric
from .base_trainer import BaseTrainer
from metric import *
from loss import BPRLoss

class CDAETrainer(BaseTrainer):
def __init__(self, cfg: DictConfig, num_items: int, num_users: int) -> None:
class S3RecTrainer(BaseTrainer):
def __init__(self, cfg: DictConfig, num_items: int, num_users: int, item2attributes, attributes_count: int) -> None:
super().__init__(cfg)
self.model = CDAE(self.cfg, num_items, num_users) ##
self.model = S3Rec(self.cfg, num_items, num_users, attributes_count)
self.optimizer: Optimizer = self._optimizer(self.cfg.optimizer, self.model, self.cfg.lr)
self.loss = self._loss()

def _loss(self):
return BPRLoss()

def _is_surpass_best_metric(self, **metric) -> bool:
(valid_loss,
) = metric['current']

(best_valid_loss,
) = metric['best']

if self.cfg.best_metric == 'loss':
return valid_loss < best_valid_loss
else:
return False

def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader):
logger.info(f"[Trainer] run...")
Expand Down

0 comments on commit 5c222dc

Please sign in to comment.