From 24aa2f2065392187c26e2612f3867f5b4ae2b1a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcos=20Galv=C3=A3o?= Date: Wed, 25 Nov 2020 19:50:17 -0300 Subject: [PATCH] fix in saved training examples --- training.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/training.py b/training.py index dc6d561..4912ba1 100644 --- a/training.py +++ b/training.py @@ -118,23 +118,29 @@ def execute_episode(board_size, neural_network, degree_exploration, num_simulati policy = mcts.get_policy_action_probabilities(state, policy_temperature) - if board_view_type == BoardView.ONE_CHANNEL: - example = game.board(BoardView.ONE_CHANNEL), policy, game.current_player - else: - example = state, policy, game.current_player - examples.append(example) - #e-greedy coin = random.random() if coin <= e_greedy: action = np.argwhere(policy == policy.max())[0] else: action = mcts.get_state_actions(state)[np.random.choice(len(mcts.get_state_actions(state)))] + + + action_choosed = np.zeros((board_size, board_size)) + action_choosed[action[0]][action[1]] = 1 + + #save examples + if board_view_type == BoardView.ONE_CHANNEL: + example = game.board(BoardView.ONE_CHANNEL), action_choosed, game.current_player + else: + example = state, action_choosed, game.current_player + examples.append(example) game.play(*action) logging.info(game.board(BoardView.ONE_CHANNEL)) winner, winner_points = game.get_winning_player() + logging.info(f'The Winner obtained: {winner_points} points.') return [(state, policy, 1 if winner == player else -1) for state, policy, player in examples] @@ -197,8 +203,8 @@ def training(board_size, num_iterations, num_episodes, num_simulations, degree_e total_episodes_done = 0 historic = [] + training_examples = [] for i in range(1, num_iterations + 1): - training_examples = [] old_neural_network = neural_network.copy() logging.info(f'Iteration {i}/{num_iterations}: Starting iteration') @@ -233,6 +239,7 @@ def training(board_size, num_iterations, num_episodes, num_simulations, degree_e logging.info(f'Iteration {i}/{num_iterations}: Training model with episodes examples') + random.shuffle(training_examples) history = neural_network.train(training_examples, verbose=training_verbose) @@ -276,6 +283,8 @@ def training(board_size, num_iterations, num_episodes, num_simulations, degree_e else: neural_network = old_neural_network + # gambiarra + neural_network.save_checkpoint(checkpoint_filepath) if (i % random_agent_interval) == 0: color = [OthelloPlayer.BLACK, OthelloPlayer.WHITE] @@ -309,6 +318,8 @@ def training(board_size, num_iterations, num_episodes, num_simulations, degree_e with open(f'historic-last-training-session-{board_size}.txt', 'w') as output: output.write(str(historic)) + with open(f'examples-{board_size}.txt', 'w') as output: + output.write(str(training_examples))