Skip to content

Commit

Permalink
fix: implements S3Rec pretrain method and insert self-attention netwo…
Browse files Browse the repository at this point in the history
…rk to pretraining stage #21
  • Loading branch information
twndus committed Jul 26, 2024
1 parent dafe357 commit 18fca50
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 79 deletions.
79 changes: 69 additions & 10 deletions models/s3rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,20 @@ class S3Rec(BaseModel):
def __init__(self, cfg, num_users, num_items, attributes_count):
super().__init__()
self.cfg = cfg
# self.user_embedding = nn.Embedding(num_users, cfg.embed_size, dtype=torch.float32)
self.item_embedding = nn.Embedding(num_items + 1, self.cfg.embed_size, dtype=torch.float32)
self.attribute_embedding = nn.Embedding(attributes_count, self.cfg.embed_size, dtype=torch.float32)
self.positional_encoding = nn.Parameter(torch.rand(self.cfg.max_seq_len, self.cfg.embed_size))

# self.query = nn.ModuleList([nn.Linear(self.cfg.embed_size / self.num_heads) for _ in range(self.cfg.num_heads)])
# self.key = nn.ModuleList([nn.Linear(self.cfg.embed_size) for _ in range(self.cfg.num_heads)])
# self.value = nn.ModuleList([nn.Linear(self.cfg.embed_size) for _ in range(self.cfg.num_heads)])
self.ffn1s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])
self.ffn2s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])
self.multihead_attns = nn.ModuleList([nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads) for _ in range(self.cfg.num_blocks)])
self.aap_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False)
self.mip_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False)
self.map_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False)
self.sp_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False)

self._init_weights()


def _init_weights(self):
for child in self.children():
if isinstance(child, nn.Embedding):
Expand All @@ -50,7 +48,7 @@ def _self_attention_block(self, X):
def _prediction_layer(self, item, self_attn_output):
return torch.einsum('bi,bi->b', (item, self_attn_output))

def forward(self, X, pos_item, neg_item):
def finetune(self, X, pos_item, neg_item):
X = self._embedding_layer(X)
X = self._self_attention_block(X)
pos_pred = self._prediction_layer(self.item_embedding(pos_item), X[:, -1])
Expand All @@ -65,8 +63,69 @@ def evaluate(self, X, pos_item, neg_items):
self.item_embedding(neg_items[:,i]), X[:, -1]).view(neg_items.size(0), -1) for i in range(neg_items.size(-1))]
neg_preds = torch.concat(neg_preds, dim=1)
return pos_pred, neg_preds

def encode(self, X):
return self._self_attention_block(self._embedding_layer(X))

def aap(self, items):
# item
item_embeddings = self.item_embedding(items)
return torch.matmul(self.aap_weight(item_embeddings), self.attribute_embedding.weight.T) # (batch, embed_size) * (attribute_size, embed_size) (batch, attribute_size)
def pretrain(self, item_masked_sequences, subsequence_masked_sequences, pos_subsequences, neg_subsequences):
# encode
attention_output = self.encode(item_masked_sequences)
subsequence_attention_output = self.encode(subsequence_masked_sequences)
pos_subsequence_attention_output = self.encode(pos_subsequences)
neg_subsequence_attention_output = self.encode(neg_subsequences)
# aap
aap_output = self.aap(attention_output) # (B, L, A)
# mip
mip_output = self.mip(attention_output)
# map
map_output = self.map(attention_output)
# sp
sp_output_pos = self.sp(attention_output, pos_subsequence_attention_output) # pos 1
sp_output_neg = self.sp(attention_output, neg_subsequence_attention_output) # neg 1
return aap_output, mip_output, map_output, (sp_output_pos, sp_output_neg)

def aap(self, attention_output):
'''
inputs:
attention_output: [ B, L, H ]
output:
[ B, L, A ]
'''
FW = self.aap_weight(attention_output) # [ B L H ]
return torch.matmul(FW, self.attribute_embedding.weight.T) # [ B L H ] [ H A ] -> [ B L A ]

def mip(self, attention_output):
'''
inputs:
attention_output: [ B, L, H ]
output:
'''
FW = self.mip_weight(attention_output) # [ B L H ]
return torch.matmul(FW, self.item_embedding.weight.t()) # [ B L H ] [ H I ] -> [ B L I ]

def map(self, attention_output):
'''
inputs:
attention_output: [ B, L, H ]
output:
[ B, L, A ]
'''
FW = self.aap_weight(attention_output) # [ B L H ]
return torch.matmul(FW, self.attribute_embedding.weight.T) # [ B L H ] [ H A ] -> [ B L A ]

def sp(self, context_attention_output, subsequence_attention_output):
'''
inputs:
context_attention_output: [ B, L, H ]
subsequence_attention_output: [ B, len_subsequence, H ]
output:
[ B ]
s - input [ i1, i2, mask, mask, mask, ..., in ]
s~ - input [ i3, i4, i5 ]
'''
s = context_attention_output[:, -1, :] # [ B H ]
s_tilde = subsequence_attention_output[:, -1, :] # [ B H ]
SW = self.sp_weight(s)
return torch.einsum('bi,bi->b', SW, s_tilde) # [ B ]
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def train(cfg, args):#train_dataset, valid_dataset, test_dataset, model_info):
if cfg.pretrain:
trainer = S3RecPreTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'],
args.data_pipeline.item2attributes, args.data_pipeline.attributes_count)
trainer.pretrain(args.train_dataset, args.valid_dataset)
trainer.pretrain(train_dataloader)
trainer.load_best_model()
else:
trainer = S3RecTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'],
Expand Down
127 changes: 59 additions & 68 deletions trainers/s3rec_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _is_surpass_best_metric(self, **metric) -> bool:
else:
return False

def pretrain(self, train_dataset, valid_dataset):
def pretrain(self, train_dataset):
logger.info(f"[Trainer] run...")

best_valid_loss: float = 1e+6
Expand All @@ -65,8 +65,8 @@ def pretrain(self, train_dataset, valid_dataset):

# train
for epoch in range(self.cfg.pretrain_epochs):
train_loss: float = self.train(torch.tensor([i for i in range(1, self.num_items+1)], dtype=torch.int32).to(self.device), train_dataset)
valid_loss = self.validate(torch.tensor([i for i in range(1, self.num_items+1)], dtype=torch.int32).to(self.device), valid_dataset)
train_loss: float = self.train(train_dataset)
valid_loss = self.validate(train_dataset)
logger.info(f'''\n[Trainer] epoch: {epoch} > train loss: {train_loss:.4f} /
valid loss: {valid_loss:.4f}''')

Expand All @@ -92,42 +92,66 @@ def pretrain(self, train_dataset, valid_dataset):
if endurance > self.cfg.patience:
logger.info(f"[Trainer] ealry stopping...")
break

def train(self, item_datasets, sequence_datasets) -> float:

def item_level_masking(self, sequences):
masks = torch.rand_like(sequences, dtype=torch.float32) < .2
item_masked_sequences = masks * sequences
return masks, item_masked_sequences

