Skip to content

Commit

Permalink
feat: implement DCNTrainer evaluate method #16
Browse files Browse the repository at this point in the history
  • Loading branch information
GangBean committed May 22, 2024
1 parent e1ab728 commit 8412151
Showing 1 changed file with 63 additions and 23 deletions.
86 changes: 63 additions & 23 deletions trainers/dcn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,25 @@ def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, valid_
for epoch in range(self.cfg.epochs):
train_loss: float = self.train(train_dataloader)
valid_loss: float = self.validate(valid_dataloader)
# (valid_precision_at_k,
# valid_recall_at_k,
# valid_map_at_k,
# valid_ndcg_at_k) = self.evaluate(valid_eval_data, 'valid')
(valid_precision_at_k,
valid_recall_at_k,
valid_map_at_k,
valid_ndcg_at_k) = self.evaluate(valid_eval_data, 'valid')
logger.info(f'''\n[Trainer] epoch: {epoch} > train loss: {train_loss:.4f} /
valid loss: {valid_loss:.4f} / ''')
# precision@K : {valid_precision_at_k:.4f} /
# Recall@K: {valid_recall_at_k:.4f} /
# MAP@K: {valid_map_at_k:.4f} /
# NDCG@K: {valid_ndcg_at_k:.4f}''')
valid loss: {valid_loss:.4f} /
precision@K : {valid_precision_at_k:.4f} /
Recall@K: {valid_recall_at_k:.4f} /
MAP@K: {valid_map_at_k:.4f} /
NDCG@K: {valid_ndcg_at_k:.4f}''')

# update model
if best_valid_loss > valid_loss:
logger.info(f"[Trainer] update best model...")
best_valid_loss = valid_loss
# best_valid_precision_at_k = valid_precision_at_k
# best_recall_k = valid_recall_at_k
# best_valid_ndcg_at_k = valid_ndcg_at_k
# best_valid_map_at_k = valid_map_at_k
best_valid_precision_at_k = valid_precision_at_k
best_recall_k = valid_recall_at_k
best_valid_ndcg_at_k = valid_ndcg_at_k
best_valid_map_at_k = valid_map_at_k
best_epoch = epoch
endurance = 0

Expand All @@ -82,13 +82,18 @@ def train(self, train_dataloader: DataLoader) -> float:
for data in tqdm(train_dataloader):
user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \
data['neg_item'].to(self.device)

# logger.info(f"{type(data['pos_item'][0])}, {data['pos_item'][0]}")
# pos_item_categories, pos_item_statecity, neg_item_categories, neg_item_statecity = \
# data['pos_item_categories'].to(self.device), data['pos_item_statecity'].to(self.device), \
# data['neg_item_categories'].to(self.device), data['neg_item_statecity'].to(self.device)
pos_item_categories = torch.tensor([self.item2attributes[item.item()]['categories'] for item in data['pos_item']]).to(self.device)
pos_item_statecity = torch.tensor([self.item2attributes[item.item()]['statecity'] for item in data['pos_item']]).to(self.device)
neg_item_categories = torch.tensor([self.item2attributes[item.item()]['categories'] for item in data['neg_item']]).to(self.device)
neg_item_statecity = torch.tensor([self.item2attributes[item.item()]['statecity'] for item in data['neg_item']]).to(self.device)

# logger.info(f"pos_categories: {torch.equal(pos_item_categories, torch.tensor([self.item2attributes[item.item()]['categories'] for item in data['pos_item']]).to(self.device))}")
# logger.info(f"pos_statecity: {torch.equal(pos_item_statecity, torch.tensor([self.item2attributes[item.item()]['statecity'] for item in data['pos_item']]).to(self.device))}")
# logger.info(f"neg_categories: {torch.equal(neg_item_categories, torch.tensor([self.item2attributes[item.item()]['categories'] for item in data['neg_item']]).to(self.device))}")
# logger.info(f"neg_statecity: {torch.equal(neg_item_statecity, torch.tensor([self.item2attributes[item.item()]['statecity'] for item in data['neg_item']]).to(self.device))}")
pos_pred = self.model(user_id, pos_item, pos_item_categories, pos_item_statecity)
neg_pred = self.model(user_id, neg_item, neg_item_categories, neg_item_statecity)

Expand All @@ -107,6 +112,9 @@ def validate(self, valid_dataloader: DataLoader) -> tuple[float]:
for data in tqdm(valid_dataloader):
user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \
data['neg_item'].to(self.device)
# pos_item_categories, pos_item_statecity, neg_item_categories, neg_item_statecity = \
# data['pos_item_categories'].to(self.device), data['pos_item_statecity'].to(self.device), \
# data['neg_item_categories'].to(self.device), data['neg_item_statecity'].to(self.device)
pos_item_categories = torch.tensor([self.item2attributes[item.item()]['categories'] for item in data['pos_item']]).to(self.device)
pos_item_statecity = torch.tensor([self.item2attributes[item.item()]['statecity'] for item in data['pos_item']]).to(self.device)
neg_item_categories = torch.tensor([self.item2attributes[item.item()]['categories'] for item in data['neg_item']]).to(self.device)
Expand All @@ -124,14 +132,30 @@ def validate(self, valid_dataloader: DataLoader) -> tuple[float]:
def evaluate(self, eval_data: pd.DataFrame, mode='valid') -> tuple:
self.model.eval()
actual, predicted = [], []
item_input = torch.tensor([item_id for item_id in range(self.num_items)]).to(self.device)
item_categories = torch.tensor([self.item2attributes[item]['categories'] for item in range(self.num_items)]).to(self.device)
item_statecity = torch.tensor([self.item2attributes[item]['statecity'] for item in range(self.num_items)]).to(self.device)

for user_id, row in tqdm(eval_data.iterrows(), total=eval_data.shape[0]):
pred = self.model(torch.tensor([user_id,]*self.num_items).to(self.device), item_input, item_categories, item_statecity)
logger.info(f"Before inference #0: {torch.cuda.memory_allocated(self.device)} allocated and {torch.cuda.memory_reserved(self.device)} reserved")
item_input = torch.tensor([item_id for item_id in range(self.num_items)], dtype=torch.int32).to(self.device)
# item_categories = torch.tensor([self.item2attributes[item]['categories'] for item in range(self.num_items)], dtype=torch.int32).to(self.device)
# item_statecity = torch.tensor([self.item2attributes[item]['statecity'] for item in range(self.num_items)], dtype=torch.int32).to(self.device)
chunk_size = 32 # self.cfg.batch_size
# logger.info(f"Before inference #1: {torch.cuda.memory_allocated(self.device)} allocated and {torch.cuda.memory_reserved(self.device)} reserved")
torch.cuda.empty_cache()
# logger.info(f"Before inference #2: {torch.cuda.memory_allocated(self.device)} allocated and {torch.cuda.memory_reserved(self.device)} reserved")
for user_id, row in tqdm(eval_data[:10].iterrows(), total=eval_data.shape[0]):
pred = []
for idx in range(0, eval_data.shape[0], chunk_size):
chunk_item_input = item_input[idx:idx+chunk_size]
chunk_item_categories = torch.tensor([self.item2attributes[item]['categories'] for item in range(idx, min(self.num_items, idx+chunk_size))], dtype=torch.int32).to(self.device)
chunk_item_statecity = torch.tensor([self.item2attributes[item]['statecity'] for item in range(idx, min(self.num_items, idx+chunk_size))], dtype=torch.int32).to(self.device)
# print(f"{chunk_size}, {chunk_item_input.size()}, {chunk_item_categories.size()}, {chunk_item_statecity.size()}")
# logger.info(f"{torch.cuda.memory_allocated(self.device)} allocated and {torch.cuda.memory_reserved(self.device)} reserved")

chunk_pred: Tensor = self.model(torch.tensor([user_id,]*len(chunk_item_input), dtype=torch.int32).to(self.device), chunk_item_input, chunk_item_categories, chunk_item_statecity)
pred.extend(chunk_pred.detach().cpu().numpy())

# torch.cuda.empty_cache()
# pred = self.model(torch.tensor([user_id,]*self.num_items).to(self.device), item_input, item_categories, item_statecity)
batch_predicted = \
self._generate_top_k_recommendation(pred, row['mask_items'])
self._generate_top_k_recommendation(np.array(pred).reshape(-1), row['mask_items'])
actual.append(row['pos_items'])
predicted.append(batch_predicted)

Expand All @@ -150,4 +174,20 @@ def evaluate(self, eval_data: pd.DataFrame, mode='valid') -> tuple:
return (test_precision_at_k,
test_recall_at_k,
test_map_at_k,
test_ndcg_at_k)
test_ndcg_at_k)

def _generate_top_k_recommendation(self, pred: np.ndarray, mask_items) -> tuple[list]:
# mask to train items
# pred = pred.cpu().detach().numpy()
pred[mask_items] = -3.40282e+38 # finfo(float32)

# find the largest topK item indexes by user
topn_index = np.argpartition(pred, -self.cfg.top_n)[-self.cfg.top_n:]
# take probs from predictions using above indexes
topn_prob = np.take_along_axis(pred, topn_index, axis=0)
# sort topK probs and find their indexes
sorted_indices = np.argsort(-topn_prob)
# apply sorted indexes to item indexes to get sorted topK item indexes by user
topn_index_sorted = np.take_along_axis(topn_index, sorted_indices, axis=0)

return topn_index_sorted

0 comments on commit 8412151

Please sign in to comment.