Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Enable depth_k != depth_v in local_attention_2d and masked_local_attention_2d #1899

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

sgrigory
Copy link

This PR resolves a TODO in tensor2tensor/layers/common_attention_test.py : enable depth_v != depth_k for common_attention.local_attention_2d and common_attention.masked_local_attention_2d

Modification is simple: one just needs to alter the shape passed to scatter_blocks_2d when generating the output, so that its last dimension is taken from v, and not q

Tests:

  • pytest tensor2tensor/layers/common_attention_test.py tensor2tensor/layers/common_image_attention_test.py passes
  • Sanity check: if v is split along the depth dimension, then applying local_attention_2d on each of the two parts separately and concatenating the results should be equivalent to applying local_attention_2d with original v - see the code below
Expand sanity check code and output
import numpy as np

import pandas as pd
import tensorflow as tf

from tensor2tensor.layers.common_attention import local_attention_2d

# Check that attention is commutative with splitting/concatenating v along the depth dimension
# Try varios split points
for split_idx in range(1, 30):

    batch, heads, length, depth_k, depth_v, query_shape = 3, 4, 25, 16, 30, (4, 4)

    q = tf.random_normal([batch, heads, length, length, depth_k], dtype=tf.float64)
    k = tf.random_normal([batch, heads, length, length, depth_k], dtype=tf.float64)
    v = tf.random_normal([batch, heads, length, length, depth_v], dtype=tf.float64)
    
    # Apply attention with the first part of v
    output_part1 = local_attention_2d(
        q,
        k,
        v[:, :, :, :, :split_idx],
        query_shape=query_shape,
        memory_flange=(3, 3))
    
    # Apply attention with the second part of v
    output_part2 = local_attention_2d(
        q,
        k,
        v[:, :, :, :, split_idx:],
        query_shape=query_shape,
        memory_flange=(3, 3))
    
    # Put together results of two parts
    output_concat = tf.concat([output_part1, output_part2], axis=4)
    
    # Apply attention with the original v
    output_full = local_attention_2d(
        q,
        k,
        v,
        query_shape=query_shape,
        memory_flange=(3, 3))
    
    # Compute the difference - should be small
    with tf.Session() as sess:
        res_diff = (output_concat - output_full).eval()
        print(np.abs(res_diff).max())


2.220446049250313e-15
2.6645352591003757e-15
2.220446049250313e-15
1.7763568394002505e-15
2.220446049250313e-15
2.220446049250313e-15
2.220446049250313e-15
0.0
2.6645352591003757e-15
1.7763568394002505e-15
2.220446049250313e-15
2.220446049250313e-15
2.6645352591003757e-15
2.220446049250313e-15
2.220446049250313e-15
0.0
2.6645352591003757e-15
2.220446049250313e-15
2.4424906541753444e-15
2.6645352591003757e-15
3.1086244689504383e-15
2.220446049250313e-15
2.220446049250313e-15
0.0
1.3322676295501878e-15
0.0
2.220446049250313e-15
0.0
1.7763568394002505e-15

@google-cla
Copy link

google-cla bot commented Oct 30, 2021

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@google-cla google-cla bot added the cla: no PR author has not signed CLA label Oct 30, 2021
@sgrigory
Copy link
Author

@googlebot I signed it!

@google-cla google-cla bot added cla: yes PR author has signed CLA and removed cla: no PR author has not signed CLA labels Oct 30, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
cla: yes PR author has signed CLA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants