-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Validation step3 #21
base: main
Are you sure you want to change the base?
Validation step3 #21
Conversation
…ectly (like dtype != float, requires grad), make flash attention dep optional (still needs to be checked)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Was able to do a minimal test after the dp>1 fix. Have you tried saving and loading checkpoints to make sure valid dataloader samples are reproducible?
Also, I'd vote to add a way to held out a portion of the train set and use it as validation, instead of needing to add an additional dataset necessarily. Something like the --split
argument in megatron https://github.com/NVIDIA/Megatron-LM/blob/81fee9b0047fb3ac6001b5e71e4df89fc01b2a1c/megatron/training/arguments.py#L1760-L1764. How hard would it be to implement this?
import torch | ||
import torchtune | ||
import flash_attn | ||
import flash_attn.layers.rotary | ||
|
||
|
||
class RotaryEmbeddingKyleLikeFA(torch.nn.Module): | ||
""" | ||
Has the same function signature as FA, for interleaved=True and separate q, kv. | ||
seqlen_offset = 0 | ||
Does not operate inplace, but that's fine for how it's used in Nanotron. | ||
""" | ||
def __init__(self, dim: int, base: float): | ||
super().__init__() | ||
self.dim = dim | ||
self.base = float(base) | ||
|
||
self.max_seq_len = None | ||
self.rpe = None | ||
|
||
def forward(self, q, kv): | ||
bs, q_len, n_heads, _ = q.shape | ||
assert self.dim == _ | ||
|
||
assert (bs, q_len, 2, n_heads, self.dim) == kv.shape | ||
|
||
if (self.rpe is None) or (self.max_seq_len != q_len): | ||
self.max_seq_len = q_len | ||
self.rpe = torchtune.modules.RotaryPositionalEmbeddings(dim=self.dim, | ||
max_seq_len=self.max_seq_len, | ||
base=self.base).to(device) | ||
q_out = self.rpe(q) | ||
kv_out = torch.stack((self.rpe(kv[:, :, 0]), kv[:, :, 1]), 2) | ||
return q_out, kv_out | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
device = torch.device(0) | ||
theta = 10000 | ||
|
||
batch_size = 3 | ||
dim_qk = 4 | ||
q_len = 256 | ||
kv_len = 256 | ||
n_heads = 4 | ||
|
||
max_seq_len = max(q_len, kv_len) | ||
|
||
print(max_seq_len) | ||
|
||
|
||
query_states = torch.rand(batch_size, q_len, n_heads, dim_qk, device=device) | ||
key_value_states = torch.rand(batch_size, kv_len, 2, n_heads, dim_qk, device=device).contiguous() | ||
|
||
|
||
interleaved = True | ||
# interleaved = False | ||
re1 = flash_attn.layers.rotary.RotaryEmbedding(dim=dim_qk, interleaved=interleaved, base=theta).to(device) | ||
re2 = torchtune.modules.RotaryPositionalEmbeddings(dim=dim_qk, max_seq_len=max_seq_len, base=theta).to(device) | ||
re3 = RotaryEmbeddingKyleLikeFA(dim=dim_qk, base=theta).to(device) | ||
|
||
|
||
|
||
print(key_value_states[:, :, 0].shape) | ||
|
||
out2 = re2(query_states) | ||
out3 = re2(key_value_states[:, :, 0]) | ||
# out4 = re2(key_value_states[:, :, 1]) | ||
|
||
out_eq = re3(query_states, kv=key_value_states) | ||
|
||
# torch.testing.assert_close(out2, query_states) | ||
out1 = re1(query_states, kv=key_value_states) | ||
|
||
torch.testing.assert_close(out_eq[0], out1[0]) | ||
torch.testing.assert_close(out_eq[1], out1[1]) | ||
|
||
|
||
# Do this second, since the computation is inplace | ||
torch.testing.assert_close(out1[0], query_states) | ||
|
||
test = torch.stack((out3, key_value_states[:, :, 1]), 2) | ||
torch.testing.assert_close(out1[1], test) | ||
# torch.testing.assert_close(out1[1][:, :, 0], out3) | ||
|
||
|
||
torch.testing.assert_close(out1[0], out2) | ||
|
||
print("done") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say this one should either go somewhere to tests
or be removed.
Co-authored-by: AleHC <[email protected]>
Co-authored-by: AleHC <[email protected]>
No description provided.