Skip to content

Commit

Permalink
feat: implements aap #21
Browse files Browse the repository at this point in the history
  • Loading branch information
GangBean committed Jul 23, 2024
1 parent facb224 commit dafe357
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 10 deletions.
8 changes: 6 additions & 2 deletions configs/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ notes: "..."
tags: [sweep, yelp, cdae, hyper-parameter, model-structure]

# train config
device: cpu
epochs: 100
device: cuda # cpu
epochs: 5
batch_size: 32
lr: 0.0001
optimizer: adam # adamw
Expand Down Expand Up @@ -49,3 +49,7 @@ model:
max_seq_len: 50
num_heads: 2
num_blocks: 2
pretrain: True # False
pretrain_epochs: 1 # 100
mask_portion: 0.2
iter_nums: 200
12 changes: 11 additions & 1 deletion models/s3rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, cfg, num_users, num_items, attributes_count):
self.cfg = cfg
# self.user_embedding = nn.Embedding(num_users, cfg.embed_size, dtype=torch.float32)
self.item_embedding = nn.Embedding(num_items + 1, self.cfg.embed_size, dtype=torch.float32)
# self.attribute_embedding = nn.Embedding(attributes_count, self.cfg.embed_size, dtype=torch.float32)
self.attribute_embedding = nn.Embedding(attributes_count, self.cfg.embed_size, dtype=torch.float32)
self.positional_encoding = nn.Parameter(torch.rand(self.cfg.max_seq_len, self.cfg.embed_size))

# self.query = nn.ModuleList([nn.Linear(self.cfg.embed_size / self.num_heads) for _ in range(self.cfg.num_heads)])
Expand All @@ -20,8 +20,11 @@ def __init__(self, cfg, num_users, num_items, attributes_count):
self.ffn1s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])
self.ffn2s = nn.ModuleList([nn.Linear(self.cfg.embed_size, self.cfg.embed_size) for _ in range(self.cfg.num_blocks)])
self.multihead_attns = nn.ModuleList([nn.MultiheadAttention(self.cfg.embed_size, self.cfg.num_heads) for _ in range(self.cfg.num_blocks)])
self.aap_weight = nn.Linear(self.cfg.embed_size, self.cfg.embed_size, bias=False)

self._init_weights()


def _init_weights(self):
for child in self.children():
if isinstance(child, nn.Embedding):
Expand All @@ -30,6 +33,8 @@ def _init_weights(self):
for sub_child in child.children():
if not isinstance(sub_child, nn.MultiheadAttention):
nn.init.xavier_uniform_(sub_child.weight)
elif isinstance(child, nn.Linear):
nn.init.xavier_uniform_(child.weight)
else:
logger.info(f"other type: {child} / {type(child)}")

Expand Down Expand Up @@ -60,3 +65,8 @@ def evaluate(self, X, pos_item, neg_items):
self.item_embedding(neg_items[:,i]), X[:, -1]).view(neg_items.size(0), -1) for i in range(neg_items.size(-1))]
neg_preds = torch.concat(neg_preds, dim=1)
return pos_pred, neg_preds

def aap(self, items):
# item
item_embeddings = self.item_embedding(items)
return torch.matmul(self.aap_weight(item_embeddings), self.attribute_embedding.weight.T) # (batch, embed_size) * (attribute_size, embed_size) (batch, attribute_size)
18 changes: 12 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from trainers.cdae_trainer import CDAETrainer
from trainers.dcn_trainer import DCNTrainer
from trainers.mf_trainer import MFTrainer
from trainers.s3rec_trainer import S3RecTrainer
from trainers.s3rec_trainer import S3RecTrainer, S3RecPreTrainer
from utils import set_seed


