diff --git a/configs/train_config.yaml b/configs/train_config.yaml index c44698f..7ed19c1 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -6,14 +6,30 @@ submit_dir: outputs/submissions data_dir: data/ log_dir: logs/ -# mlflow config -exp: True # True/ False +# wandb config +wandb: True # True/ False project: YelpRecommendation notes: "..." +tags: [test, yelp, cdae] +# sweep: True +# sweep_cfg: +# method: random +# name: sweep +# metric: +# goal: maximize +# name: val_acc +# parameters: +# batch_size: +# values: [16, 32, 64] +# epochs: +# values: [5, 10, 15] +# lr: +# min: 0.0001 +# max: 0.1 # train config -device: cpu # cpu -epochs: 10 +device: cuda # cpu +epochs: 1 batch_size: 32 lr: 0.001 optimizer: adamw diff --git a/train.py b/train.py index 079f3c7..e43b724 100644 --- a/train.py +++ b/train.py @@ -1,18 +1,38 @@ import hydra from omegaconf import OmegaConf -from data.datasets.cdae_data_pipeline import CDAEDataPipeline -from data.datasets.cdae_dataset import CDAEDataset -from trainers.cdae_trainer import CDAETrainer -from utils import set_seed +import pytz +from datetime import datetime +import wandb import torch from torch.utils.data import DataLoader from loguru import logger +from data.datasets.cdae_data_pipeline import CDAEDataPipeline +from data.datasets.cdae_dataset import CDAEDataset +from trainers.cdae_trainer import CDAETrainer +from utils import set_seed + + @hydra.main(version_base=None, config_path="configs", config_name="train_config") def main(cfg: OmegaConf): + + # wandb init + if cfg.wandb: + logger.info("[wandb] init...") + run_time: str = datetime.now().astimezone(pytz.timezone('Asia/Seoul')).strftime('%Y-%m-%d %H:%M:%S') + run_name: str = f'[{cfg.model_name}]{run_time}' + + wandb.init( + project=cfg.project, + name=run_name, + config=dict(cfg), + notes=cfg.notes, + tags=cfg.tags, + ) + logger.info(f"set seed as {cfg.seed}...") set_seed(cfg.seed) @@ -31,7 +51,7 @@ def main(cfg: OmegaConf): else: raise ValueError() - # pos_samples 를 이용한 negative sample을 수행해줘야 함 + # set dataloaders train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size) @@ -42,5 +62,9 @@ def main(cfg: OmegaConf): trainer.load_best_model() trainer.evaluate(test_dataloader) + if cfg.wandb: + logger.info("[wandb] finish...") + wandb.finish() + if __name__ == '__main__': main() diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index 0cf96dc..e540475 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -1,7 +1,4 @@ import os -import pytz -from datetime import datetime - import wandb import torch @@ -13,6 +10,8 @@ from omegaconf.dictconfig import DictConfig from abc import ABC, abstractmethod +from utils import log_metric + class BaseTrainer(ABC): def __init__(self, cfg: DictConfig) -> None: self.cfg: DictConfig = cfg @@ -53,16 +52,6 @@ def _loss(self, loss_name: str): def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader): logger.info(f"[Trainer] run...") - # wandb init - run_time: str = datetime.now().astimezone(pytz.timezone('Asia/Seoul')).strftime('%Y-%m-%d %H:%M:%S') - run_name: str = f'[{self.cfg.model_name}]{run_time}' - - wandb.init( - project="yelp", - name=run_name, - config=dict(self.cfg), - ) - best_valid_loss: float = 1e+6 best_valid_precision_at_k: float = .0 best_valid_recall_at_k: float = .0 @@ -86,14 +75,15 @@ def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader): MAP@K: {valid_map_at_k:.4f} / NDCG@K: {valid_ndcg_at_k:.4f}''') - wandb.log({ - 'train_loss': train_loss, - 'valid_loss': valid_loss, - 'valid_Precision@K': valid_precision_at_k, - 'valid_Recall@K': valid_recall_at_k, - 'valid_MAP@K': valid_map_at_k, - 'valid_NDCG@K': valid_ndcg_at_k, - }) + if self.cfg.wandb: + wandb.log({ + 'train_loss': train_loss, + 'valid_loss': valid_loss, + 'valid_Precision@K': valid_precision_at_k, + 'valid_Recall@K': valid_recall_at_k, + 'valid_MAP@K': valid_map_at_k, + 'valid_NDCG@K': valid_ndcg_at_k, + }) # update model if best_valid_loss > valid_loss: @@ -113,8 +103,6 @@ def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader): logger.info(f"[Trainer] ealry stopping...") break - wandb.finish() - @abstractmethod def train(self, train_dataloader: DataLoader) -> float: pass @@ -124,7 +112,7 @@ def validate(self, valid_dataloader: DataLoader) -> tuple[float]: pass @abstractmethod - def evaluate(self, test_dataloader: DataLoader) -> None: + def evaluate(self, test_dataloader: DataLoader) -> tuple[float]: pass def load_best_model(self): diff --git a/trainers/cdae_trainer.py b/trainers/cdae_trainer.py index 97ac1dd..e7eb899 100644 --- a/trainers/cdae_trainer.py +++ b/trainers/cdae_trainer.py @@ -10,6 +10,7 @@ from omegaconf.dictconfig import DictConfig from models.cdae import CDAE +from utils import log_metric from .base_trainer import BaseTrainer from metric import * @@ -63,8 +64,9 @@ def validate(self, valid_dataloader: DataLoader) -> tuple[float]: valid_recall_at_k, valid_map_at_k, valid_ndcg_at_k) - - def evaluate(self, test_dataloader: DataLoader) -> None: + + @log_metric + def evaluate(self, test_dataloader: DataLoader) -> tuple[float]: self.model.eval() actual, predicted = [], [] for data in tqdm(test_dataloader): @@ -90,7 +92,12 @@ def evaluate(self, test_dataloader: DataLoader) -> None: Recall@{self.cfg.top_n}: {test_recall_at_k:.4f} / MAP@{self.cfg.top_n}: {test_map_at_k:.4f} / NDCG@{self.cfg.top_n}: {test_ndcg_at_k:.4f}''') - + + return (test_precision_at_k, + test_recall_at_k, + test_map_at_k, + test_ndcg_at_k) + def _generate_target_and_top_k_recommendation(self, pred: Tensor, actual_mask, pred_mask) -> tuple[list]: actual, predicted = [], [] diff --git a/utils.py b/utils.py index fbd9754..0abcded 100644 --- a/utils.py +++ b/utils.py @@ -1,14 +1,17 @@ import os import random - import numpy as np +from typing import Callable + +import wandb import torch +from torch.utils.data import DataLoader from loguru import logger +from functools import wraps - -def set_seed(seed): +def set_seed(seed: int): logger.info("seed setting...") random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) @@ -19,3 +22,20 @@ def set_seed(seed): # some cudnn methods can be random even after fixing the seed # unless you tell it to be deterministic torch.backends.cudnn.deterministic = True + + +def log_metric(func: Callable): + @wraps(func) + def log_wandb(*args, **kwargs): + precision_at_k, recall_at_k, map_at_k, ndcg_at_k = func(*args, **kwargs) + + if wandb.run is not None: # validate wandb initialization + logger.info("[Trainer] logging test results...") + wandb.log({ + 'test_Precision@K': precision_at_k, + 'test_Recall@K': recall_at_k, + 'test_MAP@K': map_at_k, + 'test_NDCG@K': ndcg_at_k, + }) + return (precision_at_k, recall_at_k, map_at_k, ndcg_at_k) + return log_wandb