Skip to content

Commit

Permalink
feat: implements wandb logging #9
Browse files Browse the repository at this point in the history
  • Loading branch information
GangBean committed May 10, 2024
1 parent 93ad5fb commit aa713f6
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions trainers/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import pytz
from datetime import datetime

import numpy as np
import wandb

import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torch.nn import Module, BCELoss
from torch.optim import Optimizer, Adam, AdamW
Expand Down Expand Up @@ -52,6 +53,16 @@ def _loss(self, loss_name: str):
def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader):
logger.info(f"[Trainer] run...")

# wandb init
run_time: str = datetime.now().astimezone(pytz.timezone('Asia/Seoul')).strftime('%Y-%m-%d %H:%M:%S')
run_name: str = f'[{self.cfg.model_name}]{run_time}'

wandb.init(
project="yelp",
name=run_name,
config=dict(self.cfg),
)

best_valid_loss: float = 1e+6
best_valid_precision_at_k: float = .0
best_valid_recall_at_k: float = .0
Expand All @@ -75,6 +86,15 @@ def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader):
MAP@K: {valid_map_at_k:.4f} /
NDCG@K: {valid_ndcg_at_k:.4f}''')

wandb.log({
'train_loss': train_loss,
'valid_loss': valid_loss,
'valid_Precision@K': valid_precision_at_k,
'valid_Recall@K': valid_recall_at_k,
'valid_MAP@K': valid_map_at_k,
'valid_NDCG@K': valid_ndcg_at_k,
})

# update model
if best_valid_loss > valid_loss:
logger.info(f"[Trainer] update best model...")
Expand All @@ -86,14 +106,14 @@ def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader):
best_epoch = epoch
endurance = 0

# TODO: add mlflow

torch.save(self.model.state_dict(), f'{self.cfg.model_dir}/best_model.pt')
else:
endurance += 1
if endurance > self.cfg.patience:
logger.info(f"[Trainer] ealry stopping...")
break

wandb.finish()

@abstractmethod
def train(self, train_dataloader: DataLoader) -> float:
Expand Down

0 comments on commit aa713f6

Please sign in to comment.