Expand Down Expand Up @@ -100,11 +100,17 @@ def train(cfg, args):#train_dataset, valid_dataset, test_dataset, model_info):
trainer.load_best_model()
trainer.evaluate(args.test_eval_data, 'test')
elif cfg.model_name in ('S3Rec',):
trainer = S3RecTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'],
args.data_pipeline.item2attributes, args.data_pipeline.attributes_count)
trainer.run(train_dataloader, valid_dataloader)
trainer.load_best_model()
trainer.evaluate(test_dataloader)
if cfg.pretrain:
trainer = S3RecPreTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'],
args.data_pipeline.item2attributes, args.data_pipeline.attributes_count)
trainer.pretrain(args.train_dataset, args.valid_dataset)
trainer.load_best_model()
else:
trainer = S3RecTrainer(cfg, args.model_info['num_items'], args.model_info['num_users'],
args.data_pipeline.item2attributes, args.data_pipeline.attributes_count)
trainer.run(train_dataloader, valid_dataloader)
trainer.load_best_model()
trainer.evaluate(test_dataloader)

def unpack_model(cfg: OmegaConf) -> OmegaConf:
if cfg.model_name not in cfg.model:
Expand Down
170 changes: 169 additions & 1 deletion trainers/s3rec_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,181 @@
from omegaconf.dictconfig import DictConfig
import wandb

from models.cdae import CDAE
from models.s3rec import S3Rec
from utils import log_metric
from .base_trainer import BaseTrainer
from metric import *
from loss import BPRLoss

class S3RecPreTrainer:
def __init__(self, cfg: DictConfig, num_items: int, num_users: int, item2attributes, attributes_count: int) -> None:
self.cfg = cfg
self.device = self.cfg.device
self.model = S3Rec(self.cfg, num_items, num_users, attributes_count).to(self.device)
self.optimizer: Optimizer = self._optimizer(self.cfg.optimizer, self.model, self.cfg.lr)
self.loss = self._loss()
self.item2attribute = item2attributes
self.num_items = num_items
self.num_users = num_users
self.attributes_count = attributes_count

def _loss(self):
# AAP + MIP + MAP + SP
return nn.BCEWithLogitsLoss()

def _optimizer(self, optimizer_name: str, model: nn.Module, learning_rate: float, weight_decay: float=0) -> Optimizer:
if optimizer_name.lower() == 'adam':
return torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
elif optimizer_name.lower() == 'adamw':
return torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
elif optimizer_name.lower() == 'sgd':
return torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
else:
logger.error(f"Optimizer Not Exists: {optimizer_name}")
raise NotImplementedError(f"Optimizer Not Exists: {optimizer_name}")

def _is_surpass_best_metric(self, **metric) -> bool:
(valid_loss,
) = metric['current']

(best_valid_loss,
) = metric['best']

if self.cfg.best_metric == 'loss':
return valid_loss < best_valid_loss
else:
return False

def pretrain(self, train_dataset, valid_dataset):
logger.info(f"[Trainer] run...")

best_valid_loss: float = 1e+6
best_epoch: int = 0
endurance: int = 0

# train
for epoch in range(self.cfg.pretrain_epochs):
train_loss: float = self.train(torch.tensor([i for i in range(1, self.num_items+1)], dtype=torch.int32).to(self.device), train_dataset)
valid_loss = self.validate(torch.tensor([i for i in range(1, self.num_items+1)], dtype=torch.int32).to(self.device), valid_dataset)
logger.info(f'''\n[Trainer] epoch: {epoch} > train loss: {train_loss:.4f} /
valid loss: {valid_loss:.4f}''')

if self.cfg.wandb:
wandb.log({
'train_loss': train_loss,
'valid_loss': valid_loss,
})

# update model
if self._is_surpass_best_metric(
current=(valid_loss,),
best=(best_valid_loss,)):

logger.info(f"[Trainer] update best model...")
best_valid_loss = valid_loss
best_epoch = epoch
endurance = 0

torch.save(self.model.state_dict(), f'{self.cfg.model_dir}/best_pretrain_model.pt')
else:
endurance += 1
if endurance > self.cfg.patience:
logger.info(f"[Trainer] ealry stopping...")
break

def train(self, item_datasets, sequence_datasets) -> float:
self.model.train()
train_loss = 0

for iter_num in tqdm(range(self.cfg.iter_nums)): # sequence
item_chunk_size = self.num_items // self.cfg.iter_nums
items = item_datasets[item_chunk_size * iter_num: item_chunk_size * (iter_num + 1)]

sequence_chunk_size = self.num_users // self.cfg.iter_nums
# sequences = sequence_datasets[sequence_chunk_size * iter_num: sequence_chunk_size * (iter_num + 1)]

