From 244958a5cd036beaa40b30cdb9340c41e2d008c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Fri, 13 Dec 2019 18:14:10 +0200 Subject: [PATCH 1/3] Flip training positions randomly when castling is not possible Castling rights are determined from the current position. It's possible for history to contain positions with flipped board and flipped castling. --- tf/chunkparser.py | 26 ++++++++++++++++++++++++++ tf/lc0_az_policy_map.py | 39 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/tf/chunkparser.py b/tf/chunkparser.py index 9f654a23..0188264c 100644 --- a/tf/chunkparser.py +++ b/tf/chunkparser.py @@ -25,6 +25,7 @@ import struct import tensorflow as tf import unittest +import lc0_az_policy_map V4_VERSION = struct.pack('i', 4) V3_VERSION = struct.pack('i', 3) @@ -41,6 +42,11 @@ def next(self): return self.items.pop() +def flip_vertex(v): + c = v % 8 + r = v // 8 + c = 7 - c + return 8 * r + c class ChunkParser: # static batch size @@ -68,6 +74,15 @@ def __init__(self, chunkdatasrc, shuffle_size=1, sample=1, buffer_size=1, batch_ long. """ + self.policy_flip_map = lc0_az_policy_map.make_map(kind='flip_permutation') + + # Build full flip tables for flipping all input planes at once. + self.full_flip_map = np.array([flip_vertex(vertex) + p*8*8 + for p in range(104) for vertex in range(8*8)], dtype=np.int32) + + r = np.array(list(range(104*8*8)), dtype=np.int32) + assert np.array_equal((r[self.full_flip_map])[self.full_flip_map], r) + # Build 2 flat float32 planes with values 0,1 self.flat_planes = [] for i in range(2): @@ -182,6 +197,12 @@ def convert_v4_to_tuple(self, content): # Unpack bit planes and cast to 32 bit float planes = np.unpackbits(np.frombuffer(planes, dtype=np.uint8)).astype(np.float32) rule50_plane = (np.zeros(8*8, dtype=np.float32) + rule50_count) / 99 + if us_ooo == us_oo == them_ooo == them_oo == 0: + if random.randrange(2) == 0: + planes = planes[self.full_flip_map] + probs = np.frombuffer(probs, dtype=np.float32) + probs = probs[self.policy_flip_map] + probs = probs.tobytes() # Concatenate all byteplanes. Make the last plane all 1's so the NN can # detect edges of the board more easily @@ -207,6 +228,11 @@ def convert_v4_to_tuple(self, content): return (planes, probs, winner, best_q) + def maybe_flip_data(self, data): + planes, probs, winner, best_q = data + + self.policy_flip_permutation + def sample_record(self, chunkdata): """ diff --git a/tf/lc0_az_policy_map.py b/tf/lc0_az_policy_map.py index 368d9100..26517cae 100755 --- a/tf/lc0_az_policy_map.py +++ b/tf/lc0_az_policy_map.py @@ -16,6 +16,12 @@ def index_to_position(x): def position_to_index(p): return col_index[p[0]], row_index[p[1]] +def flip_position_lr(p): + if p == None: + return None + c, r = position_to_index(p) + return index_to_position([7 - c, r]) + def valid_index(i): if i[0] > 7 or i[0] < 0: return False @@ -46,16 +52,23 @@ def knight_move(start, direction, steps): def make_map(kind='matrix'): # 56 planes of queen moves moves = [] + flip_moves = [] for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: for steps in range(1, 8): for r0 in rows: for c0 in columns: start = c0 + r0 end = queen_move(start, direction, steps) + flip_start = flip_position_lr(start) + flip_end = flip_position_lr(end) if end == None: moves.append('illegal') else: moves.append(start+end) + if flip_end == None: + flip_moves.append('illegal') + else: + flip_moves.append(flip_start+flip_end) # 8 planes of knight moves for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: @@ -63,10 +76,16 @@ def make_map(kind='matrix'): for c0 in columns: start = c0 + r0 end = knight_move(start, direction, 1) + flip_start = flip_position_lr(start) + flip_end = flip_position_lr(end) if end == None: moves.append('illegal') else: moves.append(start+end) + if flip_end == None: + flip_moves.append('illegal') + else: + flip_moves.append(flip_start+flip_end) # 9 promotions for direction in ['NW', 'N', 'NE']: @@ -76,20 +95,30 @@ def make_map(kind='matrix'): # Promotion only in the second last rank if r0 != '7': moves.append('illegal') + flip_moves.append('illegal') continue start = c0 + r0 end = queen_move(start, direction, 1) + flip_start = flip_position_lr(start) + flip_end = flip_position_lr(end) if end == None: moves.append('illegal') else: moves.append(start+end+promotion) + if flip_end == None: + flip_moves.append('illegal') + else: + flip_moves.append(flip_start+flip_end+promotion) for m in policy_index: if m not in moves: raise ValueError('Missing move: {}'.format(m)) + if m not in flip_moves: + raise ValueError('Missing move: {}'.format(m)) az_to_lc0 = np.zeros((80*8*8, len(policy_index)), dtype=np.float32) indices = [] + flip_permutation = np.zeros(1858, dtype=np.int32) legal_moves = 0 for e, m in enumerate(moves): if m == 'illegal': @@ -100,18 +129,22 @@ def make_map(kind='matrix'): if m not in policy_index: raise ValueError('Missing move: {}'.format(m)) i = policy_index.index(m) + flip_i = policy_index.index(flip_moves[e]) + flip_permutation[i] = flip_i indices.append(i) az_to_lc0[e][i] = 1 + # Verify that applying flip permutation twice gives back the original policy. + assert np.array_equal(flip_permutation[flip_permutation], list(range(1858))) + assert legal_moves == len(policy_index) assert np.sum(az_to_lc0) == legal_moves - for e in range(80*8*8): - for i in range(len(policy_index)): - pass if kind == 'matrix': return az_to_lc0 elif kind == 'index': return indices + elif kind == 'flip_permutation': + return flip_permutation if __name__ == "__main__": # Generate policy map include file for lc0 From 21592af21a95fb70f26a20e88385ed1e6e0f0193 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Sat, 14 Dec 2019 08:48:30 +0200 Subject: [PATCH 2/3] Merge branch 'master' into flip --- tf/chunkparser.py | 24 +++++++++++------------- tf/train.py | 7 +++++-- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tf/chunkparser.py b/tf/chunkparser.py index 0188264c..ba41224c 100644 --- a/tf/chunkparser.py +++ b/tf/chunkparser.py @@ -51,7 +51,8 @@ def flip_vertex(v): class ChunkParser: # static batch size BATCH_SIZE = 8 - def __init__(self, chunkdatasrc, shuffle_size=1, sample=1, buffer_size=1, batch_size=256, workers=None): + def __init__(self, chunkdatasrc, shuffle_size=1, sample=1, buffer_size=1, + batch_size=256, workers=None, flip=False): """ Read data and yield batches of raw tensors. @@ -74,14 +75,16 @@ def __init__(self, chunkdatasrc, shuffle_size=1, sample=1, buffer_size=1, batch_ long. """ - self.policy_flip_map = lc0_az_policy_map.make_map(kind='flip_permutation') + self.flip = flip + if self.flip: + self.policy_flip_map = lc0_az_policy_map.make_map(kind='flip_permutation') - # Build full flip tables for flipping all input planes at once. - self.full_flip_map = np.array([flip_vertex(vertex) + p*8*8 - for p in range(104) for vertex in range(8*8)], dtype=np.int32) + # Build full flip tables for flipping all input planes at once. + self.full_flip_map = np.array([flip_vertex(vertex) + p*8*8 + for p in range(104) for vertex in range(8*8)], dtype=np.int32) - r = np.array(list(range(104*8*8)), dtype=np.int32) - assert np.array_equal((r[self.full_flip_map])[self.full_flip_map], r) + r = np.array(list(range(104*8*8)), dtype=np.int32) + assert np.array_equal((r[self.full_flip_map])[self.full_flip_map], r) # Build 2 flat float32 planes with values 0,1 self.flat_planes = [] @@ -197,7 +200,7 @@ def convert_v4_to_tuple(self, content): # Unpack bit planes and cast to 32 bit float planes = np.unpackbits(np.frombuffer(planes, dtype=np.uint8)).astype(np.float32) rule50_plane = (np.zeros(8*8, dtype=np.float32) + rule50_count) / 99 - if us_ooo == us_oo == them_ooo == them_oo == 0: + if self.flip and (us_ooo == us_oo == them_ooo == them_oo == 0): if random.randrange(2) == 0: planes = planes[self.full_flip_map] probs = np.frombuffer(probs, dtype=np.float32) @@ -228,11 +231,6 @@ def convert_v4_to_tuple(self, content): return (planes, probs, winner, best_q) - def maybe_flip_data(self, data): - planes, probs, winner, best_q = data - - self.policy_flip_permutation - def sample_record(self, chunkdata): """ diff --git a/tf/train.py b/tf/train.py index d452f03c..f344e6e3 100755 --- a/tf/train.py +++ b/tf/train.py @@ -129,6 +129,7 @@ def main(cmd): allow_less = cfg['dataset'].get('allow_less_chunks', False) train_ratio = cfg['dataset']['train_ratio'] experimental_parser = cfg['dataset'].get('experimental_v4_only_dataset', False) + flip = cfg['dataset']['flip_augmentation'] num_train = int(num_chunks*train_ratio) num_test = num_chunks - num_train if 'input_test' in cfg['dataset']: @@ -163,7 +164,8 @@ def main(cmd): .batch(split_batch_size).map(extract_inputs_outputs).prefetch(4) else: train_parser = ChunkParser(FileDataSrc(train_chunks), - shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE) + shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE, + flip=flip) train_dataset = tf.data.Dataset.from_generator( train_parser.parse, output_types=(tf.string, tf.string, tf.string, tf.string)) train_dataset = train_dataset.map(ChunkParser.parse_function) @@ -177,7 +179,8 @@ def main(cmd): .batch(split_batch_size).map(extract_inputs_outputs).prefetch(4) else: test_parser = ChunkParser(FileDataSrc(test_chunks), - shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE) + shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE, + flip=False) test_dataset = tf.data.Dataset.from_generator( test_parser.parse, output_types=(tf.string, tf.string, tf.string, tf.string)) test_dataset = test_dataset.map(ChunkParser.parse_function) From 8e86545d8e8dd38fc889a263011507c5f2a42d8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Mon, 23 Dec 2019 06:40:58 +0200 Subject: [PATCH 3/3] Consider less positions for flipping Don't flip positions with too many pieces or positions where kings are in default or castling places as those are too easy to predict to be flipped. --- tf/chunkparser.py | 45 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/tf/chunkparser.py b/tf/chunkparser.py index ba41224c..a3eccec3 100644 --- a/tf/chunkparser.py +++ b/tf/chunkparser.py @@ -200,8 +200,34 @@ def convert_v4_to_tuple(self, content): # Unpack bit planes and cast to 32 bit float planes = np.unpackbits(np.frombuffer(planes, dtype=np.uint8)).astype(np.float32) rule50_plane = (np.zeros(8*8, dtype=np.float32) + rule50_count) / 99 - if self.flip and (us_ooo == us_oo == them_ooo == them_oo == 0): - if random.randrange(2) == 0: + + if self.flip: + can_flip = True + if not (us_ooo == us_oo == them_ooo == them_oo == 0): + can_flip = False + if can_flip and random.randrange(2) == 0: + can_flip = False + if can_flip: + # Count from the last position in the history. + pieces = np.sum(planes[13*7*64:-64]) + if pieces > 16: + can_flip = False + if can_flip: + # King plane of last position in the history. + our_king = planes[6144:6208] + their_king = planes[6528:6592] + our_king_pos = np.where(our_king == 1)[0] + their_king_pos = np.where(their_king == 1)[0] + if len(our_king_pos) != 1: + return False, None + if len(their_king_pos) != 1: + return False, None + # Default and castling positions. + if our_king_pos[0] in (2, 4, 6): + can_flip = False + if their_king_pos[0] in (60, 62, 64): + can_flip = False + if can_flip: planes = planes[self.full_flip_map] probs = np.frombuffer(probs, dtype=np.float32) probs = probs[self.policy_flip_map] @@ -219,17 +245,20 @@ def convert_v4_to_tuple(self, content): self.flat_planes[move_count].tobytes() + \ self.flat_planes[1].tobytes() - assert len(planes) == ((8*13*1 + 8*1*1) * 8 * 8 * 4) + if not (len(planes) == ((8*13*1 + 8*1*1) * 8 * 8 * 4)): + return False, None winner = float(winner) - assert winner == 1.0 or winner == -1.0 or winner == 0.0 + if not (winner == 1.0 or winner == -1.0 or winner == 0.0): + return False, None winner = struct.pack('fff', winner == 1.0, winner == 0.0, winner == -1.0) best_q_w = 0.5 * (1.0 - best_d + best_q) best_q_l = 0.5 * (1.0 - best_d - best_q) - assert -1.0 <= best_q <= 1.0 and 0.0 <= best_d <= 1.0 + if not (-1.0 <= best_q <= 1.0 and 0.0 <= best_d <= 1.0): + return False, None best_q = struct.pack('fff', best_q_w, best_d, best_q_l) - return (planes, probs, winner, best_q) + return True, (planes, probs, winner, best_q) def sample_record(self, chunkdata): @@ -306,7 +335,9 @@ def tuple_gen(self, gen): applying a random symmetry on the way. """ for r in gen: - yield self.convert_v4_to_tuple(r) + success, data = self.convert_v4_to_tuple(r) + if success: + yield data def batch_gen(self, gen):