From 59937ab51d5fcf734dd686b2952c135ad1e734a6 Mon Sep 17 00:00:00 2001 From: twndus Date: Thu, 22 Aug 2024 22:27:26 -0400 Subject: [PATCH] refactor: add add & norm and dropout #21 --- configs/train_config.yaml | 9 +++++---- models/s3rec.py | 36 ++++++++++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/configs/train_config.yaml b/configs/train_config.yaml index 2f90bb7..b5c3dfb 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -8,14 +8,14 @@ log_dir: logs/ sweep: False # wandb config -wandb: True # True/ False +wandb: False project: YelpRecommendation notes: "..." tags: [yelp, s3rec] # train config -device: cuda # cpu -epochs: 100 +device: cpu +epochs: 5 batch_size: 32 lr: 0.0001 optimizer: adam # adamw @@ -49,7 +49,8 @@ model: max_seq_len: 50 num_heads: 2 num_blocks: 2 - pretrain: True # False + pretrain: False # False load_pretrain: True pretrain_epochs: 100 # 100 mask_portion: 0.2 + dropout_ratio: 0.1 diff --git a/models/s3rec.py b/models/s3rec.py index e22d0da..26c0286 100644 --- a/models/s3rec.py +++ b/models/s3rec.py @@ -4,6 +4,8 @@ from models.base_model import BaseModel from loguru import logger + + class S3Rec(BaseModel): def __init__(self, cfg, num_items, attributes_count): @@ -12,10 +14,21 @@ def __init__(self, cfg, num_items, attributes_count): self.item_embedding = nn.Embedding(num_items + 1, self.cfg.embed_size, dtype=torch.float32) self.attribute_embedding = nn.Embedding(attributes_count, self.cfg.embed_size, dtype=torch.float32) self.positional_encoding = nn.Parameter(torch.rand(self.cfg.max_seq_len, self.cfg.embed_size)) + + self.multihead_attns = nn.ModuleList( + [nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads) for _ in range(self.cfg.num_blocks)]) + self.layernorm1s = nn.ModuleList( + [nn.LayerNorm(self.cfg.embed_size) for _ in range(self.cfg.num_blocks)]) - self.ffn1s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)]) - self.ffn2s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)]) - self.multihead_attns = nn.ModuleList([nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads) for _ in range(self.cfg.num_blocks)]) + self.ffn1s = nn.ModuleList( + [nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)]) + self.ffn2s = nn.ModuleList( + [nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)]) + self.layernorm2s = nn.ModuleList( + [nn.LayerNorm(self.cfg.embed_size) for _ in range(self.cfg.num_blocks)]) + + self.dropout = nn.Dropout(self.cfg.dropout_ratio) + self.aap_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False) self.mip_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False) self.map_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False) @@ -29,7 +42,7 @@ def _init_weights(self): nn.init.xavier_uniform_(child.weight) elif isinstance(child, nn.ModuleList): # nn.Linear): for sub_child in child.children(): - if not isinstance(sub_child, nn.MultiheadAttention): + if isinstance(sub_child, nn.Linear): nn.init.xavier_uniform_(sub_child.weight) elif isinstance(child, nn.Linear): nn.init.xavier_uniform_(child.weight) @@ -40,9 +53,20 @@ def _embedding_layer(self, X): return self.item_embedding(X) + self.positional_encoding def _self_attention_block(self, X): - for multihead_attn, ffn1, ffn2 in zip(self.multihead_attns, self.ffn1s, self.ffn2s): + for multihead_attn, ffn1, ffn2, layernorm1, layernorm2 in zip( + self.multihead_attns, self.ffn1s, self.ffn2s, self.layernorm1s, self.layernorm2s): + # multi-head self-attention attn_output, attn_output_weights = multihead_attn(X, X, X) - X = ffn2(nn.functional.relu(ffn1(attn_output))) + # dropout + attn_output = self.dropout(attn_output) + # add & norm + normalized_attn_output = layernorm1(X + attn_output) + # feed-forward network + ffn_output = ffn2(nn.functional.relu(ffn1(normalized_attn_output))) + # dropout + ffn_output = self.dropout(ffn_output) + # add & norm + X = layernorm2(X + ffn_output) return X def _prediction_layer(self, item, self_attn_output):