diff --git a/ga3c/ProcessAgent.py b/ga3c/ProcessAgent.py index 55fe8fa..1b2412d 100644 --- a/ga3c/ProcessAgent.py +++ b/ga3c/ProcessAgent.py @@ -72,7 +72,11 @@ def predict(self, state): # put the state in the prediction q self.prediction_q.put((self.id, state)) # wait for the prediction to come back - p, v = self.wait_q.get() + try: + p, v = self.wait_q.get(10) + except: + return None, None + return p, v def select_action(self, prediction): @@ -90,13 +94,21 @@ def run_episode(self): time_count = 0 reward_sum = 0.0 - while not done: + while not done and self.exit_flag.value == 0: + # very first few frames if self.env.current_state is None: self.env.step(0) # 0 == NOOP continue prediction, value = self.predict(self.env.current_state) + if prediction is None and value is None: + if self.exit_flag.value !=0: + break + else: + print("Warning: couldn't get prediction. Giving up.") + continue + action = self.select_action(prediction) reward, done = self.env.step(action) reward_sum += reward diff --git a/ga3c/ProcessStats.py b/ga3c/ProcessStats.py index 937e0fb..ccb7f29 100644 --- a/ga3c/ProcessStats.py +++ b/ga3c/ProcessStats.py @@ -50,6 +50,7 @@ def __init__(self): self.predictor_count = Value('i', 0) self.agent_count = Value('i', 0) self.total_frame_count = 0 + self.exit_flag = Value('i', 0) def FPS(self): # average FPS from the beginning of the training (not current FPS) @@ -67,8 +68,12 @@ def run(self): self.start_time = time.time() first_time = datetime.now() - while True: - episode_time, reward, length = self.episode_log_q.get() + while self.exit_flag.value == 0: + try: + episode_time, reward, length = self.episode_log_q.get(timeout=0.1) + except: + continue + results_logger.write('%s, %d, %d\n' % (episode_time.strftime("%Y-%m-%d %H:%M:%S"), reward, length)) results_logger.flush() diff --git a/ga3c/Server.py b/ga3c/Server.py index 28d8a46..5036a5a 100644 --- a/ga3c/Server.py +++ b/ga3c/Server.py @@ -63,27 +63,36 @@ def add_agent(self): self.agents[-1].start() def remove_agent(self): - self.agents[-1].exit_flag.value = True - self.agents[-1].join() - self.agents.pop() + + for p in self.agents: + p.exit_flag.value = True + for p in self.agents: + p.join() + self.agents.pop() def add_predictor(self): self.predictors.append(ThreadPredictor(self, len(self.predictors))) self.predictors[-1].start() def remove_predictor(self): - self.predictors[-1].exit_flag = True - self.predictors[-1].join() - self.predictors.pop() + + for p in self.predictors: + p.exit_flag = True + for p in self.predictors: + p.join() + self.predictors.pop() def add_trainer(self): self.trainers.append(ThreadTrainer(self, len(self.trainers))) self.trainers[-1].start() def remove_trainer(self): - self.trainers[-1].exit_flag = True - self.trainers[-1].join() - self.trainers.pop() + + for p in self.trainers: + p.exit_flag = True + for p in self.trainers: + p.join() + self.trainers.pop() def train_model(self, x_, r_, a_, trainer_id): self.model.train(x_, r_, a_, trainer_id) @@ -122,11 +131,20 @@ def main(self): self.stats.should_save_model.value = 0 time.sleep(0.01) - + + print('Finished. Exiting subprocesses ...') + join_start=time.time() self.dynamic_adjustment.exit_flag = True + self.dynamic_adjustment.join() while self.agents: self.remove_agent() while self.predictors: self.remove_predictor() while self.trainers: self.remove_trainer() + self.stats.exit_flag.value = True + self.stats.join() + print('Exit. Joining takes %.2f s' % (time.time()-join_start)) + + + diff --git a/ga3c/ThreadPredictor.py b/ga3c/ThreadPredictor.py index 38c9ed1..0b4e1e2 100644 --- a/ga3c/ThreadPredictor.py +++ b/ga3c/ThreadPredictor.py @@ -47,13 +47,21 @@ def run(self): dtype=np.float32) while not self.exit_flag: - ids[0], states[0] = self.server.prediction_q.get() + try: + ids[0], states[0] = self.server.prediction_q.get(timeout=0.1) + except: + continue size = 1 while size < Config.PREDICTION_BATCH_SIZE and not self.server.prediction_q.empty(): - ids[size], states[size] = self.server.prediction_q.get() - size += 1 - + try: + ids[size], states[size] = self.server.prediction_q.get(timeout=0.1) + size += 1 + except: + if self.exit_flag: break + + if self.exit_flag: break + batch = states[:size] p, v = self.server.model.predict_p_and_v(batch) diff --git a/ga3c/ThreadTrainer.py b/ga3c/ThreadTrainer.py index 4e364ad..fa2182b 100644 --- a/ga3c/ThreadTrainer.py +++ b/ga3c/ThreadTrainer.py @@ -41,9 +41,14 @@ def __init__(self, server, id): def run(self): while not self.exit_flag: + batch_size = 0 while batch_size <= Config.TRAINING_MIN_BATCH_SIZE: - x_, r_, a_ = self.server.training_q.get() + try: + x_, r_, a_ = self.server.training_q.get(timeout=0.1) + except: + if self.exit_flag: break + continue if batch_size == 0: x__ = x_; r__ = r_; a__ = a_ else: