Skip to content

Commit

Permalink
fix bug in test_model and hyper (#277)
Browse files Browse the repository at this point in the history
* fix bug in test_model and hyper

* fix pipeline

* add hyper example

* unify example
  • Loading branch information
aptx1231 authored Apr 26, 2022
1 parent 939ef5d commit 312e5ed
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 26 deletions.
6 changes: 6 additions & 0 deletions hyper_example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"learning_rate": {
"type": "choice",
"list": [0.01, 0.005, 0.001]
}
}
1 change: 1 addition & 0 deletions hyper_example.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
learning_rate choice [0.01, 0.005, 0.001]
4 changes: 2 additions & 2 deletions hyper_tune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
训练并评估单一模型的脚本
模型调参脚本 (based on the ray[tune])
"""

import argparse
Expand All @@ -20,7 +20,7 @@
parser.add_argument('--config_file', type=str,
default=None, help='the file name of config file')
parser.add_argument('--space_file', type=str,
default=None, help='the file which specifies the parameter search space')
default='hyper_example', help='the file which specifies the parameter search space')
parser.add_argument('--scheduler', type=str,
default='FIFO', help='the trial sheduler which will be used in ray.tune.run')
parser.add_argument('--search_alg', type=str,
Expand Down
20 changes: 15 additions & 5 deletions libcity/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,15 @@ def hyper_parameter(task=None, model_name=None, dataset_name=None, config_file=N
# load config
experiment_config = ConfigParser(task, model_name, dataset_name, config_file=config_file,
other_args=other_args)
# exp_id
exp_id = experiment_config.get('exp_id', None)
if exp_id is None:
exp_id = int(random.SystemRandom().random() * 100000)
experiment_config['exp_id'] = exp_id
# logger
logger = get_logger(experiment_config)
logger.info('Begin ray-tune, task={}, model_name={}, dataset_name={}, exp_id={}'.
format(str(task), str(model_name), str(dataset_name), str(exp_id)))
logger.info(experiment_config.config)
# check space_file
if space_file is None:
Expand Down Expand Up @@ -167,8 +174,11 @@ def train(config, checkpoint_dir=None, experiment_config=None,
experiment_config[key] = config[key]
experiment_config['hyper_tune'] = True
logger = get_logger(experiment_config)
logger.info('Begin pipeline, task={}, model_name={}, dataset_name={}'
.format(str(task), str(model_name), str(dataset_name)))
# exp_id
exp_id = int(random.SystemRandom().random() * 100000)
experiment_config['exp_id'] = exp_id
logger.info('Begin pipeline, task={}, model_name={}, dataset_name={}, exp_id={}'.
format(str(task), str(model_name), str(dataset_name), str(exp_id)))
logger.info('running parameters: ' + str(config))
# load model
model = get_model(experiment_config, data_feature)
Expand Down Expand Up @@ -215,9 +225,9 @@ def train(config, checkpoint_dir=None, experiment_config=None,
# save best
best_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
model_state, optimizer_state = torch.load(best_path)
model_cache_file = './libcity/cache/model_cache/{}_{}.m'.format(
model_name, dataset_name)
ensure_dir('./libcity/cache/model_cache')
model_cache_file = './libcity/cache/{}/model_cache/{}_{}.m'.format(
exp_id, model_name, dataset_name)
ensure_dir('./libcity/cache/{}/model_cache'.format(exp_id))
torch.save((model_state, optimizer_state), model_cache_file)


Expand Down
17 changes: 12 additions & 5 deletions run_hyper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
单一模型调参脚本
模型调参脚本 (based on the hyperopt)
"""

import argparse

import random
from libcity.pipeline import objective_function
from libcity.executor import HyperTuning
from libcity.utils import str2bool, get_logger, set_random_seed, add_general_args
Expand All @@ -26,7 +26,7 @@
help='whether re-train model if the model is \
trained before')
parser.add_argument('--params_file', type=str,
default=None, help='the file which specify the \
default='hyper_example.txt', help='the file which specify the \
hyper-parameters and ranges to be adjusted')
parser.add_argument('--hyper_algo', type=str,
default='grid_search', help='hyper-parameters search algorithm')
Expand All @@ -43,11 +43,18 @@
other_args = {key: val for key, val in dict_args.items() if key not in [
'task', 'model', 'dataset', 'config_file', 'saved_model', 'train',
'params_file', 'hyper_algo'] and val is not None}

logger = get_logger({'model': args.model, 'dataset': args.dataset})
# exp_id
exp_id = dict_args.get('exp_id', None)
if exp_id is None:
# Make a new experiment ID
exp_id = int(random.SystemRandom().random() * 100000)
other_args['exp_id'] = exp_id
# logger
logger = get_logger({'model': args.model, 'dataset': args.dataset, 'exp_id': exp_id})
# seed
seed = dict_args.get('seed', 0)
set_random_seed(seed)
other_args['seed'] = seed
hp = HyperTuning(objective_function, params_file=args.params_file, algo=args.hyper_algo,
max_evals=args.max_evals, task=args.task, model_name=args.model,
dataset_name=args.dataset, config_file=args.config_file,
Expand Down
35 changes: 24 additions & 11 deletions test_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
from libcity.config import ConfigParser
from libcity.data import get_dataset
from libcity.utils import get_model, get_executor
from libcity.utils import get_model, get_executor, get_logger, set_random_seed
import random

"""
取一个batch的数据进行初步测试
Take the data of a batch for preliminary testing
"""

# 加载配置文件
config = ConfigParser(task='traj_loc_pred', model='TemplateTLP',
dataset='foursquare_tky', config_file=None,
other_args={'batch_size': 2})
# 如果是交通流量\速度预测任务,请使用下面的加载配置文件语句
# config = ConfigParser(task='traffic_state_pred', model='TemplateTSP',
# dataset='METR_LA', config_file=None, other_args={'batch_size': 2})
config = ConfigParser(task='traffic_state_pred', model='RNN',
dataset='METR_LA', other_args={'batch_size': 2})
exp_id = config.get('exp_id', None)
if exp_id is None:
exp_id = int(random.SystemRandom().random() * 100000)
config['exp_id'] = exp_id
# logger
logger = get_logger(config)
logger.info(config.config)
# seed
seed = config.get('seed', 0)
set_random_seed(seed)
# 加载数据模块
dataset = get_dataset(config)
# 数据预处理,划分数据集
Expand All @@ -18,10 +30,11 @@
batch = train_data.__iter__().__next__()
# 加载模型
model = get_model(config, data_feature)
self = model.to(config['device'])
model = model.to(config['device'])
# 加载执行器
executor = get_executor(config, model, data_feature)
# 模型预测
batch.to_tensor(config['device'])
res = model.predict(batch)
# 请自行确认 res 的 shape 是否符合赛道的约束
# 如果要加载执行器的话
executor = get_executor(config, model)
logger.info('Result shape is {}'.format(res.shape))
logger.info('Success test the model!')
6 changes: 3 additions & 3 deletions unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

#############################################
# The parameter to control the unit testing #
tested_trajectory_model = 'TemplateTLP'
tested_trajectory_dataset = 'foursquare_tky'
tested_trajectory_model = 'RNN'
tested_trajectory_dataset = 'foursquare_nyc'
tested_trajectory_encoder = 'StandardTrajectoryEncoder'
tested_traffic_state_model = 'DCRNN'
tested_traffic_state_model = 'RNN'
tested_traffic_state_dataset = 'METR_LA'
#############################################

Expand Down

0 comments on commit 312e5ed

Please sign in to comment.