diff --git a/models/s3rec.py b/models/s3rec.py index fadf918..88056ea 100644 --- a/models/s3rec.py +++ b/models/s3rec.py @@ -9,22 +9,20 @@ class S3Rec(BaseModel): def __init__(self, cfg, num_users, num_items, attributes_count): super().__init__() self.cfg = cfg - # self.user_embedding = nn.Embedding(num_users, cfg.embed_size, dtype=torch.float32) self.item_embedding = nn.Embedding(num_items + 1, self.cfg.embed_size, dtype=torch.float32) self.attribute_embedding = nn.Embedding(attributes_count, self.cfg.embed_size, dtype=torch.float32) self.positional_encoding = nn.Parameter(torch.rand(self.cfg.max_seq_len, self.cfg.embed_size)) - # self.query = nn.ModuleList([nn.Linear(self.cfg.embed_size / self.num_heads) for _ in range(self.cfg.num_heads)]) - # self.key = nn.ModuleList([nn.Linear(self.cfg.embed_size) for _ in range(self.cfg.num_heads)]) - # self.value = nn.ModuleList([nn.Linear(self.cfg.embed_size) for _ in range(self.cfg.num_heads)]) self.ffn1s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)]) self.ffn2s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)]) self.multihead_attns = nn.ModuleList([nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads) for _ in range(self.cfg.num_blocks)]) self.aap_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False) + self.mip_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False) + self.map_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False) + self.sp_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False) self._init_weights() - def _init_weights(self): for child in self.children(): if isinstance(child, nn.Embedding): @@ -50,7 +48,7 @@ def _self_attention_block(self, X): def _prediction_layer(self, item, self_attn_output): return torch.einsum('bi,bi->b', (item, self_attn_output)) - def forward(self, X, pos_item, neg_item): + def finetune(self, X, pos_item, neg_item): X = self._embedding_layer(X) X = self._self_attention_block(X) pos_pred = self._prediction_layer(self.item_embedding(pos_item), X[:, -1]) @@ -65,8 +63,69 @@ def evaluate(self, X, pos_item, neg_items): self.item_embedding(neg_items[:,i]), X[:, -1]).view(neg_items.size(0), -1) for i in range(neg_items.size(-1))] neg_preds = torch.concat(neg_preds, dim=1) return pos_pred, neg_preds + + def encode(self, X): + return self._self_attention_block(self._embedding_layer(X)) - def aap(self, items): - # item - item_embeddings = self.item_embedding(items) - return torch.matmul(self.aap_weight(item_embeddings), self.attribute_embedding.weight.T) # (batch, embed_size) * (attribute_size, embed_size) (batch, attribute_size) + def pretrain(self, item_masked_sequences, subsequence_masked_sequences, pos_subsequences, neg_subsequences): + # encode + attention_output = self.encode(item_masked_sequences) + subsequence_attention_output = self.encode(subsequence_masked_sequences) + pos_subsequence_attention_output = self.encode(pos_subsequences) + neg_subsequence_attention_output = self.encode(neg_subsequences) + # aap + aap_output = self.aap(attention_output) # (B, L, A) + # mip + mip_output = self.mip(attention_output) + # map + map_output = self.map(attention_output) + # sp + sp_output_pos = self.sp(attention_output, pos_subsequence_attention_output) # pos 1 + sp_output_neg = self.sp(attention_output, neg_subsequence_attention_output) # neg 1 + return aap_output, mip_output, map_output, (sp_output_pos, sp_output_neg) + + def aap(self, attention_output): + ''' + inputs: + attention_output: [ B, L, H ] + output: + [ B, L, A ] + ''' + FW = self.aap_weight(attention_output) # [ B L H ] + return torch.matmul(FW, self.attribute_embedding.weight.T) # [ B L H ] [ H A ] -> [ B L A ] + + def mip(self, attention_output): + ''' + inputs: + attention_output: [ B, L, H ] + output: + ''' + FW = self.mip_weight(attention_output) # [ B L H ] + return torch.matmul(FW, self.item_embedding.weight.t()) # [ B L H ] [ H I ] -> [ B L I ] + + def map(self, attention_output): + ''' + inputs: + attention_output: [ B, L, H ] + output: + [ B, L, A ] + ''' + FW = self.aap_weight(attention_output) # [ B L H ] + return torch.matmul(FW, self.attribute_embedding.weight.T) # [ B L H ] [ H A ] -> [ B L A ] + + def sp(self, context_attention_output, subsequence_attention_output): + ''' + inputs: + context_attention_output: [ B, L, H ] + subsequence_attention_output: [ B, len_subsequence, H ] + output: + [ B ] + + s - input [ i1, i2, mask, mask, mask, ..., in ] + s~ - input [ i3, i4, i5 ] + + ''' + s = context_attention_output[:, -1, :] # [ B H ] + s_tilde = subsequence_attention_output[:, -1, :] # [ B H ] + SW = self.sp_weight(s) + return torch.einsum('bi,bi->b', SW, s_tilde) # [ B ] diff --git a/train.py b/train.py index e1c101b..3637fe2 100644 --- a/train.py +++ b/train.py @@ -103,7 +103,7 @@ def train(cfg, args):#train_dataset, valid_dataset, test_dataset, model_info): if cfg.pretrain: trainer = S3RecPreTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'], args.data_pipeline.item2attributes, args.data_pipeline.attributes_count) - trainer.pretrain(args.train_dataset, args.valid_dataset) + trainer.pretrain(train_dataloader) trainer.load_best_model() else: trainer = S3RecTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'], diff --git a/trainers/s3rec_trainer.py b/trainers/s3rec_trainer.py index dd1d011..46ad4a6 100644 --- a/trainers/s3rec_trainer.py +++ b/trainers/s3rec_trainer.py @@ -56,7 +56,7 @@ def _is_surpass_best_metric(self, **metric) -> bool: else: return False - def pretrain(self, train_dataset, valid_dataset): + def pretrain(self, train_dataset): logger.info(f"[Trainer] run...") best_valid_loss: float = 1e+6 @@ -65,8 +65,8 @@ def pretrain(self, train_dataset, valid_dataset): # train for epoch in range(self.cfg.pretrain_epochs): - train_loss: float = self.train(torch.tensor([i for i in range(1, self.num_items+1)], dtype=torch.int32).to(self.device), train_dataset) - valid_loss = self.validate(torch.tensor([i for i in range(1, self.num_items+1)], dtype=torch.int32).to(self.device), valid_dataset) + train_loss: float = self.train(train_dataset) + valid_loss = self.validate(train_dataset) logger.info(f'''\n[Trainer] epoch: {epoch} > train loss: {train_loss:.4f} / valid loss: {valid_loss:.4f}''') @@ -92,42 +92,66 @@ def pretrain(self, train_dataset, valid_dataset): if endurance > self.cfg.patience: logger.info(f"[Trainer] ealry stopping...") break - - def train(self, item_datasets, sequence_datasets) -> float: + + def item_level_masking(self, sequences): + masks = torch.rand_like(sequences, dtype=torch.float32) < .2 + item_masked_sequences = masks * sequences + return masks, item_masked_sequences + + def segment_masking(self, sequences): + masks, pos_sequences, neg_sequences = torch.zeros_like(sequences), torch.zeros_like(sequences), torch.zeros_like(sequences) + for i in range(sequences.size(0)): + # sample segment length randomly + segment_len = torch.randint(low=2, high=self.cfg.max_seq_len//2, size=(1,)) + # start_index + start_idx = torch.randint(self.cfg.max_seq_len-segment_len, size=(1,)) + masks[i, start_idx:start_idx+segment_len] = 1 + # pos_sequence + pos_sequences[i, -segment_len:] = sequences[i, start_idx:start_idx+segment_len] + # neg_sequence + ## other user in same batch + neg_user_idx = torch.randint(sequences.size(0), size=(1,)) + while neg_user_idx != i: + neg_user_idx = torch.randint(sequences.size(0), size=(1,)) + ## start_idx + neg_start_idx = torch.randint(self.cfg.max_seq_len-segment_len, size=(1,)) + neg_sequences[i, -segment_len:] = sequences[neg_user_idx, neg_start_idx:neg_start_idx+segment_len] + segment_masked_sequences = (1-masks) * sequences + return segment_masked_sequences, pos_sequences, neg_sequences + + # def train(self, item_datasets, sequence_datasets) -> float: + def train(self, train_dataloader) -> float: self.model.train() train_loss = 0 - for iter_num in tqdm(range(self.cfg.iter_nums)): # sequence - item_chunk_size = self.num_items // self.cfg.iter_nums - items = item_datasets[item_chunk_size * iter_num: item_chunk_size * (iter_num + 1)] + for data in tqdm(train_dataloader): # sequence + sequences = data['X'].to(self.device) + # item_masked_sequences + masks, item_masked_sequences = self.item_level_masking(sequences) + # segment_masked_sequences + segment_masked_sequences, pos_segments, neg_segments = self.segment_masking(sequences) - sequence_chunk_size = self.num_users // self.cfg.iter_nums - # sequences = sequence_datasets[sequence_chunk_size * iter_num: sequence_chunk_size * (iter_num + 1)] + # pretrain + aap_output, mip_output, map_output, (sp_output_pos, sp_output_neg) = self.model.pretrain( + item_masked_sequences, segment_masked_sequences, pos_segments, neg_segments) # AAP: item + atrributes - pred = self.model.aap(items) # (item_chunk_size, attributes_count) - actual = torch.Tensor([[1 if attriute in self.item2attribute[item.item()] else 0 for attriute in range(self.attributes_count)] for item in items]).to(self.device) # (item_chunk_size, attributes_count) - aap_loss = nn.functional.binary_cross_entropy_with_logits(pred, actual) + aap_actual = torch.ones_like(aap_output).to(self.device) +# actual = torch.Tensor([ +# [1 if attriute in self.item2attribute[item.item()] else 0 \ + # for attriute in range(self.attributes_count)] for item in items] +# ).to(self.device) # (item_chunk_size, attributes_count) + ## compute unmasked area only + aap_loss = nn.functional.binary_cross_entropy_with_logits(aap_output, aap_actual) # MIP: sequence + item - # mask - # def random_mask(sequence): - # # mask = torch.Tensor([0] * sequence.size(0)) - # non_zero_count = torch.nonzero(sequence, as_tuple=True)[0].size(0) - # mask_indices = torch.randint(sequence.size(0) - non_zero_count, sequence.size(0), size=1) - # # mask[mask_indices] = 1 - # return mask_indices - - # masks = torch.Tensor([random_mask(sequence) for sequence in sequences]) # () - # masked_sequences = sequences * (1 - masks) - # pred = self.model.mip(masked_sequences, ) # (sequence_chunk_size, mask_count, sequence_len) item idx pred - # nn.functional.binary_cross_entropy - # # MAP: sequence + attributes - # map_loss = self.loss() - # # SP: sequence + segment - # sp_loss = self.loss() - # # X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device) - # # pos_pred, neg_pred = self.model(X, pos_item, neg_item) + ## compute masked area only + + # MAP: sequence + attribute + ## compute masked area only + + # SP: sequence + segment + ## pos_segment > neg_segment self.optimizer.zero_grad() # loss = self.loss(pos_pred, neg_pred) @@ -139,45 +163,12 @@ def train(self, item_datasets, sequence_datasets) -> float: return train_loss - def validate(self, item_datasets, sequence_datasets) -> float: + def validate(self, sequence_datasets) -> float: self.model.eval() valid_loss = 0 for iter_num in tqdm(range(self.cfg.iter_nums)): # sequence - item_chunk_size = self.num_items // self.cfg.iter_nums - items = item_datasets[item_chunk_size * iter_num: item_chunk_size * (iter_num + 1)] - - sequence_chunk_size = self.num_users // self.cfg.iter_nums - # sequences = sequence_datasets[sequence_chunk_size * iter_num: sequence_chunk_size * (iter_num + 1)] - - # AAP: item + atrributes - pred = self.model.aap(items) # (item_chunk_size, attributes_count) - actual = torch.Tensor([[1 if attriute in self.item2attribute[item.item()] else 0 for attriute in range(self.attributes_count)] for item in items]).to(self.device) # (item_chunk_size, attributes_count) - aap_loss = nn.functional.binary_cross_entropy_with_logits(pred, actual) - - # MIP: sequence + item - # mask - # def random_mask(sequence): - # # mask = torch.Tensor([0] * sequence.size(0)) - # non_zero_count = torch.nonzero(sequence, as_tuple=True)[0].size(0) - # mask_indices = torch.randint(sequence.size(0) - non_zero_count, sequence.size(0), size=1) - # # mask[mask_indices] = 1 - # return mask_indices - - # masks = torch.Tensor([random_mask(sequence) for sequence in sequences]) # () - # masked_sequences = sequences * (1 - masks) - # pred = self.model.mip(masked_sequences, ) # (sequence_chunk_size, sequence_len) item idx pred - # nn.functional.binary_cross_entropy - # # MAP: sequence + attributes - # map_loss = self.loss() - # # SP: sequence + segment - # sp_loss = self.loss() - # # X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device) - # # pos_pred, neg_pred = self.model(X, pos_item, neg_item) - - # loss = self.loss(pos_pred, neg_pred) - loss = aap_loss # + mip_loss + map_loss + sp_loss - + break valid_loss += loss.item() return valid_loss @@ -250,7 +241,7 @@ def train(self, train_dataloader: DataLoader) -> float: train_loss = 0 for data in tqdm(train_dataloader): X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device) - pos_pred, neg_pred = self.model(X, pos_item, neg_item) + pos_pred, neg_pred = self.model.finetune(X, pos_item, neg_item) self.optimizer.zero_grad() loss = self.loss(pos_pred, neg_pred) @@ -267,7 +258,7 @@ def validate(self, valid_dataloader: DataLoader) -> tuple[float]: # actual, predicted = [], [] for data in tqdm(valid_dataloader): X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device) - pos_pred, neg_pred = self.model(X, pos_item, neg_item) + pos_pred, neg_pred = self.model.finetune(X, pos_item, neg_item) self.optimizer.zero_grad() loss = self.loss(pos_pred, neg_pred)