diff --git a/configs/train_config.yaml b/configs/train_config.yaml index 70cb9b7..e0ad81d 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -14,8 +14,8 @@ notes: "..." tags: [sweep, yelp, cdae, hyper-parameter, model-structure] # train config -device: cpu -epochs: 100 +device: cuda # cpu +epochs: 5 batch_size: 32 lr: 0.0001 optimizer: adam # adamw @@ -49,3 +49,7 @@ model: max_seq_len: 50 num_heads: 2 num_blocks: 2 + pretrain: True # False + pretrain_epochs: 1 # 100 + mask_portion: 0.2 + iter_nums: 200 \ No newline at end of file diff --git a/models/s3rec.py b/models/s3rec.py index 45f1b0a..fadf918 100644 --- a/models/s3rec.py +++ b/models/s3rec.py @@ -11,7 +11,7 @@ def __init__(self, cfg, num_users, num_items, attributes_count): self.cfg = cfg # self.user_embedding = nn.Embedding(num_users, cfg.embed_size, dtype=torch.float32) self.item_embedding = nn.Embedding(num_items + 1, self.cfg.embed_size, dtype=torch.float32) - # self.attribute_embedding = nn.Embedding(attributes_count, self.cfg.embed_size, dtype=torch.float32) + self.attribute_embedding = nn.Embedding(attributes_count, self.cfg.embed_size, dtype=torch.float32) self.positional_encoding = nn.Parameter(torch.rand(self.cfg.max_seq_len, self.cfg.embed_size)) # self.query = nn.ModuleList([nn.Linear(self.cfg.embed_size / self.num_heads) for _ in range(self.cfg.num_heads)]) @@ -20,8 +20,11 @@ def __init__(self, cfg, num_users, num_items, attributes_count): self.ffn1s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)]) self.ffn2s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)]) self.multihead_attns = nn.ModuleList([nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads) for _ in range(self.cfg.num_blocks)]) + self.aap_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False) + self._init_weights() + def _init_weights(self): for child in self.children(): if isinstance(child, nn.Embedding): @@ -30,6 +33,8 @@ def _init_weights(self): for sub_child in child.children(): if not isinstance(sub_child, nn.MultiheadAttention): nn.init.xavier_uniform_(sub_child.weight) + elif isinstance(child, nn.Linear): + nn.init.xavier_uniform_(child.weight) else: logger.info(f"other type: {child} / {type(child)}") @@ -60,3 +65,8 @@ def evaluate(self, X, pos_item, neg_items): self.item_embedding(neg_items[:,i]), X[:, -1]).view(neg_items.size(0), -1) for i in range(neg_items.size(-1))] neg_preds = torch.concat(neg_preds, dim=1) return pos_pred, neg_preds + + def aap(self, items): + # item + item_embeddings = self.item_embedding(items) + return torch.matmul(self.aap_weight(item_embeddings), self.attribute_embedding.weight.T) # (batch, embed_size) * (attribute_size, embed_size) (batch, attribute_size) diff --git a/train.py b/train.py index 6c4c264..e1c101b 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,7 @@ from trainers.cdae_trainer import CDAETrainer from trainers.dcn_trainer import DCNTrainer from trainers.mf_trainer import MFTrainer -from trainers.s3rec_trainer import S3RecTrainer +from trainers.s3rec_trainer import S3RecTrainer, S3RecPreTrainer from utils import set_seed @@ -100,11 +100,17 @@ def train(cfg, args):#train_dataset, valid_dataset, test_dataset, model_info): trainer.load_best_model() trainer.evaluate(args.test_eval_data, 'test') elif cfg.model_name in ('S3Rec',): - trainer = S3RecTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'], - args.data_pipeline.item2attributes, args.data_pipeline.attributes_count) - trainer.run(train_dataloader, valid_dataloader) - trainer.load_best_model() - trainer.evaluate(test_dataloader) + if cfg.pretrain: + trainer = S3RecPreTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'], + args.data_pipeline.item2attributes, args.data_pipeline.attributes_count) + trainer.pretrain(args.train_dataset, args.valid_dataset) + trainer.load_best_model() + else: + trainer = S3RecTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'], + args.data_pipeline.item2attributes, args.data_pipeline.attributes_count) + trainer.run(train_dataloader, valid_dataloader) + trainer.load_best_model() + trainer.evaluate(test_dataloader) def unpack_model(cfg: OmegaConf) -> OmegaConf: if cfg.model_name not in cfg.model: diff --git a/trainers/s3rec_trainer.py b/trainers/s3rec_trainer.py index ec4bf97..dd1d011 100644 --- a/trainers/s3rec_trainer.py +++ b/trainers/s3rec_trainer.py @@ -11,13 +11,181 @@ from omegaconf.dictconfig import DictConfig import wandb -from models.cdae import CDAE from models.s3rec import S3Rec from utils import log_metric from .base_trainer import BaseTrainer from metric import * from loss import BPRLoss +class S3RecPreTrainer: + def __init__(self, cfg: DictConfig, num_items: int, num_users: int, item2attributes, attributes_count: int) -> None: + self.cfg = cfg + self.device = self.cfg.device + self.model = S3Rec(self.cfg, num_items, num_users, attributes_count).to(self.device) + self.optimizer: Optimizer = self._optimizer(self.cfg.optimizer, self.model, self.cfg.lr) + self.loss = self._loss() + self.item2attribute = item2attributes + self.num_items = num_items + self.num_users = num_users + self.attributes_count = attributes_count + + def _loss(self): + # AAP + MIP + MAP + SP + return nn.BCEWithLogitsLoss() + + def _optimizer(self, optimizer_name: str, model: nn.Module, learning_rate: float, weight_decay: float=0) -> Optimizer: + if optimizer_name.lower() == 'adam': + return torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + elif optimizer_name.lower() == 'adamw': + return torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + elif optimizer_name.lower() == 'sgd': + return torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + else: + logger.error(f"Optimizer Not Exists: {optimizer_name}") + raise NotImplementedError(f"Optimizer Not Exists: {optimizer_name}") + + def _is_surpass_best_metric(self, **metric) -> bool: + (valid_loss, + ) = metric['current'] + + (best_valid_loss, + ) = metric['best'] + + if self.cfg.best_metric == 'loss': + return valid_loss < best_valid_loss + else: + return False + + def pretrain(self, train_dataset, valid_dataset): + logger.info(f"[Trainer] run...") + + best_valid_loss: float = 1e+6 + best_epoch: int = 0 + endurance: int = 0 + + # train + for epoch in range(self.cfg.pretrain_epochs): + train_loss: float = self.train(torch.tensor([i for i in range(1, self.num_items+1)], dtype=torch.int32).to(self.device), train_dataset) + valid_loss = self.validate(torch.tensor([i for i in range(1, self.num_items+1)], dtype=torch.int32).to(self.device), valid_dataset) + logger.info(f'''\n[Trainer] epoch: {epoch} > train loss: {train_loss:.4f} / + valid loss: {valid_loss:.4f}''') + + if self.cfg.wandb: + wandb.log({ + 'train_loss': train_loss, + 'valid_loss': valid_loss, + }) + + # update model + if self._is_surpass_best_metric( + current=(valid_loss,), + best=(best_valid_loss,)): + + logger.info(f"[Trainer] update best model...") + best_valid_loss = valid_loss + best_epoch = epoch + endurance = 0 + + torch.save(self.model.state_dict(), f'{self.cfg.model_dir}/best_pretrain_model.pt') + else: + endurance += 1 + if endurance > self.cfg.patience: + logger.info(f"[Trainer] ealry stopping...") + break + + def train(self, item_datasets, sequence_datasets) -> float: + self.model.train() + train_loss = 0 + + for iter_num in tqdm(range(self.cfg.iter_nums)): # sequence + item_chunk_size = self.num_items // self.cfg.iter_nums + items = item_datasets[item_chunk_size * iter_num: item_chunk_size * (iter_num + 1)] + + sequence_chunk_size = self.num_users // self.cfg.iter_nums + # sequences = sequence_datasets[sequence_chunk_size * iter_num: sequence_chunk_size * (iter_num + 1)] + + # AAP: item + atrributes + pred = self.model.aap(items) # (item_chunk_size, attributes_count) + actual = torch.Tensor([[1 if attriute in self.item2attribute[item.item()] else 0 for attriute in range(self.attributes_count)] for item in items]).to(self.device) # (item_chunk_size, attributes_count) + aap_loss = nn.functional.binary_cross_entropy_with_logits(pred, actual) + + # MIP: sequence + item + # mask + # def random_mask(sequence): + # # mask = torch.Tensor([0] * sequence.size(0)) + # non_zero_count = torch.nonzero(sequence, as_tuple=True)[0].size(0) + # mask_indices = torch.randint(sequence.size(0) - non_zero_count, sequence.size(0), size=1) + # # mask[mask_indices] = 1 + # return mask_indices + + # masks = torch.Tensor([random_mask(sequence) for sequence in sequences]) # () + # masked_sequences = sequences * (1 - masks) + # pred = self.model.mip(masked_sequences, ) # (sequence_chunk_size, mask_count, sequence_len) item idx pred + # nn.functional.binary_cross_entropy + # # MAP: sequence + attributes + # map_loss = self.loss() + # # SP: sequence + segment + # sp_loss = self.loss() + # # X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device) + # # pos_pred, neg_pred = self.model(X, pos_item, neg_item) + + self.optimizer.zero_grad() + # loss = self.loss(pos_pred, neg_pred) + loss = aap_loss # + mip_loss + map_loss + sp_loss + loss.backward() + self.optimizer.step() + + train_loss += loss.item() + + return train_loss + + def validate(self, item_datasets, sequence_datasets) -> float: + self.model.eval() + valid_loss = 0 + + for iter_num in tqdm(range(self.cfg.iter_nums)): # sequence + item_chunk_size = self.num_items // self.cfg.iter_nums + items = item_datasets[item_chunk_size * iter_num: item_chunk_size * (iter_num + 1)] + + sequence_chunk_size = self.num_users // self.cfg.iter_nums + # sequences = sequence_datasets[sequence_chunk_size * iter_num: sequence_chunk_size * (iter_num + 1)] + + # AAP: item + atrributes + pred = self.model.aap(items) # (item_chunk_size, attributes_count) + actual = torch.Tensor([[1 if attriute in self.item2attribute[item.item()] else 0 for attriute in range(self.attributes_count)] for item in items]).to(self.device) # (item_chunk_size, attributes_count) + aap_loss = nn.functional.binary_cross_entropy_with_logits(pred, actual) + + # MIP: sequence + item + # mask + # def random_mask(sequence): + # # mask = torch.Tensor([0] * sequence.size(0)) + # non_zero_count = torch.nonzero(sequence, as_tuple=True)[0].size(0) + # mask_indices = torch.randint(sequence.size(0) - non_zero_count, sequence.size(0), size=1) + # # mask[mask_indices] = 1 + # return mask_indices + + # masks = torch.Tensor([random_mask(sequence) for sequence in sequences]) # () + # masked_sequences = sequences * (1 - masks) + # pred = self.model.mip(masked_sequences, ) # (sequence_chunk_size, sequence_len) item idx pred + # nn.functional.binary_cross_entropy + # # MAP: sequence + attributes + # map_loss = self.loss() + # # SP: sequence + segment + # sp_loss = self.loss() + # # X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device) + # # pos_pred, neg_pred = self.model(X, pos_item, neg_item) + + # loss = self.loss(pos_pred, neg_pred) + loss = aap_loss # + mip_loss + map_loss + sp_loss + + valid_loss += loss.item() + + return valid_loss + + def load_best_model(self): + logger.info(f"[Trainer] Load best model...") + self.model.load_state_dict(torch.load(f'{self.cfg.model_dir}/best_pretrain_model.pt')) + class S3RecTrainer(BaseTrainer): def __init__(self, cfg: DictConfig, num_items: int, num_users: int, item2attributes, attributes_count: int) -> None: super().__init__(cfg)