diff --git a/_train.py b/_train.py deleted file mode 100644 index 472cbe0..0000000 --- a/_train.py +++ /dev/null @@ -1,98 +0,0 @@ -import hydra -from omegaconf import OmegaConf - -import pytz -from datetime import datetime - -import wandb -from torch.utils.data import DataLoader - -from loguru import logger - -from data.datasets.cdae_data_pipeline import CDAEDataPipeline -from data.datasets.cdae_dataset import CDAEDataset -from trainers.cdae_trainer import CDAETrainer -from utils import set_seed - - -def init_wandb_if_needed(cfg): - if cfg.wandb: - logger.info("[wandb] init...") - run_time: str = datetime.now().astimezone(pytz.timezone('Asia/Seoul')).strftime('%Y-%m-%d %H:%M:%S') - run_name: str = f'[{cfg.model_name}]{run_time}' - - wandb.init( - project=cfg.project, - name=run_name, - config=dict(cfg), - notes=cfg.notes, - tags=cfg.tags, - ) - -def finish_wandb_if_needed(cfg): - if cfg.wandb: - logger.info("[wandb] finish...") - wandb.finish() - -def update_config_hyperparameters(cfg): - logger.info("[Sweep] Update hyper-parameters...") - for parameter in cfg.parameters: - cfg[parameter] = wandb.config[parameter] - logger.info(f"[{parameter}] {cfg[parameter]}") - -def run(cfg, train_dataset, valid_dataset, test_dataset, model_info): - set_seed(cfg.seed) - init_wandb_if_needed(cfg) - train(cfg, train_dataset, valid_dataset, test_dataset, model_info) - finish_wandb_if_needed(cfg) - -def run_sweep(cfg, *datasets): - sweep_id = wandb.sweep(sweep=OmegaConf.to_container(cfg, resolve=True), project=cfg.project) - wandb.agent(sweep_id, - function=lambda: sweep(cfg, *datasets), - count=cfg.sweep_count) - -def sweep(cfg, *datasets): - set_seed(cfg.seed) - init_wandb_if_needed(cfg) - update_config_hyperparameters(cfg) - train(cfg, *datasets) - -def train(cfg, train_dataset, valid_dataset, test_dataset, model_info): - # set dataloaders - train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) - valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) - test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size) - - if cfg.model_name in ('CDAE', ): - trainer = CDAETrainer(cfg, model_info['num_items'], model_info['num_users']) - trainer.run(train_dataloader, valid_dataloader) - trainer.load_best_model() - trainer.evaluate(test_dataloader) - -@hydra.main(version_base=None, config_path="configs", config_name="_train_config") -def main(cfg: OmegaConf): - if cfg.model_name in ('CDAE', ): - data_pipeline = CDAEDataPipeline(cfg) - else: - raise ValueError() - - df = data_pipeline.preprocess() - train_data, valid_data, test_data = data_pipeline.split(df) - - model_info = dict() # additional infos needed to create model object - if cfg.model_name in ('CDAE', ): - train_dataset = CDAEDataset(train_data, 'train') - valid_dataset = CDAEDataset(valid_data, 'valid') - test_dataset = CDAEDataset(test_data, 'test') - model_info['num_items'], model_info['num_users'] = len(df.columns)-1, len(train_data) - else: - raise ValueError() - - if cfg.wandb and cfg.sweep: - run_sweep(cfg, train_dataset, valid_dataset, test_dataset, model_info) - else: - run(cfg, train_dataset, valid_dataset, test_dataset, model_info) - -if __name__ == '__main__': - main() diff --git a/configs/_train_config.yaml b/configs/_train_config.yaml deleted file mode 100644 index 246ffb5..0000000 --- a/configs/_train_config.yaml +++ /dev/null @@ -1,47 +0,0 @@ -# run config -seed: 42 -shuffle: True -model_dir: outputs/models -submit_dir: outputs/submissions -data_dir: data/ -log_dir: logs/ - -# wandb config -wandb: True # True/ False -project: YelpRecommendation -notes: "..." -tags: [test, yelp, cdae] - -# train config -device: cuda # cpu -epochs: 1 -batch_size: 32 -lr: 0.001 -optimizer: adamw -loss: bce -patience: 5 -top_n: 10 - -# model config -model_name: CDAE -hidden_size: 64 -corruption_level: 0.6 -hidden_activation: sigmoid -output_activation: sigmoid - -# sweep config -sweep: False -sweep_count: 3 -method: random -name: sweep -metric: - goal: minimize - name: valid_loss -parameters: - batch_size: - values: [16, 32, 64] - hidden_size: - values: [50, 100, 150] - lr: - min: 0.0001 - max: 0.1 diff --git a/configs/sweep_config.yaml b/configs/sweep_config.yaml index 95f7ea6..25e58f8 100644 --- a/configs/sweep_config.yaml +++ b/configs/sweep_config.yaml @@ -1,36 +1,5 @@ -# run config -seed: 42 -shuffle: True -model_dir: outputs/models -submit_dir: outputs/submissions -data_dir: data/ -log_dir: logs/ - -# wandb config -wandb: True # True/ False -project: YelpRecommendation -notes: "..." -tags: [test, yelp, cdae] - -# train config -device: cuda # cpu -epochs: 1 -batch_size: 32 -lr: 0.001 -optimizer: adamw -loss: bce -patience: 5 -top_n: 10 - -# model config -model_name: CDAE -hidden_size: 64 -corruption_level: 0.6 -hidden_activation: sigmoid -output_activation: sigmoid - - # sweep config +## MF sweep_count: 3 method: random name: sweep diff --git a/configs/train_config.yaml b/configs/train_config.yaml index 9521e51..a396584 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -5,6 +5,7 @@ model_dir: outputs/models submit_dir: outputs/submissions data_dir: data/ log_dir: logs/ +sweep: True # wandb config wandb: True # True/ False @@ -17,20 +18,20 @@ device: cuda # cpu epochs: 10 batch_size: 32 lr: 0.001 -optimizer: sgd # adamw -loss_name: bpr # pointwise # bce +optimizer: adam # adamw +loss_name: bce # bpr # pointwise # bce patience: 5 top_n: 10 weight_decay: 0 #1e-5 # model config -#model_name: CDAE -#negative_sampling: True # False -#neg_times: 5 # this works only when negative_sampling == True, if value is 5, the number of negative samples will be 5 times the number of positive samples by users -#hidden_size: 64 -#corruption_level: 0.6 -#hidden_activation: sigmoid -#output_activation: sigmoid +model_name: CDAE +negative_sampling: True # False +neg_times: 5 # this works only when negative_sampling == True, if value is 5, the number of negative samples will be 5 times the number of positive samples by users +hidden_size: 64 +corruption_level: 0.6 +hidden_activation: sigmoid +output_activation: sigmoid -model_name: MF -embed_size: 64 +#model_name: MF +#embed_size: 64 diff --git a/poetry.lock b/poetry.lock index 48d018f..8eebde7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -389,6 +389,17 @@ files = [ [package.dependencies] six = ">=1.4.0" +[[package]] +name = "easydict" +version = "1.13" +description = "Access dict values as attributes (works recursively)." +optional = false +python-versions = "*" +files = [ + {file = "easydict-1.13-py3-none-any.whl", hash = "sha256:6b787daf4dcaf6377b4ad9403a5cee5a86adbc0ca9a5bcf5410e9902002aeac2"}, + {file = "easydict-1.13.tar.gz", hash = "sha256:b1135dedbc41c8010e2bc1f77ec9744c7faa42bce1a1c87416791449d6c87780"}, +] + [[package]] name = "executing" version = "2.0.1" @@ -2496,4 +2507,4 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "4770979d79ff85b8236e29ce7d14321e918bb053d3634602a9727ec01742a416" +content-hash = "c48fed95b49242d84a6ded1fc730c5002a363d58316454757d82e1d2c440fdeb" diff --git a/pyproject.toml b/pyproject.toml index 7bc3b5b..bd3215a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ hydra-core = "^1.3.2" loguru = "^0.7.2" wandb = "^0.17.0" scikit-learn = "1.4.0" +easydict = "^1.13" [tool.poetry.group.dev.dependencies] ipykernel = "^6.29.4" diff --git a/sweep.py b/sweep.py deleted file mode 100644 index c0f0b1b..0000000 --- a/sweep.py +++ /dev/null @@ -1,73 +0,0 @@ -import pytz -from datetime import datetime - -import wandb - -from torch.utils.data import DataLoader - -from loguru import logger - -from omegaconf import OmegaConf - -from data.datasets.cdae_data_pipeline import CDAEDataPipeline -from data.datasets.cdae_dataset import CDAEDataset -from trainers.cdae_trainer import CDAETrainer -from utils import set_seed - - -cfg: OmegaConf = None -num_items = 0 -num_users = 0 - -def main(): - logger.info(f"set seed as {cfg.seed}...") - set_seed(cfg.seed) - - logger.info("[wandb] init...") - run_time: str = datetime.now().astimezone(pytz.timezone('Asia/Seoul')).strftime('%Y-%m-%d %H:%M:%S') - run_name: str = f'[Sweep][{cfg.model_name}]{run_time}' - - wandb.init( - project=cfg.project, - name=run_name, - config=dict(cfg), - notes=cfg.notes, - tags=cfg.tags, - ) - - # update hyperparameters to selected values - logger.info("[Sweep] Update hyper-parameters...") - for parameter in cfg.parameters: - cfg[parameter] = wandb.config[parameter] - logger.info(f"[{parameter}] {cfg.lr}") - - if cfg.model_name in ('CDAE',): - trainer = CDAETrainer(cfg, num_items, num_users) - elif cfg.model_name in ('WDN', ): - trainer = CDAETrainer(cfg, num_items, num_users) - - train_dataloader = DataLoader(train_data, batch_size=cfg.batch_size, shuffle=True) - valid_dataloader = DataLoader(valid_data, batch_size=cfg.batch_size, shuffle=True) - test_dataloader = DataLoader(test_data, batch_size=cfg.batch_size) - - trainer.run(train_dataloader, valid_dataloader) - trainer.load_best_model() - trainer.evaluate(test_dataloader) - - -if __name__ == '__main__': - cfg = OmegaConf.load('./configs/sweep_config.yaml') - - data_pipeline = CDAEDataPipeline(cfg) - df = data_pipeline.preprocess() - train_data, valid_data, test_data = data_pipeline.split(df) - - train_data = CDAEDataset(train_data, 'train') - valid_data = CDAEDataset(valid_data, 'valid') - test_data = CDAEDataset(test_data, 'test') - - num_items = len(df.columns) - 1 - num_users = len(train_data) - - sweep_id = wandb.sweep(sweep=OmegaConf.to_container(cfg, resolve=True), project=cfg.project) - wandb.agent(sweep_id, function=main, count=cfg.sweep_count) diff --git a/train.py b/train.py index c6d3c69..89e9e86 100644 --- a/train.py +++ b/train.py @@ -3,9 +3,9 @@ import pytz from datetime import datetime +from easydict import EasyDict import wandb -import torch from torch.utils.data import DataLoader from loguru import logger @@ -19,10 +19,7 @@ from utils import set_seed -@hydra.main(version_base=None, config_path="configs", config_name="train_config") -def main(cfg: OmegaConf): - - # wandb init +def init_wandb_if_needed(cfg): if cfg.wandb: logger.info("[wandb] init...") run_time: str = datetime.now().astimezone(pytz.timezone('Asia/Seoul')).strftime('%Y-%m-%d %H:%M:%S') @@ -35,10 +32,56 @@ def main(cfg: OmegaConf): notes=cfg.notes, tags=cfg.tags, ) - - logger.info(f"set seed as {cfg.seed}...") + +def finish_wandb_if_needed(cfg): + if cfg.wandb: + logger.info("[wandb] finish...") + wandb.finish() + +def update_config_hyperparameters(cfg): + logger.info("[Sweep] Update hyper-parameters...") + for parameter in cfg.parameters: + cfg[parameter] = wandb.config[parameter] + logger.info(f"[{parameter}] {cfg[parameter]}") + +def run(cfg, args):#train_dataset, valid_dataset, test_dataset, model_info): + set_seed(cfg.seed) + init_wandb_if_needed(cfg) + train(cfg, args)#train_dataset, valid_dataset, test_dataset, model_info) + finish_wandb_if_needed(cfg) + +def run_sweep(cfg, args): + sweep_id = wandb.sweep(sweep=OmegaConf.to_container(cfg, resolve=True), project=cfg.project) + wandb.agent(sweep_id, + function=lambda: sweep(cfg, args), + count=cfg.sweep_count) + +def sweep(cfg, args):# *datasets): set_seed(cfg.seed) - + init_wandb_if_needed(cfg) + update_config_hyperparameters(cfg) + train(cfg, args)#*datasets) + +def train(cfg, args):#train_dataset, valid_dataset, test_dataset, model_info): + # set dataloaders + train_dataloader = DataLoader(args.train_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) + valid_dataloader = DataLoader(args.valid_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) + + if cfg.model_name != 'MF': + test_dataloader = DataLoader(args.test_dataset, batch_size=cfg.batch_size) + + if cfg.model_name in ('CDAE', ): + trainer = CDAETrainer(cfg, args.model_info['num_items'], args.model_info['num_users']) + trainer.run(train_dataloader, valid_dataloader) + trainer.load_best_model() + trainer.evaluate(test_dataloader) + elif cfg.model_name in ('MF', ): + trainer = MFTrainer(cfg, args.model_info['num_items'], args.model_info['num_users']) + trainer.run(train_dataloader, valid_dataloader, args.valid_eval_data) + trainer.evaluate(args.test_eval_data, 'test') + +@hydra.main(version_base=None, config_path="configs", config_name="train_config") +def main(cfg: OmegaConf): if cfg.model_name in ('CDAE', ): data_pipeline = CDAEDataPipeline(cfg) elif cfg.model_name == 'MF': @@ -48,38 +91,39 @@ def main(cfg: OmegaConf): df = data_pipeline.preprocess() + args = EasyDict() + + model_info = dict() # additional infos needed to create model object if cfg.model_name in ('CDAE', ): train_data, valid_data, test_data = data_pipeline.split(df) train_dataset = CDAEDataset(train_data, 'train', neg_times=cfg.neg_times) valid_dataset = CDAEDataset(valid_data, 'valid', neg_times=cfg.neg_times) test_dataset = CDAEDataset(test_data, 'test') + args.update({'test_dataset': test_dataset}) + model_info['num_items'], model_info['num_users'] = len(df.columns)-1, len(train_data) elif cfg.model_name == 'MF': train_data, valid_data, valid_eval_data, test_eval_data = data_pipeline.split(df) train_dataset = MFDataset(train_data, num_items=data_pipeline.num_items) valid_dataset = MFDataset(valid_data, num_items=data_pipeline.num_items) + args.update({'valid_eval_data': valid_eval_data, 'test_eval_data': test_eval_data}) + model_info['num_items'], model_info['num_users'] = data_pipeline.num_items, data_pipeline.num_users else: raise ValueError() - # set dataloaders - train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) - valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) + args.update({ + 'train_dataset': train_dataset, + 'valid_dataset': valid_dataset, + 'model_info': model_info, + }) - if cfg.model_name != 'MF': - test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size) - - if cfg.model_name in ('CDAE', ): - trainer = CDAETrainer(cfg, len(df.columns)-1, len(train_dataset)) - trainer.run(train_dataloader, valid_dataloader) - trainer.load_best_model() - trainer.evaluate(test_dataloader) - elif cfg.model_name in ('MF', ): - trainer = MFTrainer(cfg, data_pipeline.num_items, data_pipeline.num_users) - trainer.run(train_dataloader, valid_dataloader, valid_eval_data) - trainer.evaluate(test_eval_data, 'test') - - if cfg.wandb: - logger.info("[wandb] finish...") - wandb.finish() + if cfg.wandb and cfg.sweep: + sweep_cfg = OmegaConf.load('configs/sweep_config.yaml') + merge_cfg = OmegaConf.create({}) + merge_cfg.update(cfg) + merge_cfg.update(sweep_cfg) + run_sweep(merge_cfg, args) + else: + run(cfg, args) if __name__ == '__main__': main()