Skip to content

Commit

Permalink
refactor: refactor wandb logging and implements log_metric decorator #9
Browse files Browse the repository at this point in the history
  • Loading branch information
GangBean committed May 12, 2024
1 parent aa713f6 commit 089c54b
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 39 deletions.
24 changes: 20 additions & 4 deletions configs/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,30 @@ submit_dir: outputs/submissions
data_dir: data/
log_dir: logs/

# mlflow config
exp: True # True/ False
# wandb config
wandb: True # True/ False
project: YelpRecommendation
notes: "..."
tags: [test, yelp, cdae]
# sweep: True
# sweep_cfg:
# method: random
# name: sweep
# metric:
# goal: maximize
# name: val_acc
# parameters:
# batch_size:
# values: [16, 32, 64]
# epochs:
# values: [5, 10, 15]
# lr:
# min: 0.0001
# max: 0.1

# train config
device: cpu # cpu
epochs: 10
device: cuda # cpu
epochs: 1
batch_size: 32
lr: 0.001
optimizer: adamw
Expand Down
34 changes: 29 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,38 @@
import hydra
from omegaconf import OmegaConf

from data.datasets.cdae_data_pipeline import CDAEDataPipeline
from data.datasets.cdae_dataset import CDAEDataset
from trainers.cdae_trainer import CDAETrainer
from utils import set_seed
import pytz
from datetime import datetime

import wandb
import torch
from torch.utils.data import DataLoader

from loguru import logger

from data.datasets.cdae_data_pipeline import CDAEDataPipeline
from data.datasets.cdae_dataset import CDAEDataset
from trainers.cdae_trainer import CDAETrainer
from utils import set_seed


@hydra.main(version_base=None, config_path="configs", config_name="train_config")
def main(cfg: OmegaConf):

# wandb init
if cfg.wandb:
logger.info("[wandb] init...")
run_time: str = datetime.now().astimezone(pytz.timezone('Asia/Seoul')).strftime('%Y-%m-%d %H:%M:%S')
run_name: str = f'[{cfg.model_name}]{run_time}'

wandb.init(
project=cfg.project,
name=run_name,
config=dict(cfg),
notes=cfg.notes,
tags=cfg.tags,
)

logger.info(f"set seed as {cfg.seed}...")
set_seed(cfg.seed)

Expand All @@ -31,7 +51,7 @@ def main(cfg: OmegaConf):
else:
raise ValueError()

# pos_samples 를 이용한 negative sample을 수행해줘야 함
# set dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle)
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle)
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size)
Expand All @@ -42,5 +62,9 @@ def main(cfg: OmegaConf):
trainer.load_best_model()
trainer.evaluate(test_dataloader)

if cfg.wandb:
logger.info("[wandb] finish...")
wandb.finish()

if __name__ == '__main__':
main()
36 changes: 12 additions & 24 deletions trainers/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import os
import pytz
from datetime import datetime

import wandb

import torch
Expand All @@ -13,6 +10,8 @@
from omegaconf.dictconfig import DictConfig
from abc import ABC, abstractmethod

from utils import log_metric

class BaseTrainer(ABC):
def __init__(self, cfg: DictConfig) -> None:
self.cfg: DictConfig = cfg
Expand Down Expand Up @@ -53,16 +52,6 @@ def _loss(self, loss_name: str):
def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader):
logger.info(f"[Trainer] run...")

# wandb init
run_time: str = datetime.now().astimezone(pytz.timezone('Asia/Seoul')).strftime('%Y-%m-%d %H:%M:%S')
run_name: str = f'[{self.cfg.model_name}]{run_time}'

wandb.init(
project="yelp",
name=run_name,
config=dict(self.cfg),
)

best_valid_loss: float = 1e+6
best_valid_precision_at_k: float = .0
best_valid_recall_at_k: float = .0
Expand All @@ -86,14 +75,15 @@ def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader):
MAP@K: {valid_map_at_k:.4f} /
NDCG@K: {valid_ndcg_at_k:.4f}''')

wandb.log({
'train_loss': train_loss,
'valid_loss': valid_loss,
'valid_Precision@K': valid_precision_at_k,
'valid_Recall@K': valid_recall_at_k,
'valid_MAP@K': valid_map_at_k,
'valid_NDCG@K': valid_ndcg_at_k,
})
if self.cfg.wandb:
wandb.log({
'train_loss': train_loss,
'valid_loss': valid_loss,
'valid_Precision@K': valid_precision_at_k,
'valid_Recall@K': valid_recall_at_k,
'valid_MAP@K': valid_map_at_k,
'valid_NDCG@K': valid_ndcg_at_k,
})

# update model
if best_valid_loss > valid_loss:
Expand All @@ -113,8 +103,6 @@ def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader):
logger.info(f"[Trainer] ealry stopping...")
break

wandb.finish()

@abstractmethod
def train(self, train_dataloader: DataLoader) -> float:
pass
Expand All @@ -124,7 +112,7 @@ def validate(self, valid_dataloader: DataLoader) -> tuple[float]:
pass

@abstractmethod
def evaluate(self, test_dataloader: DataLoader) -> None:
def evaluate(self, test_dataloader: DataLoader) -> tuple[float]:
pass

def load_best_model(self):
Expand Down
13 changes: 10 additions & 3 deletions trainers/cdae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from omegaconf.dictconfig import DictConfig

from models.cdae import CDAE
from utils import log_metric
from .base_trainer import BaseTrainer
from metric import *

Expand Down Expand Up @@ -63,8 +64,9 @@ def validate(self, valid_dataloader: DataLoader) -> tuple[float]:
valid_recall_at_k,
valid_map_at_k,
valid_ndcg_at_k)

def evaluate(self, test_dataloader: DataLoader) -> None:

@log_metric
def evaluate(self, test_dataloader: DataLoader) -> tuple[float]:
self.model.eval()
actual, predicted = [], []
for data in tqdm(test_dataloader):
Expand All @@ -90,7 +92,12 @@ def evaluate(self, test_dataloader: DataLoader) -> None:
Recall@{self.cfg.top_n}: {test_recall_at_k:.4f} /
MAP@{self.cfg.top_n}: {test_map_at_k:.4f} /
NDCG@{self.cfg.top_n}: {test_ndcg_at_k:.4f}''')


return (test_precision_at_k,
test_recall_at_k,
test_map_at_k,
test_ndcg_at_k)

def _generate_target_and_top_k_recommendation(self, pred: Tensor, actual_mask, pred_mask) -> tuple[list]:
actual, predicted = [], []

Expand Down
26 changes: 23 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import os
import random

import numpy as np
from typing import Callable

import wandb

import torch
from torch.utils.data import DataLoader

from loguru import logger
from functools import wraps


def set_seed(seed):
def set_seed(seed: int):
logger.info("seed setting...")
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
Expand All @@ -19,3 +22,20 @@ def set_seed(seed):
# some cudnn methods can be random even after fixing the seed
# unless you tell it to be deterministic
torch.backends.cudnn.deterministic = True


def log_metric(func: Callable):
@wraps(func)
def log_wandb(*args, **kwargs):
precision_at_k, recall_at_k, map_at_k, ndcg_at_k = func(*args, **kwargs)

if wandb.run is not None: # validate wandb initialization
logger.info("[Trainer] logging test results...")
wandb.log({
'test_Precision@K': precision_at_k,
'test_Recall@K': recall_at_k,
'test_MAP@K': map_at_k,
'test_NDCG@K': ndcg_at_k,
})
return (precision_at_k, recall_at_k, map_at_k, ndcg_at_k)
return log_wandb

0 comments on commit 089c54b

Please sign in to comment.