Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation step3 #21

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

Validation step3 #21

wants to merge 18 commits into from

Conversation

kylematoba
Copy link

No description provided.

@kylematoba kylematoba requested review from ischlag and AleHD November 19, 2024 19:15
Copy link

@ischlag ischlag left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link

@AleHD AleHD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Was able to do a minimal test after the dp>1 fix. Have you tried saving and loading checkpoints to make sure valid dataloader samples are reproducible?

Also, I'd vote to add a way to held out a portion of the train set and use it as validation, instead of needing to add an additional dataset necessarily. Something like the --split argument in megatron https://github.com/NVIDIA/Megatron-LM/blob/81fee9b0047fb3ac6001b5e71e4df89fc01b2a1c/megatron/training/arguments.py#L1760-L1764. How hard would it be to implement this?

src/nanotron/trainer.py Outdated Show resolved Hide resolved
src/nanotron/trainer.py Outdated Show resolved Hide resolved
Comment on lines +1 to +91
import torch
import torchtune
import flash_attn
import flash_attn.layers.rotary


class RotaryEmbeddingKyleLikeFA(torch.nn.Module):
"""
Has the same function signature as FA, for interleaved=True and separate q, kv.
seqlen_offset = 0
Does not operate inplace, but that's fine for how it's used in Nanotron.
"""
def __init__(self, dim: int, base: float):
super().__init__()
self.dim = dim
self.base = float(base)

self.max_seq_len = None
self.rpe = None

def forward(self, q, kv):
bs, q_len, n_heads, _ = q.shape
assert self.dim == _

assert (bs, q_len, 2, n_heads, self.dim) == kv.shape

if (self.rpe is None) or (self.max_seq_len != q_len):
self.max_seq_len = q_len
self.rpe = torchtune.modules.RotaryPositionalEmbeddings(dim=self.dim,
max_seq_len=self.max_seq_len,
base=self.base).to(device)
q_out = self.rpe(q)
kv_out = torch.stack((self.rpe(kv[:, :, 0]), kv[:, :, 1]), 2)
return q_out, kv_out



if __name__ == "__main__":
device = torch.device(0)
theta = 10000

batch_size = 3
dim_qk = 4
q_len = 256
kv_len = 256
n_heads = 4

max_seq_len = max(q_len, kv_len)

print(max_seq_len)


query_states = torch.rand(batch_size, q_len, n_heads, dim_qk, device=device)
key_value_states = torch.rand(batch_size, kv_len, 2, n_heads, dim_qk, device=device).contiguous()


interleaved = True
# interleaved = False
re1 = flash_attn.layers.rotary.RotaryEmbedding(dim=dim_qk, interleaved=interleaved, base=theta).to(device)
re2 = torchtune.modules.RotaryPositionalEmbeddings(dim=dim_qk, max_seq_len=max_seq_len, base=theta).to(device)
re3 = RotaryEmbeddingKyleLikeFA(dim=dim_qk, base=theta).to(device)



print(key_value_states[:, :, 0].shape)

out2 = re2(query_states)
out3 = re2(key_value_states[:, :, 0])
# out4 = re2(key_value_states[:, :, 1])

out_eq = re3(query_states, kv=key_value_states)

# torch.testing.assert_close(out2, query_states)
out1 = re1(query_states, kv=key_value_states)

torch.testing.assert_close(out_eq[0], out1[0])
torch.testing.assert_close(out_eq[1], out1[1])


# Do this second, since the computation is inplace
torch.testing.assert_close(out1[0], query_states)

test = torch.stack((out3, key_value_states[:, :, 1]), 2)
torch.testing.assert_close(out1[1], test)
# torch.testing.assert_close(out1[1][:, :, 0], out3)


torch.testing.assert_close(out1[0], out2)

print("done")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say this one should either go somewhere to tests or be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants