-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement DCN datapipeline, dataset, model, trainer #16
- Loading branch information
Showing
6 changed files
with
300 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import os | ||
|
||
import pandas as pd | ||
from loguru import logger | ||
|
||
from .mf_data_pipeline import MFDataPipeline | ||
|
||
class DCNDatapipeline(MFDataPipeline): | ||
def preprocess(self) -> pd.DataFrame: | ||
''' | ||
output: pivot table (row: user, col: user-specific vector + item set, values: binary preference) | ||
''' | ||
logger.info("start preprocessing...") | ||
# load df | ||
df = self._load_df() | ||
# set num items and num users | ||
self._set_num_items_and_num_users(df) | ||
# negative sampling | ||
if self.cfg.loss_name == 'pointwise': | ||
df = self._negative_sampling(df, self.cfg.neg_times) | ||
|
||
# load item attributes | ||
self.item2attributes = self._load_attributes() | ||
|
||
logger.info("done") | ||
|
||
return df | ||
|
||
def _load_attributes(self): | ||
logger.info("load item2attributes...") | ||
df = pd.read_json(os.path.join(self.cfg.data_dir, 'yelp_item2attributes.json')).transpose() | ||
self.attributes_count = [df.categories.explode().nunique(), df.statecity.nunique()] | ||
|
||
# The item category #0 is reserved for null embedding. | ||
# Pad the category sequence to ensure a fixed input length for all items. | ||
df.categories = self._pad_sequences_in_df(df.categories, df.categories.apply(len).max()) | ||
df.categories = df.categories.apply(lambda x: [y+1 for y in x]) | ||
|
||
return df.transpose().to_dict() | ||
|
||
def _pad_sequences_in_df(self, series, max_len, padding_value=-1): | ||
def pad_sequence(seq, max_len, padding_value): | ||
return seq + [padding_value] * (max_len - len(seq)) if len(seq) < max_len else seq | ||
|
||
padded_sequences = series.apply(lambda x: pad_sequence(x, max_len, padding_value)) | ||
return padded_sequences |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .mf_dataset import MFDataset | ||
|
||
class DCNDataset(MFDataset): | ||
def __init__(self, data, num_items=None): | ||
super().__init__(data, num_items) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from models.base_model import BaseModel | ||
|
||
class DCN(BaseModel): | ||
def __init__(self, cfg, num_users, num_items, attributes_count: list): | ||
super().__init__() | ||
self.user_embedding = nn.Embedding(num_users, cfg.embed_size, dtype=torch.float32) | ||
self.item_embedding = nn.Embedding(num_items, cfg.embed_size, dtype=torch.float32) | ||
self.attributes_embeddings = nn.ModuleList([ | ||
nn.Embedding(count+1, cfg.embed_size, dtype=torch.float32) for count in attributes_count | ||
]) | ||
self.hidden_dims = [(2+len(attributes_count)) * cfg.embed_size] + cfg.hidden_dims | ||
self.cross_dims = [(2+len(attributes_count)) * cfg.embed_size] * cfg.cross_orders | ||
self.deep = self._deep() | ||
self.cross_weights, self.cross_bias = self._cross() | ||
self.output_layer = nn.Linear(self.hidden_dims[-1] + self.cross_dims[-1], 1) | ||
self.device = cfg.device | ||
self._init_weights() | ||
|
||
def _deep(self): | ||
deep = nn.Sequential() | ||
for idx in range(len(self.hidden_dims)-1): # | ||
deep.append(nn.Linear(self.hidden_dims[idx], self.hidden_dims[idx+1])) | ||
deep.append(nn.ReLU()) | ||
return deep | ||
|
||
def _cross(self): | ||
cross_weights = nn.ParameterList([nn.Parameter(torch.rand(dim)) for dim in self.cross_dims]) | ||
cross_bias = nn.ParameterList([nn.Parameter(torch.rand(dim)) for dim in self.cross_dims]) | ||
return cross_weights, cross_bias | ||
|
||
def _init_weights(self): | ||
for child in self.children(): | ||
if isinstance(child, nn.Embedding): | ||
nn.init.xavier_uniform_(child.weight) | ||
elif isinstance(child, nn.Linear): | ||
nn.init.xavier_uniform_(child.weight) | ||
nn.init.uniform_(child.bias) | ||
|
||
def forward(self, user_id, item_id, *attributes): | ||
user_emb = self.user_embedding(user_id) | ||
item_emb = self.item_embedding(item_id) | ||
|
||
attributes_emb = [] | ||
for idx, embedding in enumerate(self.attributes_embeddings): | ||
emb = embedding(attributes[idx]) | ||
if len(attributes[idx].size()) > 1: | ||
emb = torch.mean(emb, dim=1) | ||
attributes_emb.append(emb) | ||
|
||
input_x = torch.cat([user_emb, item_emb] + attributes_emb, dim=1) | ||
input_x = torch.cat([self.deep(input_x), self._forward_cross(input_x)], dim=1) | ||
|
||
return torch.sigmoid(self.output_layer(input_x)) | ||
|
||
def _forward_cross(self, x): | ||
prev_x = x | ||
for weight, bias in zip(self.cross_weights, self.cross_bias): | ||
input_x = torch.einsum('bi,bj->bij', (x, prev_x)) | ||
prev_x = torch.matmul(input_x, weight) + bias + prev_x | ||
return prev_x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch import Tensor | ||
from torch.utils.data import DataLoader | ||
from torch.optim import Optimizer | ||
|
||
from loguru import logger | ||
from omegaconf.dictconfig import DictConfig | ||
|
||
from models.dcn import DCN | ||
from .base_trainer import BaseTrainer | ||
from metric import * | ||
from loss import BPRLoss | ||
|
||
class DCNTrainer(BaseTrainer): | ||
def __init__(self, cfg: DictConfig, num_items: int, num_users: int, item2attributes: dict, attributes_count: list) -> None: | ||
super().__init__(cfg) | ||
self.num_items = num_items | ||
self.num_users = num_users | ||
self.model = DCN(self.cfg, num_users, num_items, attributes_count).to(self.device) | ||
self.optimizer: Optimizer = self._optimizer(self.cfg.optimizer, self.model, self.cfg.lr, self.cfg.weight_decay) | ||
self.loss = self._loss() | ||
self.item2attributes = item2attributes | ||
|
||
def _loss(self): | ||
return BPRLoss() | ||
|
||
def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, valid_eval_data: pd.DataFrame): | ||
logger.info(f"[Trainer] run...") | ||
|
||
best_valid_loss: float = 1e+6 | ||
best_valid_precision_at_k: float = .0 | ||
best_valid_recall_at_k: float = .0 | ||
best_valid_map_at_k: float = .0 | ||
best_valid_ndcg_at_k: float = .0 | ||
best_epoch: int = 0 | ||
endurance: int = 0 | ||
|
||
# train | ||
for epoch in range(self.cfg.epochs): | ||
train_loss: float = self.train(train_dataloader) | ||
valid_loss: float = self.validate(valid_dataloader) | ||
# (valid_precision_at_k, | ||
# valid_recall_at_k, | ||
# valid_map_at_k, | ||
# valid_ndcg_at_k) = self.evaluate(valid_eval_data, 'valid') | ||
logger.info(f'''\n[Trainer] epoch: {epoch} > train loss: {train_loss:.4f} / | ||
valid loss: {valid_loss:.4f} / ''') | ||
# precision@K : {valid_precision_at_k:.4f} / | ||
# Recall@K: {valid_recall_at_k:.4f} / | ||
# MAP@K: {valid_map_at_k:.4f} / | ||
# NDCG@K: {valid_ndcg_at_k:.4f}''') | ||
|
||
# update model | ||
if best_valid_loss > valid_loss: | ||
logger.info(f"[Trainer] update best model...") | ||
best_valid_loss = valid_loss | ||
# best_valid_precision_at_k = valid_precision_at_k | ||
# best_recall_k = valid_recall_at_k | ||
# best_valid_ndcg_at_k = valid_ndcg_at_k | ||
# best_valid_map_at_k = valid_map_at_k | ||
best_epoch = epoch | ||
endurance = 0 | ||
|
||
# TODO: add mlflow | ||
|
||
torch.save(self.model.state_dict(), f'{self.cfg.model_dir}/best_model.pt') | ||
else: | ||
endurance += 1 | ||
if endurance > self.cfg.patience: | ||
logger.info(f"[Trainer] ealry stopping...") | ||
break | ||
|
||
|
||
def train(self, train_dataloader: DataLoader) -> float: | ||
self.model.train() | ||
train_loss = 0 | ||
for data in tqdm(train_dataloader): | ||
user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \ | ||
data['neg_item'].to(self.device) | ||
|
||
# logger.info(f"{type(data['pos_item'][0])}, {data['pos_item'][0]}") | ||
pos_item_categories = torch.tensor([self.item2attributes[item.item()]['categories'] for item in data['pos_item']]).to(self.device) | ||
pos_item_statecity = torch.tensor([self.item2attributes[item.item()]['statecity'] for item in data['pos_item']]).to(self.device) | ||
neg_item_categories = torch.tensor([self.item2attributes[item.item()]['categories'] for item in data['neg_item']]).to(self.device) | ||
neg_item_statecity = torch.tensor([self.item2attributes[item.item()]['statecity'] for item in data['neg_item']]).to(self.device) | ||
|
||
pos_pred = self.model(user_id, pos_item, pos_item_categories, pos_item_statecity) | ||
neg_pred = self.model(user_id, neg_item, neg_item_categories, neg_item_statecity) | ||
|
||
self.optimizer.zero_grad() | ||
loss = self.loss(pos_pred, neg_pred) | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
train_loss += loss.item() | ||
|
||
return train_loss | ||
|
||
def validate(self, valid_dataloader: DataLoader) -> tuple[float]: | ||
self.model.eval() | ||
valid_loss = 0 | ||
for data in tqdm(valid_dataloader): | ||
user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \ | ||
data['neg_item'].to(self.device) | ||
pos_item_categories = torch.tensor([self.item2attributes[item.item()]['categories'] for item in data['pos_item']]).to(self.device) | ||
pos_item_statecity = torch.tensor([self.item2attributes[item.item()]['statecity'] for item in data['pos_item']]).to(self.device) | ||
neg_item_categories = torch.tensor([self.item2attributes[item.item()]['categories'] for item in data['neg_item']]).to(self.device) | ||
neg_item_statecity = torch.tensor([self.item2attributes[item.item()]['statecity'] for item in data['neg_item']]).to(self.device) | ||
|
||
pos_pred = self.model(user_id, pos_item, pos_item_categories, pos_item_statecity) | ||
neg_pred = self.model(user_id, neg_item, neg_item_categories, neg_item_statecity) | ||
|
||
loss = self.loss(pos_pred, neg_pred) | ||
|
||
valid_loss += loss.item() | ||
|
||
return valid_loss | ||
|
||
def evaluate(self, eval_data: pd.DataFrame, mode='valid') -> tuple: | ||
self.model.eval() | ||
actual, predicted = [], [] | ||
item_input = torch.tensor([item_id for item_id in range(self.num_items)]).to(self.device) | ||
item_categories = torch.tensor([self.item2attributes[item]['categories'] for item in range(self.num_items)]).to(self.device) | ||
item_statecity = torch.tensor([self.item2attributes[item]['statecity'] for item in range(self.num_items)]).to(self.device) | ||
|
||
for user_id, row in tqdm(eval_data.iterrows(), total=eval_data.shape[0]): | ||
pred = self.model(torch.tensor([user_id,]*self.num_items).to(self.device), item_input, item_categories, item_statecity) | ||
batch_predicted = \ | ||
self._generate_top_k_recommendation(pred, row['mask_items']) | ||
actual.append(row['pos_items']) | ||
predicted.append(batch_predicted) | ||
|
||
test_precision_at_k = precision_at_k(actual, predicted, self.cfg.top_n) | ||
test_recall_at_k = recall_at_k(actual, predicted, self.cfg.top_n) | ||
test_map_at_k = map_at_k(actual, predicted, self.cfg.top_n) | ||
test_ndcg_at_k = ndcg_at_k(actual, predicted, self.cfg.top_n) | ||
|
||
if mode == 'test': | ||
logger.info(f'''\n[Trainer] Test > | ||
precision@{self.cfg.top_n} : {test_precision_at_k:.4f} / | ||
Recall@{self.cfg.top_n}: {test_recall_at_k:.4f} / | ||
MAP@{self.cfg.top_n}: {test_map_at_k:.4f} / | ||
NDCG@{self.cfg.top_n}: {test_ndcg_at_k:.4f}''') | ||
|
||
return (test_precision_at_k, | ||
test_recall_at_k, | ||
test_map_at_k, | ||
test_ndcg_at_k) |