diff --git a/models/s3rec.py b/models/s3rec.py index 263c52c..06f513a 100644 --- a/models/s3rec.py +++ b/models/s3rec.py @@ -1,3 +1,4 @@ +import math import torch import torch.nn as nn @@ -16,7 +17,7 @@ def __init__(self, cfg, num_items, attributes_count): self.positional_encoding = nn.Parameter(torch.rand(self.cfg.max_seq_len, self.cfg.embed_size)) self.multihead_attns = nn.ModuleList( - [nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads, batch_first=True) for _ in range(self.cfg.num_blocks)]) + [MultiHeadAttention(self.cfg.embed_size, self.cfg.num_heads) for _ in range(self.cfg.num_blocks)]) self.layernorm1s = nn.ModuleList( [nn.LayerNorm(self.cfg.embed_size) for _ in range(self.cfg.num_blocks)]) @@ -55,11 +56,12 @@ def _embedding_layer(self, X): def _self_attention_block(self, X, padding_mask, attn_mask): for multihead_attn, ffn1, ffn2, layernorm1, layernorm2 in zip( self.multihead_attns, self.ffn1s, self.ffn2s, self.layernorm1s, self.layernorm2s): - # multi-head self-attention - merged_mask,_ = multihead_attn.merge_masks(attn_mask, padding_mask, X) - attn_output, attn_output_weights = multihead_attn( - X, X, X, #key_padding_mask=padding_mask, - is_causal=True, attn_mask=attn_mask) + # # multi-head self-attention + # merged_mask,_ = multihead_attn.merge_masks(attn_mask, padding_mask, X) + # attn_output, attn_output_weights = multihead_attn( + # X, X, X, #key_padding_mask=padding_mask, + # is_causal=True, attn_mask=attn_mask) + attn_output = multihead_attn(X, X, X, padding_mask, attn_mask) # dropout attn_output = self.dropout(attn_output) # add & norm @@ -161,3 +163,48 @@ def sp(self, context_attention_output, subsequence_attention_output): s_tilde = subsequence_attention_output[:, -1, :] # [ B H ] SW = self.sp_weight(s) return torch.einsum('bi,bi->b', SW, s_tilde) # [ B ] + + +class MultiHeadAttention(nn.Module): + def __init__(self, embed_size, num_heads): + super().__init__() + # self.multihead_attns = nn.ModuleList( + # [nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads, batch_first=True) for _ in range(self.cfg.num_blocks)]) + self.embed_size = embed_size + self.num_heads = num_heads + self.q_weights = nn.ModuleList( + [nn.Linear(self.embed_size, self.embed_size, bias=False) for _ in range(self.num_heads)]) + self.k_weights = nn.ModuleList( + [nn.Linear(self.embed_size, self.embed_size, bias=False) for _ in range(self.num_heads)]) + self.v_weights = nn.ModuleList( + [nn.Linear(self.embed_size, self.embed_size, bias=False) for _ in range(self.num_heads)]) + self.output = nn.Linear(num_heads * embed_size, embed_size) + + def forward(self, q, k, v, padding_mask, attn_mask): + # merged_mask,_ = multihead_attn.merge_masks(attn_mask, padding_mask, X) + # attn_output, attn_output_weights = multihead_attn( + # X, X, X, #key_padding_mask=padding_mask, + # is_causal=True, attn_mask=attn_mask) + # # dropout + attention_outputs = [] + for q_weight, k_weight, v_weight in zip(self.q_weights, self.k_weights, self.v_weights): + Q = q_weight(q) + K = k_weight(k) # (B, L, E) + K = K * padding_mask.unsqueeze(2) # (B, L) + logger.info(K) + V = v_weight(v) + + attention_score = torch.matmul(Q, K.permute(0,2,1).contiguous()) # (B, L, L) + attention_score /= math.sqrt(self.embed_size) + attention_score = attention_score * attn_mask.unsqueeze(0) + logger.info(attention_score) + attention_score = torch.nn.functional.softmax(attention_score) + logger.info(attention_score) + attention_score = torch.matmul(attention_score, V) # (B, L, E) + # torch.einsum('bi,bi->b', Q, K) # (batch, seq_len, embed_size) + attention_outputs.append(attention_score) # (H, B, L, E) + + attention_scores = torch.cat(attention_outputs, dim=-1) # (B, L, E*H) + + return self.output(attention_scores) # (B, L, E) + \ No newline at end of file