Skip to content

Commit

Permalink
feat: implemenets MultiheadAttention #21
Browse files Browse the repository at this point in the history
  • Loading branch information
GangBean committed Aug 29, 2024
1 parent bfce575 commit b7e9701
Showing 1 changed file with 53 additions and 6 deletions.
59 changes: 53 additions & 6 deletions models/s3rec.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import torch
import torch.nn as nn

Expand All @@ -16,7 +17,7 @@ def __init__(self, cfg, num_items, attributes_count):
self.positional_encoding = nn.Parameter(torch.rand(self.cfg.max_seq_len, self.cfg.embed_size))

self.multihead_attns = nn.ModuleList(
[nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads, batch_first=True) for _ in range(self.cfg.num_blocks)])
[MultiHeadAttention(self.cfg.embed_size, self.cfg.num_heads) for _ in range(self.cfg.num_blocks)])
self.layernorm1s = nn.ModuleList(
[nn.LayerNorm(self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])

Expand Down Expand Up @@ -55,11 +56,12 @@ def _embedding_layer(self, X):
def _self_attention_block(self, X, padding_mask, attn_mask):
for multihead_attn, ffn1, ffn2, layernorm1, layernorm2 in zip(
self.multihead_attns, self.ffn1s, self.ffn2s, self.layernorm1s, self.layernorm2s):
# multi-head self-attention
merged_mask,_ = multihead_attn.merge_masks(attn_mask, padding_mask, X)
attn_output, attn_output_weights = multihead_attn(
X, X, X, #key_padding_mask=padding_mask,
is_causal=True, attn_mask=attn_mask)
# # multi-head self-attention
# merged_mask,_ = multihead_attn.merge_masks(attn_mask, padding_mask, X)
# attn_output, attn_output_weights = multihead_attn(
# X, X, X, #key_padding_mask=padding_mask,
# is_causal=True, attn_mask=attn_mask)
attn_output = multihead_attn(X, X, X, padding_mask, attn_mask)
# dropout
attn_output = self.dropout(attn_output)
# add & norm
Expand Down Expand Up @@ -161,3 +163,48 @@ def sp(self, context_attention_output, subsequence_attention_output):
s_tilde = subsequence_attention_output[:, -1, :] # [ B H ]
SW = self.sp_weight(s)
return torch.einsum('bi,bi->b', SW, s_tilde) # [ B ]


class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super().__init__()
# self.multihead_attns = nn.ModuleList(
# [nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads, batch_first=True) for _ in range(self.cfg.num_blocks)])
self.embed_size = embed_size
self.num_heads = num_heads
self.q_weights = nn.ModuleList(
[nn.Linear(self.embed_size, self.embed_size, bias=False) for _ in range(self.num_heads)])
self.k_weights = nn.ModuleList(
[nn.Linear(self.embed_size, self.embed_size, bias=False) for _ in range(self.num_heads)])
self.v_weights = nn.ModuleList(
[nn.Linear(self.embed_size, self.embed_size, bias=False) for _ in range(self.num_heads)])
self.output = nn.Linear(num_heads * embed_size, embed_size)

def forward(self, q, k, v, padding_mask, attn_mask):
# merged_mask,_ = multihead_attn.merge_masks(attn_mask, padding_mask, X)
# attn_output, attn_output_weights = multihead_attn(
# X, X, X, #key_padding_mask=padding_mask,
# is_causal=True, attn_mask=attn_mask)
# # dropout
attention_outputs = []
for q_weight, k_weight, v_weight in zip(self.q_weights, self.k_weights, self.v_weights):
Q = q_weight(q)
K = k_weight(k) # (B, L, E)
K = K * padding_mask.unsqueeze(2) # (B, L)
logger.info(K)
V = v_weight(v)

attention_score = torch.matmul(Q, K.permute(0,2,1).contiguous()) # (B, L, L)
attention_score /= math.sqrt(self.embed_size)
attention_score = attention_score * attn_mask.unsqueeze(0)
logger.info(attention_score)
attention_score = torch.nn.functional.softmax(attention_score)
logger.info(attention_score)
attention_score = torch.matmul(attention_score, V) # (B, L, E)
# torch.einsum('bi,bi->b', Q, K) # (batch, seq_len, embed_size)
attention_outputs.append(attention_score) # (H, B, L, E)

attention_scores = torch.cat(attention_outputs, dim=-1) # (B, L, E*H)

return self.output(attention_scores) # (B, L, E)

0 comments on commit b7e9701

Please sign in to comment.