diff --git a/tf/multi_head_relative_attention.py b/tf/multi_head_relative_attention.py new file mode 100644 index 00000000..00eae86f --- /dev/null +++ b/tf/multi_head_relative_attention.py @@ -0,0 +1,505 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-based attention layer.""" + +import tensorflow.compat.v2 as tf +# pylint: disable=g-classes-have-attributes + +import collections +import math +import string + +import numpy as np +from keras import constraints +from keras import initializers +from keras import regularizers +from keras.engine.base_layer import Layer +from keras.layers import advanced_activations +from keras.layers import core +from keras.layers import einsum_dense +from keras.utils import tf_utils +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import keras_export + + +_CHR_IDX = string.ascii_lowercase + + +def _build_attention_equation(rank, attn_axes): + """Builds einsum equations for the attention computation. + + Query, key, value inputs after projection are expected to have the shape as: + `(bs, , , num_heads, channels)`. + `bs` and `` are treated as ``. + + The attention operations can be generalized: + (1) Query-key dot product: + `(, , num_heads, channels), (, + , num_heads, channels) -> (, + num_heads, , )` + (2) Combination: + `(, num_heads, , ), + (, , num_heads, channels) -> (, + , num_heads, channels)` + + Args: + rank: Rank of query, key, value tensors. + attn_axes: List/tuple of axes, `[-1, rank)`, + that attention will be applied to. + + Returns: + Einsum equations. + """ + target_notation = _CHR_IDX[:rank] + # `batch_dims` includes the head dim. + batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,))) + letter_offset = rank + source_notation = "" + for i in range(rank): + if i in batch_dims or i == rank - 1: + source_notation += target_notation[i] + else: + source_notation += _CHR_IDX[letter_offset] + letter_offset += 1 + + product_notation = "".join([target_notation[i] for i in batch_dims] + + [target_notation[i] for i in attn_axes] + + [source_notation[i] for i in attn_axes]) + dot_product_equation = "%s,%s->%s" % (source_notation, target_notation, + product_notation) + attn_scores_rank = len(product_notation) + combine_equation = "%s,%s->%s" % (product_notation, source_notation, + target_notation) + return dot_product_equation, combine_equation, attn_scores_rank + + +def _build_proj_equation(free_dims, bound_dims, output_dims): + """Builds an einsum equation for projections inside multi-head attention.""" + input_str = "" + kernel_str = "" + output_str = "" + bias_axes = "" + letter_offset = 0 + for i in range(free_dims): + char = _CHR_IDX[i + letter_offset] + input_str += char + output_str += char + + letter_offset += free_dims + for i in range(bound_dims): + char = _CHR_IDX[i + letter_offset] + input_str += char + kernel_str += char + + letter_offset += bound_dims + for i in range(output_dims): + char = _CHR_IDX[i + letter_offset] + kernel_str += char + output_str += char + bias_axes += char + equation = "%s,%s->%s" % (input_str, kernel_str, output_str) + + return equation, bias_axes, len(output_str) + + +def _get_output_shape(output_rank, known_last_dims): + return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) + + +class MultiHeadRelativeAttention(Layer): + """MultiHeadRelativeAttention layer. + + This is an implementation of multi-headed attention as described in the paper + "Attention is all you Need" (Vaswani et al., 2017). + If `query`, `key,` `value` are the same, then + this is self-attention. Each timestep in `query` attends to the + corresponding sequence in `key`, and returns a fixed-width vector. + + This layer first projects `query`, `key` and `value`. These are + (effectively) a list of tensors of length `num_attention_heads`, where the + corresponding shapes are `(batch_size, , key_dim)`, + `(batch_size, , key_dim)`, + `(batch_size, , value_dim)`. + + Then, the query and key tensors are dot-producted and scaled. These are + softmaxed to obtain attention probabilities. The value tensors are then + interpolated by these probabilities, then concatenated back to a single + tensor. + + Finally, the result tensor with the last dimension as value_dim can take an + linear projection and return. + + Examples: + + Performs 1D cross-attention over two sequence inputs with an attention mask. + Returns the additional attention weights over heads. + + >>> layer = MultiHeadAttention(num_heads=2, key_dim=2) + >>> target = tf.keras.Input(shape=[8, 16]) + >>> source = tf.keras.Input(shape=[4, 16]) + >>> output_tensor, weights = layer(target, source, + ... return_attention_scores=True) + >>> print(output_tensor.shape) + (None, 8, 16) + >>> print(weights.shape) + (None, 2, 8, 4) + + Performs 2D self-attention over a 5D input tensor on axes 2 and 3. + + >>> layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3)) + >>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16]) + >>> output_tensor = layer(input_tensor, input_tensor) + >>> print(output_tensor.shape) + (None, 5, 3, 4, 16) + + Args: + num_heads: Number of attention heads. + key_dim: Size of each attention head for query and key. + value_dim: Size of each attention head for value. + dropout: Dropout probability. + use_bias: Boolean, whether the dense layers use bias vectors/matrices. + output_shape: The expected shape of an output tensor, besides the batch and + sequence dims. If not specified, projects back to the key feature dim. + attention_axes: axes over which the attention is applied. `None` means + attention over all axes, but batch, heads, and features. + kernel_initializer: Initializer for dense layer kernels. + bias_initializer: Initializer for dense layer biases. + kernel_regularizer: Regularizer for dense layer kernels. + bias_regularizer: Regularizer for dense layer biases. + activity_regularizer: Regularizer for dense layer activity. + kernel_constraint: Constraint for dense layer kernels. + bias_constraint: Constraint for dense layer kernels. + + Call arguments: + query: Query `Tensor` of shape `(B, T, dim)`. + value: Value `Tensor` of shape `(B, S, dim)`. + key: Optional key `Tensor` of shape `(B, S, dim)`. If not given, will use + `value` for both `key` and `value`, which is the most common case. + attention_mask: a boolean mask of shape `(B, T, S)`, that prevents + attention to certain positions. The boolean mask specifies which query + elements can attend to which key elements, 1 indicates attention and 0 + indicates no attention. Broadcasting can happen for the missing batch + dimensions and the head dimension. + return_attention_scores: A boolean to indicate whether the output should + be attention output if True, or (attention_output, attention_scores) if + False. Defaults to False. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (no dropout). + Defaults to either using the training mode of the parent layer/model, + or False (inference) if there is no parent layer. + + Returns: + attention_output: The result of the computation, of shape `(B, T, E)`, + where `T` is for target sequence shapes and `E` is the query input last + dimension if `output_shape` is `None`. Otherwise, the multi-head outputs + are project to the shape specified by `output_shape`. + attention_scores: [Optional] multi-head attention coeffients over + attention axes. + """ + + def __init__(self, + num_heads, + key_dim, + value_dim=None, + dropout=0.0, + use_bias=True, + output_shape=None, + attention_axes=None, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs): + super(MultiHeadRelativeAttention, self).__init__(**kwargs) + self._num_heads = num_heads + self._key_dim = key_dim + self._value_dim = value_dim if value_dim else key_dim + self._dropout = dropout + self._use_bias = use_bias + self._output_shape = output_shape + self._kernel_initializer = initializers.get(kernel_initializer) + self._bias_initializer = initializers.get(bias_initializer) + self._kernel_regularizer = regularizers.get(kernel_regularizer) + self._bias_regularizer = regularizers.get(bias_regularizer) + self._kernel_constraint = constraints.get(kernel_constraint) + self._bias_constraint = constraints.get(bias_constraint) + if attention_axes is not None and not isinstance(attention_axes, + collections.abc.Sized): + self._attention_axes = (attention_axes,) + else: + self._attention_axes = attention_axes + self._built_from_signature = False + self._query_shape, self._key_shape, self._value_shape = None, None, None + + def get_config(self): + config = { + "num_heads": self._num_heads, + "key_dim": self._key_dim, + "value_dim": self._value_dim, + "dropout": self._dropout, + "use_bias": self._use_bias, + "output_shape": self._output_shape, + "attention_axes": self._attention_axes, + "kernel_initializer": + initializers.serialize(self._kernel_initializer), + "bias_initializer": + initializers.serialize(self._bias_initializer), + "kernel_regularizer": + regularizers.serialize(self._kernel_regularizer), + "bias_regularizer": + regularizers.serialize(self._bias_regularizer), + "activity_regularizer": + regularizers.serialize(self._activity_regularizer), + "kernel_constraint": + constraints.serialize(self._kernel_constraint), + "bias_constraint": + constraints.serialize(self._bias_constraint), + "query_shape": self._query_shape, + "key_shape": self._key_shape, + "value_shape": self._value_shape, + } + base_config = super(MultiHeadRelativeAttention, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + # If the layer has a different build() function from the Keras default, + # we need to trigger the customized build to create weights. + query_shape = config.pop("query_shape") + key_shape = config.pop("key_shape") + value_shape = config.pop("value_shape") + layer = cls(**config) + if None in [query_shape, key_shape, value_shape]: + logging.warning( + "One of dimensions of the input shape is missing. It should have been" + " memorized when the layer was serialized. " + "%s is created without weights.", + str(cls)) + else: + layer._build_from_signature(query_shape, value_shape, key_shape) # pylint: disable=protected-access + return layer + + def _build_from_signature(self, query, value, key=None): + """Builds layers and variables. + + Once the method is called, self._built_from_signature will be set to True. + + Args: + query: Query tensor or TensorShape. + value: Value tensor or TensorShape. + key: Key tensor or TensorShape. + """ + self._built_from_signature = True + if hasattr(query, "shape"): + self._query_shape = tf.TensorShape(query.shape) + else: + self._query_shape = tf.TensorShape(query) + if hasattr(value, "shape"): + self._value_shape = tf.TensorShape(value.shape) + else: + self._value_shape = tf.TensorShape(value) + if key is None: + self._key_shape = self._value_shape + elif hasattr(key, "shape"): + self._key_shape = tf.TensorShape(key.shape) + else: + self._key_shape = tf.TensorShape(key) + + common_kwargs = dict( + kernel_initializer=self._kernel_initializer, + bias_initializer=self._bias_initializer, + kernel_regularizer=self._kernel_regularizer, + bias_regularizer=self._bias_regularizer, + activity_regularizer=self._activity_regularizer, + kernel_constraint=self._kernel_constraint, + bias_constraint=self._bias_constraint) + # Any setup work performed only once should happen in an `init_scope` + # to avoid creating symbolic Tensors that will later pollute any eager + # operations. + with tf_utils.maybe_init_scope(self): + free_dims = self._query_shape.rank - 1 + einsum_equation, bias_axes, output_rank = _build_proj_equation( + free_dims, bound_dims=1, output_dims=2) + self._query_dense = einsum_dense.EinsumDense( + einsum_equation, + output_shape=_get_output_shape(output_rank - 1, + [self._num_heads, self._key_dim*3]), + bias_axes=bias_axes if self._use_bias else None, + name="query", + **common_kwargs) + + # Builds the attention computations for multi-head dot product attention. + # These computations could be wrapped into the keras attention layer once + # it support mult-head einsum computations. + self._build_attention(output_rank) + height = self._query_shape[-3] + width = self._query_shape[-2] + self._height = height + self._width = width + self._relative_bias_table = self.add_weight("rel_bias_table", + shape=[(2 * height - 1) * (2 * width - 1), self._num_heads], + initializer=self._bias_initializer, + regularizer=self._bias_regularizer, + constraint=self._bias_constraint, + trainable=True) + + coords = tf.meshgrid(tf.range(height), tf.range(width), indexing="ij") + coords = tf.reshape(tf.stack(coords), [2, -1]) + relative_coords = coords[:, :, None] - coords[:, None, :] + + relative_coords = tf.add(relative_coords, [[[height - 1]],[[width -1]]]) + relative_coords = tf.multiply(relative_coords, [[[2 * width -1]],[[1]]]) + relative_coords = tf.transpose(relative_coords, [1, 2, 0]) + self._relative_index = tf.reshape(tf.reduce_sum(relative_coords, axis=-1), [height*width,height*width]) + self._output_dense = self._make_output_dense( + free_dims, common_kwargs, "attention_output") + + def _make_output_dense(self, free_dims, common_kwargs, name=None): + """Builds the output projection matrix. + + Args: + free_dims: Number of free dimensions for einsum equation building. + common_kwargs: Common keyword arguments for einsum layer. + name: Name for the projection layer. + + Returns: + Projection layer. + """ + if self._output_shape: + if not isinstance(self._output_shape, collections.abc.Sized): + output_shape = [self._output_shape] + else: + output_shape = self._output_shape + else: + output_shape = [self._query_shape[-1]] + einsum_equation, bias_axes, output_rank = _build_proj_equation( + free_dims, bound_dims=2, output_dims=len(output_shape)) + return einsum_dense.EinsumDense( + einsum_equation, + output_shape=_get_output_shape(output_rank - 1, output_shape), + bias_axes=bias_axes if self._use_bias else None, + name=name, + **common_kwargs) + + def _build_attention(self, rank): + """Builds multi-head dot-product attention computations. + + This function builds attributes necessary for `_compute_attention` to + costomize attention computation to replace the default dot-product + attention. + + Args: + rank: the rank of query, key, value tensors. + """ + if self._attention_axes is None: + self._attention_axes = tuple(range(1, rank - 2)) + else: + self._attention_axes = tuple(self._attention_axes) + self._dot_product_equation, self._combine_equation, attn_scores_rank = ( + _build_attention_equation(rank, attn_axes=self._attention_axes)) + norm_axes = tuple( + range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) + self._softmax = advanced_activations.Softmax(axis=norm_axes) + self._dropout_layer = core.Dropout(rate=self._dropout) + + def _masked_softmax(self, attention_scores, attention_mask=None): + # Normalize the attention scores to probabilities. + # `attention_scores` = [B, N, T, S] + if attention_mask is not None: + # The expand dim happens starting from the `num_heads` dimension, + # (, num_heads, ) + mask_expansion_axis = -len(self._attention_axes) * 2 - 1 + for _ in range(len(attention_scores.shape) - len(attention_mask.shape)): + attention_mask = tf.expand_dims( + attention_mask, axis=mask_expansion_axis) + return self._softmax(attention_scores, attention_mask) + + def _compute_attention(self, + query, + key, + value, + attention_mask=None, + training=None): + """Applies Dot-product attention with query, key, value tensors. + + This function defines the computation inside `call` with projected + multi-head Q, K, V inputs. Users can override this function for customized + attention implementation. + + Args: + query: Projected query `Tensor` of shape `(B, T, N, key_dim)`. + key: Projected key `Tensor` of shape `(B, T, N, key_dim)`. + value: Projected value `Tensor` of shape `(B, T, N, value_dim)`. + attention_mask: a boolean mask of shape `(B, T, S)`, that prevents + attention to certain positions. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (doing nothing). + + Returns: + attention_output: Multi-headed outputs of attention computation. + attention_scores: Multi-headed attention weights. + """ + # Note: Applying scalar multiply at the smaller end of einsum improves + # XLA performance, but may introduce slight numeric differences in + # the Transformer attention head. + query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) + + # Take the dot product between "query" and "key" to get the raw + # attention scores. + attention_scores = tf.einsum(self._dot_product_equation, key, + query) + + relative_bias = tf.gather(self._relative_bias_table, axis=0, indices=self._relative_index) + relative_bias = tf.reshape(tf.transpose(relative_bias, [2, 0, 1]), [self._num_heads, self._height, self._width, self._height, self._width]) + + attention_scores = attention_scores + relative_bias + + attention_scores = self._masked_softmax(attention_scores, attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_scores_dropout = self._dropout_layer( + attention_scores, training=training) + + # `context_layer` = [B, T, N, H] + attention_output = tf.einsum(self._combine_equation, + attention_scores_dropout, value) + return attention_output, attention_scores + + def call(self, + x, + attention_mask=None, + return_attention_scores=False, + training=None): + if not self._built_from_signature: + self._build_from_signature(query=x, value=x, key=x) + + # N = `num_attention_heads` + # H = `size_per_head` + # `query` = [B, T, N ,H] + query, key, value = tf.split(self._query_dense(x), 3, axis=-1) + + attention_output, attention_scores = self._compute_attention( + query, key, value, attention_mask, training) + #attention_output = tf.reshape(attention_output, tf.shape(x)) + attention_output = self._output_dense(attention_output) + + if return_attention_scores: + return attention_output, attention_scores + return attention_output diff --git a/tf/net.py b/tf/net.py index de779ca0..0c148dee 100755 --- a/tf/net.py +++ b/tf/net.py @@ -284,16 +284,55 @@ def moves_left_to_bp(l, w): return d[w].format(n) + def ln_to_bp(w): + w = w.split(':')[0] + d = {'gamma': 'gammas', 'beta': 'betas'} + + return d[w] + + def mhra_to_bp(l, w): + w = w.split(':')[0] + d = {'rel_bias_table': 'rel_bias_table', 'kernel': 'w', 'bias': 'b'} + + prefix = '' + d2 = {'query': 'q', 'key': 'k', 'value': 'v', 'attention_output': 'o'} + if l in d2: + prefix = d2[l] + + return prefix + d[w] + + def ffn_to_bp(w): + w = w.split(':')[0] + d = {'kernel': 'w', 'bias': 'b'} + + return d[w] + + layers = name.split('/') base_layer = layers[0] weights_name = layers[-1] pb_name = None block = None + block2 = None if base_layer == 'input': pb_name = 'input.' + convblock_to_bp(weights_name) elif base_layer == 'policy1': pb_name = 'policy1.' + convblock_to_bp(weights_name) + elif base_layer == 'policyra': + if layers[1] == 'mhra': + if layers[2] == 'ln': + pb_name = 'mhra.ln' + ln_to_bp(weights_name) + else: + pb_name = 'mhra.' + mhra_to_bp(layers[2], weights_name) + elif layers[1] == 'ffn': + if layers[2] == 'ln': + pb_name = 'ffn.ln' + ln_to_bp(weights_name) + elif layers[2] == 'fc1': + pb_name = 'ffn.fc1' + ffn_to_bp(weights_name) + else: + pb_name = 'ffn.fc2' + ffn_to_bp(weights_name) + pb_name = 'policyra.' + pb_name elif base_layer == 'policy': if 'dense' in layers[1]: pb_name = policy_to_bp(weights_name) @@ -317,8 +356,26 @@ def moves_left_to_bp(l, w): pb_name = 'conv2.' + convblock_to_bp(weights_name) elif layers[1] == 'se': pb_name = 'se.' + se_to_bp(layers[-2], weights_name) - - return (pb_name, block) + elif base_layer.startswith('rra'): + block2 = int(base_layer.split('_')[1]) - 1 # 1 indexed + if layers[1] == 'mhra': + if layers[2] == 'ln': + pb_name = 'mhra.ln' + ln_to_bp(weights_name) + else: + pb_name = 'mhra.' + mhra_to_bp(layers[2], weights_name) + elif layers[1] == 'ffn': + if layers[2] == 'ln': + pb_name = 'ffn.ln' + ln_to_bp(weights_name) + elif layers[2] == 'fc1': + pb_name = 'ffn.fc1' + ffn_to_bp(weights_name) + elif layers[2] == 'fc2': + pb_name = 'ffn.fc2' + ffn_to_bp(weights_name) + elif layers[2] == 'keys': + pb_name = 'ffn.keys' + ffn_to_bp(weights_name) + elif layers[2] == 'queries': + pb_name = 'ffn.queries' + ffn_to_bp(weights_name) + + return (pb_name, block, block2) def get_weights_v2(self, names): # `names` is a list of Tensorflow tensor names to get from the protobuf. @@ -334,15 +391,17 @@ def get_weights_v2(self, names): # Renorm variables are not populated. continue - pb_name, block = self.tf_name_to_pb_name(name) + pb_name, block, block2 = self.tf_name_to_pb_name(name) if pb_name is None: raise ValueError( "Don't know where to store weight in protobuf: {}".format( name)) - if block == None: + if block == None and block2 == None: pb_weights = self.pb.weights + elif block == None: + pb_weights = self.pb.weights.rra[block2] else: pb_weights = self.pb.weights.residual[block] @@ -478,15 +537,20 @@ def fill_net_v2(self, all_weights): # 50 move rule is the 110th input, or 109 starting from 0. weights[:, 109, :, :] /= 99 - pb_name, block = self.tf_name_to_pb_name(name) + pb_name, block, block2 = self.tf_name_to_pb_name(name) if pb_name is None: raise ValueError( "Don't know where to store weight in protobuf: {}".format( name)) - if block == None: + if block == None and block2 == None: pb_weights = self.pb.weights + elif block == None: + assert block2 >= 0 + while block2 >= len(self.pb.weights.rra): + self.pb.weights.rra.add() + pb_weights = self.pb.weights.rra[block2] else: assert block >= 0 while block >= len(self.pb.weights.residual): diff --git a/tf/tfprocess.py b/tf/tfprocess.py index bcfc1c0d..224e2909 100644 --- a/tf/tfprocess.py +++ b/tf/tfprocess.py @@ -17,14 +17,17 @@ # along with Leela Zero. If not, see . import numpy as np +import math import os import random import tensorflow as tf +import tensorflow_addons as tfa import time import bisect import lc0_az_policy_map import proto.net_pb2 as pb from functools import reduce +from multi_head_relative_attention import MultiHeadRelativeAttention import operator from net import Net @@ -102,6 +105,7 @@ def __init__(self, cfg): # Network structure self.RESIDUAL_FILTERS = self.cfg['model']['filters'] self.RESIDUAL_BLOCKS = self.cfg['model']['residual_blocks'] + self.RRA_BLOCKS = self.cfg['model']['rra_blocks'] self.SE_ratio = self.cfg['model']['se_ratio'] self.policy_channels = self.cfg['model'].get('policy_channels', 32) precision = self.cfg['training'].get('precision', 'single') @@ -131,6 +135,8 @@ def __init__(self, cfg): if policy_head == "classical": self.POLICY_HEAD = pb.NetworkFormat.POLICY_CLASSICAL + elif policy_head == "ra": + self.POLICY_HEAD = pb.NetworkFormat.POLICY_RA elif policy_head == "convolution": self.POLICY_HEAD = pb.NetworkFormat.POLICY_CONVOLUTION else: @@ -257,14 +263,13 @@ def init_net(self): ] self.active_lr = 0.01 - self.optimizer = tf.keras.optimizers.SGD( - learning_rate=lambda: self.active_lr, momentum=0.9, nesterov=True) + self.optimizer = tfa.optimizers.AdamW(weight_decay=lambda: self.active_lr / 5.0, + learning_rate=lambda: self.active_lr) self.orig_optimizer = self.optimizer if self.loss_scale != 1: self.optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( self.optimizer, self.loss_scale) if self.cfg['training'].get('lookahead_optimizer'): - import tensorflow_addons as tfa self.optimizer = tfa.optimizers.Lookahead(self.optimizer) def correct_policy(target, output): @@ -397,7 +402,7 @@ def moves_left_loss(target, output): reg_term_w = self.cfg['training'].get('reg_term_weight', 1.0) def _lossMix(policy, value, moves_left, reg_term): - return pol_loss_w * policy + val_loss_w * value + moves_loss_w * moves_left + reg_term_w * reg_term + return pol_loss_w * policy + val_loss_w * value + moves_loss_w * moves_left #+ reg_term_w * reg_term self.lossMix = _lossMix @@ -643,10 +648,10 @@ def strategy_process_inner_loop(self, x, y, z, q, m): return metrics, new_grads def apply_grads(self, grads, effective_batch_splits): - grads = [ - g[0] for g in self.orig_optimizer.gradient_aggregator( - zip(grads, self.model.trainable_weights)) - ] + #grads = [ + # g[0] for g in self.orig_optimizer.gradient_aggregator( + # zip(grads, self.model.trainable_weights)) + #] if self.loss_scale != 1: grads = self.optimizer.get_unscaled_gradients(grads) max_grad_norm = self.cfg['training'].get( @@ -654,7 +659,8 @@ def apply_grads(self, grads, effective_batch_splits): grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm) self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights), - experimental_aggregate_gradients=False) + ) + #experimental_aggregate_gradients=False) return grad_norm @tf.function() @@ -1098,7 +1104,37 @@ def residual_block(self, inputs, channels, name): return tf.keras.layers.Activation('relu')(tf.keras.layers.add( [inputs, out2])) + def residual_block_ra(self, flow, channels, name, out_channels=None, last_activation='gelu'): + if out_channels is None: + out_channels = channels + orig_flow = flow + flow = tf.keras.layers.LayerNormalization(scale=False, + name=name + '/mhra/ln')(flow) + flow = orig_flow + MultiHeadRelativeAttention( + channels // 32, + 32, + kernel_regularizer=self.l2reg, + dropout=0.1, + name=name + '/mhra')(flow) + orig_flow = flow + flow = tf.keras.layers.LayerNormalization(scale=False, + name=name + '/ffn/ln')(flow) + flow = tf.keras.layers.Dense(channels, + kernel_initializer='glorot_normal', + kernel_regularizer=self.l2reg, + activation='gelu', + name=name + '/ffn/fc1')(flow) + flow = tf.keras.layers.Dense(out_channels, + kernel_initializer='glorot_normal', + kernel_regularizer=self.l2reg, + activation=last_activation, + name=name + '/ffn/fc2')(flow) + if out_channels == channels: + flow = orig_flow + flow + return flow + def construct_net(self, inputs): + inputs = tf.concat([inputs, tf.reshape(tf.eye(64, batch_shape=[tf.shape(inputs)[0]]), [-1, 64, 8, 8])], axis=-3) flow = self.conv_block(inputs, filter_size=3, output_channels=self.RESIDUAL_FILTERS, @@ -1108,9 +1144,40 @@ def construct_net(self, inputs): flow = self.residual_block(flow, self.RESIDUAL_FILTERS, name='residual_{}'.format(i + 1)) - + flow = tf.keras.layers.Permute((2, 3, 1))(flow) + for i in range(self.RRA_BLOCKS): + flow = self.residual_block_ra(flow, + self.RESIDUAL_FILTERS, + name='rra_{}'.format(i + 1)) + perm_flow = flow + flow = tf.keras.layers.Permute((3, 1, 2))(flow) # Policy head - if self.POLICY_HEAD == pb.NetworkFormat.POLICY_CONVOLUTION: + if self.POLICY_HEAD == pb.NetworkFormat.POLICY_RA: + conv_pol = self.conv_block(flow, + filter_size=1, + output_channels=self.RESIDUAL_FILTERS, + name='policy') + conv_pol = tf.keras.layers.Permute((2, 3, 1))(conv_pol) + keys = tf.keras.layers.Dense(self.RESIDUAL_FILTERS, + kernel_initializer='glorot_normal', + kernel_regularizer=self.l2reg, name='policy/keys/dense')(conv_pol) + keys = tf.reshape(keys, [-1, 64, self.RESIDUAL_FILTERS]) + queries = tf.keras.layers.Dense(self.RESIDUAL_FILTERS, + kernel_initializer='glorot_normal', + kernel_regularizer=self.l2reg, name='policy/queries/dense')(conv_pol) + queries = tf.reshape(queries, [-1, 64, self.RESIDUAL_FILTERS]) + queries = tf.multiply(queries, 1.0 / math.sqrt(float(self.RESIDUAL_FILTERS))) + qk = tf.linalg.matmul(keys, queries, transpose_b=True) + qk = tf.reshape(qk, [-1, 8, 8, 64]) + mix = tf.concat([qk, conv_pol], axis=-1) + + h_conv_pol_flat = tf.keras.layers.Flatten()(mix) + h_fc1 = tf.keras.layers.Dense(1858, + kernel_initializer='glorot_normal', + kernel_regularizer=self.l2reg, + bias_regularizer=self.l2reg, + name='policy/dense')(h_conv_pol_flat) + elif self.POLICY_HEAD == pb.NetworkFormat.POLICY_CONVOLUTION: conv_pol = self.conv_block(flow, filter_size=3, output_channels=self.RESIDUAL_FILTERS,