Skip to content

Commit

Permalink
feat: add padding and causal masks #21
Browse files Browse the repository at this point in the history
  • Loading branch information
twndus committed Aug 23, 2024
1 parent 59937ab commit bfce575
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions models/s3rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, cfg, num_items, attributes_count):
self.positional_encoding = nn.Parameter(torch.rand(self.cfg.max_seq_len, self.cfg.embed_size))

self.multihead_attns = nn.ModuleList(
[nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads) for _ in range(self.cfg.num_blocks)])
[nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads, batch_first=True) for _ in range(self.cfg.num_blocks)])
self.layernorm1s = nn.ModuleList(
[nn.LayerNorm(self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])

Expand Down Expand Up @@ -52,11 +52,14 @@ def _init_weights(self):
def _embedding_layer(self, X):
return self.item_embedding(X) + self.positional_encoding

def _self_attention_block(self, X):
def _self_attention_block(self, X, padding_mask, attn_mask):
for multihead_attn, ffn1, ffn2, layernorm1, layernorm2 in zip(
self.multihead_attns, self.ffn1s, self.ffn2s, self.layernorm1s, self.layernorm2s):
# multi-head self-attention
attn_output, attn_output_weights = multihead_attn(X, X, X)
merged_mask,_ = multihead_attn.merge_masks(attn_mask, padding_mask, X)
attn_output, attn_output_weights = multihead_attn(
X, X, X, #key_padding_mask=padding_mask,
is_causal=True, attn_mask=attn_mask)
# dropout
attn_output = self.dropout(attn_output)
# add & norm
Expand All @@ -73,8 +76,13 @@ def _prediction_layer(self, item, self_attn_output):
return torch.einsum('bi,bi->b', (item, self_attn_output))

def finetune(self, X, pos_item, neg_item):
# create padding mask
padding_mask = (X <= 0).to(self.cfg.device)
attn_mask = torch.triu(
torch.ones(self.cfg.max_seq_len, self.cfg.max_seq_len), diagonal=1
).bool().to(self.cfg.device)
X = self._embedding_layer(X)
X = self._self_attention_block(X)
X = self._self_attention_block(X, padding_mask, attn_mask)
pos_pred = self._prediction_layer(self.item_embedding(pos_item), X[:, -1])
neg_pred = self._prediction_layer(self.item_embedding(neg_item), X[:, -1])
return pos_pred, neg_pred
Expand Down

0 comments on commit bfce575

Please sign in to comment.