-
Notifications
You must be signed in to change notification settings - Fork 0
/
sb3_DQN_train.py
78 lines (72 loc) · 3.58 KB
/
sb3_DQN_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gym
import torch as th
import torch.nn as nn
import numpy as np
import matplotlib.pylab as plt
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
from gym_2048 import Gym2048Env
from stable_baselines3 import PPO, DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
from approximator import MapCNN, Map3DCNN
#%% DQN training with 3D CNN
venv = make_vec_env(Gym2048Env, n_envs=4, env_kwargs=dict(obstype="tensor"))
venv_norm = VecNormalize(venv, training=True, norm_obs=False, norm_reward=True,
clip_obs=10.0, clip_reward=1500, gamma=0.99, epsilon=1e-08)
policy_kwargs = dict(
features_extractor_class=Map3DCNN,
features_extractor_kwargs=dict(features_dim=256),
)
model = DQN("CnnPolicy", venv_norm, policy_kwargs=policy_kwargs, verbose=1, batch_size=64,
tensorboard_log=r"E:\Github_Projects\2048\results\DQN_3dCNN_rew_norm_clip1500")
#%%
model.learn(10000000, tb_log_name="DQN", log_interval=1, reset_num_timesteps=False)
#%%
venv_norm.clip_reward = 2500
model.learn(10000000, tb_log_name="DQN", log_interval=1, reset_num_timesteps=False)
#%%
model.save(r"E:\Github_Projects\2048\results\DQN_3dCNN_rew_norm_clip1500\DQN_0\DQN_18M_rew_norm_clip2500")
venv_norm.save(r"E:\Github_Projects\2048\results\DQN_3dCNN_rew_norm_clip1500\DQN_0\DQN_18M_rew_norm_clip2500_vecnorm.pkl")
#%%
model.save(r"E:\Github_Projects\2048\results\DQN_3dCNN_rew_norm_clip1500\DQN_0\DQN_08M_rew_norm_clip1500")
#%% Second run: larger batch size, larger buffer
venv = make_vec_env(Gym2048Env, n_envs=4, env_kwargs=dict(obstype="tensor"))
venv_norm = VecNormalize(venv, training=True, norm_obs=False, norm_reward=True,
clip_obs=10.0, clip_reward=2000, gamma=0.99, epsilon=1e-08)
policy_kwargs = dict(
features_extractor_class=Map3DCNN,
features_extractor_kwargs=dict(features_dim=256),
)
model = DQN("CnnPolicy", venv_norm, policy_kwargs=policy_kwargs, verbose=1, batch_size=256,
buffer_size=int(5e6), exploration_final_eps=0.01, exploration_fraction=0.1, tensorboard_log=r"E:\Github_Projects\2048\results\DQN_3dCNN_rew_norm_5mBuffer_batch256")
#%%
model.learn(20000000, tb_log_name="DQN", log_interval=1, reset_num_timesteps=False)
#%%
model.batch_size = 768
model.learning_rate = 1E-5
model.learn(5000000, tb_log_name="DQN", log_interval=1, reset_num_timesteps=False)
#%%
model.save(r"E:\Github_Projects\2048\results\DQN_3dCNN_rew_norm_5mBuffer_batch256\DQN_0\DQN_26M_rew_norm_clip2000")
venv_norm.save(r"E:\Github_Projects\2048\results\DQN_3dCNN_rew_norm_5mBuffer_batch256\DQN_0\DQN_26M_rew_norm_clip2000_vecnorm.pkl")
#%% Evaluation
eps_rew, eps_len = evaluate_policy(model, venv, n_eval_episodes=1000, render=False, return_episode_rewards=True)
print(f"Episode reward {np.mean(eps_rew)}+-{np.std(eps_rew)}")
print(f"Episode length {np.mean(eps_len)}+-{np.std(eps_len)}")
#%%
np.savez(r"E:\Github_Projects\2048\exp_data\DQN2_3dCNN_26M_scores.npz", eps_rew=eps_rew, eps_len=eps_len)
#%% Visualization
plt.hist(eps_rew, bins=65)
plt.title(f"DQN 3D CNN 26M step reward {np.mean(eps_rew):.2f}+-{np.std(eps_rew):.2f}")
plt.xlabel("Episode Reward")
plt.savefig("DQN3_eps_reward_hist.png")
plt.show()
#%%
plt.hist(eps_len, bins=65)
plt.title(f"DQN 3D CNN 26M step episode len {np.mean(eps_len):.2f}+-{np.std(eps_len):.2f}")
plt.xlabel("Episode Length")
plt.savefig("DQN3_eps_len_hist.png")
plt.show()