diff --git a/README.md b/README.md
index e503556f3..16562e429 100644
--- a/README.md
+++ b/README.md
@@ -13,10 +13,16 @@ Training an agent to solve an environment is as easy as running:
python3 coach.py -p CartPole_DQN -r
```
-
+
Blog post from the Intel® Nervana™ website can be found [here](https://www.intelnervana.com/reinforcement-learning-coach-intel).
+
+## Documentation
+
+Framework documentation, algorithm description and instructions on how to contribute a new agent/environment can be found [here](http://coach.nervanasys.com).
+
+
## Installation
Note: Coach has only been tested on Ubuntu 16.04 LTS, and with Python 3.5.
@@ -103,6 +109,8 @@ For example:
It is easy to create new presets for different levels or environments by following the same pattern as in presets.py
+More usage examples can be found [here](http://coach.nervanasys.com/usage/index.html).
+
## Running Coach Dashboard (Visualization)
Training an agent to solve an environment can be tricky, at times.
@@ -121,11 +129,6 @@ python3 dashboard.py
-## Documentation
-
-Framework documentation, algoritmic description and instructions on how to contribute a new agent/environment can be found [here](http://coach.nervanasys.com).
-
-
## Parallelizing an Algorithm
Since the introduction of [A3C](https://arxiv.org/abs/1602.01783) in 2016, many algorithms were shown to benefit from running multiple instances in parallel, on many CPU cores. So far, these algorithms include [A3C](https://arxiv.org/abs/1602.01783), [DDPG](https://arxiv.org/pdf/1704.03073.pdf), [PPO](https://arxiv.org/pdf/1707.06347.pdf), and [NAF](https://arxiv.org/pdf/1610.00633.pdf), and this is most probably only the begining.
@@ -150,11 +153,11 @@ python3 coach.py -p Hopper_A3C -n 16
## Supported Environments
-* OpenAI Gym
+* *OpenAI Gym:*
Installed by default by Coach's installer.
-* ViZDoom:
+* *ViZDoom:*
Follow the instructions described in the ViZDoom repository -
@@ -162,13 +165,13 @@ python3 coach.py -p Hopper_A3C -n 16
Additionally, Coach assumes that the environment variable VIZDOOM_ROOT points to the ViZDoom installation directory.
-* Roboschool:
+* *Roboschool:*
Follow the instructions described in the roboschool repository -
https://github.com/openai/roboschool
-* GymExtensions:
+* *GymExtensions:*
Follow the instructions described in the GymExtensions repository -
@@ -176,10 +179,19 @@ python3 coach.py -p Hopper_A3C -n 16
Additionally, add the installation directory to the PYTHONPATH environment variable.
-* PyBullet
+* *PyBullet:*
Follow the instructions described in the [Quick Start Guide](https://docs.google.com/document/d/10sXEhzFRSnvFcl3XxNGhnD4N2SedqwdAvK3dsihxVUA) (basically just - 'pip install pybullet')
+* *CARLA:*
+
+ Download release 0.7 from the CARLA repository -
+
+ https://github.com/carla-simulator/carla/releases
+
+ Create a new CARLA_ROOT environment variable pointing to CARLA's installation directory.
+
+ A simple CARLA settings file (```CarlaSettings.ini```) is supplied with Coach, and is located in the ```environments``` directory.
## Supported Algorithms
@@ -190,24 +202,24 @@ python3 coach.py -p Hopper_A3C -n 16
-* [Deep Q Network (DQN)](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf)
-* [Double Deep Q Network (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
+* [Deep Q Network (DQN)](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) ([code](agents/dqn_agent.py))
+* [Double Deep Q Network (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) ([code](agents/ddqn_agent.py))
* [Dueling Q Network](https://arxiv.org/abs/1511.06581)
-* [Mixed Monte Carlo (MMC)](https://arxiv.org/abs/1703.01310)
-* [Persistent Advantage Learning (PAL)](https://arxiv.org/abs/1512.04860)
-* [Categorical Deep Q Network (C51)](https://arxiv.org/abs/1707.06887)
-* [Quantile Regression Deep Q Network (QR-DQN)](https://arxiv.org/pdf/1710.10044v1.pdf)
-* [Bootstrapped Deep Q Network](https://arxiv.org/abs/1602.04621)
-* [N-Step Q Learning](https://arxiv.org/abs/1602.01783) | **Distributed**
-* [Neural Episodic Control (NEC)](https://arxiv.org/abs/1703.01988)
-* [Normalized Advantage Functions (NAF)](https://arxiv.org/abs/1603.00748.pdf) | **Distributed**
-* [Policy Gradients (PG)](http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) | **Distributed**
-* [Asynchronous Advantage Actor-Critic (A3C)](https://arxiv.org/abs/1602.01783) | **Distributed**
-* [Deep Deterministic Policy Gradients (DDPG)](https://arxiv.org/abs/1509.02971) | **Distributed**
-* [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
-* [Clipped Proximal Policy Optimization](https://arxiv.org/pdf/1707.06347.pdf) | **Distributed**
-* [Direct Future Prediction (DFP)](https://arxiv.org/abs/1611.01779) | **Distributed**
-
+* [Mixed Monte Carlo (MMC)](https://arxiv.org/abs/1703.01310) ([code](agents/mmc_agent.py))
+* [Persistent Advantage Learning (PAL)](https://arxiv.org/abs/1512.04860) ([code](agents/pal_agent.py))
+* [Categorical Deep Q Network (C51)](https://arxiv.org/abs/1707.06887) ([code](agents/categorical_dqn_agent.py))
+* [Quantile Regression Deep Q Network (QR-DQN)](https://arxiv.org/pdf/1710.10044v1.pdf) ([code](agents/qr_dqn_agent.py))
+* [Bootstrapped Deep Q Network](https://arxiv.org/abs/1602.04621) ([code](agents/bootstrapped_dqn_agent.py))
+* [N-Step Q Learning](https://arxiv.org/abs/1602.01783) | **Distributed** ([code](agents/n_step_q_agent.py))
+* [Neural Episodic Control (NEC)](https://arxiv.org/abs/1703.01988) ([code](agents/nec_agent.py))
+* [Normalized Advantage Functions (NAF)](https://arxiv.org/abs/1603.00748.pdf) | **Distributed** ([code](agents/naf_agent.py))
+* [Policy Gradients (PG)](http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) | **Distributed** ([code](agents/policy_gradients_agent.py))
+* [Asynchronous Advantage Actor-Critic (A3C)](https://arxiv.org/abs/1602.01783) | **Distributed** ([code](agents/actor_critic_agent.py))
+* [Deep Deterministic Policy Gradients (DDPG)](https://arxiv.org/abs/1509.02971) | **Distributed** ([code](agents/ddpg_agent.py))
+* [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) ([code](agents/ppo_agent.py))
+* [Clipped Proximal Policy Optimization](https://arxiv.org/pdf/1707.06347.pdf) | **Distributed** ([code](agents/clipped_ppo_agent.py))
+* [Direct Future Prediction (DFP)](https://arxiv.org/abs/1611.01779) | **Distributed** ([code](agents/dfp_agent.py))
+* Behavioral Cloning (BC) ([code](agents/bc_agent.py))
diff --git a/agents/__init__.py b/agents/__init__.py
index b1ae8d324..fdbd13e51 100644
--- a/agents/__init__.py
+++ b/agents/__init__.py
@@ -16,6 +16,7 @@
from agents.actor_critic_agent import *
from agents.agent import *
+from agents.bc_agent import *
from agents.bootstrapped_dqn_agent import *
from agents.clipped_ppo_agent import *
from agents.ddpg_agent import *
@@ -23,6 +24,8 @@
from agents.dfp_agent import *
from agents.dqn_agent import *
from agents.categorical_dqn_agent import *
+from agents.human_agent import *
+from agents.imitation_agent import *
from agents.mmc_agent import *
from agents.n_step_q_agent import *
from agents.naf_agent import *
diff --git a/agents/agent.py b/agents/agent.py
index ed9eabce6..a541fa5f4 100644
--- a/agents/agent.py
+++ b/agents/agent.py
@@ -50,6 +50,7 @@ def __init__(self, env, tuning_parameters, replicated_device=None, task_id=0):
self.task_id = task_id
self.sess = tuning_parameters.sess
self.env = tuning_parameters.env_instance = env
+ self.imitation = False
# i/o dimensions
if not tuning_parameters.env.desired_observation_width or not tuning_parameters.env.desired_observation_height:
@@ -61,7 +62,12 @@ def __init__(self, env, tuning_parameters, replicated_device=None, task_id=0):
self.measurements_size = tuning_parameters.env.measurements_size = (self.measurements_size[0] + 1,)
# modules
- self.memory = eval(tuning_parameters.memory + '(tuning_parameters)')
+ if tuning_parameters.agent.load_memory_from_file_path:
+ screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
+ .format(tuning_parameters.agent.load_memory_from_file_path))
+ self.memory = read_pickle(tuning_parameters.agent.load_memory_from_file_path)
+ else:
+ self.memory = eval(tuning_parameters.memory + '(tuning_parameters)')
# self.architecture = eval(tuning_parameters.architecture)
self.has_global = replicated_device is not None
@@ -121,11 +127,12 @@ def __init__(self, env, tuning_parameters, replicated_device=None, task_id=0):
def log_to_screen(self, phase):
# log to screen
- if self.current_episode > 0:
- if phase == RunPhase.TEST:
- exploration = self.evaluation_exploration_policy.get_control_param()
- else:
+ if self.current_episode >= 0:
+ if phase == RunPhase.TRAIN:
exploration = self.exploration_policy.get_control_param()
+ else:
+ exploration = self.evaluation_exploration_policy.get_control_param()
+
screen.log_dict(
OrderedDict([
("Worker", self.task_id),
@@ -135,7 +142,7 @@ def log_to_screen(self, phase):
("steps", self.total_steps_counter),
("training iteration", self.training_iteration)
]),
- prefix="Heatup" if self.in_heatup else "Training" if phase == RunPhase.TRAIN else "Testing"
+ prefix=phase
)
def update_log(self, phase=RunPhase.TRAIN):
@@ -146,7 +153,7 @@ def update_log(self, phase=RunPhase.TRAIN):
# log all the signals to file
logger.set_current_time(self.current_episode)
logger.create_signal_value('Training Iter', self.training_iteration)
- logger.create_signal_value('In Heatup', int(self.in_heatup))
+ logger.create_signal_value('In Heatup', int(phase == RunPhase.HEATUP))
logger.create_signal_value('ER #Transitions', self.memory.num_transitions())
logger.create_signal_value('ER #Episodes', self.memory.length())
logger.create_signal_value('Episode Length', self.current_episode_steps_counter)
@@ -197,24 +204,6 @@ def reset_game(self, do_not_reset_env=False):
network.curr_rnn_c_in = network.middleware_embedder.c_init
network.curr_rnn_h_in = network.middleware_embedder.h_init
- def stack_observation(self, curr_stack, observation):
- """
- Adds a new observation to an existing stack of observations from previous time-steps.
- :param curr_stack: The current observations stack.
- :param observation: The new observation
- :return: The updated observation stack
- """
-
- if curr_stack == []:
- # starting an episode
- curr_stack = np.vstack(np.expand_dims([observation] * self.tp.env.observation_stack_size, 0))
- curr_stack = self.switch_axes_order(curr_stack, from_type='channels_first', to_type='channels_last')
- else:
- curr_stack = np.append(curr_stack, np.expand_dims(np.squeeze(observation), axis=-1), axis=-1)
- curr_stack = np.delete(curr_stack, 0, -1)
-
- return curr_stack
-
def preprocess_observation(self, observation):
"""
Preprocesses the given observation.
@@ -335,26 +324,6 @@ def preprocess_reward(self, reward):
reward = max(reward, self.tp.env.reward_clipping_min)
return reward
- def switch_axes_order(self, observation, from_type='channels_first', to_type='channels_last'):
- """
- transpose an observation axes from channels_first to channels_last or vice versa
- :param observation: a numpy array
- :param from_type: can be 'channels_first' or 'channels_last'
- :param to_type: can be 'channels_first' or 'channels_last'
- :return: a new observation with the requested axes order
- """
- if from_type == to_type or len(observation.shape) == 1:
- return observation
- assert 2 <= len(observation.shape) <= 3, 'num axes of an observation must be 2 for a vector or 3 for an image'
- assert type(observation) == np.ndarray, 'observation must be a numpy array'
- if len(observation.shape) == 3:
- if from_type == 'channels_first' and to_type == 'channels_last':
- return np.transpose(observation, (1, 2, 0))
- elif from_type == 'channels_last' and to_type == 'channels_first':
- return np.transpose(observation, (2, 0, 1))
- else:
- return np.transpose(observation, (1, 0))
-
def act(self, phase=RunPhase.TRAIN):
"""
Take one step in the environment according to the network prediction and store the transition in memory
@@ -370,7 +339,7 @@ def act(self, phase=RunPhase.TRAIN):
is_first_transition_in_episode = (self.curr_state == [])
if is_first_transition_in_episode:
observation = self.preprocess_observation(self.env.observation)
- observation = self.stack_observation([], observation)
+ observation = stack_observation([], observation, self.tp.env.observation_stack_size)
self.curr_state = {'observation': observation}
if self.tp.agent.use_measurements:
@@ -378,7 +347,7 @@ def act(self, phase=RunPhase.TRAIN):
if self.tp.agent.use_accumulated_reward_as_measurement:
self.curr_state['measurements'] = np.append(self.curr_state['measurements'], 0)
- if self.in_heatup: # we do not have a stacked curr_state yet
+ if phase == RunPhase.HEATUP and not self.tp.heatup_using_network_decisions:
action = self.env.get_random_action()
else:
action, action_info = self.choose_action(self.curr_state, phase=phase)
@@ -394,11 +363,11 @@ def act(self, phase=RunPhase.TRAIN):
observation = self.preprocess_observation(result['observation'])
# plot action values online
- if self.tp.visualization.plot_action_values_online and not self.in_heatup:
+ if self.tp.visualization.plot_action_values_online and phase != RunPhase.HEATUP:
self.plot_action_values_online()
# initialize the next state
- observation = self.stack_observation(self.curr_state['observation'], observation)
+ observation = stack_observation(self.curr_state['observation'], observation, self.tp.env.observation_stack_size)
next_state = {'observation': observation}
if self.tp.agent.use_measurements and 'measurements' in result.keys():
@@ -407,7 +376,7 @@ def act(self, phase=RunPhase.TRAIN):
next_state['measurements'] = np.append(next_state['measurements'], self.total_reward_in_current_episode)
# store the transition only if we are training
- if phase == RunPhase.TRAIN:
+ if phase == RunPhase.TRAIN or phase == RunPhase.HEATUP:
transition = Transition(self.curr_state, result['action'], shaped_reward, next_state, result['done'])
for key in action_info.keys():
transition.info[key] = action_info[key]
@@ -427,7 +396,7 @@ def act(self, phase=RunPhase.TRAIN):
self.update_log(phase=phase)
self.log_to_screen(phase=phase)
- if phase == RunPhase.TRAIN:
+ if phase == RunPhase.TRAIN or phase == RunPhase.HEATUP:
self.reset_game()
self.current_episode += 1
@@ -462,11 +431,12 @@ def evaluate(self, num_episodes, keep_networks_synced=False):
for network in self.networks:
network.sync()
- if self.tp.visualization.dump_gifs and self.total_reward_in_current_episode > max_reward_achieved:
+ if self.total_reward_in_current_episode > max_reward_achieved:
max_reward_achieved = self.total_reward_in_current_episode
frame_skipping = int(5/self.tp.env.frame_skip)
- logger.create_gif(self.last_episode_images[::frame_skipping],
- name='score-{}'.format(max_reward_achieved), fps=10)
+ if self.tp.visualization.dump_gifs:
+ logger.create_gif(self.last_episode_images[::frame_skipping],
+ name='score-{}'.format(max_reward_achieved), fps=10)
average_evaluation_reward += self.total_reward_in_current_episode
self.reset_game()
@@ -496,7 +466,7 @@ def improve(self):
screen.log_title("Starting heatup {}".format(self.task_id))
num_steps_required_for_one_training_batch = self.tp.batch_size * self.tp.env.observation_stack_size
for step in range(max(self.tp.num_heatup_steps, num_steps_required_for_one_training_batch)):
- self.act()
+ self.act(phase=RunPhase.HEATUP)
# training phase
self.in_heatup = False
@@ -509,7 +479,12 @@ def improve(self):
# evaluate
evaluate_agent = (self.last_episode_evaluation_ran is not self.current_episode) and \
(self.current_episode % self.tp.evaluate_every_x_episodes == 0)
+ evaluate_agent = evaluate_agent or \
+ (self.imitation and self.training_iteration > 0 and
+ self.training_iteration % self.tp.evaluate_every_x_training_iterations == 0)
+
if evaluate_agent:
+ self.env.reset()
self.last_episode_evaluation_ran = self.current_episode
self.evaluate(self.tp.evaluation_episodes)
@@ -522,14 +497,15 @@ def improve(self):
self.save_model(model_snapshots_periods_passed)
# play and record in replay buffer
- if self.tp.agent.step_until_collecting_full_episodes:
- step = 0
- while step < self.tp.agent.num_consecutive_playing_steps or self.memory.get_episode(-1).length() != 0:
- self.act()
- step += 1
- else:
- for step in range(self.tp.agent.num_consecutive_playing_steps):
- self.act()
+ if self.tp.agent.collect_new_data:
+ if self.tp.agent.step_until_collecting_full_episodes:
+ step = 0
+ while step < self.tp.agent.num_consecutive_playing_steps or self.memory.get_episode(-1).length() != 0:
+ self.act()
+ step += 1
+ else:
+ for step in range(self.tp.agent.num_consecutive_playing_steps):
+ self.act()
# train
if self.tp.train:
@@ -537,6 +513,8 @@ def improve(self):
loss = self.train()
self.loss.add_sample(loss)
self.training_iteration += 1
+ if self.imitation:
+ self.log_to_screen(RunPhase.TRAIN)
self.post_training_commands()
def save_model(self, model_id):
diff --git a/agents/bc_agent.py b/agents/bc_agent.py
new file mode 100644
index 000000000..e06573116
--- /dev/null
+++ b/agents/bc_agent.py
@@ -0,0 +1,40 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.imitation_agent import *
+
+
+# Behavioral Cloning Agent
+class BCAgent(ImitationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ImitationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+
+ def learn_from_batch(self, batch):
+ current_states, _, actions, _, _, _ = self.extract_batch(batch)
+
+ # create the inputs for the network
+ input = current_states
+
+ # the targets for the network are the actions since this is supervised learning
+ if self.env.discrete_controls:
+ targets = np.eye(self.env.action_space_size)[[actions]]
+ else:
+ targets = actions
+
+ result = self.main_network.train_and_sync_networks(input, targets)
+ total_loss = result[0]
+
+ return total_loss
diff --git a/agents/distributional_dqn_agent.py b/agents/distributional_dqn_agent.py
new file mode 100644
index 000000000..d7c00880c
--- /dev/null
+++ b/agents/distributional_dqn_agent.py
@@ -0,0 +1,60 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.value_optimization_agent import *
+
+
+# Distributional Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
+class DistributionalDQNAgent(ValueOptimizationAgent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+ self.z_values = np.linspace(self.tp.agent.v_min, self.tp.agent.v_max, self.tp.agent.atoms)
+
+ # prediction's format is (batch,actions,atoms)
+ def get_q_values(self, prediction):
+ return np.dot(prediction, self.z_values)
+
+ def learn_from_batch(self, batch):
+ current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
+
+ # for the action we actually took, the error is calculated by the atoms distribution
+ # for all other actions, the error is 0
+ distributed_q_st_plus_1 = self.main_network.target_network.predict(next_states)
+ # initialize with the current prediction so that we will
+ TD_targets = self.main_network.online_network.predict(current_states)
+
+ # only update the action that we have actually done in this transition
+ target_actions = np.argmax(self.get_q_values(distributed_q_st_plus_1), axis=1)
+ m = np.zeros((self.tp.batch_size, self.z_values.size))
+
+ batches = np.arange(self.tp.batch_size)
+ for j in range(self.z_values.size):
+ tzj = np.fmax(np.fmin(rewards + (1.0 - game_overs) * self.tp.agent.discount * self.z_values[j],
+ self.z_values[self.z_values.size - 1]),
+ self.z_values[0])
+ bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0])
+ u = (np.ceil(bj)).astype(int)
+ l = (np.floor(bj)).astype(int)
+ m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj))
+ m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l))
+ # total_loss = cross entropy between actual result above and predicted result for the given action
+ TD_targets[batches, actions] = m
+
+ result = self.main_network.train_and_sync_networks(current_states, TD_targets)
+ total_loss = result[0]
+
+ return total_loss
+
diff --git a/agents/human_agent.py b/agents/human_agent.py
new file mode 100644
index 000000000..c75c2a2a7
--- /dev/null
+++ b/agents/human_agent.py
@@ -0,0 +1,67 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.agent import *
+import pygame
+
+
+class HumanAgent(Agent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ Agent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+
+ self.clock = pygame.time.Clock()
+ self.max_fps = int(self.tp.visualization.max_fps_for_human_control)
+
+ screen.log_title("Human Control Mode")
+ available_keys = self.env.get_available_keys()
+ if available_keys:
+ screen.log("Use keyboard keys to move. Press escape to quit. Available keys:")
+ screen.log("")
+ for action, key in self.env.get_available_keys():
+ screen.log("\t- {}: {}".format(action, key))
+ screen.separator()
+
+ def train(self):
+ return 0
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ action = self.env.get_action_from_user()
+
+ # keep constant fps
+ self.clock.tick(self.max_fps)
+
+ if not self.env.renderer.is_open:
+ self.save_replay_buffer_and_exit()
+
+ return action, {"action_value": 0}
+
+ def save_replay_buffer_and_exit(self):
+ replay_buffer_path = os.path.join(logger.experiments_path, 'replay_buffer.p')
+ self.memory.tp = None
+ to_pickle(self.memory, replay_buffer_path)
+ screen.log_title("Replay buffer was stored in {}".format(replay_buffer_path))
+ exit()
+
+ def log_to_screen(self, phase):
+ # log to screen
+ screen.log_dict(
+ OrderedDict([
+ ("Episode", self.current_episode),
+ ("total reward", self.total_reward_in_current_episode),
+ ("steps", self.total_steps_counter)
+ ]),
+ prefix="Recording"
+ )
diff --git a/agents/imitation_agent.py b/agents/imitation_agent.py
new file mode 100644
index 000000000..f7f5e0695
--- /dev/null
+++ b/agents/imitation_agent.py
@@ -0,0 +1,70 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from agents.agent import *
+
+
+# Imitation Agent
+class ImitationAgent(Agent):
+ def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
+ Agent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
+ self.main_network = NetworkWrapper(tuning_parameters, False, self.has_global, 'main',
+ self.replicated_device, self.worker_device)
+ self.networks.append(self.main_network)
+ self.imitation = True
+
+ def extract_action_values(self, prediction):
+ return prediction.squeeze()
+
+ def choose_action(self, curr_state, phase=RunPhase.TRAIN):
+ # convert to batch so we can run it through the network
+ observation = np.expand_dims(np.array(curr_state['observation']), 0)
+ if self.tp.agent.use_measurements:
+ measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
+ prediction = self.main_network.online_network.predict([observation, measurements])
+ else:
+ prediction = self.main_network.online_network.predict(observation)
+
+ # get action values and extract the best action from it
+ action_values = self.extract_action_values(prediction)
+ if self.env.discrete_controls:
+ # DISCRETE
+ # action = np.argmax(action_values)
+ action = self.evaluation_exploration_policy.get_action(action_values)
+ action_value = {"action_probability": action_values[action]}
+ else:
+ # CONTINUOUS
+ action = action_values
+ action_value = {}
+
+ return action, action_value
+
+ def log_to_screen(self, phase):
+ # log to screen
+ if phase == RunPhase.TRAIN:
+ # for the training phase - we log during the episode to visualize the progress in training
+ screen.log_dict(
+ OrderedDict([
+ ("Worker", self.task_id),
+ ("Episode", self.current_episode),
+ ("Loss", self.loss.values[-1]),
+ ("Training iteration", self.training_iteration)
+ ]),
+ prefix="Training"
+ )
+ else:
+ # for the evaluation phase - logging as in regular RL
+ Agent.log_to_screen(self, phase)
diff --git a/agents/n_step_q_agent.py b/agents/n_step_q_agent.py
index 07465230f..3b464a894 100644
--- a/agents/n_step_q_agent.py
+++ b/agents/n_step_q_agent.py
@@ -45,7 +45,7 @@ def learn_from_batch(self, batch):
# 1-Step Q learning
q_st_plus_1 = self.main_network.target_network.predict(next_states)
- for i in reversed(xrange(num_transitions)):
+ for i in reversed(range(num_transitions)):
state_value_head_targets[i][actions[i]] = \
rewards[i] + (1.0 - game_overs[i]) * self.tp.agent.discount * np.max(q_st_plus_1[i], 0)
@@ -56,7 +56,7 @@ def learn_from_batch(self, batch):
else:
R = np.max(self.main_network.target_network.predict(np.expand_dims(next_states[-1], 0)))
- for i in reversed(xrange(num_transitions)):
+ for i in reversed(range(num_transitions)):
R = rewards[i] + self.tp.agent.discount * R
state_value_head_targets[i][actions[i]] = R
diff --git a/agents/policy_optimization_agent.py b/agents/policy_optimization_agent.py
index c64dbab6d..07aac6aff 100644
--- a/agents/policy_optimization_agent.py
+++ b/agents/policy_optimization_agent.py
@@ -58,7 +58,7 @@ def log_to_screen(self, phase):
("steps", self.total_steps_counter),
("training iteration", self.training_iteration)
]),
- prefix="Heatup" if self.in_heatup else "Training" if phase == RunPhase.TRAIN else "Testing"
+ prefix=phase
)
def update_episode_statistics(self, episode):
diff --git a/architectures/network_wrapper.py b/architectures/network_wrapper.py
index a03448505..bbe6c590a 100644
--- a/architectures/network_wrapper.py
+++ b/architectures/network_wrapper.py
@@ -75,11 +75,14 @@ def __init__(self, tuning_parameters, has_target, has_global, name, replicated_d
network_is_local=True)
if not self.tp.distributed and self.tp.framework == Frameworks.TensorFlow:
- self.model_saver = tf.train.Saver()
+ variables_to_restore = tf.global_variables()
+ variables_to_restore = [v for v in variables_to_restore if '/online' in v.name]
+ self.model_saver = tf.train.Saver(variables_to_restore)
if self.tp.sess and self.tp.checkpoint_restore_dir:
checkpoint = tf.train.latest_checkpoint(self.tp.checkpoint_restore_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint))
self.model_saver.restore(self.tp.sess, checkpoint)
+ self.update_target_network()
def sync(self):
"""
diff --git a/architectures/tensorflow_components/embedders.py b/architectures/tensorflow_components/embedders.py
index 2b3212c72..6b6acd2da 100644
--- a/architectures/tensorflow_components/embedders.py
+++ b/architectures/tensorflow_components/embedders.py
@@ -15,15 +15,18 @@
#
import tensorflow as tf
+from configurations import EmbedderComplexity
class InputEmbedder(object):
- def __init__(self, input_size, activation_function=tf.nn.relu, name="embedder"):
+ def __init__(self, input_size, activation_function=tf.nn.relu,
+ embedder_complexity=EmbedderComplexity.Shallow, name="embedder"):
self.name = name
self.input_size = input_size
self.activation_function = activation_function
self.input = None
self.output = None
+ self.embedder_complexity = embedder_complexity
def __call__(self, prev_input_placeholder=None):
with tf.variable_scope(self.get_name()):
@@ -43,31 +46,77 @@ def get_name(self):
class ImageEmbedder(InputEmbedder):
- def __init__(self, input_size, input_rescaler=255.0, activation_function=tf.nn.relu, name="embedder"):
- InputEmbedder.__init__(self, input_size, activation_function, name)
+ def __init__(self, input_size, input_rescaler=255.0, activation_function=tf.nn.relu,
+ embedder_complexity=EmbedderComplexity.Shallow, name="embedder"):
+ InputEmbedder.__init__(self, input_size, activation_function, embedder_complexity, name)
self.input_rescaler = input_rescaler
def _build_module(self):
# image observation
rescaled_observation_stack = self.input / self.input_rescaler
- self.observation_conv1 = tf.layers.conv2d(rescaled_observation_stack,
- filters=32, kernel_size=(8, 8), strides=(4, 4),
- activation=self.activation_function, data_format='channels_last')
- self.observation_conv2 = tf.layers.conv2d(self.observation_conv1,
- filters=64, kernel_size=(4, 4), strides=(2, 2),
- activation=self.activation_function, data_format='channels_last')
- self.observation_conv3 = tf.layers.conv2d(self.observation_conv2,
- filters=64, kernel_size=(3, 3), strides=(1, 1),
- activation=self.activation_function, data_format='channels_last')
- self.output = tf.contrib.layers.flatten(self.observation_conv3)
+ if self.embedder_complexity == EmbedderComplexity.Shallow:
+ # same embedder as used in the original DQN paper
+ self.observation_conv1 = tf.layers.conv2d(rescaled_observation_stack,
+ filters=32, kernel_size=(8, 8), strides=(4, 4),
+ activation=self.activation_function, data_format='channels_last')
+ self.observation_conv2 = tf.layers.conv2d(self.observation_conv1,
+ filters=64, kernel_size=(4, 4), strides=(2, 2),
+ activation=self.activation_function, data_format='channels_last')
+ self.observation_conv3 = tf.layers.conv2d(self.observation_conv2,
+ filters=64, kernel_size=(3, 3), strides=(1, 1),
+ activation=self.activation_function, data_format='channels_last')
+
+ self.output = tf.contrib.layers.flatten(self.observation_conv3)
+
+ elif self.embedder_complexity == EmbedderComplexity.Deep:
+ # the embedder used in the CARLA papers
+ self.observation_conv1 = tf.layers.conv2d(rescaled_observation_stack,
+ filters=32, kernel_size=(5, 5), strides=(2, 2),
+ activation=self.activation_function, data_format='channels_last')
+ self.observation_conv2 = tf.layers.conv2d(self.observation_conv1,
+ filters=32, kernel_size=(3, 3), strides=(1, 1),
+ activation=self.activation_function, data_format='channels_last')
+ self.observation_conv3 = tf.layers.conv2d(self.observation_conv2,
+ filters=64, kernel_size=(3, 3), strides=(2, 2),
+ activation=self.activation_function, data_format='channels_last')
+ self.observation_conv4 = tf.layers.conv2d(self.observation_conv3,
+ filters=64, kernel_size=(3, 3), strides=(1, 1),
+ activation=self.activation_function, data_format='channels_last')
+ self.observation_conv5 = tf.layers.conv2d(self.observation_conv4,
+ filters=128, kernel_size=(3, 3), strides=(2, 2),
+ activation=self.activation_function, data_format='channels_last')
+ self.observation_conv6 = tf.layers.conv2d(self.observation_conv5,
+ filters=128, kernel_size=(3, 3), strides=(1, 1),
+ activation=self.activation_function, data_format='channels_last')
+ self.observation_conv7 = tf.layers.conv2d(self.observation_conv6,
+ filters=256, kernel_size=(3, 3), strides=(2, 2),
+ activation=self.activation_function, data_format='channels_last')
+ self.observation_conv8 = tf.layers.conv2d(self.observation_conv7,
+ filters=256, kernel_size=(3, 3), strides=(1, 1),
+ activation=self.activation_function, data_format='channels_last')
+
+ self.output = tf.contrib.layers.flatten(self.observation_conv8)
+ else:
+ raise ValueError("The defined embedder complexity value is invalid")
class VectorEmbedder(InputEmbedder):
- def __init__(self, input_size, activation_function=tf.nn.relu, name="embedder"):
- InputEmbedder.__init__(self, input_size, activation_function, name)
+ def __init__(self, input_size, activation_function=tf.nn.relu,
+ embedder_complexity=EmbedderComplexity.Shallow, name="embedder"):
+ InputEmbedder.__init__(self, input_size, activation_function, embedder_complexity, name)
def _build_module(self):
# vector observation
input_layer = tf.contrib.layers.flatten(self.input)
- self.output = tf.layers.dense(input_layer, 256, activation=self.activation_function)
+
+ if self.embedder_complexity == EmbedderComplexity.Shallow:
+ self.output = tf.layers.dense(input_layer, 256, activation=self.activation_function)
+
+ elif self.embedder_complexity == EmbedderComplexity.Deep:
+ # the embedder used in the CARLA papers
+ self.observation_fc1 = tf.layers.dense(input_layer, 128, activation=self.activation_function)
+ self.observation_fc2 = tf.layers.dense(self.observation_fc1, 128, activation=self.activation_function)
+ self.output = tf.layers.dense(self.observation_fc2, 128, activation=self.activation_function)
+ else:
+ raise ValueError("The defined embedder complexity value is invalid")
diff --git a/coach.py b/coach.py
index ffddbc956..45b7382d1 100644
--- a/coach.py
+++ b/coach.py
@@ -37,8 +37,29 @@
cur_time = time_started.time()
cur_date = time_started.date()
-def get_experiment_path(general_experiments_path):
- if not os.path.exists(general_experiments_path):
+
+def get_experiment_name(initial_experiment_name=''):
+ match = None
+ while match is None:
+ if initial_experiment_name == '':
+ experiment_name = screen.ask_input("Please enter an experiment name: ")
+ else:
+ experiment_name = initial_experiment_name
+
+ experiment_name = experiment_name.replace(" ", "_")
+ match = re.match("^$|^[\w -/]{1,100}$", experiment_name)
+
+ if match is None:
+ screen.error('Experiment name must be composed only of alphanumeric letters, '
+ 'underscores and dashes and should not be longer than 100 characters.')
+
+ return match.group(0)
+
+
+def get_experiment_path(experiment_name, create_path=True):
+ general_experiments_path = os.path.join('./experiments/', experiment_name)
+
+ if not os.path.exists(general_experiments_path) and create_path:
os.makedirs(general_experiments_path)
experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}'
.format(logger.two_digits(cur_date.day), logger.two_digits(cur_date.month),
@@ -52,7 +73,8 @@ def get_experiment_path(general_experiments_path):
cur_time.minute, i))
i += 1
else:
- os.makedirs(experiment_path)
+ if create_path:
+ os.makedirs(experiment_path)
return experiment_path
@@ -96,55 +118,54 @@ def check_input_and_fill_run_dict(parser):
num_workers = int(re.match("^\d+$", args.num_workers).group(0))
except ValueError:
screen.error("Parameter num_workers should be an integer.")
- exit(1)
preset_names = list_all_classes_in_module(presets)
if args.preset is not None and args.preset not in preset_names:
screen.error("A non-existing preset was selected. ")
- exit(1)
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
screen.error("The requested checkpoint folder to load from does not exist. ")
- exit(1)
if args.save_model_sec is not None:
try:
args.save_model_sec = int(args.save_model_sec)
except ValueError:
screen.error("Parameter save_model_sec should be an integer.")
- exit(1)
if args.preset is None and (args.agent_type is None or args.environment_type is None
- or args.exploration_policy_type is None):
+ or args.exploration_policy_type is None) and not args.play:
screen.error('When no preset is given for Coach to run, the user is expected to input the desired agent_type,'
' environment_type and exploration_policy_type to assemble a preset. '
'\nAt least one of these parameters was not given.')
- exit(1)
-
- experiment_name = args.experiment_name
-
- if args.experiment_name == '':
- experiment_name = screen.ask_input("Please enter an experiment name: ")
+ elif args.preset is None and args.play and args.environment_type is None:
+ screen.error('When no preset is given for Coach to run, and the user requests human control over the environment,'
+ ' the user is expected to input the desired environment_type and level.'
+ '\nAt least one of these parameters was not given.')
+ elif args.preset is None and args.play and args.environment_type:
+ args.agent_type = 'Human'
+ args.exploration_policy_type = 'ExplorationParameters'
- experiment_name = experiment_name.replace(" ", "_")
- match = re.match("^$|^\w{1,100}$", experiment_name)
+ # get experiment name and path
+ experiment_name = get_experiment_name(args.experiment_name)
+ experiment_path = get_experiment_path(experiment_name)
- if match is None:
- screen.error('Experiment name must be composed only of alphanumeric letters and underscores and should not be '
- 'longer than 100 characters.')
- exit(1)
- experiment_path = os.path.join('./experiments/', match.group(0))
- experiment_path = get_experiment_path(experiment_path)
+ if args.play and num_workers > 1:
+ screen.warning("Playing the game as a human is only available with a single worker. "
+ "The number of workers will be reduced to 1")
+ num_workers = 1
# fill run_dict
run_dict = dict()
run_dict['agent_type'] = args.agent_type
run_dict['environment_type'] = args.environment_type
run_dict['exploration_policy_type'] = args.exploration_policy_type
+ run_dict['level'] = args.level
run_dict['preset'] = args.preset
run_dict['custom_parameter'] = args.custom_parameter
run_dict['experiment_path'] = experiment_path
run_dict['framework'] = Frameworks().get(args.framework)
+ run_dict['play'] = args.play
+ run_dict['evaluate'] = args.evaluate# or args.play
# multi-threading parameters
run_dict['num_threads'] = num_workers
@@ -197,6 +218,14 @@ def run_dict_to_json(_run_dict, task_id=''):
help="(int) Number of workers for multi-process based agents, e.g. A3C",
default='1',
type=str)
+ parser.add_argument('--play',
+ help="(flag) Play as a human by controlling the game with the keyboard. "
+ "This option will save a replay buffer with the game play.",
+ action='store_true')
+ parser.add_argument('--evaluate',
+ help="(flag) Run evaluation only. This is a convenient way to disable "
+ "training in order to evaluate an existing checkpoint.",
+ action='store_true')
parser.add_argument('-v', '--verbose',
help="(flag) Don't suppress TensorFlow debug prints.",
action='store_true')
@@ -230,6 +259,12 @@ def run_dict_to_json(_run_dict, task_id=''):
,
default=None,
type=str)
+ parser.add_argument('-lvl', '--level',
+ help="(string) Choose the level that will be played in the environment that was selected."
+ "This value will override the level parameter in the environment class."
+ ,
+ default=None,
+ type=str)
parser.add_argument('-cp', '--custom_parameter',
help="(string) Semicolon separated parameters used to override specific parameters on top of"
" the selected preset (or on top of the command-line assembled one). "
@@ -259,7 +294,12 @@ def run_dict_to_json(_run_dict, task_id=''):
tuning_parameters.task_index = 0
env_instance = create_environment(tuning_parameters)
agent = eval(tuning_parameters.agent.type + '(env_instance, tuning_parameters)')
- agent.improve()
+
+ # Start the training or evaluation
+ if tuning_parameters.evaluate:
+ agent.evaluate(sys.maxsize, keep_networks_synced=True) # evaluate forever
+ else:
+ agent.improve()
# Multi-threaded runs
else:
diff --git a/configurations.py b/configurations.py
index b7f99536a..8480c1912 100644
--- a/configurations.py
+++ b/configurations.py
@@ -32,6 +32,11 @@ class InputTypes(object):
TimedObservation = 5
+class EmbedderComplexity(object):
+ Shallow = 1
+ Deep = 2
+
+
class OutputTypes(object):
Q = 1
DuelingQ = 2
@@ -60,6 +65,7 @@ class AgentParameters(object):
middleware_type = MiddlewareTypes.FC
loss_weights = [1.0]
stop_gradients_from_head = [False]
+ embedder_complexity = EmbedderComplexity.Shallow
num_output_head_copies = 1
use_measurements = False
use_accumulated_reward_as_measurement = False
@@ -90,6 +96,8 @@ class AgentParameters(object):
step_until_collecting_full_episodes = False
targets_horizon = 'N-Step'
replace_mse_with_huber_loss = False
+ load_memory_from_file_path = None
+ collect_new_data = True
# PPO related params
target_kl_divergence = 0.01
@@ -132,6 +140,7 @@ class EnvironmentParameters(object):
reward_scaling = 1.0
reward_clipping_min = None
reward_clipping_max = None
+ human_control = False
class ExplorationParameters(object):
@@ -188,6 +197,7 @@ class GeneralParameters(object):
kl_divergence_constraint = 100000
num_training_iterations = 10000000000
num_heatup_steps = 1000
+ heatup_using_network_decisions = False
batch_size = 32
save_model_sec = None
save_model_dir = None
@@ -197,6 +207,7 @@ class GeneralParameters(object):
learning_rate_decay_steps = 0
evaluation_episodes = 5
evaluate_every_x_episodes = 1000000
+ evaluate_every_x_training_iterations = 0
rescaling_interpolation_type = 'bilinear'
# setting a seed will only work for non-parallel algorithms. Parallel algorithms add uncontrollable noise in
@@ -224,6 +235,7 @@ class VisualizationParameters(object):
dump_signals_to_csv_every_x_episodes = 10
render = False
dump_gifs = True
+ max_fps_for_human_control = 10
class Roboschool(EnvironmentParameters):
@@ -252,7 +264,7 @@ class Bullet(EnvironmentParameters):
class Atari(EnvironmentParameters):
type = 'Gym'
- frame_skip = 1
+ frame_skip = 4
observation_stack_size = 4
desired_observation_height = 84
desired_observation_width = 84
@@ -268,6 +280,31 @@ class Doom(EnvironmentParameters):
desired_observation_width = 76
+class Carla(EnvironmentParameters):
+ type = 'Carla'
+ frame_skip = 1
+ observation_stack_size = 4
+ desired_observation_height = 128
+ desired_observation_width = 180
+ normalize_observation = False
+ server_height = 256
+ server_width = 360
+ config = 'environments/CarlaSettings.ini'
+ level = 'town1'
+ verbose = True
+ stereo = False
+ semantic_segmentation = False
+ depth = False
+ episode_max_time = 100000 # miliseconds for each episode
+ continuous_to_bool_threshold = 0.5
+ allow_braking = False
+
+
+class Human(AgentParameters):
+ type = 'HumanAgent'
+ num_episodes_in_experience_replay = 10000000
+
+
class NStepQ(AgentParameters):
type = 'NStepQAgent'
input_types = [InputTypes.Observation]
@@ -299,10 +336,12 @@ class DQN(AgentParameters):
class DDQN(DQN):
type = 'DDQNAgent'
+
class DuelingDQN(DQN):
type = 'DQNAgent'
output_types = [OutputTypes.DuelingQ]
+
class BootstrappedDQN(DQN):
type = 'BootstrappedDQNAgent'
num_output_head_copies = 10
@@ -314,6 +353,7 @@ class CategoricalDQN(DQN):
v_min = -10.0
v_max = 10.0
atoms = 51
+ neon_support = False
class QuantileRegressionDQN(DQN):
@@ -452,6 +492,7 @@ class ClippedPPO(AgentParameters):
step_until_collecting_full_episodes = True
beta_entropy = 0.01
+
class DFP(AgentParameters):
type = 'DFPAgent'
input_types = [InputTypes.Observation, InputTypes.Measurements, InputTypes.GoalVector]
@@ -485,6 +526,15 @@ class PAL(AgentParameters):
neon_support = True
+class BC(AgentParameters):
+ type = 'BCAgent'
+ input_types = [InputTypes.Observation]
+ output_types = [OutputTypes.Q]
+ loss_weights = [1.0]
+ collect_new_data = False
+ evaluate_every_x_training_iterations = 50000
+
+
class EGreedyExploration(ExplorationParameters):
policy = 'EGreedy'
initial_epsilon = 0.5
diff --git a/docs/docs/algorithms/imitation/bc.md b/docs/docs/algorithms/imitation/bc.md
new file mode 100644
index 000000000..84e477a5b
--- /dev/null
+++ b/docs/docs/algorithms/imitation/bc.md
@@ -0,0 +1,25 @@
+# Behavioral Cloning
+
+**Actions space:** Discrete|Continuous
+
+## Network Structure
+
+
+
+
+
+
+
+
+
+## Algorithm Description
+
+### Training the network
+
+The replay buffer contains the expert demonstrations for the task.
+These demonstrations are given as state, action tuples, and with no reward.
+The training goal is to reduce the difference between the actions predicted by the network and the actions taken by the expert for each state.
+
+1. Sample a batch of transitions from the replay buffer.
+2. Use the current states as input to the network, and the expert actions as the targets of the network.
+3. The loss function for the network is MSE, and therefore we use the Q head to minimize this loss.
diff --git a/docs/docs/algorithms/value_optimization/distributional_dqn.md b/docs/docs/algorithms/value_optimization/distributional_dqn.md
new file mode 100644
index 000000000..5dcc4c260
--- /dev/null
+++ b/docs/docs/algorithms/value_optimization/distributional_dqn.md
@@ -0,0 +1,33 @@
+# Distributional DQN
+
+**Actions space:** Discrete
+
+**References:** [A Distributional Perspective on Reinforcement Learning](https://arxiv.org/abs/1707.06887)
+
+## Network Structure
+
+
+
+
+
+
+
+
+
+## Algorithmic Description
+
+### Training the network
+
+1. Sample a batch of transitions from the replay buffer.
+2. The Bellman update is projected to the set of atoms representing the $ Q $ values distribution, such that the $i-th$ component of the projected update is calculated as follows:
+ $$ (\Phi \hat{T} Z_{\theta}(s_t,a_t))_i=\sum_{j=0}^{N-1}\Big[1-\frac{|[\hat{T}_{z_{j}}]^{V_{MAX}}_{V_{MIN}}-z_i|}{\Delta z}\Big]^1_0 \ p_j(s_{t+1}, \pi(s_{t+1})) $$
+ where:
+ * $[ \cdot ] $ bounds its argument in the range [a, b]
+ * $\hat{T}_{z_{j}}$ is the Bellman update for atom $z_j$: $\hat{T}_{z_{j}} := r+\gamma z_j$
+
+
+3. Network is trained with the cross entropy loss between the resulting probability distribution and the target probability distribution. Only the target of the actions that were actually taken is updated.
+4. Once in every few thousand steps, weights are copied from the online network to the target network.
+
+
+
diff --git a/docs/docs/contributing/add_env.md b/docs/docs/contributing/add_env.md
index 51c51c182..f09b4dbb8 100644
--- a/docs/docs/contributing/add_env.md
+++ b/docs/docs/contributing/add_env.md
@@ -1,33 +1,53 @@
Adding a new environment to Coach is as easy as solving CartPole.
-There a few simple steps to follow, and we will walk through them one by one.
+There are a few simple steps to follow, and we will walk through them one by one.
1. Coach defines a simple API for implementing a new environment which is defined in environment/environment_wrapper.py.
There are several functions to implement, but only some of them are mandatory.
- Here are the mandatory ones:
+ Here are the important ones:
- def step(self, action_idx):
+ def _take_action(self, action_idx):
"""
- Perform a single step on the environment using the given action.
- :param action_idx: the action to perform on the environment
- :return: A dictionary containing the observation, reward, done flag, action and measurements
+ An environment dependent function that sends an action to the simulator.
+ :param action_idx: the action to perform on the environment.
+ :return: None
"""
pass
- def render(self):
+ def _preprocess_observation(self, observation):
"""
- Call the environment function for rendering to the screen.
+ Do initial observation preprocessing such as cropping, rgb2gray, rescale etc.
+ Implementing this function is optional.
+ :param observation: a raw observation from the environment
+ :return: the preprocessed observation
+ """
+ return observation
+
+ def _update_state(self):
+ """
+ Updates the state from the environment.
+ Should update self.observation, self.reward, self.done, self.measurements and self.info
+ :return: None
"""
pass
-
+
def _restart_environment_episode(self, force_environment_reset=False):
"""
- :param force_environment_reset: Force the environment to reset even if the episode is not done yet.
- :return:
+ :param force_environment_reset: Force the environment to reset even if the episode is not done yet.
+ :return:
"""
pass
+ def get_rendered_image(self):
+ """
+ Return a numpy array containing the image that will be rendered to the screen.
+ This can be different from the observation. For example, mujoco's observation is a measurements vector.
+ :return: numpy array containing the image that will be rendered to the screen
+ """
+ return self.observation
+
+
2. Make sure to import the environment in environments/\_\_init\_\_.py:
from doom_environment_wrapper import *
diff --git a/docs/docs/img/algorithms.png b/docs/docs/img/algorithms.png
index 2dc14077b..f83c1e69f 100644
Binary files a/docs/docs/img/algorithms.png and b/docs/docs/img/algorithms.png differ
diff --git a/docs/docs/usage.md b/docs/docs/usage.md
new file mode 100644
index 000000000..26d13b9c4
--- /dev/null
+++ b/docs/docs/usage.md
@@ -0,0 +1,133 @@
+# Coach Usage
+
+## Training an Agent
+
+### Single-threaded Algorithms
+
+This is the most common case. Just choose a preset using the `-p` flag and press enter.
+
+*Example:*
+
+`python coach.py -p CartPole_DQN`
+
+### Multi-threaded Algorithms
+
+Multi-threaded algorithms are very common this days.
+They typically achieve the best results, and scale gracefully with the number of threads.
+In Coach, running such algorithms is done by selecting a suitable preset, and choosing the number of threads to run using the `-n` flag.
+
+*Example:*
+
+`python coach.py -p CartPole_A3C -n 8`
+
+## Evaluating an Agent
+
+There are several options for evaluating an agent during the training:
+
+* For multi-threaded runs, an evaluation agent will constantly run in the background and evaluate the model during the training.
+
+* For single-threaded runs, it is possible to define an evaluation period through the preset. This will run several episodes of evaluation once in a while.
+
+Additionally, it is possible to save checkpoints of the agents networks and then run only in evaluation mode.
+Saving checkpoints can be done by specifying the number of seconds between storing checkpoints using the `-s` flag.
+The checkpoints will be saved into the experiment directory.
+Loading a model for evaluation can be done by specifying the `-crd` flag with the experiment directory, and the `--evaluate` flag to disable training.
+
+*Example:*
+
+`python coach.py -p CartPole_DQN -s 60`
+`python coach.py -p CartPole_DQN --evaluate -crd CHECKPOINT_RESTORE_DIR`
+
+## Playing with the Environment as a Human
+
+Interacting with the environment as a human can be useful for understanding its difficulties and for collecting data for imitation learning.
+In Coach, this can be easily done by selecting a preset that defines the environment to use, and specifying the `--play` flag.
+When the environment is loaded, the available keyboard buttons will be printed to the screen.
+Pressing the escape key when finished will end the simulation and store the replay buffer in the experiment dir.
+
+*Example:*
+
+`python coach.py -p Breakout_DQN --play`
+
+## Learning Through Imitation Learning
+
+Learning through imitation of human behavior is a nice way to speedup the learning.
+In Coach, this can be done in two steps -
+
+1. Create a dataset of demonstrations by playing with the environment as a human.
+ After this step, a pickle of the replay buffer containing your game play will be stored in the experiment directory.
+ The path to this replay buffer will be printed to the screen.
+ To do so, you should select an environment type and level through the command line, and specify the `--play` flag.
+
+ *Example:*
+
+ `python coach.py -et Doom -lvl Basic --play`
+
+
+2. Next, use an imitation learning preset and set the replay buffer path accordingly.
+ The path can be set either from the command line or from the preset itself.
+
+ *Example:*
+
+ `python coach.py -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=\"/replay_buffer.p\"'`
+
+
+## Visualizations
+
+### Rendering the Environment
+
+Rendering the environment can be done by using the `-r` flag.
+When working with multi-threaded algorithms, the rendered image will be representing the game play of the evaluation worker.
+When working with single-threaded algorithms, the rendered image will be representing the single worker which can be either training or evaluating.
+Keep in mind that rendering the environment in single-threaded algorithms may slow the training to some extent.
+When playing with the environment using the `--play` flag, the environment will be rendered automatically without the need for specifying the `-r` flag.
+
+*Example:*
+
+`python coach.py -p Breakout_DQN -r`
+
+### Dumping GIFs
+
+Coach allows storing GIFs of the agent game play.
+To dump GIF files, use the `-dg` flag.
+The files are dumped after every evaluation episode, and are saved into the experiment directory, under a gifs sub-directory.
+
+*Example:*
+
+`python coach.py -p Breakout_A3C -n 4 -dg`
+
+## Switching between deep learning frameworks
+
+Coach uses TensorFlow as its main backend framework, but it also supports neon for some of the algorithms.
+By default, TensorFlow will be used. It is possible to switch to neon using the `-f` flag.
+
+*Example:*
+
+`python coach.py -p Doom_Basic_DQN -f neon`
+
+## Additional Flags
+
+There are several convenient flags which are important to know about.
+Here we will list most of the flags, but these can be updated from time to time.
+The most up to date description can be found by using the `-h` flag.
+
+
+|Flag |Type |Description |
+|-------------------------------|----------|--------------|
+|`-p PRESET`, ``--preset PRESET`|string |Name of a preset to run (as configured in presets.py) |
+|`-l`, `--list` |flag |List all available presets|
+|`-e EXPERIMENT_NAME`, `--experiment_name EXPERIMENT_NAME`|string|Experiment name to be used to store the results.|
+|`-r`, `--render` |flag |Render environment|
+|`-f FRAMEWORK`, `--framework FRAMEWORK`|string|Neural network framework. Available values: tensorflow, neon|
+|`-n NUM_WORKERS`, `--num_workers NUM_WORKERS`|int|Number of workers for multi-process based agents, e.g. A3C|
+|`--play` |flag |Play as a human by controlling the game with the keyboard. This option will save a replay buffer with the game play.|
+|`--evaluate` |flag |Run evaluation only. This is a convenient way to disable training in order to evaluate an existing checkpoint.|
+|`-v`, `--verbose` |flag |Don't suppress TensorFlow debug prints.|
+|`-s SAVE_MODEL_SEC`, `--save_model_sec SAVE_MODEL_SEC`|int|Time in seconds between saving checkpoints of the model.|
+|`-crd CHECKPOINT_RESTORE_DIR`, `--checkpoint_restore_dir CHECKPOINT_RESTORE_DIR`|string|Path to a folder containing a checkpoint to restore the model from.|
+|`-dg`, `--dump_gifs` |flag |Enable the gif saving functionality.|
+|`-at AGENT_TYPE`, `--agent_type AGENT_TYPE`|string|Choose an agent type class to override on top of the selected preset. If no preset is defined, a preset can be set from the command-line by combining settings which are set by using `--agent_type`, `--experiment_type`, `--environemnt_type`|
+|`-et ENVIRONMENT_TYPE`, `--environment_type ENVIRONMENT_TYPE`|string|Choose an environment type class to override on top of the selected preset. If no preset is defined, a preset can be set from the command-line by combining settings which are set by using `--agent_type`, `--experiment_type`, `--environemnt_type`|
+|`-ept EXPLORATION_POLICY_TYPE`, `--exploration_policy_type EXPLORATION_POLICY_TYPE`|string|Choose an exploration policy type class to override on top of the selected preset.If no preset is defined, a preset can be set from the command-line by combining settings which are set by using `--agent_type`, `--experiment_type`, `--environemnt_type`|
+|`-lvl LEVEL`, `--level LEVEL` |string|Choose the level that will be played in the environment that was selected. This value will override the level parameter in the environment class.|
+|`-cp CUSTOM_PARAMETER`, `--custom_parameter CUSTOM_PARAMETER`|string| Semicolon separated parameters used to override specific parameters on top of the selected preset (or on top of the command-line assembled one). Whenever a parameter value is a string, it should be inputted as `'\"string\"'`. For ex.: `"visualization.render=False;` `num_training_iterations=500;` `optimizer='rmsprop'"`|
\ No newline at end of file
diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index 48fca7d3f..de96d9555 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -11,6 +11,7 @@ extra_css: [extra.css]
pages:
- Home : index.md
- Design: design.md
+- Usage: usage.md
- Algorithms:
- 'DQN' : algorithms/value_optimization/dqn.md
- 'Double DQN' : algorithms/value_optimization/double_dqn.md
@@ -28,6 +29,7 @@ pages:
- 'Proximal Policy Optimization' : algorithms/policy_optimization/ppo.md
- 'Clipped Proximal Policy Optimization' : algorithms/policy_optimization/cppo.md
- 'Direct Future Prediction' : algorithms/other/dfp.md
+ - 'Behavioral Cloning' : algorithms/imitation/bc.md
- Coach Dashboard : 'dashboard.md'
- Contributing :
diff --git a/environments/CarlaSettings.ini b/environments/CarlaSettings.ini
new file mode 100644
index 000000000..236500f76
--- /dev/null
+++ b/environments/CarlaSettings.ini
@@ -0,0 +1,62 @@
+[CARLA/Server]
+; If set to false, a mock controller will be used instead of waiting for a real
+; client to connect.
+UseNetworking=true
+; Ports to use for the server-client communication. This can be overridden by
+; the command-line switch `-world-port=N`, write and read ports will be set to
+; N+1 and N+2 respectively.
+WorldPort=2000
+; Time-out in milliseconds for the networking operations.
+ServerTimeOut=10000000000
+; In synchronous mode, CARLA waits every frame until the control from the client
+; is received.
+SynchronousMode=true
+; Send info about every non-player agent in the scene every frame, the
+; information is attached to the measurements message. This includes other
+; vehicles, pedestrians and traffic signs. Disabled by default to improve
+; performance.
+SendNonPlayerAgentsInfo=false
+
+[CARLA/LevelSettings]
+; Path of the vehicle class to be used for the player. Leave empty for default.
+; Paths follow the pattern "/Game/Blueprints/Vehicles/Mustang/Mustang.Mustang_C"
+PlayerVehicle=
+; Number of non-player vehicles to be spawned into the level.
+NumberOfVehicles=15
+; Number of non-player pedestrians to be spawned into the level.
+NumberOfPedestrians=30
+; Index of the weather/lighting presets to use. If negative, the default presets
+; of the map will be used.
+WeatherId=1
+; Seeds for the pseudo-random number generators.
+SeedVehicles=123456789
+SeedPedestrians=123456789
+
+[CARLA/SceneCapture]
+; Names of the cameras to be attached to the player, comma-separated, each of
+; them should be defined in its own subsection. E.g., Uncomment next line to add
+; a camera called MyCamera to the vehicle
+
+Cameras=CameraRGB
+
+; Now, every camera we added needs to be defined it in its own subsection.
+[CARLA/SceneCapture/CameraRGB]
+; Post-processing effect to be applied. Valid values:
+; * None No effects applied.
+; * SceneFinal Post-processing present at scene (bloom, fog, etc).
+; * Depth Depth map ground-truth only.
+; * SemanticSegmentation Semantic segmentation ground-truth only.
+PostProcessing=SceneFinal
+; Size of the captured image in pixels.
+ImageSizeX=360
+ImageSizeY=256
+; Camera (horizontal) field of view in degrees.
+CameraFOV=90
+; Position of the camera relative to the car in centimeters.
+CameraPositionX=200
+CameraPositionY=0
+CameraPositionZ=140
+; Rotation of the camera relative to the car in degrees.
+CameraRotationPitch=0
+CameraRotationRoll=0
+CameraRotationYaw=0
diff --git a/environments/__init__.py b/environments/__init__.py
index 443cdde97..1dd8d1dc8 100644
--- a/environments/__init__.py
+++ b/environments/__init__.py
@@ -15,13 +15,16 @@
#
from logger import *
-from utils import Enum
+from utils import Enum, get_open_port
from environments.gym_environment_wrapper import *
from environments.doom_environment_wrapper import *
+from environments.carla_environment_wrapper import *
+
class EnvTypes(Enum):
Doom = "DoomEnvironmentWrapper"
Gym = "GymEnvironmentWrapper"
+ Carla = "CarlaEnvironmentWrapper"
def create_environment(tuning_parameters):
diff --git a/environments/carla_environment_wrapper.py b/environments/carla_environment_wrapper.py
new file mode 100644
index 000000000..ee9c3ed10
--- /dev/null
+++ b/environments/carla_environment_wrapper.py
@@ -0,0 +1,230 @@
+import sys
+from os import path, environ
+
+try:
+ sys.path.append(path.join(environ.get('CARLA_ROOT'), 'PythonClient'))
+ from carla.client import CarlaClient
+ from carla.settings import CarlaSettings
+ from carla.tcp import TCPConnectionError
+ from carla.sensor import Camera
+ from carla.client import VehicleControl
+except ImportError:
+ from logger import failed_imports
+ failed_imports.append("CARLA")
+
+import numpy as np
+import time
+import logging
+import subprocess
+import signal
+from environments.environment_wrapper import EnvironmentWrapper
+from utils import *
+from logger import screen, logger
+from PIL import Image
+
+
+# enum of the available levels and their path
+class CarlaLevel(Enum):
+ TOWN1 = "/Game/Maps/Town01"
+ TOWN2 = "/Game/Maps/Town02"
+
+key_map = {
+ 'BRAKE': (274,), # down arrow
+ 'GAS': (273,), # up arrow
+ 'TURN_LEFT': (276,), # left arrow
+ 'TURN_RIGHT': (275,), # right arrow
+ 'GAS_AND_TURN_LEFT': (273, 276),
+ 'GAS_AND_TURN_RIGHT': (273, 275),
+ 'BRAKE_AND_TURN_LEFT': (274, 276),
+ 'BRAKE_AND_TURN_RIGHT': (274, 275),
+}
+
+
+class CarlaEnvironmentWrapper(EnvironmentWrapper):
+ def __init__(self, tuning_parameters):
+ EnvironmentWrapper.__init__(self, tuning_parameters)
+
+ self.tp = tuning_parameters
+
+ # server configuration
+ self.server_height = self.tp.env.server_height
+ self.server_width = self.tp.env.server_width
+ self.port = get_open_port()
+ self.host = 'localhost'
+ self.map = CarlaLevel().get(self.tp.env.level)
+
+ # client configuration
+ self.verbose = self.tp.env.verbose
+ self.depth = self.tp.env.depth
+ self.stereo = self.tp.env.stereo
+ self.semantic_segmentation = self.tp.env.semantic_segmentation
+ self.height = self.server_height * (1 + int(self.depth) + int(self.semantic_segmentation))
+ self.width = self.server_width * (1 + int(self.stereo))
+ self.size = (self.width, self.height)
+
+ self.config = self.tp.env.config
+ if self.config:
+ # load settings from file
+ with open(self.config, 'r') as fp:
+ self.settings = fp.read()
+ else:
+ # hard coded settings
+ self.settings = CarlaSettings()
+ self.settings.set(
+ SynchronousMode=True,
+ SendNonPlayerAgentsInfo=False,
+ NumberOfVehicles=15,
+ NumberOfPedestrians=30,
+ WeatherId=1)
+ self.settings.randomize_seeds()
+
+ # add cameras
+ camera = Camera('CameraRGB')
+ camera.set_image_size(self.width, self.height)
+ camera.set_position(200, 0, 140)
+ camera.set_rotation(0, 0, 0)
+ self.settings.add_sensor(camera)
+
+ # open the server
+ self.server = self._open_server()
+
+ logging.disable(40)
+
+ # open the client
+ self.game = CarlaClient(self.host, self.port, timeout=99999999)
+ self.game.connect()
+ scene = self.game.load_settings(self.settings)
+
+ # get available start positions
+ positions = scene.player_start_spots
+ self.num_pos = len(positions)
+ self.iterator_start_positions = 0
+
+ # action space
+ self.discrete_controls = False
+ self.action_space_size = 2
+ self.action_space_high = [1, 1]
+ self.action_space_low = [-1, -1]
+ self.action_space_abs_range = np.maximum(np.abs(self.action_space_low), np.abs(self.action_space_high))
+ self.steering_strength = 0.5
+ self.gas_strength = 1.0
+ self.brake_strength = 0.5
+ self.actions = {0: [0., 0.],
+ 1: [0., -self.steering_strength],
+ 2: [0., self.steering_strength],
+ 3: [self.gas_strength, 0.],
+ 4: [-self.brake_strength, 0],
+ 5: [self.gas_strength, -self.steering_strength],
+ 6: [self.gas_strength, self.steering_strength],
+ 7: [self.brake_strength, -self.steering_strength],
+ 8: [self.brake_strength, self.steering_strength]}
+ self.actions_description = ['NO-OP', 'TURN_LEFT', 'TURN_RIGHT', 'GAS', 'BRAKE',
+ 'GAS_AND_TURN_LEFT', 'GAS_AND_TURN_RIGHT',
+ 'BRAKE_AND_TURN_LEFT', 'BRAKE_AND_TURN_RIGHT']
+ for idx, action in enumerate(self.actions_description):
+ for key in key_map.keys():
+ if action == key:
+ self.key_to_action[key_map[key]] = idx
+ self.num_speedup_steps = 30
+
+ # measurements
+ self.measurements_size = (1,)
+ self.autopilot = None
+
+ # env initialization
+ self.reset(True)
+
+ # render
+ if self.is_rendered:
+ image = self.get_rendered_image()
+ self.renderer.create_screen(image.shape[1], image.shape[0])
+
+ def _open_server(self):
+ log_path = path.join(logger.experiments_path, "CARLA_LOG_{}.txt".format(self.port))
+ with open(log_path, "wb") as out:
+ cmd = [path.join(environ.get('CARLA_ROOT'), 'CarlaUE4.sh'), self.map,
+ "-benchmark", "-carla-server", "-fps=10", "-world-port={}".format(self.port),
+ "-windowed -ResX={} -ResY={}".format(self.server_width, self.server_height),
+ "-carla-no-hud"]
+ if self.config:
+ cmd.append("-carla-settings={}".format(self.config))
+ p = subprocess.Popen(cmd, stdout=out, stderr=out)
+
+ return p
+
+ def _close_server(self):
+ os.killpg(os.getpgid(self.server.pid), signal.SIGKILL)
+
+ def _update_state(self):
+ # get measurements and observations
+ measurements = []
+ while type(measurements) == list:
+ measurements, sensor_data = self.game.read_data()
+ self.observation = sensor_data['CameraRGB'].data
+
+ self.location = (measurements.player_measurements.transform.location.x,
+ measurements.player_measurements.transform.location.y,
+ measurements.player_measurements.transform.location.z)
+
+ is_collision = measurements.player_measurements.collision_vehicles != 0 \
+ or measurements.player_measurements.collision_pedestrians != 0 \
+ or measurements.player_measurements.collision_other != 0
+
+ speed_reward = measurements.player_measurements.forward_speed - 1
+ if speed_reward > 30.:
+ speed_reward = 30.
+ self.reward = speed_reward \
+ - (measurements.player_measurements.intersection_otherlane * 5) \
+ - (measurements.player_measurements.intersection_offroad * 5) \
+ - is_collision * 100 \
+ - np.abs(self.control.steer) * 10
+
+ # update measurements
+ self.measurements = [measurements.player_measurements.forward_speed]
+ self.autopilot = measurements.player_measurements.autopilot_control
+
+ # action_p = ['%.2f' % member for member in [self.control.throttle, self.control.steer]]
+ # screen.success('REWARD: %.2f, ACTIONS: %s' % (self.reward, action_p))
+
+ if (measurements.game_timestamp >= self.tp.env.episode_max_time) or is_collision:
+ # screen.success('EPISODE IS DONE. GameTime: {}, Collision: {}'.format(str(measurements.game_timestamp),
+ # str(is_collision)))
+ self.done = True
+
+ def _take_action(self, action_idx):
+ if type(action_idx) == int:
+ action = self.actions[action_idx]
+ else:
+ action = action_idx
+ self.last_action_idx = action
+
+ self.control = VehicleControl()
+ self.control.throttle = np.clip(action[0], 0, 1)
+ self.control.steer = np.clip(action[1], -1, 1)
+ self.control.brake = np.abs(np.clip(action[0], -1, 0))
+ if not self.tp.env.allow_braking:
+ self.control.brake = 0
+ self.control.hand_brake = False
+ self.control.reverse = False
+
+ self.game.send_control(self.control)
+
+ def _restart_environment_episode(self, force_environment_reset=False):
+ self.iterator_start_positions += 1
+ if self.iterator_start_positions >= self.num_pos:
+ self.iterator_start_positions = 0
+
+ try:
+ self.game.start_episode(self.iterator_start_positions)
+ except:
+ self.game.connect()
+ self.game.start_episode(self.iterator_start_positions)
+
+ # start the game with some initial speed
+ observation = None
+ for i in range(self.num_speedup_steps):
+ observation = self.step([1.0, 0])['observation']
+ self.observation = observation
+
+ return observation
+
diff --git a/environments/doom_environment_wrapper.py b/environments/doom_environment_wrapper.py
index ecdc9e001..997a067c4 100644
--- a/environments/doom_environment_wrapper.py
+++ b/environments/doom_environment_wrapper.py
@@ -25,6 +25,7 @@
from environments.environment_wrapper import EnvironmentWrapper
from os import path, environ
from utils import *
+from logger import *
# enum of the available levels and their path
@@ -39,6 +40,43 @@ class DoomLevel(Enum):
DEFEND_THE_LINE = "defend_the_line.cfg"
DEADLY_CORRIDOR = "deadly_corridor.cfg"
+key_map = {
+ 'NO-OP': 96, # `
+ 'ATTACK': 13, # enter
+ 'CROUCH': 306, # ctrl
+ 'DROP_SELECTED_ITEM': ord("t"),
+ 'DROP_SELECTED_WEAPON': ord("t"),
+ 'JUMP': 32, # spacebar
+ 'LAND': ord("l"),
+ 'LOOK_DOWN': 274, # down arrow
+ 'LOOK_UP': 273, # up arrow
+ 'MOVE_BACKWARD': ord("s"),
+ 'MOVE_DOWN': ord("s"),
+ 'MOVE_FORWARD': ord("w"),
+ 'MOVE_LEFT': 276,
+ 'MOVE_RIGHT': 275,
+ 'MOVE_UP': ord("w"),
+ 'RELOAD': ord("r"),
+ 'SELECT_NEXT_WEAPON': ord("q"),
+ 'SELECT_PREV_WEAPON': ord("e"),
+ 'SELECT_WEAPON0': ord("0"),
+ 'SELECT_WEAPON1': ord("1"),
+ 'SELECT_WEAPON2': ord("2"),
+ 'SELECT_WEAPON3': ord("3"),
+ 'SELECT_WEAPON4': ord("4"),
+ 'SELECT_WEAPON5': ord("5"),
+ 'SELECT_WEAPON6': ord("6"),
+ 'SELECT_WEAPON7': ord("7"),
+ 'SELECT_WEAPON8': ord("8"),
+ 'SELECT_WEAPON9': ord("9"),
+ 'SPEED': 304, # shift
+ 'STRAFE': 9, # tab
+ 'TURN180': ord("u"),
+ 'TURN_LEFT': ord("a"), # left arrow
+ 'TURN_RIGHT': ord("d"), # right arrow
+ 'USE': ord("f"),
+}
+
class DoomEnvironmentWrapper(EnvironmentWrapper):
def __init__(self, tuning_parameters):
@@ -49,26 +87,42 @@ def __init__(self, tuning_parameters):
self.scenarios_dir = path.join(environ.get('VIZDOOM_ROOT'), 'scenarios')
self.game = vizdoom.DoomGame()
self.game.load_config(path.join(self.scenarios_dir, self.level))
- self.game.set_window_visible(self.is_rendered)
+ self.game.set_window_visible(False)
self.game.add_game_args("+vid_forcesurface 1")
- if self.is_rendered:
+
+ self.wait_for_explicit_human_action = True
+ if self.human_control:
+ self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_640X480)
+ self.renderer.create_screen(640, 480)
+ elif self.is_rendered:
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_320X240)
+ self.renderer.create_screen(320, 240)
else:
# lower resolution since we actually take only 76x60 and we don't need to render
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_160X120)
+
self.game.set_render_hud(False)
self.game.set_render_crosshair(False)
self.game.set_render_decals(False)
self.game.set_render_particles(False)
self.game.init()
+ # action space
self.action_space_abs_range = 0
self.actions = {}
- self.action_space_size = self.game.get_available_buttons_size()
- for action_idx in range(self.action_space_size):
- self.actions[action_idx] = [0] * self.action_space_size
- self.actions[action_idx][action_idx] = 1
- self.actions_description = [str(action) for action in self.game.get_available_buttons()]
+ self.action_space_size = self.game.get_available_buttons_size() + 1
+ self.action_vector_size = self.action_space_size - 1
+ self.actions[0] = [0] * self.action_vector_size
+ for action_idx in range(self.action_vector_size):
+ self.actions[action_idx + 1] = [0] * self.action_vector_size
+ self.actions[action_idx + 1][action_idx] = 1
+ self.actions_description = ['NO-OP']
+ self.actions_description += [str(action).split(".")[1] for action in self.game.get_available_buttons()]
+ for idx, action in enumerate(self.actions_description):
+ if action in key_map.keys():
+ self.key_to_action[(key_map[action],)] = idx
+
+ # measurement
self.measurements_size = self.game.get_state().game_variables.shape
self.width = self.game.get_screen_width()
@@ -77,27 +131,17 @@ def __init__(self, tuning_parameters):
self.game.set_seed(self.tp.seed)
self.reset()
- def _update_observation_and_measurements(self):
+ def _update_state(self):
# extract all data from the current state
state = self.game.get_state()
if state is not None and state.screen_buffer is not None:
- self.observation = self._preprocess_observation(state.screen_buffer)
+ self.observation = state.screen_buffer
self.measurements = state.game_variables
+ self.reward = self.game.get_last_reward()
self.done = self.game.is_episode_finished()
- def step(self, action_idx):
- self.reward = 0
- for frame in range(self.tp.env.frame_skip):
- self.reward += self.game.make_action(self._idx_to_action(action_idx))
- self._update_observation_and_measurements()
- if self.done:
- break
-
- return {'observation': self.observation,
- 'reward': self.reward,
- 'done': self.done,
- 'action': action_idx,
- 'measurements': self.measurements}
+ def _take_action(self, action_idx):
+ self.game.make_action(self._idx_to_action(action_idx), self.frame_skip)
def _preprocess_observation(self, observation):
if observation is None:
@@ -108,3 +152,5 @@ def _preprocess_observation(self, observation):
def _restart_environment_episode(self, force_environment_reset=False):
self.game.new_episode()
+
+
diff --git a/environments/environment_wrapper.py b/environments/environment_wrapper.py
index ac73613f8..1491c5c59 100644
--- a/environments/environment_wrapper.py
+++ b/environments/environment_wrapper.py
@@ -17,6 +17,9 @@
import numpy as np
from utils import *
from configurations import Preset
+from renderer import Renderer
+import operator
+import time
class EnvironmentWrapper(object):
@@ -31,13 +34,19 @@ def __init__(self, tuning_parameters):
self.observation = []
self.reward = 0
self.done = False
+ self.default_action = 0
self.last_action_idx = 0
+ self.episode_idx = 0
+ self.last_episode_time = time.time()
self.measurements = []
+ self.info = []
self.action_space_low = 0
self.action_space_high = 0
self.action_space_abs_range = 0
+ self.actions_description = {}
self.discrete_controls = True
self.action_space_size = 0
+ self.key_to_action = {}
self.width = 1
self.height = 1
self.is_state_type_image = True
@@ -50,17 +59,11 @@ def __init__(self, tuning_parameters):
self.is_rendered = self.tp.visualization.render
self.seed = self.tp.seed
self.frame_skip = self.tp.env.frame_skip
-
- def _update_observation_and_measurements(self):
- # extract all the available measurments (ovservation, depthmap, lives, ammo etc.)
- pass
-
- def _restart_environment_episode(self, force_environment_reset=False):
- """
- :param force_environment_reset: Force the environment to reset even if the episode is not done yet.
- :return:
- """
- pass
+ self.human_control = self.tp.env.human_control
+ self.wait_for_explicit_human_action = False
+ self.is_rendered = self.is_rendered or self.human_control
+ self.game_is_open = True
+ self.renderer = Renderer()
def _idx_to_action(self, action_idx):
"""
@@ -71,13 +74,43 @@ def _idx_to_action(self, action_idx):
"""
return self.actions[action_idx]
- def _preprocess_observation(self, observation):
+ def _action_to_idx(self, action):
"""
- Do initial observation preprocessing such as cropping, rgb2gray, rescale etc.
- :param observation: a raw observation from the environment
- :return: the preprocessed observation
+ Convert an environment action to one of the available actions of the wrapper.
+ For example, if the available actions are 4,5,6 then this function will map 4->0, 5->1, 6->2
+ :param action: the environment action
+ :return: an action index between 0 and self.action_space_size - 1, or -1 if the action does not exist
"""
- pass
+ for key, val in self.actions.items():
+ if val == action:
+ return key
+ return -1
+
+ def get_action_from_user(self):
+ """
+ Get an action from the user keyboard
+ :return: action index
+ """
+ if self.wait_for_explicit_human_action:
+ while len(self.renderer.pressed_keys) == 0:
+ self.renderer.get_events()
+
+ if self.key_to_action == {}:
+ # the keys are the numbers on the keyboard corresponding to the action index
+ if len(self.renderer.pressed_keys) > 0:
+ action_idx = self.renderer.pressed_keys[0] - ord("1")
+ if 0 <= action_idx < self.action_space_size:
+ return action_idx
+ else:
+ # the keys are mapped through the environment to more intuitive keyboard keys
+ # key = tuple(self.renderer.pressed_keys)
+ # for key in self.renderer.pressed_keys:
+ for env_keys in self.key_to_action.keys():
+ if set(env_keys) == set(self.renderer.pressed_keys):
+ return self.key_to_action[env_keys]
+
+ # return the default action 0 so that the environment will continue running
+ return self.default_action
def step(self, action_idx):
"""
@@ -85,13 +118,29 @@ def step(self, action_idx):
:param action_idx: the action to perform on the environment
:return: A dictionary containing the observation, reward, done flag, action and measurements
"""
- pass
+ self.last_action_idx = action_idx
+
+ self._take_action(action_idx)
+
+ self._update_state()
+
+ if self.is_rendered:
+ self.render()
+
+ self.observation = self._preprocess_observation(self.observation)
+
+ return {'observation': self.observation,
+ 'reward': self.reward,
+ 'done': self.done,
+ 'action': self.last_action_idx,
+ 'measurements': self.measurements,
+ 'info': self.info}
def render(self):
"""
Call the environment function for rendering to the screen
"""
- pass
+ self.renderer.render_image(self.get_rendered_image())
def reset(self, force_environment_reset=False):
"""
@@ -100,15 +149,25 @@ def reset(self, force_environment_reset=False):
:return: A dictionary containing the observation, reward, done flag, action and measurements
"""
self._restart_environment_episode(force_environment_reset)
+ self.last_episode_time = time.time()
self.done = False
+ self.episode_idx += 1
self.reward = 0.0
self.last_action_idx = 0
- self._update_observation_and_measurements()
+ self._update_state()
+
+ # render before the preprocessing of the observation, so that the image will be in its original quality
+ if self.is_rendered:
+ self.render()
+
+ self.observation = self._preprocess_observation(self.observation)
+
return {'observation': self.observation,
'reward': self.reward,
'done': self.done,
'action': self.last_action_idx,
- 'measurements': self.measurements}
+ 'measurements': self.measurements,
+ 'info': self.info}
def get_random_action(self):
"""
@@ -129,10 +188,62 @@ def change_phase(self, phase):
"""
self.phase = phase
+ def get_available_keys(self):
+ """
+ Return a list of tuples mapping between action names and the keyboard key that triggers them
+ :return: a list of tuples mapping between action names and the keyboard key that triggers them
+ """
+ available_keys = []
+ if self.key_to_action != {}:
+ for key, idx in sorted(self.key_to_action.items(), key=operator.itemgetter(1)):
+ if key != ():
+ key_names = [self.renderer.get_key_names([k])[0] for k in key]
+ available_keys.append((self.actions_description[idx], ' + '.join(key_names)))
+ elif self.discrete_controls:
+ for action in range(self.action_space_size):
+ available_keys.append(("Action {}".format(action + 1), action + 1))
+ return available_keys
+
+ # The following functions define the interaction with the environment.
+ # Any new environment that inherits the EnvironmentWrapper class should use these signatures.
+ # Some of these functions are optional - please read their description for more details.
+
+ def _take_action(self, action_idx):
+ """
+ An environment dependent function that sends an action to the simulator.
+ :param action_idx: the action to perform on the environment
+ :return: None
+ """
+ pass
+
+ def _preprocess_observation(self, observation):
+ """
+ Do initial observation preprocessing such as cropping, rgb2gray, rescale etc.
+ Implementing this function is optional.
+ :param observation: a raw observation from the environment
+ :return: the preprocessed observation
+ """
+ return observation
+
+ def _update_state(self):
+ """
+ Updates the state from the environment.
+ Should update self.observation, self.reward, self.done, self.measurements and self.info
+ :return: None
+ """
+ pass
+
+ def _restart_environment_episode(self, force_environment_reset=False):
+ """
+ :param force_environment_reset: Force the environment to reset even if the episode is not done yet.
+ :return:
+ """
+ pass
+
def get_rendered_image(self):
"""
Return a numpy array containing the image that will be rendered to the screen.
This can be different from the observation. For example, mujoco's observation is a measurements vector.
:return: numpy array containing the image that will be rendered to the screen
"""
- return self.observation
+ return self.observation
\ No newline at end of file
diff --git a/environments/gym_environment_wrapper.py b/environments/gym_environment_wrapper.py
index 0721fc94e..e6320cf98 100644
--- a/environments/gym_environment_wrapper.py
+++ b/environments/gym_environment_wrapper.py
@@ -15,8 +15,10 @@
#
import sys
+from logger import *
import gym
import numpy as np
+import time
try:
import roboschool
from OpenGL import GL
@@ -40,8 +42,6 @@
from utils import force_list, RunPhase
from environments.environment_wrapper import EnvironmentWrapper
-i = 0
-
class GymEnvironmentWrapper(EnvironmentWrapper):
def __init__(self, tuning_parameters):
@@ -53,29 +53,30 @@ def __init__(self, tuning_parameters):
self.env.seed(self.seed)
# self.env_spec = gym.spec(self.env_id)
+ self.env.frameskip = self.frame_skip
self.discrete_controls = type(self.env.action_space) != gym.spaces.box.Box
- # pybullet requires rendering before resetting the environment, but other gym environments (Pendulum) will crash
- try:
- if self.is_rendered:
- self.render()
- except:
- pass
-
- o = self.reset(True)['observation']
+ self.observation = self.reset(True)['observation']
# render
if self.is_rendered:
- self.render()
+ image = self.get_rendered_image()
+ scale = 1
+ if self.human_control:
+ scale = 2
+ self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
- self.is_state_type_image = len(o.shape) > 1
+ self.is_state_type_image = len(self.observation.shape) > 1
if self.is_state_type_image:
- self.width = o.shape[1]
- self.height = o.shape[0]
+ self.width = self.observation.shape[1]
+ self.height = self.observation.shape[0]
else:
- self.width = o.shape[0]
+ self.width = self.observation.shape[0]
+ # action space
self.actions_description = {}
+ if hasattr(self.env.unwrapped, 'get_action_meanings'):
+ self.actions_description = self.env.unwrapped.get_action_meanings()
if self.discrete_controls:
self.action_space_size = self.env.action_space.n
self.action_space_abs_range = 0
@@ -85,34 +86,31 @@ def __init__(self, tuning_parameters):
self.action_space_low = self.env.action_space.low
self.action_space_abs_range = np.maximum(np.abs(self.action_space_low), np.abs(self.action_space_high))
self.actions = {i: i for i in range(self.action_space_size)}
+ self.key_to_action = {}
+ if hasattr(self.env.unwrapped, 'get_keys_to_action'):
+ self.key_to_action = self.env.unwrapped.get_keys_to_action()
+
+ # measurements
self.timestep_limit = self.env.spec.timestep_limit
- self.current_ale_lives = 0
self.measurements_size = len(self.step(0)['info'].keys())
- # env intialization
- self.observation = o
- self.reward = 0
- self.done = False
- self.last_action = self.actions[0]
-
- def render(self):
- self.env.render()
-
- def step(self, action_idx):
+ def _update_state(self):
+ if hasattr(self.env.env, 'ale'):
+ if self.phase == RunPhase.TRAIN and hasattr(self, 'current_ale_lives'):
+ # signal termination for life loss
+ if self.current_ale_lives != self.env.env.ale.lives():
+ self.done = True
+ self.current_ale_lives = self.env.env.ale.lives()
+ def _take_action(self, action_idx):
if action_idx is None:
action_idx = self.last_action_idx
- self.last_action_idx = action_idx
-
if self.discrete_controls:
action = self.actions[action_idx]
else:
action = action_idx
- if hasattr(self.env.env, 'ale'):
- prev_ale_lives = self.env.env.ale.lives()
-
# pendulum-v0 for example expects a list
if not self.discrete_controls:
# catching cases where the action for continuous control is a number instead of a list the
@@ -128,42 +126,26 @@ def step(self, action_idx):
self.observation, self.reward, self.done, self.info = self.env.step(action)
- if hasattr(self.env.env, 'ale') and self.phase == RunPhase.TRAIN:
- # signal termination for breakout life loss
- if prev_ale_lives != self.env.env.ale.lives():
- self.done = True
-
+ def _preprocess_observation(self, observation):
if any(env in self.env_id for env in ["Breakout", "Pong"]):
# crop image
- self.observation = self.observation[34:195, :, :]
-
- if self.is_rendered:
- self.render()
-
- return {'observation': self.observation,
- 'reward': self.reward,
- 'done': self.done,
- 'action': self.last_action_idx,
- 'info': self.info}
+ observation = observation[34:195, :, :]
+ return observation
def _restart_environment_episode(self, force_environment_reset=False):
# prevent reset of environment if there are ale lives left
- if "Breakout" in self.env_id and self.env.env.ale.lives() > 0 and not force_environment_reset:
+ if (hasattr(self.env.env, 'ale') and self.env.env.ale.lives() > 0) \
+ and not force_environment_reset and not self.env._past_limit():
return self.observation
if self.seed:
self.env.seed(self.seed)
- observation = self.env.reset()
- while observation is None:
- observation = self.step(0)['observation']
- if "Breakout" in self.env_id:
- # crop image
- observation = observation[34:195, :, :]
+ self.observation = self.env.reset()
+ while self.observation is None:
+ self.step(0)
- self.observation = observation
-
- return observation
+ return self.observation
def get_rendered_image(self):
return self.env.render(mode='rgb_array')
diff --git a/img/algorithms.png b/img/algorithms.png
index 2dc14077b..f83c1e69f 100644
Binary files a/img/algorithms.png and b/img/algorithms.png differ
diff --git a/img/ant.gif b/img/ant.gif
deleted file mode 100644
index 919d786eb..000000000
Binary files a/img/ant.gif and /dev/null differ
diff --git a/img/carla.gif b/img/carla.gif
new file mode 100644
index 000000000..af61a9016
Binary files /dev/null and b/img/carla.gif differ
diff --git a/img/doom.gif b/img/doom.gif
deleted file mode 100644
index 53853468e..000000000
Binary files a/img/doom.gif and /dev/null differ
diff --git a/img/doom_deathmatch.gif b/img/doom_deathmatch.gif
new file mode 100644
index 000000000..e949b088f
Binary files /dev/null and b/img/doom_deathmatch.gif differ
diff --git a/img/minitaur.gif b/img/minitaur.gif
deleted file mode 100644
index bff226494..000000000
Binary files a/img/minitaur.gif and /dev/null differ
diff --git a/img/montezuma.gif b/img/montezuma.gif
new file mode 100644
index 000000000..999798a28
Binary files /dev/null and b/img/montezuma.gif differ
diff --git a/install.sh b/install.sh
index 2c3606b9f..4f2a77523 100755
--- a/install.sh
+++ b/install.sh
@@ -192,10 +192,14 @@ if [ ${INSTALL_NEON} -eq 1 ]; then
# Neon
sudo -E apt-get install libhdf5-dev libyaml-dev pkg-config clang virtualenv libcurl4-openssl-dev libopencv-dev libsox-dev -y
- git clone https://github.com/NervanaSystems/neon.git
- cd neon && make sysinstall -j
- cd ..
+ pip3 install nervananeon
+fi
+
+if ! [ -x "$(command -v nvidia-smi)" ]; then
+ # Intel Optimized TensorFlow
+ pip3 install https://anaconda.org/intel/tensorflow/1.3.0/download/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl
+else
+ # GPU supported TensorFlow
+ pip3 install tensorflow-gpu
fi
-# Intel Optimized TensorFlow
-pip3 install https://anaconda.org/intel/tensorflow/1.3.0/download/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl
diff --git a/logger.py b/logger.py
index 11458a845..caa0265a1 100644
--- a/logger.py
+++ b/logger.py
@@ -18,6 +18,7 @@
import os
from pprint import pprint
import threading
+from subprocess import Popen, PIPE
import time
import datetime
from six.moves import input
@@ -61,7 +62,7 @@ def separator(self):
print("")
def log(self, data):
- print(self.name + ": " + data)
+ print(data)
def log_dict(self, dict, prefix=""):
str = "{}{}{} - ".format(Colors.PURPLE, prefix, Colors.END)
@@ -78,8 +79,10 @@ def success(self, text):
def warning(self, text):
print("{}{}{}".format(Colors.YELLOW, text, Colors.END))
- def error(self, text):
+ def error(self, text, crash=True):
print("{}{}{}".format(Colors.RED, text, Colors.END))
+ if crash:
+ exit(1)
def ask_input(self, title):
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
diff --git a/memories/episodic_experience_replay.py b/memories/episodic_experience_replay.py
index b5d92f540..1d7e61f9d 100644
--- a/memories/episodic_experience_replay.py
+++ b/memories/episodic_experience_replay.py
@@ -74,7 +74,9 @@ def sample_last_n_episodes(self, n):
def sample(self, size):
assert self.num_transitions_in_complete_episodes() > size, \
- 'There are not enough transitions in the replay buffer'
+ 'There are not enough transitions in the replay buffer. ' \
+ 'Available transitions: {}. Requested transitions: {}.'\
+ .format(self.num_transitions_in_complete_episodes(), size)
batch = []
transitions_idx = np.random.randint(self.num_transitions_in_complete_episodes(), size=size)
for i in transitions_idx:
diff --git a/memories/memory.py b/memories/memory.py
index 5d32c062e..4c479e321 100644
--- a/memories/memory.py
+++ b/memories/memory.py
@@ -73,6 +73,7 @@ def update_returns(self, discount, is_bootstrapped=False, n_step_return=-1):
if n_step_return == -1 or n_step_return > self.length():
n_step_return = self.length()
rewards = np.array([t.reward for t in self.transitions])
+ rewards = rewards.astype('float')
total_return = rewards.copy()
current_discount = discount
for i in range(1, n_step_return):
@@ -123,12 +124,30 @@ def to_batch(self):
class Transition(object):
- def __init__(self, state, action, reward, next_state, game_over):
+ def __init__(self, state, action, reward=0, next_state=None, game_over=False):
+ """
+ A transition is a tuple containing the information of a single step of interaction
+ between the agent and the environment. The most basic version should contain the following values:
+ (current state, action, reward, next state, game over)
+ For imitation learning algorithms, if the reward, next state or game over is not known,
+ it is sufficient to store the current state and action taken by the expert.
+
+ :param state: The current state. Assumed to be a dictionary where the observation
+ is located at state['observation']
+ :param action: The current action that was taken
+ :param reward: The reward received from the environment
+ :param next_state: The next state of the environment after applying the action.
+ The next state should be similar to the state in its structure.
+ :param game_over: A boolean which should be True if the episode terminated after
+ the execution of the action.
+ """
self.state = copy.deepcopy(state)
self.state['observation'] = np.array(self.state['observation'], copy=False)
self.action = action
self.reward = reward
self.total_return = None
+ if not next_state:
+ next_state = state
self.next_state = copy.deepcopy(next_state)
self.next_state['observation'] = np.array(self.next_state['observation'], copy=False)
self.game_over = game_over
diff --git a/presets.py b/presets.py
index 3290f9750..dc3b5ecc5 100644
--- a/presets.py
+++ b/presets.py
@@ -38,6 +38,15 @@ def json_to_preset(json_path):
if run_dict['exploration_policy_type'] is not None:
tuning_parameters.exploration = eval(run_dict['exploration_policy_type'])()
+ # human control
+ if run_dict['play']:
+ tuning_parameters.agent.type = 'HumanAgent'
+ tuning_parameters.env.human_control = True
+ tuning_parameters.num_heatup_steps = 0
+
+ if run_dict['level']:
+ tuning_parameters.env.level = run_dict['level']
+
if run_dict['custom_parameter'] is not None:
unstripped_key_value_pairs = [pair.split('=') for pair in run_dict['custom_parameter'].split(';')]
stripped_key_value_pairs = [tuple([pair[0].strip(), ast.literal_eval(pair[1].strip())]) for pair in
@@ -331,7 +340,7 @@ def __init__(self):
self.agent.num_steps_between_gradient_updates = 5
self.test = True
- self.test_max_step_threshold = 1000
+ self.test_max_step_threshold = 2000
self.test_min_return_threshold = 150
self.test_num_workers = 8
@@ -926,7 +935,7 @@ def __init__(self):
self.agent.middleware_type = MiddlewareTypes.FC
self.test = True
- self.test_max_step_threshold = 200
+ self.test_max_step_threshold = 1000
self.test_min_return_threshold = 150
self.test_num_workers = 8
@@ -1182,3 +1191,93 @@ def __init__(self):
self.agent.beta_entropy = 0.05
self.clip_gradients = 40.0
self.agent.middleware_type = MiddlewareTypes.FC
+
+
+class Carla_A3C(Preset):
+ def __init__(self):
+ Preset.__init__(self, ActorCritic, Carla, EntropyExploration)
+ self.agent.embedder_complexity = EmbedderComplexity.Deep
+ self.agent.policy_gradient_rescaler = 'GAE'
+ self.learning_rate = 0.0001
+ self.num_heatup_steps = 0
+ # self.env.reward_scaling = 1.0e9
+ self.agent.discount = 0.99
+ self.agent.apply_gradients_every_x_episodes = 1
+ self.agent.num_steps_between_gradient_updates = 30
+ self.agent.gae_lambda = 1
+ self.agent.beta_entropy = 0.01
+ self.clip_gradients = 40
+ self.agent.middleware_type = MiddlewareTypes.FC
+
+
+class Carla_DDPG(Preset):
+ def __init__(self):
+ Preset.__init__(self, DDPG, Carla, OUExploration)
+ self.agent.embedder_complexity = EmbedderComplexity.Deep
+ self.learning_rate = 0.0001
+ self.num_heatup_steps = 1000
+ self.agent.num_consecutive_training_steps = 5
+
+
+class Carla_BC(Preset):
+ def __init__(self):
+ Preset.__init__(self, BC, Carla, ExplorationParameters)
+ self.agent.embedder_complexity = EmbedderComplexity.Deep
+ self.agent.load_memory_from_file_path = 'datasets/carla_town1.p'
+ self.learning_rate = 0.0005
+ self.num_heatup_steps = 0
+ self.evaluation_episodes = 5
+ self.batch_size = 120
+ self.evaluate_every_x_training_iterations = 5000
+
+
+class Doom_Basic_BC(Preset):
+ def __init__(self):
+ Preset.__init__(self, BC, Doom, ExplorationParameters)
+ self.env.level = 'basic'
+ self.agent.load_memory_from_file_path = 'datasets/doom_basic.p'
+ self.learning_rate = 0.0005
+ self.num_heatup_steps = 0
+ self.evaluation_episodes = 5
+ self.batch_size = 120
+ self.evaluate_every_x_training_iterations = 100
+ self.num_training_iterations = 2000
+
+
+class Doom_Defend_BC(Preset):
+ def __init__(self):
+ Preset.__init__(self, BC, Doom, ExplorationParameters)
+ self.env.level = 'defend'
+ self.agent.load_memory_from_file_path = 'datasets/doom_defend.p'
+ self.learning_rate = 0.0005
+ self.num_heatup_steps = 0
+ self.evaluation_episodes = 5
+ self.batch_size = 120
+ self.evaluate_every_x_training_iterations = 100
+
+
+class Doom_Deathmatch_BC(Preset):
+ def __init__(self):
+ Preset.__init__(self, BC, Doom, ExplorationParameters)
+ self.env.level = 'deathmatch'
+ self.agent.load_memory_from_file_path = 'datasets/doom_deathmatch.p'
+ self.learning_rate = 0.0005
+ self.num_heatup_steps = 0
+ self.evaluation_episodes = 5
+ self.batch_size = 120
+ self.evaluate_every_x_training_iterations = 100
+
+
+class MontezumaRevenge_BC(Preset):
+ def __init__(self):
+ Preset.__init__(self, BC, Atari, ExplorationParameters)
+ self.env.level = 'MontezumaRevenge-v0'
+ self.agent.load_memory_from_file_path = 'datasets/montezuma_revenge.p'
+ self.learning_rate = 0.0005
+ self.num_heatup_steps = 0
+ self.evaluation_episodes = 5
+ self.batch_size = 120
+ self.evaluate_every_x_training_iterations = 100
+ self.exploration.evaluation_epsilon = 0.05
+ self.exploration.evaluation_policy = 'EGreedy'
+ self.env.frame_skip = 1
diff --git a/renderer.py b/renderer.py
new file mode 100644
index 000000000..cddc81017
--- /dev/null
+++ b/renderer.py
@@ -0,0 +1,85 @@
+import pygame
+from pygame.locals import *
+import numpy as np
+
+
+class Renderer(object):
+ def __init__(self):
+ self.size = (1, 1)
+ self.screen = None
+ self.clock = pygame.time.Clock()
+ self.display = pygame.display
+ self.fps = 30
+ self.pressed_keys = []
+ self.is_open = False
+
+ def create_screen(self, width, height):
+ """
+ Creates a pygame window
+ :param width: the width of the window
+ :param height: the height of the window
+ :return: None
+ """
+ self.size = (width, height)
+ self.screen = self.display.set_mode(self.size, HWSURFACE | DOUBLEBUF)
+ self.display.set_caption("Coach")
+ self.is_open = True
+
+ def normalize_image(self, image):
+ """
+ Normalize image values to be between 0 and 255
+ :param image: 2D/3D array containing an image with arbitrary values
+ :return: the input image with values rescaled to 0-255
+ """
+ image_min, image_max = image.min(), image.max()
+ return 255.0 * (image - image_min) / (image_max - image_min)
+
+ def render_image(self, image):
+ """
+ Render the given image to the pygame window
+ :param image: a grayscale or color image in an arbitrary size. assumes that the channels are the last axis
+ :return: None
+ """
+ if self.is_open:
+ if len(image.shape) == 3:
+ if image.shape[0] == 3 or image.shape[0] == 1:
+ image = np.transpose(image, (1, 2, 0))
+ surface = pygame.surfarray.make_surface(image.swapaxes(0, 1))
+ surface = pygame.transform.scale(surface, self.size)
+ self.screen.blit(surface, (0, 0))
+ self.display.flip()
+ self.clock.tick()
+ self.get_events()
+
+ def get_events(self):
+ """
+ Get all the window events in the last tick and reponse accordingly
+ :return: None
+ """
+ for event in pygame.event.get():
+ if event.type == pygame.KEYDOWN:
+ self.pressed_keys.append(event.key)
+ # esc pressed
+ if event.key == pygame.K_ESCAPE:
+ self.close()
+ elif event.type == pygame.KEYUP:
+ if event.key in self.pressed_keys:
+ self.pressed_keys.remove(event.key)
+ elif event.type == pygame.QUIT:
+ self.close()
+
+ def get_key_names(self, key_ids):
+ """
+ Get the key name for each key index in the list
+ :param key_ids: a list of key id's
+ :return: a list of key names corresponding to the key id's
+ """
+ return [pygame.key.name(key_id) for key_id in key_ids]
+
+ def close(self):
+ """
+ Close the pygame window
+ :return: None
+ """
+ self.is_open = False
+ pygame.quit()
diff --git a/requirements_coach.txt b/requirements_coach.txt
index 799820bfb..a0fd644f8 100644
--- a/requirements_coach.txt
+++ b/requirements_coach.txt
@@ -3,6 +3,7 @@ Pillow==4.3.0
matplotlib==2.0.2
numpy==1.13.0
pandas==0.20.2
+pygame==1.9.3
PyOpenGL==3.1.0
scipy==0.19.0
scikit-image==0.13.0
diff --git a/run_test.py b/run_test.py
new file mode 100644
index 000000000..eeb51a434
--- /dev/null
+++ b/run_test.py
@@ -0,0 +1,164 @@
+#
+# Copyright (c) 2017 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# -*- coding: utf-8 -*-
+import presets
+import numpy as np
+import pandas as pd
+from os import path
+import os
+import glob
+import shutil
+import sys
+import time
+from logger import screen
+from utils import list_all_classes_in_module, threaded_cmd_line_run, killed_processes
+from subprocess import Popen
+import signal
+import argparse
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-p', '--preset',
+ help="(string) Name of a preset to run (as configured in presets.py)",
+ default=None,
+ type=str)
+ parser.add_argument('-itf', '--ignore_tensorflow',
+ help="(flag) Don't test TensorFlow presets.",
+ action='store_true')
+ parser.add_argument('-in', '--ignore_neon',
+ help="(flag) Don't test neon presets.",
+ action='store_true')
+
+ args = parser.parse_args()
+ if args.preset is not None:
+ presets_lists = [args.preset]
+ else:
+ presets_lists = list_all_classes_in_module(presets)
+ win_size = 10
+ fail_count = 0
+ test_count = 0
+ read_csv_tries = 70
+
+ # create a clean experiment directory
+ test_name = '__test'
+ test_path = os.path.join('./experiments', test_name)
+ if path.exists(test_path):
+ shutil.rmtree(test_path)
+
+ for idx, preset_name in enumerate(presets_lists):
+ preset = eval('presets.{}()'.format(preset_name))
+ if preset.test:
+ frameworks = []
+ if preset.agent.tensorflow_support and not args.ignore_tensorflow:
+ frameworks.append('tensorflow')
+ if preset.agent.neon_support and not args.ignore_neon:
+ frameworks.append('neon')
+
+ for framework in frameworks:
+ test_count += 1
+
+ # run the experiment in a separate thread
+ screen.log_title("Running test {} - {}".format(preset_name, framework))
+ cmd = 'CUDA_VISIBLE_DEVICES='' python3 coach.py -p {} -f {} -e {} -n {} -cp "seed=0" &> test_log_{}_{}.txt '\
+ .format(preset_name, framework, test_name, preset.test_num_workers, preset_name, framework)
+ p = Popen(cmd, shell=True, executable="/bin/bash", preexec_fn=os.setsid)
+
+ # get the csv with the results
+ csv_path = None
+ csv_paths = []
+
+ if preset.test_num_workers > 1:
+ # we have an evaluator
+ reward_str = 'Evaluation Reward'
+ filename_pattern = 'evaluator*.csv'
+ else:
+ reward_str = 'Training Reward'
+ filename_pattern = 'worker*.csv'
+
+ initialization_error = False
+ test_passed = False
+
+ tries_counter = 0
+ while not csv_paths:
+ csv_paths = glob.glob(path.join(test_path, '*', filename_pattern))
+ if tries_counter > read_csv_tries:
+ break
+ tries_counter += 1
+ time.sleep(1)
+
+ if csv_paths:
+ csv_path = csv_paths[0]
+
+ # verify results
+ csv = None
+ time.sleep(1)
+ averaged_rewards = [0]
+
+ last_num_episodes = 0
+ while csv is None or csv['Episode #'].values[-1] < preset.test_max_step_threshold:
+ try:
+ csv = pd.read_csv(csv_path)
+ except:
+ # sometimes the csv is being written at the same time we are
+ # trying to read it. no problem -> try again
+ continue
+
+ if reward_str not in csv.keys():
+ continue
+
+ rewards = csv[reward_str].values
+ rewards = rewards[~np.isnan(rewards)]
+
+ if len(rewards) >= win_size:
+ averaged_rewards = np.convolve(rewards, np.ones(win_size) / win_size, mode='valid')
+ else:
+ time.sleep(1)
+ continue
+
+ # print progress
+ percentage = int((100*last_num_episodes)/preset.test_max_step_threshold)
+ sys.stdout.write("\rReward: ({}/{})".format(round(averaged_rewards[-1], 1), preset.test_min_return_threshold))
+ sys.stdout.write(' Episode: ({}/{})'.format(last_num_episodes, preset.test_max_step_threshold))
+ sys.stdout.write(' {}%|{}{}| '.format(percentage, '#'*int(percentage/10), ' '*(10-int(percentage/10))))
+ sys.stdout.flush()
+
+ if csv['Episode #'].shape[0] - last_num_episodes <= 0:
+ continue
+
+ last_num_episodes = csv['Episode #'].values[-1]
+
+ # check if reward is enough
+ if np.any(averaged_rewards > preset.test_min_return_threshold):
+ test_passed = True
+ break
+ time.sleep(1)
+
+ # kill test and print result
+ os.killpg(os.getpgid(p.pid), signal.SIGTERM)
+ if test_passed:
+ screen.success("Passed successfully")
+ else:
+ screen.error("Failed due to a mismatch with the golden", crash=False)
+ fail_count += 1
+ shutil.rmtree(test_path)
+
+ screen.separator()
+ if fail_count == 0:
+ screen.success(" Summary: " + str(test_count) + "/" + str(test_count) + " tests passed successfully")
+ else:
+ screen.error(" Summary: " + str(test_count - fail_count) + "/" + str(test_count) + " tests passed successfully")
diff --git a/utils.py b/utils.py
index c66072440..db9799499 100644
--- a/utils.py
+++ b/utils.py
@@ -20,6 +20,7 @@
import numpy as np
import threading
from subprocess import call, Popen
+import signal
killed_processes = []
@@ -54,9 +55,9 @@ def to_string(self, enum):
class RunPhase(Enum):
- HEATUP = 0
- TRAIN = 1
- TEST = 2
+ HEATUP = "Heatup"
+ TRAIN = "Training"
+ TEST = "Testing"
def list_all_classes_in_module(module):
@@ -292,3 +293,59 @@ def get_open_port():
s.close()
return port
+
+class timeout:
+ def __init__(self, seconds=1, error_message='Timeout'):
+ self.seconds = seconds
+ self.error_message = error_message
+
+ def _handle_timeout(self, signum, frame):
+ raise TimeoutError(self.error_message)
+
+ def __enter__(self):
+ signal.signal(signal.SIGALRM, self._handle_timeout)
+ signal.alarm(self.seconds)
+
+ def __exit__(self, type, value, traceback):
+ signal.alarm(0)
+
+
+def switch_axes_order(observation, from_type='channels_first', to_type='channels_last'):
+ """
+ transpose an observation axes from channels_first to channels_last or vice versa
+ :param observation: a numpy array
+ :param from_type: can be 'channels_first' or 'channels_last'
+ :param to_type: can be 'channels_first' or 'channels_last'
+ :return: a new observation with the requested axes order
+ """
+ if from_type == to_type or len(observation.shape) == 1:
+ return observation
+ assert 2 <= len(observation.shape) <= 3, 'num axes of an observation must be 2 for a vector or 3 for an image'
+ assert type(observation) == np.ndarray, 'observation must be a numpy array'
+ if len(observation.shape) == 3:
+ if from_type == 'channels_first' and to_type == 'channels_last':
+ return np.transpose(observation, (1, 2, 0))
+ elif from_type == 'channels_last' and to_type == 'channels_first':
+ return np.transpose(observation, (2, 0, 1))
+ else:
+ return np.transpose(observation, (1, 0))
+
+
+def stack_observation(curr_stack, observation, stack_size):
+ """
+ Adds a new observation to an existing stack of observations from previous time-steps.
+ :param curr_stack: The current observations stack.
+ :param observation: The new observation
+ :param stack_size: The required stack size
+ :return: The updated observation stack
+ """
+
+ if curr_stack == []:
+ # starting an episode
+ curr_stack = np.vstack(np.expand_dims([observation] * stack_size, 0))
+ curr_stack = switch_axes_order(curr_stack, from_type='channels_first', to_type='channels_last')
+ else:
+ curr_stack = np.append(curr_stack, np.expand_dims(np.squeeze(observation), axis=-1), axis=-1)
+ curr_stack = np.delete(curr_stack, 0, -1)
+
+ return curr_stack