forked from kuun1000/RL_robot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TrainEnvPandaPP.py
79 lines (62 loc) · 2.45 KB
/
TrainEnvPandaPP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from AmbienteRobot import RobotEnv
import pybullet as p
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, CallbackList
from stable_baselines3.common.monitor import Monitor
def create_train_env():
env = RobotEnv(is_for_training=True)
return Monitor(env) # Avvolgi l'ambiente con Monitor
def create_test_env():
env = RobotEnv(is_for_training=True)
return Monitor(env) # Avvolgi l'ambiente con Monitor
if __name__ == '__main__':
# Definisce la cartella per tensorboard
log_dir = "./tensorboard_logs/"
# Numero di ambienti in parallelo
num_envs = 4
# Parallelizzazione dell'ambiente di addestramento
train_env = SubprocVecEnv([create_train_env for _ in range(num_envs)])
print("Training environment initialized and reset successfully")
train_env.reset()
# Definizione del modello DQN con parametri ottimizzati
model = DQN(
'CnnPolicy',
train_env,
learning_rate=5e-4,
buffer_size=200000,
learning_starts=7500, # Aumentato per dare più tempo all'esplorazione iniziale
batch_size=256,
gamma=0.99,
exploration_fraction=0.2,
exploration_final_eps=0.05,
train_freq=4, # Da rivedere
target_update_interval=5000,
verbose=1,
tensorboard_log=log_dir # Logging per TensorBoard
)
# Configurazione delle callback per i checkpoint e la valutazione
checkpoint_callback = CheckpointCallback(
save_freq=1000,
save_path='./checkpoints/',
name_prefix='dqn_model'
)
eval_callback = EvalCallback(
DummyVecEnv([create_test_env]),
best_model_save_path='./logs/',
log_path='./logs/',
eval_freq=1000,
n_eval_episodes=10,
deterministic=True,
render=False
)
# Lista dei callback da passare al modello
callback = CallbackList([checkpoint_callback, eval_callback])
# Addestramento del modello con i parametri ottimizzati
model.learn(total_timesteps=40000, callback=callback) # Aumentato il numero di timestep
model.save("dqn_Z_6")
# Disconnette PyBullet dopo l'addestramento
if p.isConnected():
p.disconnect()
# Chiudi l'ambiente di addestramento
train_env.close()