diff --git a/README.md b/README.md index e503556f3..16562e429 100644 --- a/README.md +++ b/README.md @@ -13,10 +13,16 @@ Training an agent to solve an environment is as easy as running: python3 coach.py -p CartPole_DQN -r ``` -Doom Health GatheringPyBullet Minitaur Gym Extensions Ant +Doom Deathmatch CARLA MontezumaRevenge Blog post from the Intel® Nervana™ website can be found [here](https://www.intelnervana.com/reinforcement-learning-coach-intel). + +## Documentation + +Framework documentation, algorithm description and instructions on how to contribute a new agent/environment can be found [here](http://coach.nervanasys.com). + + ## Installation Note: Coach has only been tested on Ubuntu 16.04 LTS, and with Python 3.5. @@ -103,6 +109,8 @@ For example: It is easy to create new presets for different levels or environments by following the same pattern as in presets.py +More usage examples can be found [here](http://coach.nervanasys.com/usage/index.html). + ## Running Coach Dashboard (Visualization) Training an agent to solve an environment can be tricky, at times. @@ -121,11 +129,6 @@ python3 dashboard.py Coach Design -## Documentation - -Framework documentation, algoritmic description and instructions on how to contribute a new agent/environment can be found [here](http://coach.nervanasys.com). - - ## Parallelizing an Algorithm Since the introduction of [A3C](https://arxiv.org/abs/1602.01783) in 2016, many algorithms were shown to benefit from running multiple instances in parallel, on many CPU cores. So far, these algorithms include [A3C](https://arxiv.org/abs/1602.01783), [DDPG](https://arxiv.org/pdf/1704.03073.pdf), [PPO](https://arxiv.org/pdf/1707.06347.pdf), and [NAF](https://arxiv.org/pdf/1610.00633.pdf), and this is most probably only the begining. @@ -150,11 +153,11 @@ python3 coach.py -p Hopper_A3C -n 16 ## Supported Environments -* OpenAI Gym +* *OpenAI Gym:* Installed by default by Coach's installer. -* ViZDoom: +* *ViZDoom:* Follow the instructions described in the ViZDoom repository - @@ -162,13 +165,13 @@ python3 coach.py -p Hopper_A3C -n 16 Additionally, Coach assumes that the environment variable VIZDOOM_ROOT points to the ViZDoom installation directory. -* Roboschool: +* *Roboschool:* Follow the instructions described in the roboschool repository - https://github.com/openai/roboschool -* GymExtensions: +* *GymExtensions:* Follow the instructions described in the GymExtensions repository - @@ -176,10 +179,19 @@ python3 coach.py -p Hopper_A3C -n 16 Additionally, add the installation directory to the PYTHONPATH environment variable. -* PyBullet +* *PyBullet:* Follow the instructions described in the [Quick Start Guide](https://docs.google.com/document/d/10sXEhzFRSnvFcl3XxNGhnD4N2SedqwdAvK3dsihxVUA) (basically just - 'pip install pybullet') +* *CARLA:* + + Download release 0.7 from the CARLA repository - + + https://github.com/carla-simulator/carla/releases + + Create a new CARLA_ROOT environment variable pointing to CARLA's installation directory. + + A simple CARLA settings file (```CarlaSettings.ini```) is supplied with Coach, and is located in the ```environments``` directory. ## Supported Algorithms @@ -190,24 +202,24 @@ python3 coach.py -p Hopper_A3C -n 16 -* [Deep Q Network (DQN)](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) -* [Double Deep Q Network (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) +* [Deep Q Network (DQN)](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) ([code](agents/dqn_agent.py)) +* [Double Deep Q Network (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) ([code](agents/ddqn_agent.py)) * [Dueling Q Network](https://arxiv.org/abs/1511.06581) -* [Mixed Monte Carlo (MMC)](https://arxiv.org/abs/1703.01310) -* [Persistent Advantage Learning (PAL)](https://arxiv.org/abs/1512.04860) -* [Categorical Deep Q Network (C51)](https://arxiv.org/abs/1707.06887) -* [Quantile Regression Deep Q Network (QR-DQN)](https://arxiv.org/pdf/1710.10044v1.pdf) -* [Bootstrapped Deep Q Network](https://arxiv.org/abs/1602.04621) -* [N-Step Q Learning](https://arxiv.org/abs/1602.01783) | **Distributed** -* [Neural Episodic Control (NEC)](https://arxiv.org/abs/1703.01988) -* [Normalized Advantage Functions (NAF)](https://arxiv.org/abs/1603.00748.pdf) | **Distributed** -* [Policy Gradients (PG)](http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) | **Distributed** -* [Asynchronous Advantage Actor-Critic (A3C)](https://arxiv.org/abs/1602.01783) | **Distributed** -* [Deep Deterministic Policy Gradients (DDPG)](https://arxiv.org/abs/1509.02971) | **Distributed** -* [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) -* [Clipped Proximal Policy Optimization](https://arxiv.org/pdf/1707.06347.pdf) | **Distributed** -* [Direct Future Prediction (DFP)](https://arxiv.org/abs/1611.01779) | **Distributed** - +* [Mixed Monte Carlo (MMC)](https://arxiv.org/abs/1703.01310) ([code](agents/mmc_agent.py)) +* [Persistent Advantage Learning (PAL)](https://arxiv.org/abs/1512.04860) ([code](agents/pal_agent.py)) +* [Categorical Deep Q Network (C51)](https://arxiv.org/abs/1707.06887) ([code](agents/categorical_dqn_agent.py)) +* [Quantile Regression Deep Q Network (QR-DQN)](https://arxiv.org/pdf/1710.10044v1.pdf) ([code](agents/qr_dqn_agent.py)) +* [Bootstrapped Deep Q Network](https://arxiv.org/abs/1602.04621) ([code](agents/bootstrapped_dqn_agent.py)) +* [N-Step Q Learning](https://arxiv.org/abs/1602.01783) | **Distributed** ([code](agents/n_step_q_agent.py)) +* [Neural Episodic Control (NEC)](https://arxiv.org/abs/1703.01988) ([code](agents/nec_agent.py)) +* [Normalized Advantage Functions (NAF)](https://arxiv.org/abs/1603.00748.pdf) | **Distributed** ([code](agents/naf_agent.py)) +* [Policy Gradients (PG)](http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) | **Distributed** ([code](agents/policy_gradients_agent.py)) +* [Asynchronous Advantage Actor-Critic (A3C)](https://arxiv.org/abs/1602.01783) | **Distributed** ([code](agents/actor_critic_agent.py)) +* [Deep Deterministic Policy Gradients (DDPG)](https://arxiv.org/abs/1509.02971) | **Distributed** ([code](agents/ddpg_agent.py)) +* [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) ([code](agents/ppo_agent.py)) +* [Clipped Proximal Policy Optimization](https://arxiv.org/pdf/1707.06347.pdf) | **Distributed** ([code](agents/clipped_ppo_agent.py)) +* [Direct Future Prediction (DFP)](https://arxiv.org/abs/1611.01779) | **Distributed** ([code](agents/dfp_agent.py)) +* Behavioral Cloning (BC) ([code](agents/bc_agent.py)) diff --git a/agents/__init__.py b/agents/__init__.py index b1ae8d324..fdbd13e51 100644 --- a/agents/__init__.py +++ b/agents/__init__.py @@ -16,6 +16,7 @@ from agents.actor_critic_agent import * from agents.agent import * +from agents.bc_agent import * from agents.bootstrapped_dqn_agent import * from agents.clipped_ppo_agent import * from agents.ddpg_agent import * @@ -23,6 +24,8 @@ from agents.dfp_agent import * from agents.dqn_agent import * from agents.categorical_dqn_agent import * +from agents.human_agent import * +from agents.imitation_agent import * from agents.mmc_agent import * from agents.n_step_q_agent import * from agents.naf_agent import * diff --git a/agents/agent.py b/agents/agent.py index ed9eabce6..a541fa5f4 100644 --- a/agents/agent.py +++ b/agents/agent.py @@ -50,6 +50,7 @@ def __init__(self, env, tuning_parameters, replicated_device=None, task_id=0): self.task_id = task_id self.sess = tuning_parameters.sess self.env = tuning_parameters.env_instance = env + self.imitation = False # i/o dimensions if not tuning_parameters.env.desired_observation_width or not tuning_parameters.env.desired_observation_height: @@ -61,7 +62,12 @@ def __init__(self, env, tuning_parameters, replicated_device=None, task_id=0): self.measurements_size = tuning_parameters.env.measurements_size = (self.measurements_size[0] + 1,) # modules - self.memory = eval(tuning_parameters.memory + '(tuning_parameters)') + if tuning_parameters.agent.load_memory_from_file_path: + screen.log_title("Loading replay buffer from pickle. Pickle path: {}" + .format(tuning_parameters.agent.load_memory_from_file_path)) + self.memory = read_pickle(tuning_parameters.agent.load_memory_from_file_path) + else: + self.memory = eval(tuning_parameters.memory + '(tuning_parameters)') # self.architecture = eval(tuning_parameters.architecture) self.has_global = replicated_device is not None @@ -121,11 +127,12 @@ def __init__(self, env, tuning_parameters, replicated_device=None, task_id=0): def log_to_screen(self, phase): # log to screen - if self.current_episode > 0: - if phase == RunPhase.TEST: - exploration = self.evaluation_exploration_policy.get_control_param() - else: + if self.current_episode >= 0: + if phase == RunPhase.TRAIN: exploration = self.exploration_policy.get_control_param() + else: + exploration = self.evaluation_exploration_policy.get_control_param() + screen.log_dict( OrderedDict([ ("Worker", self.task_id), @@ -135,7 +142,7 @@ def log_to_screen(self, phase): ("steps", self.total_steps_counter), ("training iteration", self.training_iteration) ]), - prefix="Heatup" if self.in_heatup else "Training" if phase == RunPhase.TRAIN else "Testing" + prefix=phase ) def update_log(self, phase=RunPhase.TRAIN): @@ -146,7 +153,7 @@ def update_log(self, phase=RunPhase.TRAIN): # log all the signals to file logger.set_current_time(self.current_episode) logger.create_signal_value('Training Iter', self.training_iteration) - logger.create_signal_value('In Heatup', int(self.in_heatup)) + logger.create_signal_value('In Heatup', int(phase == RunPhase.HEATUP)) logger.create_signal_value('ER #Transitions', self.memory.num_transitions()) logger.create_signal_value('ER #Episodes', self.memory.length()) logger.create_signal_value('Episode Length', self.current_episode_steps_counter) @@ -197,24 +204,6 @@ def reset_game(self, do_not_reset_env=False): network.curr_rnn_c_in = network.middleware_embedder.c_init network.curr_rnn_h_in = network.middleware_embedder.h_init - def stack_observation(self, curr_stack, observation): - """ - Adds a new observation to an existing stack of observations from previous time-steps. - :param curr_stack: The current observations stack. - :param observation: The new observation - :return: The updated observation stack - """ - - if curr_stack == []: - # starting an episode - curr_stack = np.vstack(np.expand_dims([observation] * self.tp.env.observation_stack_size, 0)) - curr_stack = self.switch_axes_order(curr_stack, from_type='channels_first', to_type='channels_last') - else: - curr_stack = np.append(curr_stack, np.expand_dims(np.squeeze(observation), axis=-1), axis=-1) - curr_stack = np.delete(curr_stack, 0, -1) - - return curr_stack - def preprocess_observation(self, observation): """ Preprocesses the given observation. @@ -335,26 +324,6 @@ def preprocess_reward(self, reward): reward = max(reward, self.tp.env.reward_clipping_min) return reward - def switch_axes_order(self, observation, from_type='channels_first', to_type='channels_last'): - """ - transpose an observation axes from channels_first to channels_last or vice versa - :param observation: a numpy array - :param from_type: can be 'channels_first' or 'channels_last' - :param to_type: can be 'channels_first' or 'channels_last' - :return: a new observation with the requested axes order - """ - if from_type == to_type or len(observation.shape) == 1: - return observation - assert 2 <= len(observation.shape) <= 3, 'num axes of an observation must be 2 for a vector or 3 for an image' - assert type(observation) == np.ndarray, 'observation must be a numpy array' - if len(observation.shape) == 3: - if from_type == 'channels_first' and to_type == 'channels_last': - return np.transpose(observation, (1, 2, 0)) - elif from_type == 'channels_last' and to_type == 'channels_first': - return np.transpose(observation, (2, 0, 1)) - else: - return np.transpose(observation, (1, 0)) - def act(self, phase=RunPhase.TRAIN): """ Take one step in the environment according to the network prediction and store the transition in memory @@ -370,7 +339,7 @@ def act(self, phase=RunPhase.TRAIN): is_first_transition_in_episode = (self.curr_state == []) if is_first_transition_in_episode: observation = self.preprocess_observation(self.env.observation) - observation = self.stack_observation([], observation) + observation = stack_observation([], observation, self.tp.env.observation_stack_size) self.curr_state = {'observation': observation} if self.tp.agent.use_measurements: @@ -378,7 +347,7 @@ def act(self, phase=RunPhase.TRAIN): if self.tp.agent.use_accumulated_reward_as_measurement: self.curr_state['measurements'] = np.append(self.curr_state['measurements'], 0) - if self.in_heatup: # we do not have a stacked curr_state yet + if phase == RunPhase.HEATUP and not self.tp.heatup_using_network_decisions: action = self.env.get_random_action() else: action, action_info = self.choose_action(self.curr_state, phase=phase) @@ -394,11 +363,11 @@ def act(self, phase=RunPhase.TRAIN): observation = self.preprocess_observation(result['observation']) # plot action values online - if self.tp.visualization.plot_action_values_online and not self.in_heatup: + if self.tp.visualization.plot_action_values_online and phase != RunPhase.HEATUP: self.plot_action_values_online() # initialize the next state - observation = self.stack_observation(self.curr_state['observation'], observation) + observation = stack_observation(self.curr_state['observation'], observation, self.tp.env.observation_stack_size) next_state = {'observation': observation} if self.tp.agent.use_measurements and 'measurements' in result.keys(): @@ -407,7 +376,7 @@ def act(self, phase=RunPhase.TRAIN): next_state['measurements'] = np.append(next_state['measurements'], self.total_reward_in_current_episode) # store the transition only if we are training - if phase == RunPhase.TRAIN: + if phase == RunPhase.TRAIN or phase == RunPhase.HEATUP: transition = Transition(self.curr_state, result['action'], shaped_reward, next_state, result['done']) for key in action_info.keys(): transition.info[key] = action_info[key] @@ -427,7 +396,7 @@ def act(self, phase=RunPhase.TRAIN): self.update_log(phase=phase) self.log_to_screen(phase=phase) - if phase == RunPhase.TRAIN: + if phase == RunPhase.TRAIN or phase == RunPhase.HEATUP: self.reset_game() self.current_episode += 1 @@ -462,11 +431,12 @@ def evaluate(self, num_episodes, keep_networks_synced=False): for network in self.networks: network.sync() - if self.tp.visualization.dump_gifs and self.total_reward_in_current_episode > max_reward_achieved: + if self.total_reward_in_current_episode > max_reward_achieved: max_reward_achieved = self.total_reward_in_current_episode frame_skipping = int(5/self.tp.env.frame_skip) - logger.create_gif(self.last_episode_images[::frame_skipping], - name='score-{}'.format(max_reward_achieved), fps=10) + if self.tp.visualization.dump_gifs: + logger.create_gif(self.last_episode_images[::frame_skipping], + name='score-{}'.format(max_reward_achieved), fps=10) average_evaluation_reward += self.total_reward_in_current_episode self.reset_game() @@ -496,7 +466,7 @@ def improve(self): screen.log_title("Starting heatup {}".format(self.task_id)) num_steps_required_for_one_training_batch = self.tp.batch_size * self.tp.env.observation_stack_size for step in range(max(self.tp.num_heatup_steps, num_steps_required_for_one_training_batch)): - self.act() + self.act(phase=RunPhase.HEATUP) # training phase self.in_heatup = False @@ -509,7 +479,12 @@ def improve(self): # evaluate evaluate_agent = (self.last_episode_evaluation_ran is not self.current_episode) and \ (self.current_episode % self.tp.evaluate_every_x_episodes == 0) + evaluate_agent = evaluate_agent or \ + (self.imitation and self.training_iteration > 0 and + self.training_iteration % self.tp.evaluate_every_x_training_iterations == 0) + if evaluate_agent: + self.env.reset() self.last_episode_evaluation_ran = self.current_episode self.evaluate(self.tp.evaluation_episodes) @@ -522,14 +497,15 @@ def improve(self): self.save_model(model_snapshots_periods_passed) # play and record in replay buffer - if self.tp.agent.step_until_collecting_full_episodes: - step = 0 - while step < self.tp.agent.num_consecutive_playing_steps or self.memory.get_episode(-1).length() != 0: - self.act() - step += 1 - else: - for step in range(self.tp.agent.num_consecutive_playing_steps): - self.act() + if self.tp.agent.collect_new_data: + if self.tp.agent.step_until_collecting_full_episodes: + step = 0 + while step < self.tp.agent.num_consecutive_playing_steps or self.memory.get_episode(-1).length() != 0: + self.act() + step += 1 + else: + for step in range(self.tp.agent.num_consecutive_playing_steps): + self.act() # train if self.tp.train: @@ -537,6 +513,8 @@ def improve(self): loss = self.train() self.loss.add_sample(loss) self.training_iteration += 1 + if self.imitation: + self.log_to_screen(RunPhase.TRAIN) self.post_training_commands() def save_model(self, model_id): diff --git a/agents/bc_agent.py b/agents/bc_agent.py new file mode 100644 index 000000000..e06573116 --- /dev/null +++ b/agents/bc_agent.py @@ -0,0 +1,40 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from agents.imitation_agent import * + + +# Behavioral Cloning Agent +class BCAgent(ImitationAgent): + def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0): + ImitationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id) + + def learn_from_batch(self, batch): + current_states, _, actions, _, _, _ = self.extract_batch(batch) + + # create the inputs for the network + input = current_states + + # the targets for the network are the actions since this is supervised learning + if self.env.discrete_controls: + targets = np.eye(self.env.action_space_size)[[actions]] + else: + targets = actions + + result = self.main_network.train_and_sync_networks(input, targets) + total_loss = result[0] + + return total_loss diff --git a/agents/distributional_dqn_agent.py b/agents/distributional_dqn_agent.py new file mode 100644 index 000000000..d7c00880c --- /dev/null +++ b/agents/distributional_dqn_agent.py @@ -0,0 +1,60 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from agents.value_optimization_agent import * + + +# Distributional Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf +class DistributionalDQNAgent(ValueOptimizationAgent): + def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0): + ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id) + self.z_values = np.linspace(self.tp.agent.v_min, self.tp.agent.v_max, self.tp.agent.atoms) + + # prediction's format is (batch,actions,atoms) + def get_q_values(self, prediction): + return np.dot(prediction, self.z_values) + + def learn_from_batch(self, batch): + current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch) + + # for the action we actually took, the error is calculated by the atoms distribution + # for all other actions, the error is 0 + distributed_q_st_plus_1 = self.main_network.target_network.predict(next_states) + # initialize with the current prediction so that we will + TD_targets = self.main_network.online_network.predict(current_states) + + # only update the action that we have actually done in this transition + target_actions = np.argmax(self.get_q_values(distributed_q_st_plus_1), axis=1) + m = np.zeros((self.tp.batch_size, self.z_values.size)) + + batches = np.arange(self.tp.batch_size) + for j in range(self.z_values.size): + tzj = np.fmax(np.fmin(rewards + (1.0 - game_overs) * self.tp.agent.discount * self.z_values[j], + self.z_values[self.z_values.size - 1]), + self.z_values[0]) + bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0]) + u = (np.ceil(bj)).astype(int) + l = (np.floor(bj)).astype(int) + m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj)) + m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l)) + # total_loss = cross entropy between actual result above and predicted result for the given action + TD_targets[batches, actions] = m + + result = self.main_network.train_and_sync_networks(current_states, TD_targets) + total_loss = result[0] + + return total_loss + diff --git a/agents/human_agent.py b/agents/human_agent.py new file mode 100644 index 000000000..c75c2a2a7 --- /dev/null +++ b/agents/human_agent.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from agents.agent import * +import pygame + + +class HumanAgent(Agent): + def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0): + Agent.__init__(self, env, tuning_parameters, replicated_device, thread_id) + + self.clock = pygame.time.Clock() + self.max_fps = int(self.tp.visualization.max_fps_for_human_control) + + screen.log_title("Human Control Mode") + available_keys = self.env.get_available_keys() + if available_keys: + screen.log("Use keyboard keys to move. Press escape to quit. Available keys:") + screen.log("") + for action, key in self.env.get_available_keys(): + screen.log("\t- {}: {}".format(action, key)) + screen.separator() + + def train(self): + return 0 + + def choose_action(self, curr_state, phase=RunPhase.TRAIN): + action = self.env.get_action_from_user() + + # keep constant fps + self.clock.tick(self.max_fps) + + if not self.env.renderer.is_open: + self.save_replay_buffer_and_exit() + + return action, {"action_value": 0} + + def save_replay_buffer_and_exit(self): + replay_buffer_path = os.path.join(logger.experiments_path, 'replay_buffer.p') + self.memory.tp = None + to_pickle(self.memory, replay_buffer_path) + screen.log_title("Replay buffer was stored in {}".format(replay_buffer_path)) + exit() + + def log_to_screen(self, phase): + # log to screen + screen.log_dict( + OrderedDict([ + ("Episode", self.current_episode), + ("total reward", self.total_reward_in_current_episode), + ("steps", self.total_steps_counter) + ]), + prefix="Recording" + ) diff --git a/agents/imitation_agent.py b/agents/imitation_agent.py new file mode 100644 index 000000000..f7f5e0695 --- /dev/null +++ b/agents/imitation_agent.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from agents.agent import * + + +# Imitation Agent +class ImitationAgent(Agent): + def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0): + Agent.__init__(self, env, tuning_parameters, replicated_device, thread_id) + self.main_network = NetworkWrapper(tuning_parameters, False, self.has_global, 'main', + self.replicated_device, self.worker_device) + self.networks.append(self.main_network) + self.imitation = True + + def extract_action_values(self, prediction): + return prediction.squeeze() + + def choose_action(self, curr_state, phase=RunPhase.TRAIN): + # convert to batch so we can run it through the network + observation = np.expand_dims(np.array(curr_state['observation']), 0) + if self.tp.agent.use_measurements: + measurements = np.expand_dims(np.array(curr_state['measurements']), 0) + prediction = self.main_network.online_network.predict([observation, measurements]) + else: + prediction = self.main_network.online_network.predict(observation) + + # get action values and extract the best action from it + action_values = self.extract_action_values(prediction) + if self.env.discrete_controls: + # DISCRETE + # action = np.argmax(action_values) + action = self.evaluation_exploration_policy.get_action(action_values) + action_value = {"action_probability": action_values[action]} + else: + # CONTINUOUS + action = action_values + action_value = {} + + return action, action_value + + def log_to_screen(self, phase): + # log to screen + if phase == RunPhase.TRAIN: + # for the training phase - we log during the episode to visualize the progress in training + screen.log_dict( + OrderedDict([ + ("Worker", self.task_id), + ("Episode", self.current_episode), + ("Loss", self.loss.values[-1]), + ("Training iteration", self.training_iteration) + ]), + prefix="Training" + ) + else: + # for the evaluation phase - logging as in regular RL + Agent.log_to_screen(self, phase) diff --git a/agents/n_step_q_agent.py b/agents/n_step_q_agent.py index 07465230f..3b464a894 100644 --- a/agents/n_step_q_agent.py +++ b/agents/n_step_q_agent.py @@ -45,7 +45,7 @@ def learn_from_batch(self, batch): # 1-Step Q learning q_st_plus_1 = self.main_network.target_network.predict(next_states) - for i in reversed(xrange(num_transitions)): + for i in reversed(range(num_transitions)): state_value_head_targets[i][actions[i]] = \ rewards[i] + (1.0 - game_overs[i]) * self.tp.agent.discount * np.max(q_st_plus_1[i], 0) @@ -56,7 +56,7 @@ def learn_from_batch(self, batch): else: R = np.max(self.main_network.target_network.predict(np.expand_dims(next_states[-1], 0))) - for i in reversed(xrange(num_transitions)): + for i in reversed(range(num_transitions)): R = rewards[i] + self.tp.agent.discount * R state_value_head_targets[i][actions[i]] = R diff --git a/agents/policy_optimization_agent.py b/agents/policy_optimization_agent.py index c64dbab6d..07aac6aff 100644 --- a/agents/policy_optimization_agent.py +++ b/agents/policy_optimization_agent.py @@ -58,7 +58,7 @@ def log_to_screen(self, phase): ("steps", self.total_steps_counter), ("training iteration", self.training_iteration) ]), - prefix="Heatup" if self.in_heatup else "Training" if phase == RunPhase.TRAIN else "Testing" + prefix=phase ) def update_episode_statistics(self, episode): diff --git a/architectures/network_wrapper.py b/architectures/network_wrapper.py index a03448505..bbe6c590a 100644 --- a/architectures/network_wrapper.py +++ b/architectures/network_wrapper.py @@ -75,11 +75,14 @@ def __init__(self, tuning_parameters, has_target, has_global, name, replicated_d network_is_local=True) if not self.tp.distributed and self.tp.framework == Frameworks.TensorFlow: - self.model_saver = tf.train.Saver() + variables_to_restore = tf.global_variables() + variables_to_restore = [v for v in variables_to_restore if '/online' in v.name] + self.model_saver = tf.train.Saver(variables_to_restore) if self.tp.sess and self.tp.checkpoint_restore_dir: checkpoint = tf.train.latest_checkpoint(self.tp.checkpoint_restore_dir) screen.log_title("Loading checkpoint: {}".format(checkpoint)) self.model_saver.restore(self.tp.sess, checkpoint) + self.update_target_network() def sync(self): """ diff --git a/architectures/tensorflow_components/embedders.py b/architectures/tensorflow_components/embedders.py index 2b3212c72..6b6acd2da 100644 --- a/architectures/tensorflow_components/embedders.py +++ b/architectures/tensorflow_components/embedders.py @@ -15,15 +15,18 @@ # import tensorflow as tf +from configurations import EmbedderComplexity class InputEmbedder(object): - def __init__(self, input_size, activation_function=tf.nn.relu, name="embedder"): + def __init__(self, input_size, activation_function=tf.nn.relu, + embedder_complexity=EmbedderComplexity.Shallow, name="embedder"): self.name = name self.input_size = input_size self.activation_function = activation_function self.input = None self.output = None + self.embedder_complexity = embedder_complexity def __call__(self, prev_input_placeholder=None): with tf.variable_scope(self.get_name()): @@ -43,31 +46,77 @@ def get_name(self): class ImageEmbedder(InputEmbedder): - def __init__(self, input_size, input_rescaler=255.0, activation_function=tf.nn.relu, name="embedder"): - InputEmbedder.__init__(self, input_size, activation_function, name) + def __init__(self, input_size, input_rescaler=255.0, activation_function=tf.nn.relu, + embedder_complexity=EmbedderComplexity.Shallow, name="embedder"): + InputEmbedder.__init__(self, input_size, activation_function, embedder_complexity, name) self.input_rescaler = input_rescaler def _build_module(self): # image observation rescaled_observation_stack = self.input / self.input_rescaler - self.observation_conv1 = tf.layers.conv2d(rescaled_observation_stack, - filters=32, kernel_size=(8, 8), strides=(4, 4), - activation=self.activation_function, data_format='channels_last') - self.observation_conv2 = tf.layers.conv2d(self.observation_conv1, - filters=64, kernel_size=(4, 4), strides=(2, 2), - activation=self.activation_function, data_format='channels_last') - self.observation_conv3 = tf.layers.conv2d(self.observation_conv2, - filters=64, kernel_size=(3, 3), strides=(1, 1), - activation=self.activation_function, data_format='channels_last') - self.output = tf.contrib.layers.flatten(self.observation_conv3) + if self.embedder_complexity == EmbedderComplexity.Shallow: + # same embedder as used in the original DQN paper + self.observation_conv1 = tf.layers.conv2d(rescaled_observation_stack, + filters=32, kernel_size=(8, 8), strides=(4, 4), + activation=self.activation_function, data_format='channels_last') + self.observation_conv2 = tf.layers.conv2d(self.observation_conv1, + filters=64, kernel_size=(4, 4), strides=(2, 2), + activation=self.activation_function, data_format='channels_last') + self.observation_conv3 = tf.layers.conv2d(self.observation_conv2, + filters=64, kernel_size=(3, 3), strides=(1, 1), + activation=self.activation_function, data_format='channels_last') + + self.output = tf.contrib.layers.flatten(self.observation_conv3) + + elif self.embedder_complexity == EmbedderComplexity.Deep: + # the embedder used in the CARLA papers + self.observation_conv1 = tf.layers.conv2d(rescaled_observation_stack, + filters=32, kernel_size=(5, 5), strides=(2, 2), + activation=self.activation_function, data_format='channels_last') + self.observation_conv2 = tf.layers.conv2d(self.observation_conv1, + filters=32, kernel_size=(3, 3), strides=(1, 1), + activation=self.activation_function, data_format='channels_last') + self.observation_conv3 = tf.layers.conv2d(self.observation_conv2, + filters=64, kernel_size=(3, 3), strides=(2, 2), + activation=self.activation_function, data_format='channels_last') + self.observation_conv4 = tf.layers.conv2d(self.observation_conv3, + filters=64, kernel_size=(3, 3), strides=(1, 1), + activation=self.activation_function, data_format='channels_last') + self.observation_conv5 = tf.layers.conv2d(self.observation_conv4, + filters=128, kernel_size=(3, 3), strides=(2, 2), + activation=self.activation_function, data_format='channels_last') + self.observation_conv6 = tf.layers.conv2d(self.observation_conv5, + filters=128, kernel_size=(3, 3), strides=(1, 1), + activation=self.activation_function, data_format='channels_last') + self.observation_conv7 = tf.layers.conv2d(self.observation_conv6, + filters=256, kernel_size=(3, 3), strides=(2, 2), + activation=self.activation_function, data_format='channels_last') + self.observation_conv8 = tf.layers.conv2d(self.observation_conv7, + filters=256, kernel_size=(3, 3), strides=(1, 1), + activation=self.activation_function, data_format='channels_last') + + self.output = tf.contrib.layers.flatten(self.observation_conv8) + else: + raise ValueError("The defined embedder complexity value is invalid") class VectorEmbedder(InputEmbedder): - def __init__(self, input_size, activation_function=tf.nn.relu, name="embedder"): - InputEmbedder.__init__(self, input_size, activation_function, name) + def __init__(self, input_size, activation_function=tf.nn.relu, + embedder_complexity=EmbedderComplexity.Shallow, name="embedder"): + InputEmbedder.__init__(self, input_size, activation_function, embedder_complexity, name) def _build_module(self): # vector observation input_layer = tf.contrib.layers.flatten(self.input) - self.output = tf.layers.dense(input_layer, 256, activation=self.activation_function) + + if self.embedder_complexity == EmbedderComplexity.Shallow: + self.output = tf.layers.dense(input_layer, 256, activation=self.activation_function) + + elif self.embedder_complexity == EmbedderComplexity.Deep: + # the embedder used in the CARLA papers + self.observation_fc1 = tf.layers.dense(input_layer, 128, activation=self.activation_function) + self.observation_fc2 = tf.layers.dense(self.observation_fc1, 128, activation=self.activation_function) + self.output = tf.layers.dense(self.observation_fc2, 128, activation=self.activation_function) + else: + raise ValueError("The defined embedder complexity value is invalid") diff --git a/coach.py b/coach.py index ffddbc956..45b7382d1 100644 --- a/coach.py +++ b/coach.py @@ -37,8 +37,29 @@ cur_time = time_started.time() cur_date = time_started.date() -def get_experiment_path(general_experiments_path): - if not os.path.exists(general_experiments_path): + +def get_experiment_name(initial_experiment_name=''): + match = None + while match is None: + if initial_experiment_name == '': + experiment_name = screen.ask_input("Please enter an experiment name: ") + else: + experiment_name = initial_experiment_name + + experiment_name = experiment_name.replace(" ", "_") + match = re.match("^$|^[\w -/]{1,100}$", experiment_name) + + if match is None: + screen.error('Experiment name must be composed only of alphanumeric letters, ' + 'underscores and dashes and should not be longer than 100 characters.') + + return match.group(0) + + +def get_experiment_path(experiment_name, create_path=True): + general_experiments_path = os.path.join('./experiments/', experiment_name) + + if not os.path.exists(general_experiments_path) and create_path: os.makedirs(general_experiments_path) experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}' .format(logger.two_digits(cur_date.day), logger.two_digits(cur_date.month), @@ -52,7 +73,8 @@ def get_experiment_path(general_experiments_path): cur_time.minute, i)) i += 1 else: - os.makedirs(experiment_path) + if create_path: + os.makedirs(experiment_path) return experiment_path @@ -96,55 +118,54 @@ def check_input_and_fill_run_dict(parser): num_workers = int(re.match("^\d+$", args.num_workers).group(0)) except ValueError: screen.error("Parameter num_workers should be an integer.") - exit(1) preset_names = list_all_classes_in_module(presets) if args.preset is not None and args.preset not in preset_names: screen.error("A non-existing preset was selected. ") - exit(1) if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir): screen.error("The requested checkpoint folder to load from does not exist. ") - exit(1) if args.save_model_sec is not None: try: args.save_model_sec = int(args.save_model_sec) except ValueError: screen.error("Parameter save_model_sec should be an integer.") - exit(1) if args.preset is None and (args.agent_type is None or args.environment_type is None - or args.exploration_policy_type is None): + or args.exploration_policy_type is None) and not args.play: screen.error('When no preset is given for Coach to run, the user is expected to input the desired agent_type,' ' environment_type and exploration_policy_type to assemble a preset. ' '\nAt least one of these parameters was not given.') - exit(1) - - experiment_name = args.experiment_name - - if args.experiment_name == '': - experiment_name = screen.ask_input("Please enter an experiment name: ") + elif args.preset is None and args.play and args.environment_type is None: + screen.error('When no preset is given for Coach to run, and the user requests human control over the environment,' + ' the user is expected to input the desired environment_type and level.' + '\nAt least one of these parameters was not given.') + elif args.preset is None and args.play and args.environment_type: + args.agent_type = 'Human' + args.exploration_policy_type = 'ExplorationParameters' - experiment_name = experiment_name.replace(" ", "_") - match = re.match("^$|^\w{1,100}$", experiment_name) + # get experiment name and path + experiment_name = get_experiment_name(args.experiment_name) + experiment_path = get_experiment_path(experiment_name) - if match is None: - screen.error('Experiment name must be composed only of alphanumeric letters and underscores and should not be ' - 'longer than 100 characters.') - exit(1) - experiment_path = os.path.join('./experiments/', match.group(0)) - experiment_path = get_experiment_path(experiment_path) + if args.play and num_workers > 1: + screen.warning("Playing the game as a human is only available with a single worker. " + "The number of workers will be reduced to 1") + num_workers = 1 # fill run_dict run_dict = dict() run_dict['agent_type'] = args.agent_type run_dict['environment_type'] = args.environment_type run_dict['exploration_policy_type'] = args.exploration_policy_type + run_dict['level'] = args.level run_dict['preset'] = args.preset run_dict['custom_parameter'] = args.custom_parameter run_dict['experiment_path'] = experiment_path run_dict['framework'] = Frameworks().get(args.framework) + run_dict['play'] = args.play + run_dict['evaluate'] = args.evaluate# or args.play # multi-threading parameters run_dict['num_threads'] = num_workers @@ -197,6 +218,14 @@ def run_dict_to_json(_run_dict, task_id=''): help="(int) Number of workers for multi-process based agents, e.g. A3C", default='1', type=str) + parser.add_argument('--play', + help="(flag) Play as a human by controlling the game with the keyboard. " + "This option will save a replay buffer with the game play.", + action='store_true') + parser.add_argument('--evaluate', + help="(flag) Run evaluation only. This is a convenient way to disable " + "training in order to evaluate an existing checkpoint.", + action='store_true') parser.add_argument('-v', '--verbose', help="(flag) Don't suppress TensorFlow debug prints.", action='store_true') @@ -230,6 +259,12 @@ def run_dict_to_json(_run_dict, task_id=''): , default=None, type=str) + parser.add_argument('-lvl', '--level', + help="(string) Choose the level that will be played in the environment that was selected." + "This value will override the level parameter in the environment class." + , + default=None, + type=str) parser.add_argument('-cp', '--custom_parameter', help="(string) Semicolon separated parameters used to override specific parameters on top of" " the selected preset (or on top of the command-line assembled one). " @@ -259,7 +294,12 @@ def run_dict_to_json(_run_dict, task_id=''): tuning_parameters.task_index = 0 env_instance = create_environment(tuning_parameters) agent = eval(tuning_parameters.agent.type + '(env_instance, tuning_parameters)') - agent.improve() + + # Start the training or evaluation + if tuning_parameters.evaluate: + agent.evaluate(sys.maxsize, keep_networks_synced=True) # evaluate forever + else: + agent.improve() # Multi-threaded runs else: diff --git a/configurations.py b/configurations.py index b7f99536a..8480c1912 100644 --- a/configurations.py +++ b/configurations.py @@ -32,6 +32,11 @@ class InputTypes(object): TimedObservation = 5 +class EmbedderComplexity(object): + Shallow = 1 + Deep = 2 + + class OutputTypes(object): Q = 1 DuelingQ = 2 @@ -60,6 +65,7 @@ class AgentParameters(object): middleware_type = MiddlewareTypes.FC loss_weights = [1.0] stop_gradients_from_head = [False] + embedder_complexity = EmbedderComplexity.Shallow num_output_head_copies = 1 use_measurements = False use_accumulated_reward_as_measurement = False @@ -90,6 +96,8 @@ class AgentParameters(object): step_until_collecting_full_episodes = False targets_horizon = 'N-Step' replace_mse_with_huber_loss = False + load_memory_from_file_path = None + collect_new_data = True # PPO related params target_kl_divergence = 0.01 @@ -132,6 +140,7 @@ class EnvironmentParameters(object): reward_scaling = 1.0 reward_clipping_min = None reward_clipping_max = None + human_control = False class ExplorationParameters(object): @@ -188,6 +197,7 @@ class GeneralParameters(object): kl_divergence_constraint = 100000 num_training_iterations = 10000000000 num_heatup_steps = 1000 + heatup_using_network_decisions = False batch_size = 32 save_model_sec = None save_model_dir = None @@ -197,6 +207,7 @@ class GeneralParameters(object): learning_rate_decay_steps = 0 evaluation_episodes = 5 evaluate_every_x_episodes = 1000000 + evaluate_every_x_training_iterations = 0 rescaling_interpolation_type = 'bilinear' # setting a seed will only work for non-parallel algorithms. Parallel algorithms add uncontrollable noise in @@ -224,6 +235,7 @@ class VisualizationParameters(object): dump_signals_to_csv_every_x_episodes = 10 render = False dump_gifs = True + max_fps_for_human_control = 10 class Roboschool(EnvironmentParameters): @@ -252,7 +264,7 @@ class Bullet(EnvironmentParameters): class Atari(EnvironmentParameters): type = 'Gym' - frame_skip = 1 + frame_skip = 4 observation_stack_size = 4 desired_observation_height = 84 desired_observation_width = 84 @@ -268,6 +280,31 @@ class Doom(EnvironmentParameters): desired_observation_width = 76 +class Carla(EnvironmentParameters): + type = 'Carla' + frame_skip = 1 + observation_stack_size = 4 + desired_observation_height = 128 + desired_observation_width = 180 + normalize_observation = False + server_height = 256 + server_width = 360 + config = 'environments/CarlaSettings.ini' + level = 'town1' + verbose = True + stereo = False + semantic_segmentation = False + depth = False + episode_max_time = 100000 # miliseconds for each episode + continuous_to_bool_threshold = 0.5 + allow_braking = False + + +class Human(AgentParameters): + type = 'HumanAgent' + num_episodes_in_experience_replay = 10000000 + + class NStepQ(AgentParameters): type = 'NStepQAgent' input_types = [InputTypes.Observation] @@ -299,10 +336,12 @@ class DQN(AgentParameters): class DDQN(DQN): type = 'DDQNAgent' + class DuelingDQN(DQN): type = 'DQNAgent' output_types = [OutputTypes.DuelingQ] + class BootstrappedDQN(DQN): type = 'BootstrappedDQNAgent' num_output_head_copies = 10 @@ -314,6 +353,7 @@ class CategoricalDQN(DQN): v_min = -10.0 v_max = 10.0 atoms = 51 + neon_support = False class QuantileRegressionDQN(DQN): @@ -452,6 +492,7 @@ class ClippedPPO(AgentParameters): step_until_collecting_full_episodes = True beta_entropy = 0.01 + class DFP(AgentParameters): type = 'DFPAgent' input_types = [InputTypes.Observation, InputTypes.Measurements, InputTypes.GoalVector] @@ -485,6 +526,15 @@ class PAL(AgentParameters): neon_support = True +class BC(AgentParameters): + type = 'BCAgent' + input_types = [InputTypes.Observation] + output_types = [OutputTypes.Q] + loss_weights = [1.0] + collect_new_data = False + evaluate_every_x_training_iterations = 50000 + + class EGreedyExploration(ExplorationParameters): policy = 'EGreedy' initial_epsilon = 0.5 diff --git a/docs/docs/algorithms/imitation/bc.md b/docs/docs/algorithms/imitation/bc.md new file mode 100644 index 000000000..84e477a5b --- /dev/null +++ b/docs/docs/algorithms/imitation/bc.md @@ -0,0 +1,25 @@ +# Behavioral Cloning + +**Actions space:** Discrete|Continuous + +## Network Structure + +

+ + + +

+ + + +## Algorithm Description + +### Training the network + +The replay buffer contains the expert demonstrations for the task. +These demonstrations are given as state, action tuples, and with no reward. +The training goal is to reduce the difference between the actions predicted by the network and the actions taken by the expert for each state. + +1. Sample a batch of transitions from the replay buffer. +2. Use the current states as input to the network, and the expert actions as the targets of the network. +3. The loss function for the network is MSE, and therefore we use the Q head to minimize this loss. diff --git a/docs/docs/algorithms/value_optimization/distributional_dqn.md b/docs/docs/algorithms/value_optimization/distributional_dqn.md new file mode 100644 index 000000000..5dcc4c260 --- /dev/null +++ b/docs/docs/algorithms/value_optimization/distributional_dqn.md @@ -0,0 +1,33 @@ +# Distributional DQN + +**Actions space:** Discrete + +**References:** [A Distributional Perspective on Reinforcement Learning](https://arxiv.org/abs/1707.06887) + +## Network Structure + +

+ + + +

+ + + +## Algorithmic Description + +### Training the network + +1. Sample a batch of transitions from the replay buffer. +2. The Bellman update is projected to the set of atoms representing the $ Q $ values distribution, such that the $i-th$ component of the projected update is calculated as follows: + $$ (\Phi \hat{T} Z_{\theta}(s_t,a_t))_i=\sum_{j=0}^{N-1}\Big[1-\frac{|[\hat{T}_{z_{j}}]^{V_{MAX}}_{V_{MIN}}-z_i|}{\Delta z}\Big]^1_0 \ p_j(s_{t+1}, \pi(s_{t+1})) $$ + where: + * $[ \cdot ] $ bounds its argument in the range [a, b] + * $\hat{T}_{z_{j}}$ is the Bellman update for atom $z_j$:     $\hat{T}_{z_{j}} := r+\gamma z_j$ + + +3. Network is trained with the cross entropy loss between the resulting probability distribution and the target probability distribution. Only the target of the actions that were actually taken is updated. +4. Once in every few thousand steps, weights are copied from the online network to the target network. + + + diff --git a/docs/docs/contributing/add_env.md b/docs/docs/contributing/add_env.md index 51c51c182..f09b4dbb8 100644 --- a/docs/docs/contributing/add_env.md +++ b/docs/docs/contributing/add_env.md @@ -1,33 +1,53 @@ Adding a new environment to Coach is as easy as solving CartPole. -There a few simple steps to follow, and we will walk through them one by one. +There are a few simple steps to follow, and we will walk through them one by one. 1. Coach defines a simple API for implementing a new environment which is defined in environment/environment_wrapper.py. There are several functions to implement, but only some of them are mandatory. - Here are the mandatory ones: + Here are the important ones: - def step(self, action_idx): + def _take_action(self, action_idx): """ - Perform a single step on the environment using the given action. - :param action_idx: the action to perform on the environment - :return: A dictionary containing the observation, reward, done flag, action and measurements + An environment dependent function that sends an action to the simulator. + :param action_idx: the action to perform on the environment. + :return: None """ pass - def render(self): + def _preprocess_observation(self, observation): """ - Call the environment function for rendering to the screen. + Do initial observation preprocessing such as cropping, rgb2gray, rescale etc. + Implementing this function is optional. + :param observation: a raw observation from the environment + :return: the preprocessed observation + """ + return observation + + def _update_state(self): + """ + Updates the state from the environment. + Should update self.observation, self.reward, self.done, self.measurements and self.info + :return: None """ pass - + def _restart_environment_episode(self, force_environment_reset=False): """ - :param force_environment_reset: Force the environment to reset even if the episode is not done yet. - :return: + :param force_environment_reset: Force the environment to reset even if the episode is not done yet. + :return: """ pass + def get_rendered_image(self): + """ + Return a numpy array containing the image that will be rendered to the screen. + This can be different from the observation. For example, mujoco's observation is a measurements vector. + :return: numpy array containing the image that will be rendered to the screen + """ + return self.observation + + 2. Make sure to import the environment in environments/\_\_init\_\_.py: from doom_environment_wrapper import * diff --git a/docs/docs/img/algorithms.png b/docs/docs/img/algorithms.png index 2dc14077b..f83c1e69f 100644 Binary files a/docs/docs/img/algorithms.png and b/docs/docs/img/algorithms.png differ diff --git a/docs/docs/usage.md b/docs/docs/usage.md new file mode 100644 index 000000000..26d13b9c4 --- /dev/null +++ b/docs/docs/usage.md @@ -0,0 +1,133 @@ +# Coach Usage + +## Training an Agent + +### Single-threaded Algorithms + +This is the most common case. Just choose a preset using the `-p` flag and press enter. + +*Example:* + +`python coach.py -p CartPole_DQN` + +### Multi-threaded Algorithms + +Multi-threaded algorithms are very common this days. +They typically achieve the best results, and scale gracefully with the number of threads. +In Coach, running such algorithms is done by selecting a suitable preset, and choosing the number of threads to run using the `-n` flag. + +*Example:* + +`python coach.py -p CartPole_A3C -n 8` + +## Evaluating an Agent + +There are several options for evaluating an agent during the training: + +* For multi-threaded runs, an evaluation agent will constantly run in the background and evaluate the model during the training. + +* For single-threaded runs, it is possible to define an evaluation period through the preset. This will run several episodes of evaluation once in a while. + +Additionally, it is possible to save checkpoints of the agents networks and then run only in evaluation mode. +Saving checkpoints can be done by specifying the number of seconds between storing checkpoints using the `-s` flag. +The checkpoints will be saved into the experiment directory. +Loading a model for evaluation can be done by specifying the `-crd` flag with the experiment directory, and the `--evaluate` flag to disable training. + +*Example:* + +`python coach.py -p CartPole_DQN -s 60` +`python coach.py -p CartPole_DQN --evaluate -crd CHECKPOINT_RESTORE_DIR` + +## Playing with the Environment as a Human + +Interacting with the environment as a human can be useful for understanding its difficulties and for collecting data for imitation learning. +In Coach, this can be easily done by selecting a preset that defines the environment to use, and specifying the `--play` flag. +When the environment is loaded, the available keyboard buttons will be printed to the screen. +Pressing the escape key when finished will end the simulation and store the replay buffer in the experiment dir. + +*Example:* + +`python coach.py -p Breakout_DQN --play` + +## Learning Through Imitation Learning + +Learning through imitation of human behavior is a nice way to speedup the learning. +In Coach, this can be done in two steps - + +1. Create a dataset of demonstrations by playing with the environment as a human. + After this step, a pickle of the replay buffer containing your game play will be stored in the experiment directory. + The path to this replay buffer will be printed to the screen. + To do so, you should select an environment type and level through the command line, and specify the `--play` flag. + + *Example:* + + `python coach.py -et Doom -lvl Basic --play` + + +2. Next, use an imitation learning preset and set the replay buffer path accordingly. + The path can be set either from the command line or from the preset itself. + + *Example:* + + `python coach.py -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=\"/replay_buffer.p\"'` + + +## Visualizations + +### Rendering the Environment + +Rendering the environment can be done by using the `-r` flag. +When working with multi-threaded algorithms, the rendered image will be representing the game play of the evaluation worker. +When working with single-threaded algorithms, the rendered image will be representing the single worker which can be either training or evaluating. +Keep in mind that rendering the environment in single-threaded algorithms may slow the training to some extent. +When playing with the environment using the `--play` flag, the environment will be rendered automatically without the need for specifying the `-r` flag. + +*Example:* + +`python coach.py -p Breakout_DQN -r` + +### Dumping GIFs + +Coach allows storing GIFs of the agent game play. +To dump GIF files, use the `-dg` flag. +The files are dumped after every evaluation episode, and are saved into the experiment directory, under a gifs sub-directory. + +*Example:* + +`python coach.py -p Breakout_A3C -n 4 -dg` + +## Switching between deep learning frameworks + +Coach uses TensorFlow as its main backend framework, but it also supports neon for some of the algorithms. +By default, TensorFlow will be used. It is possible to switch to neon using the `-f` flag. + +*Example:* + +`python coach.py -p Doom_Basic_DQN -f neon` + +## Additional Flags + +There are several convenient flags which are important to know about. +Here we will list most of the flags, but these can be updated from time to time. +The most up to date description can be found by using the `-h` flag. + + +|Flag |Type |Description | +|-------------------------------|----------|--------------| +|`-p PRESET`, ``--preset PRESET`|string |Name of a preset to run (as configured in presets.py) | +|`-l`, `--list` |flag |List all available presets| +|`-e EXPERIMENT_NAME`, `--experiment_name EXPERIMENT_NAME`|string|Experiment name to be used to store the results.| +|`-r`, `--render` |flag |Render environment| +|`-f FRAMEWORK`, `--framework FRAMEWORK`|string|Neural network framework. Available values: tensorflow, neon| +|`-n NUM_WORKERS`, `--num_workers NUM_WORKERS`|int|Number of workers for multi-process based agents, e.g. A3C| +|`--play` |flag |Play as a human by controlling the game with the keyboard. This option will save a replay buffer with the game play.| +|`--evaluate` |flag |Run evaluation only. This is a convenient way to disable training in order to evaluate an existing checkpoint.| +|`-v`, `--verbose` |flag |Don't suppress TensorFlow debug prints.| +|`-s SAVE_MODEL_SEC`, `--save_model_sec SAVE_MODEL_SEC`|int|Time in seconds between saving checkpoints of the model.| +|`-crd CHECKPOINT_RESTORE_DIR`, `--checkpoint_restore_dir CHECKPOINT_RESTORE_DIR`|string|Path to a folder containing a checkpoint to restore the model from.| +|`-dg`, `--dump_gifs` |flag |Enable the gif saving functionality.| +|`-at AGENT_TYPE`, `--agent_type AGENT_TYPE`|string|Choose an agent type class to override on top of the selected preset. If no preset is defined, a preset can be set from the command-line by combining settings which are set by using `--agent_type`, `--experiment_type`, `--environemnt_type`| +|`-et ENVIRONMENT_TYPE`, `--environment_type ENVIRONMENT_TYPE`|string|Choose an environment type class to override on top of the selected preset. If no preset is defined, a preset can be set from the command-line by combining settings which are set by using `--agent_type`, `--experiment_type`, `--environemnt_type`| +|`-ept EXPLORATION_POLICY_TYPE`, `--exploration_policy_type EXPLORATION_POLICY_TYPE`|string|Choose an exploration policy type class to override on top of the selected preset.If no preset is defined, a preset can be set from the command-line by combining settings which are set by using `--agent_type`, `--experiment_type`, `--environemnt_type`| +|`-lvl LEVEL`, `--level LEVEL` |string|Choose the level that will be played in the environment that was selected. This value will override the level parameter in the environment class.| +|`-cp CUSTOM_PARAMETER`, `--custom_parameter CUSTOM_PARAMETER`|string| Semicolon separated parameters used to override specific parameters on top of the selected preset (or on top of the command-line assembled one). Whenever a parameter value is a string, it should be inputted as `'\"string\"'`. For ex.: `"visualization.render=False;` `num_training_iterations=500;` `optimizer='rmsprop'"`| \ No newline at end of file diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 48fca7d3f..de96d9555 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -11,6 +11,7 @@ extra_css: [extra.css] pages: - Home : index.md - Design: design.md +- Usage: usage.md - Algorithms: - 'DQN' : algorithms/value_optimization/dqn.md - 'Double DQN' : algorithms/value_optimization/double_dqn.md @@ -28,6 +29,7 @@ pages: - 'Proximal Policy Optimization' : algorithms/policy_optimization/ppo.md - 'Clipped Proximal Policy Optimization' : algorithms/policy_optimization/cppo.md - 'Direct Future Prediction' : algorithms/other/dfp.md + - 'Behavioral Cloning' : algorithms/imitation/bc.md - Coach Dashboard : 'dashboard.md' - Contributing : diff --git a/environments/CarlaSettings.ini b/environments/CarlaSettings.ini new file mode 100644 index 000000000..236500f76 --- /dev/null +++ b/environments/CarlaSettings.ini @@ -0,0 +1,62 @@ +[CARLA/Server] +; If set to false, a mock controller will be used instead of waiting for a real +; client to connect. +UseNetworking=true +; Ports to use for the server-client communication. This can be overridden by +; the command-line switch `-world-port=N`, write and read ports will be set to +; N+1 and N+2 respectively. +WorldPort=2000 +; Time-out in milliseconds for the networking operations. +ServerTimeOut=10000000000 +; In synchronous mode, CARLA waits every frame until the control from the client +; is received. +SynchronousMode=true +; Send info about every non-player agent in the scene every frame, the +; information is attached to the measurements message. This includes other +; vehicles, pedestrians and traffic signs. Disabled by default to improve +; performance. +SendNonPlayerAgentsInfo=false + +[CARLA/LevelSettings] +; Path of the vehicle class to be used for the player. Leave empty for default. +; Paths follow the pattern "/Game/Blueprints/Vehicles/Mustang/Mustang.Mustang_C" +PlayerVehicle= +; Number of non-player vehicles to be spawned into the level. +NumberOfVehicles=15 +; Number of non-player pedestrians to be spawned into the level. +NumberOfPedestrians=30 +; Index of the weather/lighting presets to use. If negative, the default presets +; of the map will be used. +WeatherId=1 +; Seeds for the pseudo-random number generators. +SeedVehicles=123456789 +SeedPedestrians=123456789 + +[CARLA/SceneCapture] +; Names of the cameras to be attached to the player, comma-separated, each of +; them should be defined in its own subsection. E.g., Uncomment next line to add +; a camera called MyCamera to the vehicle + +Cameras=CameraRGB + +; Now, every camera we added needs to be defined it in its own subsection. +[CARLA/SceneCapture/CameraRGB] +; Post-processing effect to be applied. Valid values: +; * None No effects applied. +; * SceneFinal Post-processing present at scene (bloom, fog, etc). +; * Depth Depth map ground-truth only. +; * SemanticSegmentation Semantic segmentation ground-truth only. +PostProcessing=SceneFinal +; Size of the captured image in pixels. +ImageSizeX=360 +ImageSizeY=256 +; Camera (horizontal) field of view in degrees. +CameraFOV=90 +; Position of the camera relative to the car in centimeters. +CameraPositionX=200 +CameraPositionY=0 +CameraPositionZ=140 +; Rotation of the camera relative to the car in degrees. +CameraRotationPitch=0 +CameraRotationRoll=0 +CameraRotationYaw=0 diff --git a/environments/__init__.py b/environments/__init__.py index 443cdde97..1dd8d1dc8 100644 --- a/environments/__init__.py +++ b/environments/__init__.py @@ -15,13 +15,16 @@ # from logger import * -from utils import Enum +from utils import Enum, get_open_port from environments.gym_environment_wrapper import * from environments.doom_environment_wrapper import * +from environments.carla_environment_wrapper import * + class EnvTypes(Enum): Doom = "DoomEnvironmentWrapper" Gym = "GymEnvironmentWrapper" + Carla = "CarlaEnvironmentWrapper" def create_environment(tuning_parameters): diff --git a/environments/carla_environment_wrapper.py b/environments/carla_environment_wrapper.py new file mode 100644 index 000000000..ee9c3ed10 --- /dev/null +++ b/environments/carla_environment_wrapper.py @@ -0,0 +1,230 @@ +import sys +from os import path, environ + +try: + sys.path.append(path.join(environ.get('CARLA_ROOT'), 'PythonClient')) + from carla.client import CarlaClient + from carla.settings import CarlaSettings + from carla.tcp import TCPConnectionError + from carla.sensor import Camera + from carla.client import VehicleControl +except ImportError: + from logger import failed_imports + failed_imports.append("CARLA") + +import numpy as np +import time +import logging +import subprocess +import signal +from environments.environment_wrapper import EnvironmentWrapper +from utils import * +from logger import screen, logger +from PIL import Image + + +# enum of the available levels and their path +class CarlaLevel(Enum): + TOWN1 = "/Game/Maps/Town01" + TOWN2 = "/Game/Maps/Town02" + +key_map = { + 'BRAKE': (274,), # down arrow + 'GAS': (273,), # up arrow + 'TURN_LEFT': (276,), # left arrow + 'TURN_RIGHT': (275,), # right arrow + 'GAS_AND_TURN_LEFT': (273, 276), + 'GAS_AND_TURN_RIGHT': (273, 275), + 'BRAKE_AND_TURN_LEFT': (274, 276), + 'BRAKE_AND_TURN_RIGHT': (274, 275), +} + + +class CarlaEnvironmentWrapper(EnvironmentWrapper): + def __init__(self, tuning_parameters): + EnvironmentWrapper.__init__(self, tuning_parameters) + + self.tp = tuning_parameters + + # server configuration + self.server_height = self.tp.env.server_height + self.server_width = self.tp.env.server_width + self.port = get_open_port() + self.host = 'localhost' + self.map = CarlaLevel().get(self.tp.env.level) + + # client configuration + self.verbose = self.tp.env.verbose + self.depth = self.tp.env.depth + self.stereo = self.tp.env.stereo + self.semantic_segmentation = self.tp.env.semantic_segmentation + self.height = self.server_height * (1 + int(self.depth) + int(self.semantic_segmentation)) + self.width = self.server_width * (1 + int(self.stereo)) + self.size = (self.width, self.height) + + self.config = self.tp.env.config + if self.config: + # load settings from file + with open(self.config, 'r') as fp: + self.settings = fp.read() + else: + # hard coded settings + self.settings = CarlaSettings() + self.settings.set( + SynchronousMode=True, + SendNonPlayerAgentsInfo=False, + NumberOfVehicles=15, + NumberOfPedestrians=30, + WeatherId=1) + self.settings.randomize_seeds() + + # add cameras + camera = Camera('CameraRGB') + camera.set_image_size(self.width, self.height) + camera.set_position(200, 0, 140) + camera.set_rotation(0, 0, 0) + self.settings.add_sensor(camera) + + # open the server + self.server = self._open_server() + + logging.disable(40) + + # open the client + self.game = CarlaClient(self.host, self.port, timeout=99999999) + self.game.connect() + scene = self.game.load_settings(self.settings) + + # get available start positions + positions = scene.player_start_spots + self.num_pos = len(positions) + self.iterator_start_positions = 0 + + # action space + self.discrete_controls = False + self.action_space_size = 2 + self.action_space_high = [1, 1] + self.action_space_low = [-1, -1] + self.action_space_abs_range = np.maximum(np.abs(self.action_space_low), np.abs(self.action_space_high)) + self.steering_strength = 0.5 + self.gas_strength = 1.0 + self.brake_strength = 0.5 + self.actions = {0: [0., 0.], + 1: [0., -self.steering_strength], + 2: [0., self.steering_strength], + 3: [self.gas_strength, 0.], + 4: [-self.brake_strength, 0], + 5: [self.gas_strength, -self.steering_strength], + 6: [self.gas_strength, self.steering_strength], + 7: [self.brake_strength, -self.steering_strength], + 8: [self.brake_strength, self.steering_strength]} + self.actions_description = ['NO-OP', 'TURN_LEFT', 'TURN_RIGHT', 'GAS', 'BRAKE', + 'GAS_AND_TURN_LEFT', 'GAS_AND_TURN_RIGHT', + 'BRAKE_AND_TURN_LEFT', 'BRAKE_AND_TURN_RIGHT'] + for idx, action in enumerate(self.actions_description): + for key in key_map.keys(): + if action == key: + self.key_to_action[key_map[key]] = idx + self.num_speedup_steps = 30 + + # measurements + self.measurements_size = (1,) + self.autopilot = None + + # env initialization + self.reset(True) + + # render + if self.is_rendered: + image = self.get_rendered_image() + self.renderer.create_screen(image.shape[1], image.shape[0]) + + def _open_server(self): + log_path = path.join(logger.experiments_path, "CARLA_LOG_{}.txt".format(self.port)) + with open(log_path, "wb") as out: + cmd = [path.join(environ.get('CARLA_ROOT'), 'CarlaUE4.sh'), self.map, + "-benchmark", "-carla-server", "-fps=10", "-world-port={}".format(self.port), + "-windowed -ResX={} -ResY={}".format(self.server_width, self.server_height), + "-carla-no-hud"] + if self.config: + cmd.append("-carla-settings={}".format(self.config)) + p = subprocess.Popen(cmd, stdout=out, stderr=out) + + return p + + def _close_server(self): + os.killpg(os.getpgid(self.server.pid), signal.SIGKILL) + + def _update_state(self): + # get measurements and observations + measurements = [] + while type(measurements) == list: + measurements, sensor_data = self.game.read_data() + self.observation = sensor_data['CameraRGB'].data + + self.location = (measurements.player_measurements.transform.location.x, + measurements.player_measurements.transform.location.y, + measurements.player_measurements.transform.location.z) + + is_collision = measurements.player_measurements.collision_vehicles != 0 \ + or measurements.player_measurements.collision_pedestrians != 0 \ + or measurements.player_measurements.collision_other != 0 + + speed_reward = measurements.player_measurements.forward_speed - 1 + if speed_reward > 30.: + speed_reward = 30. + self.reward = speed_reward \ + - (measurements.player_measurements.intersection_otherlane * 5) \ + - (measurements.player_measurements.intersection_offroad * 5) \ + - is_collision * 100 \ + - np.abs(self.control.steer) * 10 + + # update measurements + self.measurements = [measurements.player_measurements.forward_speed] + self.autopilot = measurements.player_measurements.autopilot_control + + # action_p = ['%.2f' % member for member in [self.control.throttle, self.control.steer]] + # screen.success('REWARD: %.2f, ACTIONS: %s' % (self.reward, action_p)) + + if (measurements.game_timestamp >= self.tp.env.episode_max_time) or is_collision: + # screen.success('EPISODE IS DONE. GameTime: {}, Collision: {}'.format(str(measurements.game_timestamp), + # str(is_collision))) + self.done = True + + def _take_action(self, action_idx): + if type(action_idx) == int: + action = self.actions[action_idx] + else: + action = action_idx + self.last_action_idx = action + + self.control = VehicleControl() + self.control.throttle = np.clip(action[0], 0, 1) + self.control.steer = np.clip(action[1], -1, 1) + self.control.brake = np.abs(np.clip(action[0], -1, 0)) + if not self.tp.env.allow_braking: + self.control.brake = 0 + self.control.hand_brake = False + self.control.reverse = False + + self.game.send_control(self.control) + + def _restart_environment_episode(self, force_environment_reset=False): + self.iterator_start_positions += 1 + if self.iterator_start_positions >= self.num_pos: + self.iterator_start_positions = 0 + + try: + self.game.start_episode(self.iterator_start_positions) + except: + self.game.connect() + self.game.start_episode(self.iterator_start_positions) + + # start the game with some initial speed + observation = None + for i in range(self.num_speedup_steps): + observation = self.step([1.0, 0])['observation'] + self.observation = observation + + return observation + diff --git a/environments/doom_environment_wrapper.py b/environments/doom_environment_wrapper.py index ecdc9e001..997a067c4 100644 --- a/environments/doom_environment_wrapper.py +++ b/environments/doom_environment_wrapper.py @@ -25,6 +25,7 @@ from environments.environment_wrapper import EnvironmentWrapper from os import path, environ from utils import * +from logger import * # enum of the available levels and their path @@ -39,6 +40,43 @@ class DoomLevel(Enum): DEFEND_THE_LINE = "defend_the_line.cfg" DEADLY_CORRIDOR = "deadly_corridor.cfg" +key_map = { + 'NO-OP': 96, # ` + 'ATTACK': 13, # enter + 'CROUCH': 306, # ctrl + 'DROP_SELECTED_ITEM': ord("t"), + 'DROP_SELECTED_WEAPON': ord("t"), + 'JUMP': 32, # spacebar + 'LAND': ord("l"), + 'LOOK_DOWN': 274, # down arrow + 'LOOK_UP': 273, # up arrow + 'MOVE_BACKWARD': ord("s"), + 'MOVE_DOWN': ord("s"), + 'MOVE_FORWARD': ord("w"), + 'MOVE_LEFT': 276, + 'MOVE_RIGHT': 275, + 'MOVE_UP': ord("w"), + 'RELOAD': ord("r"), + 'SELECT_NEXT_WEAPON': ord("q"), + 'SELECT_PREV_WEAPON': ord("e"), + 'SELECT_WEAPON0': ord("0"), + 'SELECT_WEAPON1': ord("1"), + 'SELECT_WEAPON2': ord("2"), + 'SELECT_WEAPON3': ord("3"), + 'SELECT_WEAPON4': ord("4"), + 'SELECT_WEAPON5': ord("5"), + 'SELECT_WEAPON6': ord("6"), + 'SELECT_WEAPON7': ord("7"), + 'SELECT_WEAPON8': ord("8"), + 'SELECT_WEAPON9': ord("9"), + 'SPEED': 304, # shift + 'STRAFE': 9, # tab + 'TURN180': ord("u"), + 'TURN_LEFT': ord("a"), # left arrow + 'TURN_RIGHT': ord("d"), # right arrow + 'USE': ord("f"), +} + class DoomEnvironmentWrapper(EnvironmentWrapper): def __init__(self, tuning_parameters): @@ -49,26 +87,42 @@ def __init__(self, tuning_parameters): self.scenarios_dir = path.join(environ.get('VIZDOOM_ROOT'), 'scenarios') self.game = vizdoom.DoomGame() self.game.load_config(path.join(self.scenarios_dir, self.level)) - self.game.set_window_visible(self.is_rendered) + self.game.set_window_visible(False) self.game.add_game_args("+vid_forcesurface 1") - if self.is_rendered: + + self.wait_for_explicit_human_action = True + if self.human_control: + self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_640X480) + self.renderer.create_screen(640, 480) + elif self.is_rendered: self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_320X240) + self.renderer.create_screen(320, 240) else: # lower resolution since we actually take only 76x60 and we don't need to render self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_160X120) + self.game.set_render_hud(False) self.game.set_render_crosshair(False) self.game.set_render_decals(False) self.game.set_render_particles(False) self.game.init() + # action space self.action_space_abs_range = 0 self.actions = {} - self.action_space_size = self.game.get_available_buttons_size() - for action_idx in range(self.action_space_size): - self.actions[action_idx] = [0] * self.action_space_size - self.actions[action_idx][action_idx] = 1 - self.actions_description = [str(action) for action in self.game.get_available_buttons()] + self.action_space_size = self.game.get_available_buttons_size() + 1 + self.action_vector_size = self.action_space_size - 1 + self.actions[0] = [0] * self.action_vector_size + for action_idx in range(self.action_vector_size): + self.actions[action_idx + 1] = [0] * self.action_vector_size + self.actions[action_idx + 1][action_idx] = 1 + self.actions_description = ['NO-OP'] + self.actions_description += [str(action).split(".")[1] for action in self.game.get_available_buttons()] + for idx, action in enumerate(self.actions_description): + if action in key_map.keys(): + self.key_to_action[(key_map[action],)] = idx + + # measurement self.measurements_size = self.game.get_state().game_variables.shape self.width = self.game.get_screen_width() @@ -77,27 +131,17 @@ def __init__(self, tuning_parameters): self.game.set_seed(self.tp.seed) self.reset() - def _update_observation_and_measurements(self): + def _update_state(self): # extract all data from the current state state = self.game.get_state() if state is not None and state.screen_buffer is not None: - self.observation = self._preprocess_observation(state.screen_buffer) + self.observation = state.screen_buffer self.measurements = state.game_variables + self.reward = self.game.get_last_reward() self.done = self.game.is_episode_finished() - def step(self, action_idx): - self.reward = 0 - for frame in range(self.tp.env.frame_skip): - self.reward += self.game.make_action(self._idx_to_action(action_idx)) - self._update_observation_and_measurements() - if self.done: - break - - return {'observation': self.observation, - 'reward': self.reward, - 'done': self.done, - 'action': action_idx, - 'measurements': self.measurements} + def _take_action(self, action_idx): + self.game.make_action(self._idx_to_action(action_idx), self.frame_skip) def _preprocess_observation(self, observation): if observation is None: @@ -108,3 +152,5 @@ def _preprocess_observation(self, observation): def _restart_environment_episode(self, force_environment_reset=False): self.game.new_episode() + + diff --git a/environments/environment_wrapper.py b/environments/environment_wrapper.py index ac73613f8..1491c5c59 100644 --- a/environments/environment_wrapper.py +++ b/environments/environment_wrapper.py @@ -17,6 +17,9 @@ import numpy as np from utils import * from configurations import Preset +from renderer import Renderer +import operator +import time class EnvironmentWrapper(object): @@ -31,13 +34,19 @@ def __init__(self, tuning_parameters): self.observation = [] self.reward = 0 self.done = False + self.default_action = 0 self.last_action_idx = 0 + self.episode_idx = 0 + self.last_episode_time = time.time() self.measurements = [] + self.info = [] self.action_space_low = 0 self.action_space_high = 0 self.action_space_abs_range = 0 + self.actions_description = {} self.discrete_controls = True self.action_space_size = 0 + self.key_to_action = {} self.width = 1 self.height = 1 self.is_state_type_image = True @@ -50,17 +59,11 @@ def __init__(self, tuning_parameters): self.is_rendered = self.tp.visualization.render self.seed = self.tp.seed self.frame_skip = self.tp.env.frame_skip - - def _update_observation_and_measurements(self): - # extract all the available measurments (ovservation, depthmap, lives, ammo etc.) - pass - - def _restart_environment_episode(self, force_environment_reset=False): - """ - :param force_environment_reset: Force the environment to reset even if the episode is not done yet. - :return: - """ - pass + self.human_control = self.tp.env.human_control + self.wait_for_explicit_human_action = False + self.is_rendered = self.is_rendered or self.human_control + self.game_is_open = True + self.renderer = Renderer() def _idx_to_action(self, action_idx): """ @@ -71,13 +74,43 @@ def _idx_to_action(self, action_idx): """ return self.actions[action_idx] - def _preprocess_observation(self, observation): + def _action_to_idx(self, action): """ - Do initial observation preprocessing such as cropping, rgb2gray, rescale etc. - :param observation: a raw observation from the environment - :return: the preprocessed observation + Convert an environment action to one of the available actions of the wrapper. + For example, if the available actions are 4,5,6 then this function will map 4->0, 5->1, 6->2 + :param action: the environment action + :return: an action index between 0 and self.action_space_size - 1, or -1 if the action does not exist """ - pass + for key, val in self.actions.items(): + if val == action: + return key + return -1 + + def get_action_from_user(self): + """ + Get an action from the user keyboard + :return: action index + """ + if self.wait_for_explicit_human_action: + while len(self.renderer.pressed_keys) == 0: + self.renderer.get_events() + + if self.key_to_action == {}: + # the keys are the numbers on the keyboard corresponding to the action index + if len(self.renderer.pressed_keys) > 0: + action_idx = self.renderer.pressed_keys[0] - ord("1") + if 0 <= action_idx < self.action_space_size: + return action_idx + else: + # the keys are mapped through the environment to more intuitive keyboard keys + # key = tuple(self.renderer.pressed_keys) + # for key in self.renderer.pressed_keys: + for env_keys in self.key_to_action.keys(): + if set(env_keys) == set(self.renderer.pressed_keys): + return self.key_to_action[env_keys] + + # return the default action 0 so that the environment will continue running + return self.default_action def step(self, action_idx): """ @@ -85,13 +118,29 @@ def step(self, action_idx): :param action_idx: the action to perform on the environment :return: A dictionary containing the observation, reward, done flag, action and measurements """ - pass + self.last_action_idx = action_idx + + self._take_action(action_idx) + + self._update_state() + + if self.is_rendered: + self.render() + + self.observation = self._preprocess_observation(self.observation) + + return {'observation': self.observation, + 'reward': self.reward, + 'done': self.done, + 'action': self.last_action_idx, + 'measurements': self.measurements, + 'info': self.info} def render(self): """ Call the environment function for rendering to the screen """ - pass + self.renderer.render_image(self.get_rendered_image()) def reset(self, force_environment_reset=False): """ @@ -100,15 +149,25 @@ def reset(self, force_environment_reset=False): :return: A dictionary containing the observation, reward, done flag, action and measurements """ self._restart_environment_episode(force_environment_reset) + self.last_episode_time = time.time() self.done = False + self.episode_idx += 1 self.reward = 0.0 self.last_action_idx = 0 - self._update_observation_and_measurements() + self._update_state() + + # render before the preprocessing of the observation, so that the image will be in its original quality + if self.is_rendered: + self.render() + + self.observation = self._preprocess_observation(self.observation) + return {'observation': self.observation, 'reward': self.reward, 'done': self.done, 'action': self.last_action_idx, - 'measurements': self.measurements} + 'measurements': self.measurements, + 'info': self.info} def get_random_action(self): """ @@ -129,10 +188,62 @@ def change_phase(self, phase): """ self.phase = phase + def get_available_keys(self): + """ + Return a list of tuples mapping between action names and the keyboard key that triggers them + :return: a list of tuples mapping between action names and the keyboard key that triggers them + """ + available_keys = [] + if self.key_to_action != {}: + for key, idx in sorted(self.key_to_action.items(), key=operator.itemgetter(1)): + if key != (): + key_names = [self.renderer.get_key_names([k])[0] for k in key] + available_keys.append((self.actions_description[idx], ' + '.join(key_names))) + elif self.discrete_controls: + for action in range(self.action_space_size): + available_keys.append(("Action {}".format(action + 1), action + 1)) + return available_keys + + # The following functions define the interaction with the environment. + # Any new environment that inherits the EnvironmentWrapper class should use these signatures. + # Some of these functions are optional - please read their description for more details. + + def _take_action(self, action_idx): + """ + An environment dependent function that sends an action to the simulator. + :param action_idx: the action to perform on the environment + :return: None + """ + pass + + def _preprocess_observation(self, observation): + """ + Do initial observation preprocessing such as cropping, rgb2gray, rescale etc. + Implementing this function is optional. + :param observation: a raw observation from the environment + :return: the preprocessed observation + """ + return observation + + def _update_state(self): + """ + Updates the state from the environment. + Should update self.observation, self.reward, self.done, self.measurements and self.info + :return: None + """ + pass + + def _restart_environment_episode(self, force_environment_reset=False): + """ + :param force_environment_reset: Force the environment to reset even if the episode is not done yet. + :return: + """ + pass + def get_rendered_image(self): """ Return a numpy array containing the image that will be rendered to the screen. This can be different from the observation. For example, mujoco's observation is a measurements vector. :return: numpy array containing the image that will be rendered to the screen """ - return self.observation + return self.observation \ No newline at end of file diff --git a/environments/gym_environment_wrapper.py b/environments/gym_environment_wrapper.py index 0721fc94e..e6320cf98 100644 --- a/environments/gym_environment_wrapper.py +++ b/environments/gym_environment_wrapper.py @@ -15,8 +15,10 @@ # import sys +from logger import * import gym import numpy as np +import time try: import roboschool from OpenGL import GL @@ -40,8 +42,6 @@ from utils import force_list, RunPhase from environments.environment_wrapper import EnvironmentWrapper -i = 0 - class GymEnvironmentWrapper(EnvironmentWrapper): def __init__(self, tuning_parameters): @@ -53,29 +53,30 @@ def __init__(self, tuning_parameters): self.env.seed(self.seed) # self.env_spec = gym.spec(self.env_id) + self.env.frameskip = self.frame_skip self.discrete_controls = type(self.env.action_space) != gym.spaces.box.Box - # pybullet requires rendering before resetting the environment, but other gym environments (Pendulum) will crash - try: - if self.is_rendered: - self.render() - except: - pass - - o = self.reset(True)['observation'] + self.observation = self.reset(True)['observation'] # render if self.is_rendered: - self.render() + image = self.get_rendered_image() + scale = 1 + if self.human_control: + scale = 2 + self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale) - self.is_state_type_image = len(o.shape) > 1 + self.is_state_type_image = len(self.observation.shape) > 1 if self.is_state_type_image: - self.width = o.shape[1] - self.height = o.shape[0] + self.width = self.observation.shape[1] + self.height = self.observation.shape[0] else: - self.width = o.shape[0] + self.width = self.observation.shape[0] + # action space self.actions_description = {} + if hasattr(self.env.unwrapped, 'get_action_meanings'): + self.actions_description = self.env.unwrapped.get_action_meanings() if self.discrete_controls: self.action_space_size = self.env.action_space.n self.action_space_abs_range = 0 @@ -85,34 +86,31 @@ def __init__(self, tuning_parameters): self.action_space_low = self.env.action_space.low self.action_space_abs_range = np.maximum(np.abs(self.action_space_low), np.abs(self.action_space_high)) self.actions = {i: i for i in range(self.action_space_size)} + self.key_to_action = {} + if hasattr(self.env.unwrapped, 'get_keys_to_action'): + self.key_to_action = self.env.unwrapped.get_keys_to_action() + + # measurements self.timestep_limit = self.env.spec.timestep_limit - self.current_ale_lives = 0 self.measurements_size = len(self.step(0)['info'].keys()) - # env intialization - self.observation = o - self.reward = 0 - self.done = False - self.last_action = self.actions[0] - - def render(self): - self.env.render() - - def step(self, action_idx): + def _update_state(self): + if hasattr(self.env.env, 'ale'): + if self.phase == RunPhase.TRAIN and hasattr(self, 'current_ale_lives'): + # signal termination for life loss + if self.current_ale_lives != self.env.env.ale.lives(): + self.done = True + self.current_ale_lives = self.env.env.ale.lives() + def _take_action(self, action_idx): if action_idx is None: action_idx = self.last_action_idx - self.last_action_idx = action_idx - if self.discrete_controls: action = self.actions[action_idx] else: action = action_idx - if hasattr(self.env.env, 'ale'): - prev_ale_lives = self.env.env.ale.lives() - # pendulum-v0 for example expects a list if not self.discrete_controls: # catching cases where the action for continuous control is a number instead of a list the @@ -128,42 +126,26 @@ def step(self, action_idx): self.observation, self.reward, self.done, self.info = self.env.step(action) - if hasattr(self.env.env, 'ale') and self.phase == RunPhase.TRAIN: - # signal termination for breakout life loss - if prev_ale_lives != self.env.env.ale.lives(): - self.done = True - + def _preprocess_observation(self, observation): if any(env in self.env_id for env in ["Breakout", "Pong"]): # crop image - self.observation = self.observation[34:195, :, :] - - if self.is_rendered: - self.render() - - return {'observation': self.observation, - 'reward': self.reward, - 'done': self.done, - 'action': self.last_action_idx, - 'info': self.info} + observation = observation[34:195, :, :] + return observation def _restart_environment_episode(self, force_environment_reset=False): # prevent reset of environment if there are ale lives left - if "Breakout" in self.env_id and self.env.env.ale.lives() > 0 and not force_environment_reset: + if (hasattr(self.env.env, 'ale') and self.env.env.ale.lives() > 0) \ + and not force_environment_reset and not self.env._past_limit(): return self.observation if self.seed: self.env.seed(self.seed) - observation = self.env.reset() - while observation is None: - observation = self.step(0)['observation'] - if "Breakout" in self.env_id: - # crop image - observation = observation[34:195, :, :] + self.observation = self.env.reset() + while self.observation is None: + self.step(0) - self.observation = observation - - return observation + return self.observation def get_rendered_image(self): return self.env.render(mode='rgb_array') diff --git a/img/algorithms.png b/img/algorithms.png index 2dc14077b..f83c1e69f 100644 Binary files a/img/algorithms.png and b/img/algorithms.png differ diff --git a/img/ant.gif b/img/ant.gif deleted file mode 100644 index 919d786eb..000000000 Binary files a/img/ant.gif and /dev/null differ diff --git a/img/carla.gif b/img/carla.gif new file mode 100644 index 000000000..af61a9016 Binary files /dev/null and b/img/carla.gif differ diff --git a/img/doom.gif b/img/doom.gif deleted file mode 100644 index 53853468e..000000000 Binary files a/img/doom.gif and /dev/null differ diff --git a/img/doom_deathmatch.gif b/img/doom_deathmatch.gif new file mode 100644 index 000000000..e949b088f Binary files /dev/null and b/img/doom_deathmatch.gif differ diff --git a/img/minitaur.gif b/img/minitaur.gif deleted file mode 100644 index bff226494..000000000 Binary files a/img/minitaur.gif and /dev/null differ diff --git a/img/montezuma.gif b/img/montezuma.gif new file mode 100644 index 000000000..999798a28 Binary files /dev/null and b/img/montezuma.gif differ diff --git a/install.sh b/install.sh index 2c3606b9f..4f2a77523 100755 --- a/install.sh +++ b/install.sh @@ -192,10 +192,14 @@ if [ ${INSTALL_NEON} -eq 1 ]; then # Neon sudo -E apt-get install libhdf5-dev libyaml-dev pkg-config clang virtualenv libcurl4-openssl-dev libopencv-dev libsox-dev -y - git clone https://github.com/NervanaSystems/neon.git - cd neon && make sysinstall -j - cd .. + pip3 install nervananeon +fi + +if ! [ -x "$(command -v nvidia-smi)" ]; then + # Intel Optimized TensorFlow + pip3 install https://anaconda.org/intel/tensorflow/1.3.0/download/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl +else + # GPU supported TensorFlow + pip3 install tensorflow-gpu fi -# Intel Optimized TensorFlow -pip3 install https://anaconda.org/intel/tensorflow/1.3.0/download/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl diff --git a/logger.py b/logger.py index 11458a845..caa0265a1 100644 --- a/logger.py +++ b/logger.py @@ -18,6 +18,7 @@ import os from pprint import pprint import threading +from subprocess import Popen, PIPE import time import datetime from six.moves import input @@ -61,7 +62,7 @@ def separator(self): print("") def log(self, data): - print(self.name + ": " + data) + print(data) def log_dict(self, dict, prefix=""): str = "{}{}{} - ".format(Colors.PURPLE, prefix, Colors.END) @@ -78,8 +79,10 @@ def success(self, text): def warning(self, text): print("{}{}{}".format(Colors.YELLOW, text, Colors.END)) - def error(self, text): + def error(self, text, crash=True): print("{}{}{}".format(Colors.RED, text, Colors.END)) + if crash: + exit(1) def ask_input(self, title): return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END)) diff --git a/memories/episodic_experience_replay.py b/memories/episodic_experience_replay.py index b5d92f540..1d7e61f9d 100644 --- a/memories/episodic_experience_replay.py +++ b/memories/episodic_experience_replay.py @@ -74,7 +74,9 @@ def sample_last_n_episodes(self, n): def sample(self, size): assert self.num_transitions_in_complete_episodes() > size, \ - 'There are not enough transitions in the replay buffer' + 'There are not enough transitions in the replay buffer. ' \ + 'Available transitions: {}. Requested transitions: {}.'\ + .format(self.num_transitions_in_complete_episodes(), size) batch = [] transitions_idx = np.random.randint(self.num_transitions_in_complete_episodes(), size=size) for i in transitions_idx: diff --git a/memories/memory.py b/memories/memory.py index 5d32c062e..4c479e321 100644 --- a/memories/memory.py +++ b/memories/memory.py @@ -73,6 +73,7 @@ def update_returns(self, discount, is_bootstrapped=False, n_step_return=-1): if n_step_return == -1 or n_step_return > self.length(): n_step_return = self.length() rewards = np.array([t.reward for t in self.transitions]) + rewards = rewards.astype('float') total_return = rewards.copy() current_discount = discount for i in range(1, n_step_return): @@ -123,12 +124,30 @@ def to_batch(self): class Transition(object): - def __init__(self, state, action, reward, next_state, game_over): + def __init__(self, state, action, reward=0, next_state=None, game_over=False): + """ + A transition is a tuple containing the information of a single step of interaction + between the agent and the environment. The most basic version should contain the following values: + (current state, action, reward, next state, game over) + For imitation learning algorithms, if the reward, next state or game over is not known, + it is sufficient to store the current state and action taken by the expert. + + :param state: The current state. Assumed to be a dictionary where the observation + is located at state['observation'] + :param action: The current action that was taken + :param reward: The reward received from the environment + :param next_state: The next state of the environment after applying the action. + The next state should be similar to the state in its structure. + :param game_over: A boolean which should be True if the episode terminated after + the execution of the action. + """ self.state = copy.deepcopy(state) self.state['observation'] = np.array(self.state['observation'], copy=False) self.action = action self.reward = reward self.total_return = None + if not next_state: + next_state = state self.next_state = copy.deepcopy(next_state) self.next_state['observation'] = np.array(self.next_state['observation'], copy=False) self.game_over = game_over diff --git a/presets.py b/presets.py index 3290f9750..dc3b5ecc5 100644 --- a/presets.py +++ b/presets.py @@ -38,6 +38,15 @@ def json_to_preset(json_path): if run_dict['exploration_policy_type'] is not None: tuning_parameters.exploration = eval(run_dict['exploration_policy_type'])() + # human control + if run_dict['play']: + tuning_parameters.agent.type = 'HumanAgent' + tuning_parameters.env.human_control = True + tuning_parameters.num_heatup_steps = 0 + + if run_dict['level']: + tuning_parameters.env.level = run_dict['level'] + if run_dict['custom_parameter'] is not None: unstripped_key_value_pairs = [pair.split('=') for pair in run_dict['custom_parameter'].split(';')] stripped_key_value_pairs = [tuple([pair[0].strip(), ast.literal_eval(pair[1].strip())]) for pair in @@ -331,7 +340,7 @@ def __init__(self): self.agent.num_steps_between_gradient_updates = 5 self.test = True - self.test_max_step_threshold = 1000 + self.test_max_step_threshold = 2000 self.test_min_return_threshold = 150 self.test_num_workers = 8 @@ -926,7 +935,7 @@ def __init__(self): self.agent.middleware_type = MiddlewareTypes.FC self.test = True - self.test_max_step_threshold = 200 + self.test_max_step_threshold = 1000 self.test_min_return_threshold = 150 self.test_num_workers = 8 @@ -1182,3 +1191,93 @@ def __init__(self): self.agent.beta_entropy = 0.05 self.clip_gradients = 40.0 self.agent.middleware_type = MiddlewareTypes.FC + + +class Carla_A3C(Preset): + def __init__(self): + Preset.__init__(self, ActorCritic, Carla, EntropyExploration) + self.agent.embedder_complexity = EmbedderComplexity.Deep + self.agent.policy_gradient_rescaler = 'GAE' + self.learning_rate = 0.0001 + self.num_heatup_steps = 0 + # self.env.reward_scaling = 1.0e9 + self.agent.discount = 0.99 + self.agent.apply_gradients_every_x_episodes = 1 + self.agent.num_steps_between_gradient_updates = 30 + self.agent.gae_lambda = 1 + self.agent.beta_entropy = 0.01 + self.clip_gradients = 40 + self.agent.middleware_type = MiddlewareTypes.FC + + +class Carla_DDPG(Preset): + def __init__(self): + Preset.__init__(self, DDPG, Carla, OUExploration) + self.agent.embedder_complexity = EmbedderComplexity.Deep + self.learning_rate = 0.0001 + self.num_heatup_steps = 1000 + self.agent.num_consecutive_training_steps = 5 + + +class Carla_BC(Preset): + def __init__(self): + Preset.__init__(self, BC, Carla, ExplorationParameters) + self.agent.embedder_complexity = EmbedderComplexity.Deep + self.agent.load_memory_from_file_path = 'datasets/carla_town1.p' + self.learning_rate = 0.0005 + self.num_heatup_steps = 0 + self.evaluation_episodes = 5 + self.batch_size = 120 + self.evaluate_every_x_training_iterations = 5000 + + +class Doom_Basic_BC(Preset): + def __init__(self): + Preset.__init__(self, BC, Doom, ExplorationParameters) + self.env.level = 'basic' + self.agent.load_memory_from_file_path = 'datasets/doom_basic.p' + self.learning_rate = 0.0005 + self.num_heatup_steps = 0 + self.evaluation_episodes = 5 + self.batch_size = 120 + self.evaluate_every_x_training_iterations = 100 + self.num_training_iterations = 2000 + + +class Doom_Defend_BC(Preset): + def __init__(self): + Preset.__init__(self, BC, Doom, ExplorationParameters) + self.env.level = 'defend' + self.agent.load_memory_from_file_path = 'datasets/doom_defend.p' + self.learning_rate = 0.0005 + self.num_heatup_steps = 0 + self.evaluation_episodes = 5 + self.batch_size = 120 + self.evaluate_every_x_training_iterations = 100 + + +class Doom_Deathmatch_BC(Preset): + def __init__(self): + Preset.__init__(self, BC, Doom, ExplorationParameters) + self.env.level = 'deathmatch' + self.agent.load_memory_from_file_path = 'datasets/doom_deathmatch.p' + self.learning_rate = 0.0005 + self.num_heatup_steps = 0 + self.evaluation_episodes = 5 + self.batch_size = 120 + self.evaluate_every_x_training_iterations = 100 + + +class MontezumaRevenge_BC(Preset): + def __init__(self): + Preset.__init__(self, BC, Atari, ExplorationParameters) + self.env.level = 'MontezumaRevenge-v0' + self.agent.load_memory_from_file_path = 'datasets/montezuma_revenge.p' + self.learning_rate = 0.0005 + self.num_heatup_steps = 0 + self.evaluation_episodes = 5 + self.batch_size = 120 + self.evaluate_every_x_training_iterations = 100 + self.exploration.evaluation_epsilon = 0.05 + self.exploration.evaluation_policy = 'EGreedy' + self.env.frame_skip = 1 diff --git a/renderer.py b/renderer.py new file mode 100644 index 000000000..cddc81017 --- /dev/null +++ b/renderer.py @@ -0,0 +1,85 @@ +import pygame +from pygame.locals import * +import numpy as np + + +class Renderer(object): + def __init__(self): + self.size = (1, 1) + self.screen = None + self.clock = pygame.time.Clock() + self.display = pygame.display + self.fps = 30 + self.pressed_keys = [] + self.is_open = False + + def create_screen(self, width, height): + """ + Creates a pygame window + :param width: the width of the window + :param height: the height of the window + :return: None + """ + self.size = (width, height) + self.screen = self.display.set_mode(self.size, HWSURFACE | DOUBLEBUF) + self.display.set_caption("Coach") + self.is_open = True + + def normalize_image(self, image): + """ + Normalize image values to be between 0 and 255 + :param image: 2D/3D array containing an image with arbitrary values + :return: the input image with values rescaled to 0-255 + """ + image_min, image_max = image.min(), image.max() + return 255.0 * (image - image_min) / (image_max - image_min) + + def render_image(self, image): + """ + Render the given image to the pygame window + :param image: a grayscale or color image in an arbitrary size. assumes that the channels are the last axis + :return: None + """ + if self.is_open: + if len(image.shape) == 3: + if image.shape[0] == 3 or image.shape[0] == 1: + image = np.transpose(image, (1, 2, 0)) + surface = pygame.surfarray.make_surface(image.swapaxes(0, 1)) + surface = pygame.transform.scale(surface, self.size) + self.screen.blit(surface, (0, 0)) + self.display.flip() + self.clock.tick() + self.get_events() + + def get_events(self): + """ + Get all the window events in the last tick and reponse accordingly + :return: None + """ + for event in pygame.event.get(): + if event.type == pygame.KEYDOWN: + self.pressed_keys.append(event.key) + # esc pressed + if event.key == pygame.K_ESCAPE: + self.close() + elif event.type == pygame.KEYUP: + if event.key in self.pressed_keys: + self.pressed_keys.remove(event.key) + elif event.type == pygame.QUIT: + self.close() + + def get_key_names(self, key_ids): + """ + Get the key name for each key index in the list + :param key_ids: a list of key id's + :return: a list of key names corresponding to the key id's + """ + return [pygame.key.name(key_id) for key_id in key_ids] + + def close(self): + """ + Close the pygame window + :return: None + """ + self.is_open = False + pygame.quit() diff --git a/requirements_coach.txt b/requirements_coach.txt index 799820bfb..a0fd644f8 100644 --- a/requirements_coach.txt +++ b/requirements_coach.txt @@ -3,6 +3,7 @@ Pillow==4.3.0 matplotlib==2.0.2 numpy==1.13.0 pandas==0.20.2 +pygame==1.9.3 PyOpenGL==3.1.0 scipy==0.19.0 scikit-image==0.13.0 diff --git a/run_test.py b/run_test.py new file mode 100644 index 000000000..eeb51a434 --- /dev/null +++ b/run_test.py @@ -0,0 +1,164 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# -*- coding: utf-8 -*- +import presets +import numpy as np +import pandas as pd +from os import path +import os +import glob +import shutil +import sys +import time +from logger import screen +from utils import list_all_classes_in_module, threaded_cmd_line_run, killed_processes +from subprocess import Popen +import signal +import argparse + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--preset', + help="(string) Name of a preset to run (as configured in presets.py)", + default=None, + type=str) + parser.add_argument('-itf', '--ignore_tensorflow', + help="(flag) Don't test TensorFlow presets.", + action='store_true') + parser.add_argument('-in', '--ignore_neon', + help="(flag) Don't test neon presets.", + action='store_true') + + args = parser.parse_args() + if args.preset is not None: + presets_lists = [args.preset] + else: + presets_lists = list_all_classes_in_module(presets) + win_size = 10 + fail_count = 0 + test_count = 0 + read_csv_tries = 70 + + # create a clean experiment directory + test_name = '__test' + test_path = os.path.join('./experiments', test_name) + if path.exists(test_path): + shutil.rmtree(test_path) + + for idx, preset_name in enumerate(presets_lists): + preset = eval('presets.{}()'.format(preset_name)) + if preset.test: + frameworks = [] + if preset.agent.tensorflow_support and not args.ignore_tensorflow: + frameworks.append('tensorflow') + if preset.agent.neon_support and not args.ignore_neon: + frameworks.append('neon') + + for framework in frameworks: + test_count += 1 + + # run the experiment in a separate thread + screen.log_title("Running test {} - {}".format(preset_name, framework)) + cmd = 'CUDA_VISIBLE_DEVICES='' python3 coach.py -p {} -f {} -e {} -n {} -cp "seed=0" &> test_log_{}_{}.txt '\ + .format(preset_name, framework, test_name, preset.test_num_workers, preset_name, framework) + p = Popen(cmd, shell=True, executable="/bin/bash", preexec_fn=os.setsid) + + # get the csv with the results + csv_path = None + csv_paths = [] + + if preset.test_num_workers > 1: + # we have an evaluator + reward_str = 'Evaluation Reward' + filename_pattern = 'evaluator*.csv' + else: + reward_str = 'Training Reward' + filename_pattern = 'worker*.csv' + + initialization_error = False + test_passed = False + + tries_counter = 0 + while not csv_paths: + csv_paths = glob.glob(path.join(test_path, '*', filename_pattern)) + if tries_counter > read_csv_tries: + break + tries_counter += 1 + time.sleep(1) + + if csv_paths: + csv_path = csv_paths[0] + + # verify results + csv = None + time.sleep(1) + averaged_rewards = [0] + + last_num_episodes = 0 + while csv is None or csv['Episode #'].values[-1] < preset.test_max_step_threshold: + try: + csv = pd.read_csv(csv_path) + except: + # sometimes the csv is being written at the same time we are + # trying to read it. no problem -> try again + continue + + if reward_str not in csv.keys(): + continue + + rewards = csv[reward_str].values + rewards = rewards[~np.isnan(rewards)] + + if len(rewards) >= win_size: + averaged_rewards = np.convolve(rewards, np.ones(win_size) / win_size, mode='valid') + else: + time.sleep(1) + continue + + # print progress + percentage = int((100*last_num_episodes)/preset.test_max_step_threshold) + sys.stdout.write("\rReward: ({}/{})".format(round(averaged_rewards[-1], 1), preset.test_min_return_threshold)) + sys.stdout.write(' Episode: ({}/{})'.format(last_num_episodes, preset.test_max_step_threshold)) + sys.stdout.write(' {}%|{}{}| '.format(percentage, '#'*int(percentage/10), ' '*(10-int(percentage/10)))) + sys.stdout.flush() + + if csv['Episode #'].shape[0] - last_num_episodes <= 0: + continue + + last_num_episodes = csv['Episode #'].values[-1] + + # check if reward is enough + if np.any(averaged_rewards > preset.test_min_return_threshold): + test_passed = True + break + time.sleep(1) + + # kill test and print result + os.killpg(os.getpgid(p.pid), signal.SIGTERM) + if test_passed: + screen.success("Passed successfully") + else: + screen.error("Failed due to a mismatch with the golden", crash=False) + fail_count += 1 + shutil.rmtree(test_path) + + screen.separator() + if fail_count == 0: + screen.success(" Summary: " + str(test_count) + "/" + str(test_count) + " tests passed successfully") + else: + screen.error(" Summary: " + str(test_count - fail_count) + "/" + str(test_count) + " tests passed successfully") diff --git a/utils.py b/utils.py index c66072440..db9799499 100644 --- a/utils.py +++ b/utils.py @@ -20,6 +20,7 @@ import numpy as np import threading from subprocess import call, Popen +import signal killed_processes = [] @@ -54,9 +55,9 @@ def to_string(self, enum): class RunPhase(Enum): - HEATUP = 0 - TRAIN = 1 - TEST = 2 + HEATUP = "Heatup" + TRAIN = "Training" + TEST = "Testing" def list_all_classes_in_module(module): @@ -292,3 +293,59 @@ def get_open_port(): s.close() return port + +class timeout: + def __init__(self, seconds=1, error_message='Timeout'): + self.seconds = seconds + self.error_message = error_message + + def _handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self._handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + +def switch_axes_order(observation, from_type='channels_first', to_type='channels_last'): + """ + transpose an observation axes from channels_first to channels_last or vice versa + :param observation: a numpy array + :param from_type: can be 'channels_first' or 'channels_last' + :param to_type: can be 'channels_first' or 'channels_last' + :return: a new observation with the requested axes order + """ + if from_type == to_type or len(observation.shape) == 1: + return observation + assert 2 <= len(observation.shape) <= 3, 'num axes of an observation must be 2 for a vector or 3 for an image' + assert type(observation) == np.ndarray, 'observation must be a numpy array' + if len(observation.shape) == 3: + if from_type == 'channels_first' and to_type == 'channels_last': + return np.transpose(observation, (1, 2, 0)) + elif from_type == 'channels_last' and to_type == 'channels_first': + return np.transpose(observation, (2, 0, 1)) + else: + return np.transpose(observation, (1, 0)) + + +def stack_observation(curr_stack, observation, stack_size): + """ + Adds a new observation to an existing stack of observations from previous time-steps. + :param curr_stack: The current observations stack. + :param observation: The new observation + :param stack_size: The required stack size + :return: The updated observation stack + """ + + if curr_stack == []: + # starting an episode + curr_stack = np.vstack(np.expand_dims([observation] * stack_size, 0)) + curr_stack = switch_axes_order(curr_stack, from_type='channels_first', to_type='channels_last') + else: + curr_stack = np.append(curr_stack, np.expand_dims(np.squeeze(observation), axis=-1), axis=-1) + curr_stack = np.delete(curr_stack, 0, -1) + + return curr_stack