You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a doubt about the rotary positional encoding part of the code.
your code :
def rotate_as_if_first(x, rotary_emb):
# x: [bs, num_attention_heads, seq_len, head_size]
# apply rotary as if all elements were first in the sequence
cos, sin = rotary_emb(x, x.shape[-2])
return rotate_one(x, cos, sin, torch.zeros(x.shape[0], x.shape[-2], dtype=torch.long, device=cos.device))
Should it be like this :
def rotate_as_if_first(x, rotary_emb, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
# apply rotary as if all elements were first in the sequence
cos, sin = rotary_emb(x, x.shape[-2])
return rotate_one(x, cos, sin, position_ids)
When the function rotate_as_if_first calls the function rotate_one, the parameter position_ids needs to be passed in instead of generating a position parameter by torch.zeros(x.shape[0], x.shape[-2], dtype=torch.long, device=cos.device) .
The text was updated successfully, but these errors were encountered:
Hi, Thanks for the question! We always treat the memory keys as if they have position 0. Position ids inside the local context are converted to be in range 0, 2047 here
More context:
Memory layers use positional encodings for local context in the standard way. Whereas for the memory keys, they encode them as if they were at the beginning of the local context.
In other words, let $$t_0, t_1, t_2, t_3, \ldots t_{2047}, t_{2048}, \ldots, t_{4095}, \ldots$$
be some input.
LongLLaMA will process it in context windows. First, it will process $$t_0, t_1, t_2, t_3, \ldots t_{2047}$$
and move the (key, value) pairs from memory layers to the memory cache. The local context part ($t_0, \ldots, t_{2047}$) uses $2048$ rotary positional encodings.
Then LongLLaMA will process $$t_{2048}, \ldots, t_{4095}$$
Here again the local context part ($t_{2048}, \ldots, t_{4095}$)
uses the same $2048$ rotary positional encodings as the previous local context ($t_0 \ldots t_{2047}$).
Memory layers see previous embeddings (keys and values corresponding to $t_0, \ldots, t_{2047}$), but as if they were located at the same position as $t_{2048}$ (what is position 0 after the conversion).
I have a doubt about the rotary positional encoding part of the code.
your code :
Should it be like this :
When the function
rotate_as_if_first
calls the functionrotate_one
, the parameterposition_ids
needs to be passed in instead of generating a position parameter bytorch.zeros(x.shape[0], x.shape[-2], dtype=torch.long, device=cos.device)
.The text was updated successfully, but these errors were encountered: