From 582a8526f20c5b6b05d337958c26218f296db9e5 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 2 Oct 2018 11:09:36 +0200 Subject: [PATCH 1/2] Fix framestack for run atari acer --- stable_baselines/acer/run_atari.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stable_baselines/acer/run_atari.py b/stable_baselines/acer/run_atari.py index 4d86a77cfe..f653132e89 100644 --- a/stable_baselines/acer/run_atari.py +++ b/stable_baselines/acer/run_atari.py @@ -4,6 +4,7 @@ from stable_baselines.acer import ACER from stable_baselines.common.policies import CnnPolicy, CnnLstmPolicy from stable_baselines.common.cmd_util import make_atari_env, atari_arg_parser +from stable_baselines.common.vec_env import VecFrameStack def train(env_id, num_timesteps, seed, policy, lr_schedule, num_cpu): @@ -18,7 +19,7 @@ def train(env_id, num_timesteps, seed, policy, lr_schedule, num_cpu): 'double_linear_con', 'middle_drop' or 'double_middle_drop') :param num_cpu: (int) The number of cpu to train on """ - env = make_atari_env(env_id, num_cpu, seed) + env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4) if policy == 'cnn': policy_fn = CnnPolicy elif policy == 'lstm': From 8c43ec392858a0ab49f3a8fc794aaede8ed5ca6b Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 2 Oct 2018 11:14:18 +0200 Subject: [PATCH 2/2] Hotfix param --- stable_baselines/acer/run_atari.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines/acer/run_atari.py b/stable_baselines/acer/run_atari.py index f653132e89..92ba6c472b 100644 --- a/stable_baselines/acer/run_atari.py +++ b/stable_baselines/acer/run_atari.py @@ -19,7 +19,7 @@ def train(env_id, num_timesteps, seed, policy, lr_schedule, num_cpu): 'double_linear_con', 'middle_drop' or 'double_middle_drop') :param num_cpu: (int) The number of cpu to train on """ - env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4) + env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4) if policy == 'cnn': policy_fn = CnnPolicy elif policy == 'lstm':