-
Notifications
You must be signed in to change notification settings - Fork 6
/
main_tune.py
34 lines (30 loc) · 1.35 KB
/
main_tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
import sys
import os
import torch
import time
from experiment.tune_and_exp import tune_and_experiment_multiple_runs
from utils.utils import Logger
from types import SimpleNamespace
from experiment.tune_config import config_default
if __name__ == "__main__":
parser = argparse.ArgumentParser('Hyper-parameter tuning')
parser.add_argument('--data', dest='data', default='wisdm', type=str)
parser.add_argument('--encoder', dest='encoder', default='CNN', type=str)
parser.add_argument('--agent', dest='agent', default='DT2W', type=str)
parser.add_argument('--norm', dest='norm', default='BN', type=str)
args = parser.parse_args()
# Include unchanged general params
args = SimpleNamespace(**vars(args), **config_default)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Set directories
exp_start_time = time.strftime("%b-%d-%H-%M-%S", time.localtime())
exp_path_1 = args.encoder + '_' + args.data
exp_path_2 = args.agent + '_' + args.norm + '_' + exp_start_time
exp_path = os.path.join(args.path_prefix, exp_path_1, exp_path_2) # Path for running the whole experiment
if not os.path.exists(exp_path):
os.makedirs(exp_path)
args.exp_path = exp_path
log_path = args.exp_path + '/log.txt'
sys.stdout = Logger('{}'.format(log_path))
tune_and_experiment_multiple_runs(args)