diff --git a/tf/chunkparser.py b/tf/chunkparser.py index 9f654a23..a3eccec3 100644 --- a/tf/chunkparser.py +++ b/tf/chunkparser.py @@ -25,6 +25,7 @@ import struct import tensorflow as tf import unittest +import lc0_az_policy_map V4_VERSION = struct.pack('i', 4) V3_VERSION = struct.pack('i', 3) @@ -41,11 +42,17 @@ def next(self): return self.items.pop() +def flip_vertex(v): + c = v % 8 + r = v // 8 + c = 7 - c + return 8 * r + c class ChunkParser: # static batch size BATCH_SIZE = 8 - def __init__(self, chunkdatasrc, shuffle_size=1, sample=1, buffer_size=1, batch_size=256, workers=None): + def __init__(self, chunkdatasrc, shuffle_size=1, sample=1, buffer_size=1, + batch_size=256, workers=None, flip=False): """ Read data and yield batches of raw tensors. @@ -68,6 +75,17 @@ def __init__(self, chunkdatasrc, shuffle_size=1, sample=1, buffer_size=1, batch_ long. """ + self.flip = flip + if self.flip: + self.policy_flip_map = lc0_az_policy_map.make_map(kind='flip_permutation') + + # Build full flip tables for flipping all input planes at once. + self.full_flip_map = np.array([flip_vertex(vertex) + p*8*8 + for p in range(104) for vertex in range(8*8)], dtype=np.int32) + + r = np.array(list(range(104*8*8)), dtype=np.int32) + assert np.array_equal((r[self.full_flip_map])[self.full_flip_map], r) + # Build 2 flat float32 planes with values 0,1 self.flat_planes = [] for i in range(2): @@ -183,6 +201,38 @@ def convert_v4_to_tuple(self, content): planes = np.unpackbits(np.frombuffer(planes, dtype=np.uint8)).astype(np.float32) rule50_plane = (np.zeros(8*8, dtype=np.float32) + rule50_count) / 99 + if self.flip: + can_flip = True + if not (us_ooo == us_oo == them_ooo == them_oo == 0): + can_flip = False + if can_flip and random.randrange(2) == 0: + can_flip = False + if can_flip: + # Count from the last position in the history. + pieces = np.sum(planes[13*7*64:-64]) + if pieces > 16: + can_flip = False + if can_flip: + # King plane of last position in the history. + our_king = planes[6144:6208] + their_king = planes[6528:6592] + our_king_pos = np.where(our_king == 1)[0] + their_king_pos = np.where(their_king == 1)[0] + if len(our_king_pos) != 1: + return False, None + if len(their_king_pos) != 1: + return False, None + # Default and castling positions. + if our_king_pos[0] in (2, 4, 6): + can_flip = False + if their_king_pos[0] in (60, 62, 64): + can_flip = False + if can_flip: + planes = planes[self.full_flip_map] + probs = np.frombuffer(probs, dtype=np.float32) + probs = probs[self.policy_flip_map] + probs = probs.tobytes() + # Concatenate all byteplanes. Make the last plane all 1's so the NN can # detect edges of the board more easily planes = planes.tobytes() + \ @@ -195,17 +245,20 @@ def convert_v4_to_tuple(self, content): self.flat_planes[move_count].tobytes() + \ self.flat_planes[1].tobytes() - assert len(planes) == ((8*13*1 + 8*1*1) * 8 * 8 * 4) + if not (len(planes) == ((8*13*1 + 8*1*1) * 8 * 8 * 4)): + return False, None winner = float(winner) - assert winner == 1.0 or winner == -1.0 or winner == 0.0 + if not (winner == 1.0 or winner == -1.0 or winner == 0.0): + return False, None winner = struct.pack('fff', winner == 1.0, winner == 0.0, winner == -1.0) best_q_w = 0.5 * (1.0 - best_d + best_q) best_q_l = 0.5 * (1.0 - best_d - best_q) - assert -1.0 <= best_q <= 1.0 and 0.0 <= best_d <= 1.0 + if not (-1.0 <= best_q <= 1.0 and 0.0 <= best_d <= 1.0): + return False, None best_q = struct.pack('fff', best_q_w, best_d, best_q_l) - return (planes, probs, winner, best_q) + return True, (planes, probs, winner, best_q) def sample_record(self, chunkdata): @@ -282,7 +335,9 @@ def tuple_gen(self, gen): applying a random symmetry on the way. """ for r in gen: - yield self.convert_v4_to_tuple(r) + success, data = self.convert_v4_to_tuple(r) + if success: + yield data def batch_gen(self, gen): diff --git a/tf/lc0_az_policy_map.py b/tf/lc0_az_policy_map.py index 368d9100..26517cae 100755 --- a/tf/lc0_az_policy_map.py +++ b/tf/lc0_az_policy_map.py @@ -16,6 +16,12 @@ def index_to_position(x): def position_to_index(p): return col_index[p[0]], row_index[p[1]] +def flip_position_lr(p): + if p == None: + return None + c, r = position_to_index(p) + return index_to_position([7 - c, r]) + def valid_index(i): if i[0] > 7 or i[0] < 0: return False @@ -46,16 +52,23 @@ def knight_move(start, direction, steps): def make_map(kind='matrix'): # 56 planes of queen moves moves = [] + flip_moves = [] for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: for steps in range(1, 8): for r0 in rows: for c0 in columns: start = c0 + r0 end = queen_move(start, direction, steps) + flip_start = flip_position_lr(start) + flip_end = flip_position_lr(end) if end == None: moves.append('illegal') else: moves.append(start+end) + if flip_end == None: + flip_moves.append('illegal') + else: + flip_moves.append(flip_start+flip_end) # 8 planes of knight moves for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']: @@ -63,10 +76,16 @@ def make_map(kind='matrix'): for c0 in columns: start = c0 + r0 end = knight_move(start, direction, 1) + flip_start = flip_position_lr(start) + flip_end = flip_position_lr(end) if end == None: moves.append('illegal') else: moves.append(start+end) + if flip_end == None: + flip_moves.append('illegal') + else: + flip_moves.append(flip_start+flip_end) # 9 promotions for direction in ['NW', 'N', 'NE']: @@ -76,20 +95,30 @@ def make_map(kind='matrix'): # Promotion only in the second last rank if r0 != '7': moves.append('illegal') + flip_moves.append('illegal') continue start = c0 + r0 end = queen_move(start, direction, 1) + flip_start = flip_position_lr(start) + flip_end = flip_position_lr(end) if end == None: moves.append('illegal') else: moves.append(start+end+promotion) + if flip_end == None: + flip_moves.append('illegal') + else: + flip_moves.append(flip_start+flip_end+promotion) for m in policy_index: if m not in moves: raise ValueError('Missing move: {}'.format(m)) + if m not in flip_moves: + raise ValueError('Missing move: {}'.format(m)) az_to_lc0 = np.zeros((80*8*8, len(policy_index)), dtype=np.float32) indices = [] + flip_permutation = np.zeros(1858, dtype=np.int32) legal_moves = 0 for e, m in enumerate(moves): if m == 'illegal': @@ -100,18 +129,22 @@ def make_map(kind='matrix'): if m not in policy_index: raise ValueError('Missing move: {}'.format(m)) i = policy_index.index(m) + flip_i = policy_index.index(flip_moves[e]) + flip_permutation[i] = flip_i indices.append(i) az_to_lc0[e][i] = 1 + # Verify that applying flip permutation twice gives back the original policy. + assert np.array_equal(flip_permutation[flip_permutation], list(range(1858))) + assert legal_moves == len(policy_index) assert np.sum(az_to_lc0) == legal_moves - for e in range(80*8*8): - for i in range(len(policy_index)): - pass if kind == 'matrix': return az_to_lc0 elif kind == 'index': return indices + elif kind == 'flip_permutation': + return flip_permutation if __name__ == "__main__": # Generate policy map include file for lc0 diff --git a/tf/train.py b/tf/train.py index d452f03c..f344e6e3 100755 --- a/tf/train.py +++ b/tf/train.py @@ -129,6 +129,7 @@ def main(cmd): allow_less = cfg['dataset'].get('allow_less_chunks', False) train_ratio = cfg['dataset']['train_ratio'] experimental_parser = cfg['dataset'].get('experimental_v4_only_dataset', False) + flip = cfg['dataset']['flip_augmentation'] num_train = int(num_chunks*train_ratio) num_test = num_chunks - num_train if 'input_test' in cfg['dataset']: @@ -163,7 +164,8 @@ def main(cmd): .batch(split_batch_size).map(extract_inputs_outputs).prefetch(4) else: train_parser = ChunkParser(FileDataSrc(train_chunks), - shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE) + shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE, + flip=flip) train_dataset = tf.data.Dataset.from_generator( train_parser.parse, output_types=(tf.string, tf.string, tf.string, tf.string)) train_dataset = train_dataset.map(ChunkParser.parse_function) @@ -177,7 +179,8 @@ def main(cmd): .batch(split_batch_size).map(extract_inputs_outputs).prefetch(4) else: test_parser = ChunkParser(FileDataSrc(test_chunks), - shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE) + shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE, + flip=False) test_dataset = tf.data.Dataset.from_generator( test_parser.parse, output_types=(tf.string, tf.string, tf.string, tf.string)) test_dataset = test_dataset.map(ChunkParser.parse_function)