Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Randomly flip training positions #97

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 61 additions & 6 deletions tf/chunkparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import struct
import tensorflow as tf
import unittest
import lc0_az_policy_map

V4_VERSION = struct.pack('i', 4)
V3_VERSION = struct.pack('i', 3)
Expand All @@ -41,11 +42,17 @@ def next(self):
return self.items.pop()


def flip_vertex(v):
c = v % 8
r = v // 8
c = 7 - c
return 8 * r + c

class ChunkParser:
# static batch size
BATCH_SIZE = 8
def __init__(self, chunkdatasrc, shuffle_size=1, sample=1, buffer_size=1, batch_size=256, workers=None):
def __init__(self, chunkdatasrc, shuffle_size=1, sample=1, buffer_size=1,
batch_size=256, workers=None, flip=False):
"""
Read data and yield batches of raw tensors.

Expand All @@ -68,6 +75,17 @@ def __init__(self, chunkdatasrc, shuffle_size=1, sample=1, buffer_size=1, batch_
long.
"""

self.flip = flip
if self.flip:
self.policy_flip_map = lc0_az_policy_map.make_map(kind='flip_permutation')

# Build full flip tables for flipping all input planes at once.
self.full_flip_map = np.array([flip_vertex(vertex) + p*8*8
for p in range(104) for vertex in range(8*8)], dtype=np.int32)

r = np.array(list(range(104*8*8)), dtype=np.int32)
assert np.array_equal((r[self.full_flip_map])[self.full_flip_map], r)

# Build 2 flat float32 planes with values 0,1
self.flat_planes = []
for i in range(2):
Expand Down Expand Up @@ -183,6 +201,38 @@ def convert_v4_to_tuple(self, content):
planes = np.unpackbits(np.frombuffer(planes, dtype=np.uint8)).astype(np.float32)
rule50_plane = (np.zeros(8*8, dtype=np.float32) + rule50_count) / 99

if self.flip:
can_flip = True
if not (us_ooo == us_oo == them_ooo == them_oo == 0):
can_flip = False
if can_flip and random.randrange(2) == 0:
can_flip = False
if can_flip:
# Count from the last position in the history.
pieces = np.sum(planes[13*7*64:-64])
if pieces > 16:
can_flip = False
if can_flip:
# King plane of last position in the history.
our_king = planes[6144:6208]
their_king = planes[6528:6592]
our_king_pos = np.where(our_king == 1)[0]
their_king_pos = np.where(their_king == 1)[0]
if len(our_king_pos) != 1:
return False, None
if len(their_king_pos) != 1:
return False, None
# Default and castling positions.
if our_king_pos[0] in (2, 4, 6):
can_flip = False
if their_king_pos[0] in (60, 62, 64):
can_flip = False
if can_flip:
planes = planes[self.full_flip_map]
probs = np.frombuffer(probs, dtype=np.float32)
probs = probs[self.policy_flip_map]
probs = probs.tobytes()

# Concatenate all byteplanes. Make the last plane all 1's so the NN can
# detect edges of the board more easily
planes = planes.tobytes() + \
Expand All @@ -195,17 +245,20 @@ def convert_v4_to_tuple(self, content):
self.flat_planes[move_count].tobytes() + \
self.flat_planes[1].tobytes()

assert len(planes) == ((8*13*1 + 8*1*1) * 8 * 8 * 4)
if not (len(planes) == ((8*13*1 + 8*1*1) * 8 * 8 * 4)):
return False, None
winner = float(winner)
assert winner == 1.0 or winner == -1.0 or winner == 0.0
if not (winner == 1.0 or winner == -1.0 or winner == 0.0):
return False, None
winner = struct.pack('fff', winner == 1.0, winner == 0.0, winner == -1.0)

best_q_w = 0.5 * (1.0 - best_d + best_q)
best_q_l = 0.5 * (1.0 - best_d - best_q)
assert -1.0 <= best_q <= 1.0 and 0.0 <= best_d <= 1.0
if not (-1.0 <= best_q <= 1.0 and 0.0 <= best_d <= 1.0):
return False, None
best_q = struct.pack('fff', best_q_w, best_d, best_q_l)

return (planes, probs, winner, best_q)
return True, (planes, probs, winner, best_q)


