Skip to content

Commit

Permalink
refactor: add add & norm and dropout #21
Browse files Browse the repository at this point in the history
  • Loading branch information
twndus committed Aug 23, 2024
1 parent fdca1cb commit 59937ab
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
9 changes: 5 additions & 4 deletions configs/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ log_dir: logs/
sweep: False

# wandb config
wandb: True # True/ False
wandb: False
project: YelpRecommendation
notes: "..."
tags: [yelp, s3rec]

# train config
device: cuda # cpu
epochs: 100
device: cpu
epochs: 5
batch_size: 32
lr: 0.0001
optimizer: adam # adamw
Expand Down Expand Up @@ -49,7 +49,8 @@ model:
max_seq_len: 50
num_heads: 2
num_blocks: 2
pretrain: True # False
pretrain: False # False
load_pretrain: True
pretrain_epochs: 100 # 100
mask_portion: 0.2
dropout_ratio: 0.1
36 changes: 30 additions & 6 deletions models/s3rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from models.base_model import BaseModel

from loguru import logger


class S3Rec(BaseModel):

def __init__(self, cfg, num_items, attributes_count):
Expand All @@ -12,10 +14,21 @@ def __init__(self, cfg, num_items, attributes_count):
self.item_embedding = nn.Embedding(num_items + 1, self.cfg.embed_size, dtype=torch.float32)
self.attribute_embedding = nn.Embedding(attributes_count, self.cfg.embed_size, dtype=torch.float32)
self.positional_encoding = nn.Parameter(torch.rand(self.cfg.max_seq_len, self.cfg.embed_size))

self.multihead_attns = nn.ModuleList(
[nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads) for _ in range(self.cfg.num_blocks)])
self.layernorm1s = nn.ModuleList(
[nn.LayerNorm(self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])

self.ffn1s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])
self.ffn2s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])
self.multihead_attns = nn.ModuleList([nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads) for _ in range(self.cfg.num_blocks)])
self.ffn1s = nn.ModuleList(
[nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])
self.ffn2s = nn.ModuleList(
[nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])
self.layernorm2s = nn.ModuleList(
[nn.LayerNorm(self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])

self.dropout = nn.Dropout(self.cfg.dropout_ratio)

self.aap_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False)
self.mip_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False)
self.map_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False)
Expand All @@ -29,7 +42,7 @@ def _init_weights(self):
nn.init.xavier_uniform_(child.weight)
elif isinstance(child, nn.ModuleList): # nn.Linear):
for sub_child in child.children():
if not isinstance(sub_child, nn.MultiheadAttention):
if isinstance(sub_child, nn.Linear):
nn.init.xavier_uniform_(sub_child.weight)
elif isinstance(child, nn.Linear):
nn.init.xavier_uniform_(child.weight)
Expand All @@ -40,9 +53,20 @@ def _embedding_layer(self, X):
return self.item_embedding(X) + self.positional_encoding

def _self_attention_block(self, X):
for multihead_attn, ffn1, ffn2 in zip(self.multihead_attns, self.ffn1s, self.ffn2s):
for multihead_attn, ffn1, ffn2, layernorm1, layernorm2 in zip(
self.multihead_attns, self.ffn1s, self.ffn2s, self.layernorm1s, self.layernorm2s):
# multi-head self-attention
attn_output, attn_output_weights = multihead_attn(X, X, X)
X = ffn2(nn.functional.relu(ffn1(attn_output)))
# dropout
attn_output = self.dropout(attn_output)
# add & norm
normalized_attn_output = layernorm1(X + attn_output)
# feed-forward network
ffn_output = ffn2(nn.functional.relu(ffn1(normalized_attn_output)))
# dropout
ffn_output = self.dropout(ffn_output)
# add & norm
X = layernorm2(X + ffn_output)
return X

def _prediction_layer(self, item, self_attn_output):
Expand Down

0 comments on commit 59937ab

Please sign in to comment.