def segment_masking(self, sequences):
masks, pos_sequences, neg_sequences = torch.zeros_like(sequences), torch.zeros_like(sequences), torch.zeros_like(sequences)
for i in range(sequences.size(0)):
# sample segment length randomly
segment_len = torch.randint(low=2, high=self.cfg.max_seq_len//2, size=(1,))
# start_index
start_idx = torch.randint(self.cfg.max_seq_len-segment_len, size=(1,))
masks[i, start_idx:start_idx+segment_len] = 1
# pos_sequence
pos_sequences[i, -segment_len:] = sequences[i, start_idx:start_idx+segment_len]
# neg_sequence
## other user in same batch
neg_user_idx = torch.randint(sequences.size(0), size=(1,))
while neg_user_idx != i:
neg_user_idx = torch.randint(sequences.size(0), size=(1,))
## start_idx
neg_start_idx = torch.randint(self.cfg.max_seq_len-segment_len, size=(1,))
neg_sequences[i, -segment_len:] = sequences[neg_user_idx, neg_start_idx:neg_start_idx+segment_len]
segment_masked_sequences = (1-masks) * sequences
return segment_masked_sequences, pos_sequences, neg_sequences

# def train(self, item_datasets, sequence_datasets) -> float:
def train(self, train_dataloader) -> float:
self.model.train()
train_loss = 0

for iter_num in tqdm(range(self.cfg.iter_nums)): # sequence
item_chunk_size = self.num_items // self.cfg.iter_nums
items = item_datasets[item_chunk_size * iter_num: item_chunk_size * (iter_num + 1)]
for data in tqdm(train_dataloader): # sequence
sequences = data['X'].to(self.device)
# item_masked_sequences
masks, item_masked_sequences = self.item_level_masking(sequences)
# segment_masked_sequences
segment_masked_sequences, pos_segments, neg_segments = self.segment_masking(sequences)

sequence_chunk_size = self.num_users // self.cfg.iter_nums
# sequences = sequence_datasets[sequence_chunk_size * iter_num: sequence_chunk_size * (iter_num + 1)]
# pretrain
aap_output, mip_output, map_output, (sp_output_pos, sp_output_neg) = self.model.pretrain(
item_masked_sequences, segment_masked_sequences, pos_segments, neg_segments)

# AAP: item + atrributes
pred = self.model.aap(items) # (item_chunk_size, attributes_count)
actual = torch.Tensor([[1 if attriute in self.item2attribute[item.item()] else 0 for attriute in range(self.attributes_count)] for item in items]).to(self.device) # (item_chunk_size, attributes_count)
aap_loss = nn.functional.binary_cross_entropy_with_logits(pred, actual)
aap_actual = torch.ones_like(aap_output).to(self.device)
# actual = torch.Tensor([
# [1 if attriute in self.item2attribute[item.item()] else 0 \
# for attriute in range(self.attributes_count)] for item in items]
# ).to(self.device) # (item_chunk_size, attributes_count)
## compute unmasked area only
aap_loss = nn.functional.binary_cross_entropy_with_logits(aap_output, aap_actual)

# MIP: sequence + item
# mask
# def random_mask(sequence):
# # mask = torch.Tensor([0] * sequence.size(0))
# non_zero_count = torch.nonzero(sequence, as_tuple=True)[0].size(0)
# mask_indices = torch.randint(sequence.size(0) - non_zero_count, sequence.size(0), size=1)
# # mask[mask_indices] = 1
# return mask_indices

# masks = torch.Tensor([random_mask(sequence) for sequence in sequences]) # ()
# masked_sequences = sequences * (1 - masks)
# pred = self.model.mip(masked_sequences, ) # (sequence_chunk_size, mask_count, sequence_len) item idx pred
# nn.functional.binary_cross_entropy
# # MAP: sequence + attributes
# map_loss = self.loss()
# # SP: sequence + segment
# sp_loss = self.loss()
# # X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device)
# # pos_pred, neg_pred = self.model(X, pos_item, neg_item)
## compute masked area only

# MAP: sequence + attribute
## compute masked area only

# SP: sequence + segment
## pos_segment > neg_segment

self.optimizer.zero_grad()
# loss = self.loss(pos_pred, neg_pred)
Expand All @@ -139,45 +163,12 @@ def train(self, item_datasets, sequence_datasets) -> float:

return train_loss

def validate(self, item_datasets, sequence_datasets) -> float:
def validate(self, sequence_datasets) -> float:
self.model.eval()
valid_loss = 0

for iter_num in tqdm(range(self.cfg.iter_nums)): # sequence
item_chunk_size = self.num_items // self.cfg.iter_nums
items = item_datasets[item_chunk_size * iter_num: item_chunk_size * (iter_num + 1)]

sequence_chunk_size = self.num_users // self.cfg.iter_nums
# sequences = sequence_datasets[sequence_chunk_size * iter_num: sequence_chunk_size * (iter_num + 1)]

# AAP: item + atrributes
pred = self.model.aap(items) # (item_chunk_size, attributes_count)
actual = torch.Tensor([[1 if attriute in self.item2attribute[item.item()] else 0 for attriute in range(self.attributes_count)] for item in items]).to(self.device) # (item_chunk_size, attributes_count)
aap_loss = nn.functional.binary_cross_entropy_with_logits(pred, actual)

# MIP: sequence + item
# mask
# def random_mask(sequence):
# # mask = torch.Tensor([0] * sequence.size(0))
# non_zero_count = torch.nonzero(sequence, as_tuple=True)[0].size(0)
# mask_indices = torch.randint(sequence.size(0) - non_zero_count, sequence.size(0), size=1)
# # mask[mask_indices] = 1
# return mask_indices

# masks = torch.Tensor([random_mask(sequence) for sequence in sequences]) # ()
# masked_sequences = sequences * (1 - masks)
# pred = self.model.mip(masked_sequences, ) # (sequence_chunk_size, sequence_len) item idx pred
# nn.functional.binary_cross_entropy
# # MAP: sequence + attributes
# map_loss = self.loss()
# # SP: sequence + segment
# sp_loss = self.loss()
# # X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device)
# # pos_pred, neg_pred = self.model(X, pos_item, neg_item)

# loss = self.loss(pos_pred, neg_pred)
loss = aap_loss # + mip_loss + map_loss + sp_loss

break
valid_loss += loss.item()

return valid_loss
Expand Down Expand Up @@ -250,7 +241,7 @@ def train(self, train_dataloader: DataLoader) -> float:
train_loss = 0
for data in tqdm(train_dataloader):
X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device)
pos_pred, neg_pred = self.model(X, pos_item, neg_item)
pos_pred, neg_pred = self.model.finetune(X, pos_item, neg_item)

self.optimizer.zero_grad()
loss = self.loss(pos_pred, neg_pred)
Expand All @@ -267,7 +258,7 @@ def validate(self, valid_dataloader: DataLoader) -> tuple[float]:
# actual, predicted = [], []
for data in tqdm(valid_dataloader):
X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device)
pos_pred, neg_pred = self.model(X, pos_item, neg_item)
pos_pred, neg_pred = self.model.finetune(X, pos_item, neg_item)

self.optimizer.zero_grad()
loss = self.loss(pos_pred, neg_pred)
Expand Down

0 comments on commit 18fca50

Please sign in to comment.