diff --git a/configs/train_config.yaml b/configs/train_config.yaml index 4509527..70cb9b7 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -14,7 +14,7 @@ notes: "..." tags: [sweep, yelp, cdae, hyper-parameter, model-structure] # train config -device: cuda # cpu +device: cpu epochs: 100 batch_size: 32 lr: 0.0001 @@ -45,4 +45,7 @@ model: embed_size: 64 num_orders: 3 S3Rec: + embed_size: 64 max_seq_len: 50 + num_heads: 2 + num_blocks: 2 diff --git a/data/datasets/s3rec_data_pipeline.py b/data/datasets/s3rec_data_pipeline.py index ecf5547..938eaca 100644 --- a/data/datasets/s3rec_data_pipeline.py +++ b/data/datasets/s3rec_data_pipeline.py @@ -74,9 +74,9 @@ def _load_df(self): def _load_attributes(self): logger.info("load item2attributes...") df = pd.read_json(os.path.join(self.cfg.data_dir, 'yelp_item2attributes.json')).transpose() - self.attributes_count = [df.categories.explode().nunique(), df.statecity.nunique()] + self.attributes_count = df.categories.explode().nunique() - return df.transpose().to_dict() + return df.drop(columns=['statecity']).transpose().to_dict() def _set_num_items_and_num_users(self, df): diff --git a/data/datasets/s3rec_dataset.py b/data/datasets/s3rec_dataset.py index cbd575a..8360b7d 100644 --- a/data/datasets/s3rec_dataset.py +++ b/data/datasets/s3rec_dataset.py @@ -32,14 +32,14 @@ def __getitem__(self, user_id): if self.train: return { 'user_id': user_id, - 'X': data['X'], + 'X': np.array(data['X'], dtype='int64'), 'pos_item': pos_item, 'neg_item': self._negative_sampling(data['behaviors'])[0] } else: return { 'user_id': user_id, - 'X': data['X'], + 'X': np.array(data['X'], dtype='int64'), 'pos_item': pos_item, 'neg_items': self._negative_sampling(data['behaviors']) } diff --git a/train.py b/train.py index ed235ae..5210259 100644 --- a/train.py +++ b/train.py @@ -22,6 +22,7 @@ from trainers.cdae_trainer import CDAETrainer from trainers.dcn_trainer import DCNTrainer from trainers.mf_trainer import MFTrainer +from trainers.s3rec_trainer import S3RecTrainer from utils import set_seed @@ -98,6 +99,12 @@ def train(cfg, args):#train_dataset, valid_dataset, test_dataset, model_info): trainer.run(train_dataloader, valid_dataloader, args.valid_eval_data) trainer.load_best_model() trainer.evaluate(args.test_eval_data, 'test') + elif cfg.model_name in ('S3Rec',): + trainer = S3RecTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'], + args.data_pipeline.item2attributes, args.data_pipeline.attributes_count) + trainer.run(train_dataloader, valid_dataloader) + trainer.load_best_model() + trainer.evaluate(test_dataloader) def unpack_model(cfg: OmegaConf) -> OmegaConf: if cfg.model_name not in cfg.model: diff --git a/trainers/s3rec_trainer.py b/trainers/s3rec_trainer.py index d9723f0..daff18e 100644 --- a/trainers/s3rec_trainer.py +++ b/trainers/s3rec_trainer.py @@ -12,20 +12,33 @@ import wandb from models.cdae import CDAE +from models.s3rec import S3Rec from utils import log_metric from .base_trainer import BaseTrainer from metric import * from loss import BPRLoss -class CDAETrainer(BaseTrainer): - def __init__(self, cfg: DictConfig, num_items: int, num_users: int) -> None: +class S3RecTrainer(BaseTrainer): + def __init__(self, cfg: DictConfig, num_items: int, num_users: int, item2attributes, attributes_count: int) -> None: super().__init__(cfg) - self.model = CDAE(self.cfg, num_items, num_users) ## + self.model = S3Rec(self.cfg, num_items, num_users, attributes_count) self.optimizer: Optimizer = self._optimizer(self.cfg.optimizer, self.model, self.cfg.lr) self.loss = self._loss() def _loss(self): return BPRLoss() + + def _is_surpass_best_metric(self, **metric) -> bool: + (valid_loss, + ) = metric['current'] + + (best_valid_loss, + ) = metric['best'] + + if self.cfg.best_metric == 'loss': + return valid_loss < best_valid_loss + else: + return False def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader): logger.info(f"[Trainer] run...")