From f3162a15210f0fc1a2275d206b73cd2ec5ffc1ff Mon Sep 17 00:00:00 2001 From: GangBean Date: Thu, 23 May 2024 06:55:34 +0000 Subject: [PATCH] refactor: convert sweep config to split format per model #16 --- configs/cdae_sweep_config.yaml | 24 ++++++++++++++++++++++++ configs/mf_sweep_config.yaml | 18 ++++++++++++++++++ train.py | 2 +- 3 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 configs/cdae_sweep_config.yaml create mode 100644 configs/mf_sweep_config.yaml diff --git a/configs/cdae_sweep_config.yaml b/configs/cdae_sweep_config.yaml new file mode 100644 index 0000000..adf2b3c --- /dev/null +++ b/configs/cdae_sweep_config.yaml @@ -0,0 +1,24 @@ +# CDAE +sweep_count: 100 +method: grid # grid, random, bayes +name: cdae_grid_sweep +metric: + goal: minimize # minimize, maximize + name: valid_loss # valid_MAP@K +parameters: + # batch_size: ##########[COMMON]############### + # values: [16, 32, 64] + # lr: + # values: [1e-1, 1e-3, 1e-4, 1e-5] + # optimizer: + # values: [adam, adamw] + # weight_decay: + # values: [1e-1, 1e-3, 1e-5] + neg_times: ###########[MODEL]############### + values: [1, 5, 10] + hidden_size: + values: [32, 64, 128, 256, 512, 1024] + corruption_level: + values: [.1, .3, .4, .5, .6, .9] + hidden_activation: + values: [sigmoid, identity] diff --git a/configs/mf_sweep_config.yaml b/configs/mf_sweep_config.yaml new file mode 100644 index 0000000..d4df918 --- /dev/null +++ b/configs/mf_sweep_config.yaml @@ -0,0 +1,18 @@ +# CDAE +sweep_count: 100 +method: grid # grid, random, bayes +name: mf_grid_sweep +metric: + goal: minimize + name: valid_loss +parameters: + # batch_size: ##########[COMMON]############### + # values: [16, 32, 64] + # lr: + # values: [1e-1, 1e-3, 1e-4, 1e-5] + optimizer: + values: [adam, adamw] + # weight_decay: + # values: [1e-1, 1e-3, 1e-5] + embed_size: ###########[MODEL]############### + values: [32, 64, 128, 256, 512, 1024] diff --git a/train.py b/train.py index 31b01eb..64dc25a 100644 --- a/train.py +++ b/train.py @@ -151,7 +151,7 @@ def main(cfg: OmegaConf): }) if cfg.wandb and cfg.sweep: - sweep_cfg = OmegaConf.load('configs/sweep_config.yaml') + sweep_cfg = OmegaConf.load(f'configs/{cfg.model_name.lower()}_sweep_config.yaml') merge_cfg = OmegaConf.create({}) merge_cfg.update(cfg) merge_cfg.update(sweep_cfg)