-
Notifications
You must be signed in to change notification settings - Fork 0
/
mean_attention_distance.py
191 lines (131 loc) · 7.71 KB
/
mean_attention_distance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# -*- coding: utf-8 -*-
"""mean_attention_distance.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/github/all-things-vits/code-samples/blob/main/probing/mean_attention_distance.ipynb
# Mean attention distance
In this notebook, we implement mean attention distance as shown in [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929).
Thanks to [Simon Kornblith](https://twitter.com/skornblith?lang=en) for helping with `compute_distance_matrix()` utility.
"""
!pip install -q datasets transformers
"""First, let's load an image for demonstration purposes. You can bring in your own image here."""
from datasets import load_dataset
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
image
"""We then implement a utility to load a pretrained ViT model. For this notebook, we use the ["google/vit-base-patch16-224"](google/vit-base-patch16-224) checkpoint. But you can find all the original checkpoints [here](https://huggingface.co/models?search=vit&author=google).
Along with the model, we also load its corresponding image processor which is just handy class for taking care of all the image preprocessing.
"""
from transformers import AutoImageProcessor, ViTForImageClassification
import torch
def load_ckpt(ckpt_id="google/vit-base-patch16-224"):
"""Loads a pretrained model along with its processor class."""
image_processor = AutoImageProcessor.from_pretrained(ckpt_id)
model = ViTForImageClassification.from_pretrained(ckpt_id).eval()
return model, image_processor
"""Next, we write a utility to run inference. This utility will print the predicted top class and will return the attention scores from all the transformer blocks. We need the attention scores to compute mean attention distance."""
from PIL import Image
def perform_inference(image: Image, model: torch.nn.Module, processor):
"""Performs inference given an image, a model, and its processor."""
inputs = processor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
# model predicts one of the 1000 ImageNet classes
predicted_label = outputs.logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
return outputs.attentions
"""With this, we should be able to perform inference and extract the attention scores."""
ckpt_id = "google/vit-base-patch16-224"
model, processor = load_ckpt(ckpt_id)
attention_scores = perform_inference(image, model, processor)
"""We have got attention scores coming from different transformer blocks. We can confirm that by printing the shape of `attention_scores`:"""
for i in range(len(attention_scores)):
print(f"Transformer block {i}: {attention_scores[i].shape}")
"""So, we have attention score matrices from 12 different transformer blocks for a single image. We have 12 attention heads in each transformer blocks. Since the model uses a patch size of 16x16 and an image size of 224x224 we have 196 patches in total. It becomes 197 when we add the CLS token.
Next, we implement the utilities for computing mean attention distance. Refer to the [ViT paper](https://arxiv.org/abs/2010.11929) to learn more about the technical details. However, we also provide extensive comments in the code for easier understanding.
The comments come from [this notebook](https://github.com/sayakpaul/probing-vits/blob/main/notebooks/single-instance-probing.ipynb). We used [this gist](https://gist.github.com/simonster/155894d48aef2bd36bd2dd8267e62391) as a reference for implementing these utilities.
"""
import numpy as np
# For vanilla ViT models, this should be 1. For DeiT models,
# this should be two.
num_cls_tokens = 1
def compute_distance_matrix(patch_size, num_patches, length):
"""Helper function to compute distance matrix."""
distance_matrix = np.zeros((num_patches, num_patches))
for i in range(num_patches):
for j in range(num_patches):
if i == j: # zero distance
continue
xi, yi = (int(i / length)), (i % length)
xj, yj = (int(j / length)), (j % length)
distance_matrix[i, j] = patch_size * np.linalg.norm([xi - xj, yi - yj])
return distance_matrix
def compute_mean_attention_dist(patch_size, attention_weights):
# The attention_weights shape = (batch, num_heads, num_patches, num_patches)
attention_weights = attention_weights[
..., num_cls_tokens:, num_cls_tokens:
] # Removing the CLS token
num_patches = attention_weights.shape[-1]
length = int(np.sqrt(num_patches))
assert length**2 == num_patches, "Num patches is not perfect square"
distance_matrix = compute_distance_matrix(patch_size, num_patches, length)
h, w = distance_matrix.shape
distance_matrix = distance_matrix.reshape((1, 1, h, w))
# The attention_weights along the last axis adds to 1
# this is due to the fact that they are softmax of the raw logits
# summation of the (attention_weights * distance_matrix)
# should result in an average distance per token
mean_distances = attention_weights * distance_matrix
mean_distances = np.sum(
mean_distances, axis=-1
) # sum along last axis to get average distance per token
mean_distances = np.mean(
mean_distances, axis=-1
) # now average across all the tokens
return mean_distances
"""Now, we write a utility to gather the mean attention distances from all the transformer blocks."""
def gather_mads(attention_scores, patch_size: int = 16):
all_mean_distances = {
f"block_{i}_mean_dist": compute_mean_attention_dist(
patch_size=patch_size, attention_weights=attention_weight.numpy()
)
for i, attention_weight in enumerate(attention_scores)
}
return all_mean_distances
"""Putting this utility to test, we get:"""
patch_size = ckpt_id.split("-")[-2]
patch_size = int(patch_size.replace("patch", ""))
all_mads = gather_mads(attention_scores, patch_size)
"""It's time to visualize the distances!"""
import matplotlib.pyplot as plt
def visualize_mads(all_mads, ckpt_id):
num_heads = len(all_mads)
plt.figure(figsize=(6, 6))
for idx in range(len(all_mads)):
mean_distance = all_mads[f"block_{idx}_mean_dist"]
x = [idx] * num_heads
y = mean_distance[0, :]
plt.scatter(x=x, y=y, label=f"attention_head_{idx}")
plt.xlabel("Block index")
plt.ylabel("MAD")
plt.legend(loc="lower right")
plt.title(ckpt_id, fontsize=14)
plt.show()
visualize_mads(all_mads, ckpt_id)
"""So, the observation is that the attention blocks from the lower transformer blocks attend to BOTH local and global patches while the higher ones are mostly global.
This holds with the observations made in the following works:
* [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)
* [Do Vision Transformers See Like Convolutional Neural Networks?](https://arxiv.org/abs/2108.08810)
Let's put out utilities together into a generic utility that lets us load a compatible checkpoint and visualize the attention distances.
"""
def visualize_mads_for_a_ckpt(image, ckpt_id):
# Load model and perform inference.
model, processor = load_ckpt(ckpt_id)
attention_scores = perform_inference(image, model, processor)
# Compute MAD.
all_mads = gather_mads(attention_scores)
# Visualize.
visualize_mads(all_mads, ckpt_id)
visualize_mads_for_a_ckpt(image, ckpt_id="google/vit-base-patch16-224")
"""What about a different checkpoint? Let's try ["google/vit-base-patch16-224-in21k"](https://huggingface.co/google/vit-base-patch16-224-in21k):"""
visualize_mads_for_a_ckpt(image, ckpt_id="google/vit-base-patch16-224-in21k")