-
Notifications
You must be signed in to change notification settings - Fork 5
/
DiffTransformer.py
330 lines (276 loc) · 12.8 KB
/
DiffTransformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Applies normalization across the last dimension and scales the output.
"""
def __init__(self, d, eps=1e-5):
"""
Args:
d (int): Dimension of the input features.
eps (float): Small value to avoid division by zero.
"""
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(d))
def forward(self, x):
"""
Forward pass for RMSNorm.
Args:
x (Tensor): Input tensor of shape (batch, sequence_length, d).
Returns:
Tensor: Normalized and scaled tensor.
"""
norm = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / norm * self.scale
class SwiGLU(nn.Module):
"""
SwiGLU Activation Function.
Combines the Swish activation with Gated Linear Units.
"""
def __init__(self, d_model):
"""
Args:
d_model (int): Dimension of the input features.
"""
super().__init__()
# Intermediate projection layers
# Typically, SwiGLU splits the computation into two parts
self.WG = nn.Linear(d_model, d_model * 2)
self.W1 = nn.Linear(d_model, d_model * 2)
self.W2 = nn.Linear(d_model * 2, d_model)
def forward(self, x):
"""
Forward pass for SwiGLU.
Args:
x (Tensor): Input tensor of shape (batch, sequence_length, d_model).
Returns:
Tensor: Output tensor after applying SwiGLU.
"""
# Apply the gates
g = F.silu(self.WG(x)) # Activation part
z = self.W1(x) # Linear part
# Element-wise multiplication and projection
return self.W2(g * z)
class MultiHeadDifferentialAttention(nn.Module):
"""
Multi-Head Differential Attention Mechanism.
Replaces the conventional softmax attention with a differential attention.
Incorporates a causal mask to ensure autoregressive behavior.
"""
def __init__(self, d_model, num_heads, lambda_init):
"""
Args:
d_model (int): Dimension of the model. Must be divisible by num_heads.
num_heads (int): Number of attention heads.
lambda_init (float): Initial value for lambda.
"""
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.num_heads = num_heads
self.d_head = d_model // num_heads
# Linear projections for queries, keys, and values
# Project to 2 * d_head per head for differential attention
self.W_q = nn.Linear(d_model, 2 * self.d_head * num_heads, bias=False)
self.W_k = nn.Linear(d_model, 2 * self.d_head * num_heads, bias=False)
self.W_v = nn.Linear(d_model, 2 * self.d_head * num_heads, bias=False)
self.W_o = nn.Linear(2 * self.d_head * num_heads, d_model, bias=False)
# Learnable parameters for lambda reparameterization
self.lambda_q1 = nn.Parameter(torch.randn(num_heads, self.d_head))
self.lambda_k1 = nn.Parameter(torch.randn(num_heads, self.d_head))
self.lambda_q2 = nn.Parameter(torch.randn(num_heads, self.d_head))
self.lambda_k2 = nn.Parameter(torch.randn(num_heads, self.d_head))
self.lambda_init = lambda_init
# Scale parameter for RMSNorm
self.rms_scale = nn.Parameter(torch.ones(2 * self.d_head))
self.eps = 1e-5 # Epsilon for numerical stability
# Initialize weights (optional but recommended)
self._reset_parameters()
def _reset_parameters(self):
"""
Initialize parameters for improved training stability.
"""
nn.init.xavier_uniform_(self.W_q.weight)
nn.init.xavier_uniform_(self.W_k.weight)
nn.init.xavier_uniform_(self.W_v.weight)
nn.init.xavier_uniform_(self.W_o.weight)
nn.init.constant_(self.rms_scale, 1.0)
def forward(self, X):
"""
Forward pass for Multi-Head Differential Attention.
Args:
X (Tensor): Input tensor of shape (batch, sequence_length, d_model).
Returns:
Tensor: Output tensor after applying differential attention.
"""
batch, N, d_model = X.shape
# Project inputs to queries, keys, and values
Q = self.W_q(X) # Shape: (batch, N, 2 * num_heads * d_head)
K = self.W_k(X) # Shape: (batch, N, 2 * num_heads * d_head)
V = self.W_v(X) # Shape: (batch, N, 2 * num_heads * d_head)
# Reshape and permute for multi-head attention
# New shape: (batch, num_heads, sequence_length, 2 * d_head)
Q = Q.view(batch, N, self.num_heads, 2 * self.d_head).transpose(1, 2)
K = K.view(batch, N, self.num_heads, 2 * self.d_head).transpose(1, 2)
V = V.view(batch, N, self.num_heads, 2 * self.d_head).transpose(1, 2)
# Split Q and K into Q1, Q2 and K1, K2
Q1, Q2 = Q.chunk(2, dim=-1) # Each of shape: (batch, num_heads, N, d_head)
K1, K2 = K.chunk(2, dim=-1) # Each of shape: (batch, num_heads, N, d_head)
# Compute lambda using reparameterization
# lambda_val = exp(lambda_q1 . lambda_k1) - exp(lambda_q2 . lambda_k2) + lambda_init
# Compute dot products for each head
# Shape of lambda_val: (num_heads,)
lambda_q1_dot_k1 = torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() # (num_heads,)
lambda_q2_dot_k2 = torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() # (num_heads,)
lambda_val = torch.exp(lambda_q1_dot_k1) - torch.exp(lambda_q2_dot_k2) + self.lambda_init # (num_heads,)
# Expand lambda_val to match attention dimensions
# Shape: (batch, num_heads, 1, 1)
lambda_val = lambda_val.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
# ------------------- Causal Mask Implementation ------------------- #
# Create a causal mask to prevent attention to future tokens
# Shape of mask: (1, 1, N, N)
mask = torch.tril(torch.ones((N, N), device=X.device)).unsqueeze(0).unsqueeze(0) # (1, 1, N, N)
# Replace 1s with 0.0 and 0s with -inf
mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0)
# -------------------------------------------------------------------- #
# Compute attention scores
scaling = 1 / math.sqrt(self.d_head)
A1 = torch.matmul(Q1, K1.transpose(-2, -1)) * scaling # (batch, num_heads, N, N)
A2 = torch.matmul(Q2, K2.transpose(-2, -1)) * scaling # (batch, num_heads, N, N)
# Apply the causal mask
A1 = A1 + mask # Mask out future positions
A2 = A2 + mask # Mask out future positions
# Apply softmax to get attention weights
attention1 = F.softmax(A1, dim=-1) # (batch, num_heads, N, N)
attention2 = F.softmax(A2, dim=-1) # (batch, num_heads, N, N)
attention = attention1 - lambda_val * attention2 # (batch, num_heads, N, N)
# Apply attention weights to values
O = torch.matmul(attention, V) # (batch, num_heads, N, 2 * d_head)
# Normalize each head independently using RMSNorm
# First, reshape for RMSNorm
O_reshaped = O.contiguous().view(batch * self.num_heads, N, 2 * self.d_head) # (batch*num_heads, N, 2*d_head)
# Compute RMSNorm
rms_norm = torch.sqrt(O_reshaped.pow(2).mean(dim=-1, keepdim=True) + self.eps) # (batch*num_heads, N, 1)
O_normalized = (O_reshaped / rms_norm) * self.rms_scale # (batch*num_heads, N, 2*d_head)
# Reshape back to (batch, num_heads, N, 2 * d_head)
O_normalized = O_normalized.view(batch, self.num_heads, N, 2 * self.d_head)
# Scale the normalized output
O_normalized = O_normalized * (1 - self.lambda_init) # Scalar scaling
# Concatenate all heads
# New shape: (batch, N, num_heads * 2 * d_head)
O_concat = O_normalized.transpose(1, 2).contiguous().view(batch, N, self.num_heads * 2 * self.d_head)
# Final linear projection
out = self.W_o(O_concat) # (batch, N, d_model)
return out
class DiffTransformerLayer(nn.Module):
"""
Single Layer of the DiffTransformer Architecture.
Consists of Multi-Head Differential Attention followed by a SwiGLU Feed-Forward Network.
"""
def __init__(self, d_model, num_heads, lambda_init):
"""
Args:
d_model (int): Dimension of the model.
num_heads (int): Number of attention heads.
lambda_init (float): Initial value for lambda in Differential Attention.
"""
super().__init__()
self.norm1 = RMSNorm(d_model)
self.attn = MultiHeadDifferentialAttention(d_model, num_heads, lambda_init)
self.norm2 = RMSNorm(d_model)
self.ff = SwiGLU(d_model)
def forward(self, x):
"""
Forward pass for a single transformer layer.
Args:
x (Tensor): Input tensor of shape (batch, sequence_length, d_model).
Returns:
Tensor: Output tensor after processing through the layer.
"""
# Apply Multi-Head Differential Attention with residual connection
y = self.attn(self.norm1(x)) + x
# Apply SwiGLU Feed-Forward Network with residual connection
z = self.ff(self.norm2(y)) + y
return z
class DiffTransformer(nn.Module):
"""
The DiffTransformer Model incorporating multiple DiffTransformerLayers.
Suitable for sequence modeling tasks such as language modeling.
"""
def __init__(self, vocab_size, d_model, num_heads, num_layers, max_seq_length=512):
"""
Args:
vocab_size (int): Size of the vocabulary.
d_model (int): Dimension of the model.
num_heads (int): Number of attention heads.
num_layers (int): Number of transformer layers.
max_seq_length (int): Maximum sequence length.
"""
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_length, d_model)
self.layers = nn.ModuleList([
DiffTransformerLayer(
d_model=d_model,
num_heads=num_heads,
lambda_init=0.8 - 0.6 * math.exp(-0.3 * (l - 1)) # Decaying lambda_init
)
for l in range(1, num_layers + 1)
])
self.norm = RMSNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
# Initialize weights (optional but recommended)
self._reset_parameters()
def _reset_parameters(self):
"""
Initialize parameters for improved training stability.
"""
nn.init.xavier_uniform_(self.token_emb.weight)
nn.init.xavier_uniform_(self.pos_emb.weight)
nn.init.xavier_uniform_(self.head.weight)
def forward(self, x):
"""
Forward pass for the DiffTransformer.
Args:
x (Tensor): Input tensor of token indices of shape (batch, sequence_length).
Returns:
Tensor: Logits for each token in the vocabulary of shape (batch, sequence_length, vocab_size).
"""
batch, N = x.shape
positions = torch.arange(N, device=x.device).unsqueeze(0).expand(batch, N) # (batch, N)
X = self.token_emb(x) + self.pos_emb(positions) # (batch, N, d_model)
for layer in self.layers:
X = layer(X)
X = self.norm(X) # (batch, N, d_model)
logits = self.head(X) # (batch, N, vocab_size)
return logits
# Example usage:
if __name__ == "__main__":
# Define model hyperparameters
vocab_size = 30522 # Example vocabulary size (e.g., BERT)
d_model = 768
num_heads = 12
num_layers = 12
max_seq_length = 512
# Instantiate the model
model = DiffTransformer(
vocab_size=vocab_size,
d_model=d_model,
num_heads=num_heads,
num_layers=num_layers,
max_seq_length=max_seq_length
)
# Move model to device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Example input: batch of token indices
batch_size = 2
seq_length = 128
input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).to(device) # (batch, N)
# Forward pass
logits = model(input_ids) # (batch, N, vocab_size)
print(logits.shape) # Should output: torch.Size([2, 128, 30522])