diff --git a/Arena.py b/Arena.py index dc11bd68b..4f8949736 100644 --- a/Arena.py +++ b/Arena.py @@ -90,9 +90,9 @@ def playGames(self, num, verbose=False): for _ in tqdm(range(num), desc="Arena.playGames (2)"): gameResult = self.playGame(verbose=verbose) - if gameResult == -1: + if gameResult == 1: oneWon += 1 - elif gameResult == 1: + elif gameResult == -1: twoWon += 1 else: draws += 1 diff --git a/Coach.py b/Coach.py index 9d228a07b..15b649e29 100644 --- a/Coach.py +++ b/Coach.py @@ -58,7 +58,7 @@ def executeEpisode(self): pi = self.mcts.getActionProb(canonicalBoard, temp=temp) sym = self.game.getSymmetries(canonicalBoard, pi) for b, p in sym: - trainExamples.append([b, self.curPlayer, p, None]) + trainExamples.append([self.game.toArray(b), self.curPlayer, p, None]) action = np.random.choice(len(pi), p=pi) board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action) @@ -88,7 +88,7 @@ def learn(self): self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree iterationTrainExamples += self.executeEpisode() - # save the iteration examples to the history + # save the iteration examples to the history self.trainExamplesHistory.append(iterationTrainExamples) if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory: @@ -96,7 +96,7 @@ def learn(self): f"Removing the oldest entry in trainExamples. len(trainExamplesHistory) = {len(self.trainExamplesHistory)}") self.trainExamplesHistory.pop(0) # backup history to a file - # NB! the examples were collected using the model from the previous iteration, so (i-1) + # NB! the examples were collected using the model from the previous iteration, so (i-1) self.saveTrainExamples(i - 1) # shuffle examples before training diff --git a/Game.py b/Game.py index 6647c3b1a..2513c1e0c 100644 --- a/Game.py +++ b/Game.py @@ -68,7 +68,7 @@ def getGameEnded(self, board, player): Returns: r: 0 if game has not ended. 1 if player won, -1 if player lost, small non-zero value for draw. - + """ pass @@ -111,3 +111,10 @@ def stringRepresentation(self, board): Required by MCTS for hashing. """ pass + + def toArray(self, board): + """ + Returns: + a board representation suitable as the input to your neural network + """ + return board diff --git a/MCTS.py b/MCTS.py index b4b0013ad..4c0232a16 100644 --- a/MCTS.py +++ b/MCTS.py @@ -82,7 +82,7 @@ def search(self, canonicalBoard): if s not in self.Ps: # leaf node - self.Ps[s], v = self.nnet.predict(canonicalBoard) + self.Ps[s], v = self.nnet.predict(self.game.toArray(canonicalBoard)) valids = self.game.getValidMoves(canonicalBoard, 1) self.Ps[s] = self.Ps[s] * valids # masking invalid moves sum_Ps_s = np.sum(self.Ps[s]) @@ -92,7 +92,7 @@ def search(self, canonicalBoard): # if all valid moves were masked make all valid moves equally probable # NB! All valid moves may be masked if either your NNet architecture is insufficient or you've get overfitting or something else. - # If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process. + # If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process. log.error("All valid moves were masked, doing a workaround.") self.Ps[s] = self.Ps[s] + valids self.Ps[s] /= np.sum(self.Ps[s]) diff --git a/_chess/ChessGame.py b/_chess/ChessGame.py new file mode 100644 index 000000000..12f701654 --- /dev/null +++ b/_chess/ChessGame.py @@ -0,0 +1,112 @@ +from __future__ import print_function +import sys +sys.path.append('..') +from Game import Game + +import numpy as np +import chess + +def to_np(board): + a = [0]*(8*8*6) + for sq,pc in board.piece_map().items(): + a[sq*6+pc.piece_type-1] = 1 if pc.color else -1 + return np.array(a) + +def from_move(move): + return move.from_square*64+move.to_square + +def to_move(action): + to_sq = action % 64 + from_sq = int(action / 64) + return chess.Move(from_sq, to_sq) + +def who(turn): + return 1 if turn else -1 + +def mirror_move(move): + return chess.Move(chess.square_mirror(move.from_square), chess.square_mirror(move.to_square)) + +CHECKMATE =1 +STALEMATE= 2 +INSUFFICIENT_MATERIAL= 3 +SEVENTYFIVE_MOVES= 4 +FIVEFOLD_REPETITION= 5 +FIFTY_MOVES= 6 +THREEFOLD_REPETITION= 7 + +class ChessGame(Game): + + def __init__(self, n=8): + pass + + def getInitBoard(self): + # return initial board (numpy board) + return chess.Board() + + def getBoardSize(self): + # (a,b) tuple + # 6 piece type + return (8, 8, 6) + + def toArray(self, board): + return to_np(board) + + def getActionSize(self): + # return number of actions + return 64*64 + # return self.n*self.n*16+1 + + def getNextState(self, board, player, action): + # if player takes action on board, return next (board,player) + # action must be a valid move + assert(who(board.turn) == player) + move = to_move(action) + if not board.turn: + # assume the move comes from the canonical board... + move = mirror_move(move) + if move not in board.legal_moves: + # could be a pawn promotion, which has an extra letter in UCI format + move = chess.Move.from_uci(move.uci()+'q') # assume promotion to queen + if move not in board.legal_moves: + assert False, "%s not in %s" % (str(move), str(list(board.legal_moves))) + board = board.copy() + board.push(move) + return (board, who(board.turn)) + + def getValidMoves(self, board, player): + # return a fixed size binary vector + assert(who(board.turn) == player) + acts = [0]*self.getActionSize() + for move in board.legal_moves: + acts[from_move(move)] = 1 + return np.array(acts) + + def getGameEnded(self, board, player): + # return 0 if not ended, 1 if player 1 won, -1 if player 1 lost + outcome = board.outcome() + if outcome is not None: + if outcome.winner is None: + # draw return very little value + return 1e-4 + else: + return who(outcome.winner) + return 0 + + def getCanonicalForm(self, board, player): + # return state if player==1, else return -state if player==-1 + assert(who(board.turn) == player) + if board.turn: + return board + else: + return board.mirror() + + def getSymmetries(self, board, pi): + # mirror, rotational + return [(board,pi)] + + def stringRepresentation(self, board): + return board.fen() + + @staticmethod + def display(board): + print(board) diff --git a/_chess/ChessPlayers.py b/_chess/ChessPlayers.py new file mode 100644 index 000000000..56bbfbbfb --- /dev/null +++ b/_chess/ChessPlayers.py @@ -0,0 +1,57 @@ +import chess +import random +import numpy as np +from _chess.ChessGame import who, from_move, mirror_move +from stockfish import Stockfish + +class RandomPlayer(): + def __init__(self, game): + self.game = game + + def play(self, board): + valids = self.game.getValidMoves(board, who(board.turn)) + moves = np.argwhere(valids==1) + return random.choice(moves)[0] + +def move_from_uci(board, uci): + try: + move = chess.Move.from_uci(uci) + except ValueError: + print('expected an UCI move') + return None + if move not in board.legal_moves: + print('expected a valid move') + return None + return move + +class HumanChessPlayer(): + def __init__(self, game): + pass + + def play(self, board): + mboard = board + if board.turn: + mboard = board.mirror() + print('Valid Moves', end=':') + for move in mboard.legal_moves: + print(move.uci(), end=',') + print() + human_move = input() + move = move_from_uci(mboard, human_move.strip()) + if move is None: + print('try again, e.g., %s' % random.choice(list(mboard.legal_moves)).uci()) + return self.play(board) + if board.turn: + move = mirror_move(move) + return from_move(move) + +class StockFishPlayer(): + def __init__(self, game, elo=1000): + self.stockfish = Stockfish(parameters={"Threads": 2, "Minimum Thinking Time": 30}) + self.stockfish.set_elo_rating(elo) + + def play(self, board): + self.stockfish.set_fen_position(board.fen()) + uci_move = self.stockfish.get_best_move() + move = move_from_uci(board, uci_move.strip()) + return from_move(move) diff --git a/_chess/__init__.py b/_chess/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/_chess/pytorch/ChessNNet.py b/_chess/pytorch/ChessNNet.py new file mode 100644 index 000000000..d4d4027a9 --- /dev/null +++ b/_chess/pytorch/ChessNNet.py @@ -0,0 +1,55 @@ +import sys +sys.path.append('..') +from utils import * + +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.autograd import Variable + +class ChessNNet(nn.Module): + def __init__(self, game, args): + # game params + self.board_x, self.board_y, self.board_z = game.getBoardSize() + self.action_size = game.getActionSize() + self.args = args + + super(ChessNNet, self).__init__() + self.conv1 = nn.Conv3d(1, args.num_channels, 3, stride=1, padding=1) + self.conv2 = nn.Conv3d(args.num_channels, args.num_channels, 3, stride=1, padding=1) + self.conv3 = nn.Conv3d(args.num_channels, args.num_channels, 3, stride=1) + self.conv4 = nn.Conv3d(args.num_channels, args.num_channels, 3, stride=1) + + self.bn1 = nn.BatchNorm3d(args.num_channels) + self.bn2 = nn.BatchNorm3d(args.num_channels) + self.bn3 = nn.BatchNorm3d(args.num_channels) + self.bn4 = nn.BatchNorm3d(args.num_channels) + + self.fc1 = nn.Linear(args.num_channels*(self.board_x-4)*(self.board_y-4)*(self.board_z-4), 1024) + self.fc_bn1 = nn.BatchNorm1d(1024) + + self.fc2 = nn.Linear(1024, 512) + self.fc_bn2 = nn.BatchNorm1d(512) + + self.fc3 = nn.Linear(512, self.action_size) + + self.fc4 = nn.Linear(512, 1) + + def forward(self, s): + s = s.view(-1, 1, self.board_x, self.board_y, self.board_z) + s = F.relu(self.bn1(self.conv1(s))) + s = F.relu(self.bn2(self.conv2(s))) + s = F.relu(self.bn3(self.conv3(s))) + s = F.relu(self.bn4(self.conv4(s))) + s = s.view(-1, self.args.num_channels*(self.board_x-4)*(self.board_y-4)*(self.board_z-4)) + + s = F.dropout(F.relu(self.fc_bn1(self.fc1(s))), p=self.args.dropout, training=self.training) + s = F.dropout(F.relu(self.fc_bn2(self.fc2(s))), p=self.args.dropout, training=self.training) + + pi = self.fc3(s) + v = self.fc4(s) + + return F.log_softmax(pi, dim=1), torch.tanh(v) diff --git a/_chess/pytorch/NNet.py b/_chess/pytorch/NNet.py new file mode 100644 index 000000000..df45bf914 --- /dev/null +++ b/_chess/pytorch/NNet.py @@ -0,0 +1,150 @@ +import argparse +import os +import shutil +import time +import random +import numpy as np +import math +import sys +sys.path.append('../../') +from utils import * +from pytorch_classification.utils import Bar, AverageMeter +from NeuralNet import NeuralNet + +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms + +from .ChessNNet import ChessNNet as onnet + +args = dotdict({ + 'lr': 0.002, + 'dropout': 0.3, + 'epochs': 10, + 'batch_size': 64, + 'cuda': torch.cuda.is_available(), + 'num_channels': 128, +}) + +class NNetWrapper(NeuralNet): + def __init__(self, game): + self.nnet = onnet(game, args) + self.board_x, self.board_y, self.board_z = game.getBoardSize() + self.action_size = game.getActionSize() + + if args.cuda: + self.nnet.cuda() + + def train(self, examples): + """ + examples: list of examples, each example is of form (board, pi, v) + """ + optimizer = optim.Adam(self.nnet.parameters()) + + for epoch in range(args.epochs): + print('EPOCH ::: ' + str(epoch+1)) + self.nnet.train() + data_time = AverageMeter() + batch_time = AverageMeter() + pi_losses = AverageMeter() + v_losses = AverageMeter() + end = time.time() + + bar = Bar('Training Net', max=int(len(examples)/args.batch_size)) + batch_idx = 0 + + while batch_idx < int(len(examples)/args.batch_size): + sample_ids = np.random.randint(len(examples), size=args.batch_size) + boards, pis, vs = list(zip(*[examples[i] for i in sample_ids])) + boards = torch.FloatTensor(np.array(boards).astype(np.float64)) + target_pis = torch.FloatTensor(np.array(pis)) + target_vs = torch.FloatTensor(np.array(vs).astype(np.float64)) + + # predict + if args.cuda: + boards, target_pis, target_vs = boards.contiguous().cuda(), target_pis.contiguous().cuda(), target_vs.contiguous().cuda() + + # measure data loading time + data_time.update(time.time() - end) + + # compute output + out_pi, out_v = self.nnet(boards) + l_pi = self.loss_pi(target_pis, out_pi) + l_v = self.loss_v(target_vs, out_v) + total_loss = l_pi + l_v + + # record loss + pi_losses.update(l_pi.item(), boards.size(0)) + v_losses.update(l_v.item(), boards.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + batch_idx += 1 + + # plot progress + bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_pi: {lpi:.4f} | Loss_v: {lv:.3f}'.format( + batch=batch_idx, + size=int(len(examples)/args.batch_size), + data=data_time.avg, + bt=batch_time.avg, + total=bar.elapsed_td, + eta=bar.eta_td, + lpi=pi_losses.avg, + lv=v_losses.avg, + ) + bar.next() + bar.finish() + + + def predict(self, board): + """ + board: Chess.Board + """ + # timing + start = time.time() + + # preparing input + board = torch.FloatTensor(board.astype(np.float64)) + if args.cuda: board = board.contiguous().cuda() + board = board.view(1, self.board_x, self.board_y, self.board_z) + self.nnet.eval() + with torch.no_grad(): + pi, v = self.nnet(board) + + #print('PREDICTION TIME TAKEN : {0:03f}'.format(time.time()-start)) + return torch.exp(pi).data.cpu().numpy()[0], v.data.cpu().numpy()[0] + + def loss_pi(self, targets, outputs): + return -torch.sum(targets*outputs)/targets.size()[0] + + def loss_v(self, targets, outputs): + return torch.sum((targets-outputs.view(-1))**2)/targets.size()[0] + + def save_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'): + filepath = os.path.join(folder, filename) + if not os.path.exists(folder): + print("Checkpoint Directory does not exist! Making directory {}".format(folder)) + os.mkdir(folder) + else: + print("Checkpoint Directory exists! ") + torch.save({ + 'state_dict' : self.nnet.state_dict(), + }, filepath) + + def load_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'): + # https://github.com/pytorch/examples/blob/master/imagenet/main.py#L98 + filepath = os.path.join(folder, filename) + if not os.path.exists(filepath): + raise("No model in path {}".format(filepath)) + map_location = None if args.cuda else 'cpu' + checkpoint = torch.load(filepath, map_location=map_location) + self.nnet.load_state_dict(checkpoint['state_dict']) diff --git a/main.py b/main.py index 5687c2a38..a7b39ca4a 100644 --- a/main.py +++ b/main.py @@ -3,8 +3,10 @@ import coloredlogs from Coach import Coach -from othello.OthelloGame import OthelloGame as Game -from othello.pytorch.NNet import NNetWrapper as nn +# from othello.OthelloGame import OthelloGame as Game +# from othello.pytorch.NNet import NNetWrapper as nn +from _chess.ChessGame import ChessGame as Game +from _chess.pytorch.NNet import NNetWrapper as nn from utils import * log = logging.getLogger(__name__) diff --git a/test_all_games.py b/test_all_games.py index 9d03dc43e..6f2b96c70 100644 --- a/test_all_games.py +++ b/test_all_games.py @@ -12,6 +12,7 @@ - TicTacToe [Yes] - Connect4 [Yes] - Gobang [Yes] [Yes] + - Chess [Yes] """ @@ -43,6 +44,10 @@ from gobang.keras.NNet import NNetWrapper as GobangKerasNNet from gobang.tensorflow.NNet import NNetWrapper as GobangTensorflowNNet +from _chess.ChessGame import ChessGame +from _chess.pytorch.NNet import NNetWrapper as ChessPytorchNNet + + import numpy as np from utils import * @@ -80,6 +85,8 @@ def test_gobang_keras(self): def test_gobang_tensorflow(self): self.execute_game_test(GobangGame(), GobangTensorflowNNet) + def test_chess_pytorch(self): + self.execute_game_test(ChessGame(), ChessPytorchNNet) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()