def sample_record(self, chunkdata):
Expand Down Expand Up @@ -282,7 +335,9 @@ def tuple_gen(self, gen):
applying a random symmetry on the way.
"""
for r in gen:
yield self.convert_v4_to_tuple(r)
success, data = self.convert_v4_to_tuple(r)
if success:
yield data


def batch_gen(self, gen):
Expand Down
39 changes: 36 additions & 3 deletions tf/lc0_az_policy_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def index_to_position(x):
def position_to_index(p):
return col_index[p[0]], row_index[p[1]]

def flip_position_lr(p):
if p == None:
return None
c, r = position_to_index(p)
return index_to_position([7 - c, r])

def valid_index(i):
if i[0] > 7 or i[0] < 0:
return False
Expand Down Expand Up @@ -46,27 +52,40 @@ def knight_move(start, direction, steps):
def make_map(kind='matrix'):
# 56 planes of queen moves
moves = []
flip_moves = []
for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']:
for steps in range(1, 8):
for r0 in rows:
for c0 in columns:
start = c0 + r0
end = queen_move(start, direction, steps)
flip_start = flip_position_lr(start)
flip_end = flip_position_lr(end)
if end == None:
moves.append('illegal')
else:
moves.append(start+end)
if flip_end == None:
flip_moves.append('illegal')
else:
flip_moves.append(flip_start+flip_end)

# 8 planes of knight moves
for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']:
for r0 in rows:
for c0 in columns:
start = c0 + r0
end = knight_move(start, direction, 1)
flip_start = flip_position_lr(start)
flip_end = flip_position_lr(end)
if end == None:
moves.append('illegal')
else:
moves.append(start+end)
if flip_end == None:
flip_moves.append('illegal')
else:
flip_moves.append(flip_start+flip_end)

# 9 promotions
for direction in ['NW', 'N', 'NE']:
Expand All @@ -76,20 +95,30 @@ def make_map(kind='matrix'):
# Promotion only in the second last rank
if r0 != '7':
moves.append('illegal')
flip_moves.append('illegal')
continue
start = c0 + r0
end = queen_move(start, direction, 1)
flip_start = flip_position_lr(start)
flip_end = flip_position_lr(end)
if end == None:
moves.append('illegal')
else:
moves.append(start+end+promotion)
if flip_end == None:
flip_moves.append('illegal')
else:
flip_moves.append(flip_start+flip_end+promotion)

for m in policy_index:
if m not in moves:
raise ValueError('Missing move: {}'.format(m))
if m not in flip_moves:
raise ValueError('Missing move: {}'.format(m))

az_to_lc0 = np.zeros((80*8*8, len(policy_index)), dtype=np.float32)
indices = []
flip_permutation = np.zeros(1858, dtype=np.int32)
legal_moves = 0
for e, m in enumerate(moves):
if m == 'illegal':
Expand All @@ -100,18 +129,22 @@ def make_map(kind='matrix'):
if m not in policy_index:
raise ValueError('Missing move: {}'.format(m))
i = policy_index.index(m)
flip_i = policy_index.index(flip_moves[e])
flip_permutation[i] = flip_i
indices.append(i)
az_to_lc0[e][i] = 1

# Verify that applying flip permutation twice gives back the original policy.
assert np.array_equal(flip_permutation[flip_permutation], list(range(1858)))

assert legal_moves == len(policy_index)
assert np.sum(az_to_lc0) == legal_moves
for e in range(80*8*8):
for i in range(len(policy_index)):
pass
if kind == 'matrix':
return az_to_lc0
elif kind == 'index':
return indices
elif kind == 'flip_permutation':
return flip_permutation

if __name__ == "__main__":
# Generate policy map include file for lc0
Expand Down
7 changes: 5 additions & 2 deletions tf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def main(cmd):
allow_less = cfg['dataset'].get('allow_less_chunks', False)
train_ratio = cfg['dataset']['train_ratio']
experimental_parser = cfg['dataset'].get('experimental_v4_only_dataset', False)
flip = cfg['dataset']['flip_augmentation']
num_train = int(num_chunks*train_ratio)
num_test = num_chunks - num_train
if 'input_test' in cfg['dataset']:
Expand Down Expand Up @@ -163,7 +164,8 @@ def main(cmd):
.batch(split_batch_size).map(extract_inputs_outputs).prefetch(4)
else:
train_parser = ChunkParser(FileDataSrc(train_chunks),
shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE)
shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE,
flip=flip)
train_dataset = tf.data.Dataset.from_generator(
train_parser.parse, output_types=(tf.string, tf.string, tf.string, tf.string))
train_dataset = train_dataset.map(ChunkParser.parse_function)
Expand All @@ -177,7 +179,8 @@ def main(cmd):
.batch(split_batch_size).map(extract_inputs_outputs).prefetch(4)
else:
test_parser = ChunkParser(FileDataSrc(test_chunks),
shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE)
shuffle_size=shuffle_size, sample=SKIP, batch_size=ChunkParser.BATCH_SIZE,
flip=False)
test_dataset = tf.data.Dataset.from_generator(
test_parser.parse, output_types=(tf.string, tf.string, tf.string, tf.string))
test_dataset = test_dataset.map(ChunkParser.parse_function)
Expand Down