-
Notifications
You must be signed in to change notification settings - Fork 58
/
run.py
73 lines (61 loc) · 2.53 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pickle
import time
import numpy as np
import argparse
import re
from envs import TradingEnv
from agent import DQNAgent
from utils import get_data, get_scaler, maybe_make_dir
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--episode', type=int, default=2000,
help='number of episode to run')
parser.add_argument('-b', '--batch_size', type=int, default=32,
help='batch size for experience replay')
parser.add_argument('-i', '--initial_invest', type=int, default=20000,
help='initial investment amount')
parser.add_argument('-m', '--mode', type=str, required=True,
help='either "train" or "test"')
parser.add_argument('-w', '--weights', type=str, help='a trained model weights')
args = parser.parse_args()
maybe_make_dir('weights')
maybe_make_dir('portfolio_val')
timestamp = time.strftime('%Y%m%d%H%M')
data = np.around(get_data())
train_data = data[:, :3526]
test_data = data[:, 3526:]
env = TradingEnv(train_data, args.initial_invest)
state_size = env.observation_space.shape
action_size = env.action_space.n
agent = DQNAgent(state_size, action_size)
scaler = get_scaler(env)
portfolio_value = []
if args.mode == 'test':
# remake the env with test data
env = TradingEnv(test_data, args.initial_invest)
# load trained weights
agent.load(args.weights)
# when test, the timestamp is same as time when weights was trained
timestamp = re.findall(r'\d{12}', args.weights)[0]
for e in range(args.episode):
state = env.reset()
state = scaler.transform([state])
for time in range(env.n_step):
action = agent.act(state)
next_state, reward, done, info = env.step(action)
next_state = scaler.transform([next_state])
if args.mode == 'train':
agent.remember(state, action, reward, next_state, done)
state = next_state
if done:
print("episode: {}/{}, episode end value: {}".format(
e + 1, args.episode, info['cur_val']))
portfolio_value.append(info['cur_val']) # append episode end portfolio value
break
if args.mode == 'train' and len(agent.memory) > args.batch_size:
agent.replay(args.batch_size)
if args.mode == 'train' and (e + 1) % 10 == 0: # checkpoint weights
agent.save('weights/{}-dqn.h5'.format(timestamp))
# save portfolio value history to disk
with open('portfolio_val/{}-{}.p'.format(timestamp, args.mode), 'wb') as fp:
pickle.dump(portfolio_value, fp)