# AAP: item + atrributes
pred = self.model.aap(items) # (item_chunk_size, attributes_count)
actual = torch.Tensor([[1 if attriute in self.item2attribute[item.item()] else 0 for attriute in range(self.attributes_count)] for item in items]).to(self.device) # (item_chunk_size, attributes_count)
aap_loss = nn.functional.binary_cross_entropy_with_logits(pred, actual)

# MIP: sequence + item
# mask
# def random_mask(sequence):
# # mask = torch.Tensor([0] * sequence.size(0))
# non_zero_count = torch.nonzero(sequence, as_tuple=True)[0].size(0)
# mask_indices = torch.randint(sequence.size(0) - non_zero_count, sequence.size(0), size=1)
# # mask[mask_indices] = 1
# return mask_indices

# masks = torch.Tensor([random_mask(sequence) for sequence in sequences]) # ()
# masked_sequences = sequences * (1 - masks)
# pred = self.model.mip(masked_sequences, ) # (sequence_chunk_size, mask_count, sequence_len) item idx pred
# nn.functional.binary_cross_entropy
# # MAP: sequence + attributes
# map_loss = self.loss()
# # SP: sequence + segment
# sp_loss = self.loss()
# # X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device)
# # pos_pred, neg_pred = self.model(X, pos_item, neg_item)

self.optimizer.zero_grad()
# loss = self.loss(pos_pred, neg_pred)
loss = aap_loss # + mip_loss + map_loss + sp_loss
loss.backward()
self.optimizer.step()

train_loss += loss.item()

return train_loss

def validate(self, item_datasets, sequence_datasets) -> float:
self.model.eval()
valid_loss = 0

for iter_num in tqdm(range(self.cfg.iter_nums)): # sequence
item_chunk_size = self.num_items // self.cfg.iter_nums
items = item_datasets[item_chunk_size * iter_num: item_chunk_size * (iter_num + 1)]

sequence_chunk_size = self.num_users // self.cfg.iter_nums
# sequences = sequence_datasets[sequence_chunk_size * iter_num: sequence_chunk_size * (iter_num + 1)]

# AAP: item + atrributes
pred = self.model.aap(items) # (item_chunk_size, attributes_count)
actual = torch.Tensor([[1 if attriute in self.item2attribute[item.item()] else 0 for attriute in range(self.attributes_count)] for item in items]).to(self.device) # (item_chunk_size, attributes_count)
aap_loss = nn.functional.binary_cross_entropy_with_logits(pred, actual)

# MIP: sequence + item
# mask
# def random_mask(sequence):
# # mask = torch.Tensor([0] * sequence.size(0))
# non_zero_count = torch.nonzero(sequence, as_tuple=True)[0].size(0)
# mask_indices = torch.randint(sequence.size(0) - non_zero_count, sequence.size(0), size=1)
# # mask[mask_indices] = 1
# return mask_indices

# masks = torch.Tensor([random_mask(sequence) for sequence in sequences]) # ()
# masked_sequences = sequences * (1 - masks)
# pred = self.model.mip(masked_sequences, ) # (sequence_chunk_size, sequence_len) item idx pred
# nn.functional.binary_cross_entropy
# # MAP: sequence + attributes
# map_loss = self.loss()
# # SP: sequence + segment
# sp_loss = self.loss()
# # X, pos_item, neg_item = data['X'].to(self.device), data['pos_item'].to(self.device), data['neg_item'].to(self.device)
# # pos_pred, neg_pred = self.model(X, pos_item, neg_item)

# loss = self.loss(pos_pred, neg_pred)
loss = aap_loss # + mip_loss + map_loss + sp_loss

valid_loss += loss.item()

return valid_loss

def load_best_model(self):
logger.info(f"[Trainer] Load best model...")
self.model.load_state_dict(torch.load(f'{self.cfg.model_dir}/best_pretrain_model.pt'))

class S3RecTrainer(BaseTrainer):
def __init__(self, cfg: DictConfig, num_items: int, num_users: int, item2attributes, attributes_count: int) -> None:
super().__init__(cfg)
Expand Down

0 comments on commit dafe357

Please sign in to comment.