Skip to content

Commit

Permalink
Add depthwise convolution and allow to use it in CausalAttention.
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszkaiser committed Sep 29, 2021
1 parent 6151599 commit d875d80
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
16 changes: 14 additions & 2 deletions trax/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def ConfigurableAttention(q_layer, k_layer, v_layer, final_layer, # pylint: dis

@assert_shape('bld->bld')
def CausalAttention(d_feature, n_heads=1, dropout=0.0,
max_inference_length=2048, mode='train'):
max_inference_length=2048, use_dconv=False, mode='train'):
"""Returns a layer that maps activations to activations, with causal masking.
Like :py:class:`Attention`, this layer type represents one pass of multi-head
Expand All @@ -453,15 +453,27 @@ def CausalAttention(d_feature, n_heads=1, dropout=0.0,
created in ``'train'`` mode.
max_inference_length: Maximum sequence length allowed in non-training
modes.
use_dconv: if True, use depthwise convolutions on top of dense layers
for Q, K and V.
mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
"""
if d_feature % n_heads != 0:
raise ValueError(
f'Dimensionality of feature embedding ({d_feature}) is not a multiple '
f'of the requested number of attention heads ({n_heads}).')

def QKVLayer():
"""Function returning the Q, K and V layer."""
if use_dconv:
return cb.Serial(
core.Dense(d_feature),
convolution.CausalDepthwiseConv()
)
else:
return core.Dense(d_feature)

return ConfigurableAttention(
core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature),
QKVLayer(), QKVLayer(), QKVLayer(),
core.Dense(d_feature), n_heads=n_heads,
qkv_attention_layer=DotProductCausalAttention(
dropout=dropout, max_inference_length=max_inference_length,
Expand Down
49 changes: 49 additions & 0 deletions trax/layers/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,52 @@ def Conv1d(filters, kernel_size, stride=1, padding='VALID',
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
use_bias=use_bias)


def _zero_pad(x, pad, axis):
"""Helper for jnp.pad with 0s for single-axis case."""
pad_widths = [(0, 0)] * len(x.shape)
pad_widths[axis] = pad # Padding on axis.
return jnp.pad(x, pad_widths, mode='constant')


# @assert_shape('bld->bld')
class CausalDepthwiseConv(base.Layer):
"""A causal depthwise convolution layer."""

def __init__(self,
kernel_size=3,
kernel_initializer=init.GlorotUniformInitializer(),
use_bfloat16=False):
"""Returns a causal depthwise convolution layer."""
super().__init__(n_in=1, n_out=1)
self._kernel_size = kernel_size
self._kernel_initializer = kernel_initializer
self._use_bfloat16 = use_bfloat16

def forward(self, x):
"""Executes this layer as part of a forward pass through the model.
Args:
x: Tensor of same shape and dtype as the input signature used to
initialize this layer.
Returns:
Tensor of same shape and dtype as the input.
"""
w = self.weights
res = x * w[0, :][None, None, :]
for i in range(1, self._kernel_size):
x = _zero_pad(x, (1, 0), 1)
x = x[:, :-1, :]
res += x * w[i, :][None, None, :]
return res

def init_weights_and_state(self, input_signature):
"""Randomly initializes this layer's weights."""
shape_w = (self._kernel_size, input_signature.shape[-1])
rng_w, _ = fastmath.random.split(self.rng, 2)
w = self._kernel_initializer(shape_w, rng_w)
if self._use_bfloat16:
w = w.astype(jnp.bfloat16)
self.weights = w

0 comments on commit d875d80

Please sign in to comment.