forked from huggingface/nanotron
-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Validation step3 #21
Open
kylematoba
wants to merge
18
commits into
main
Choose a base branch
from
validation_step3
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Validation step3 #21
Changes from 5 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
f587ee8
add validation step, take out some steps that I don't think work corr…
kylematoba d0facd6
fix FA, start to make validation step look like train step
kylematoba a821f1e
FA-less ROPE with FA signature
kylematoba d6d89ae
Can do training without FA
kylematoba ffb5fbf
work
kylematoba 5c4c0c6
Fixed wrong lr initialization when loading checkpoints
AleHD ace4d33
Log against iteration_step
kylematoba 8659582
Update src/nanotron/trainer.py
kylematoba b8731e2
Update src/nanotron/trainer.py
kylematoba 0fc0f4e
Merge branch 'validation_step3' of github.com:swiss-ai/nanotron into …
kylematoba f73f33b
Redo saving and loading
kylematoba c4effa9
Merge branch 'load_correct_lr' into validation_step3
kylematoba 36d5819
work
kylematoba 5329310
work
kylematoba 3ead1c7
tweaks
kylematoba 2bc34de
tweaks
kylematoba 1182d75
Merge branch 'validation_step3' of github.com:swiss-ai/nanotron into …
kylematoba 2e74edf
work
kylematoba File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
checkpoints: | ||
checkpoint_interval: 10 | ||
checkpoints_path: checkpoints | ||
checkpoints_path_is_shared_file_system: false | ||
resume_checkpoint_path: null | ||
save_initial_state: false | ||
data_stages: | ||
- data: | ||
dataset: | ||
dataset_overwrite_cache: false | ||
dataset_processing_num_proc_per_process: 1 | ||
hf_dataset_config_name: null | ||
hf_dataset_or_datasets: stas/openwebtext-10k | ||
hf_dataset_splits: train | ||
text_column_name: text | ||
num_loading_workers: 1 | ||
seed: 42 | ||
name: Stable Training Stage | ||
start_training_step: 1 | ||
- data: | ||
dataset: | ||
dataset_overwrite_cache: false | ||
dataset_processing_num_proc_per_process: 1 | ||
hf_dataset_config_name: null | ||
hf_dataset_or_datasets: stas/openwebtext-10k | ||
hf_dataset_splits: train | ||
text_column_name: text | ||
num_loading_workers: 1 | ||
seed: 42 | ||
name: Annealing Phase | ||
start_training_step: 10 | ||
valid_data_stages: | ||
- data: | ||
dataset: | ||
dataset_overwrite_cache: false | ||
dataset_processing_num_proc_per_process: 1 | ||
hf_dataset_config_name: null | ||
hf_dataset_or_datasets: stas/oscar-en-10k | ||
hf_dataset_splits: train | ||
text_column_name: text | ||
num_loading_workers: 1 | ||
seed: 42 | ||
name: Stable Training Stage | ||
start_training_step: 1 | ||
- data: | ||
dataset: | ||
dataset_overwrite_cache: false | ||
dataset_processing_num_proc_per_process: 1 | ||
hf_dataset_config_name: null | ||
hf_dataset_or_datasets: stas/oscar-en-10k | ||
hf_dataset_splits: train | ||
text_column_name: text | ||
num_loading_workers: 1 | ||
seed: 42 | ||
name: Annealing Phase | ||
start_training_step: 8 | ||
general: | ||
benchmark_csv_path: null | ||
consumed_train_samples: null | ||
ignore_sanity_checks: true | ||
project: debug | ||
run: tiny_llama_%date_%jobid | ||
seed: 42 | ||
step: null | ||
lighteval: null | ||
logging: | ||
iteration_step_info_interval: 1 | ||
log_level: info | ||
log_level_replica: info | ||
model: | ||
ddp_bucket_cap_mb: 25 | ||
dtype: float32 | ||
init_method: | ||
std: 0.025 | ||
make_vocab_size_divisible_by: 1 | ||
model_config: | ||
bos_token_id: 1 | ||
eos_token_id: 2 | ||
hidden_act: silu | ||
hidden_size: 16 | ||
initializer_range: 0.02 | ||
intermediate_size: 64 | ||
is_llama_config: true | ||
max_position_embeddings: 256 | ||
num_attention_heads: 4 | ||
num_hidden_layers: 2 | ||
num_key_value_heads: 4 | ||
pad_token_id: null | ||
pretraining_tp: 1 | ||
rms_norm_eps: 1.0e-05 | ||
rope_scaling: null | ||
tie_word_embeddings: true | ||
use_cache: true | ||
vocab_size: 256 | ||
optimizer: | ||
accumulate_grad_in_fp32: true | ||
clip_grad: 1.0 | ||
learning_rate_scheduler: | ||
learning_rate: 0.0003 | ||
lr_decay_starting_step: null | ||
lr_decay_steps: 13 | ||
lr_decay_style: cosine | ||
lr_warmup_steps: 2 | ||
lr_warmup_style: linear | ||
min_decay_lr: 1.0e-05 | ||
optimizer_factory: | ||
adam_beta1: 0.9 | ||
adam_beta2: 0.95 | ||
adam_eps: 1.0e-08 | ||
name: adamW | ||
torch_adam_is_fused: true | ||
weight_decay: 0.01 | ||
zero_stage: 0 | ||
parallelism: | ||
dp: 1 | ||
expert_parallel_size: 1 | ||
pp: 1 | ||
pp_engine: 1f1b | ||
tp: 1 | ||
tp_linear_async_communication: true | ||
tp_mode: REDUCE_SCATTER | ||
profiler: null | ||
tokenizer: | ||
tokenizer_max_length: null | ||
tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel | ||
tokenizer_revision: null | ||
tokens: | ||
batch_accumulation_per_replica: 1 | ||
limit_test_batches: 0 | ||
limit_val_batches: 5 | ||
micro_batch_size: 2 | ||
sequence_length: 256 | ||
train_steps: 200 | ||
val_check_interval: 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import torch | ||
import torchtune | ||
import flash_attn | ||
import flash_attn.layers.rotary | ||
|
||
|
||
class RotaryEmbeddingKyleLikeFA(torch.nn.Module): | ||
""" | ||
Has the same function signature as FA, for interleaved=True and separate q, kv. | ||
seqlen_offset = 0 | ||
Does not operate inplace, but that's fine for how it's used in Nanotron. | ||
""" | ||
def __init__(self, dim: int, base: float): | ||
super().__init__() | ||
self.dim = dim | ||
self.base = float(base) | ||
|
||
self.max_seq_len = None | ||
self.rpe = None | ||
|
||
def forward(self, q, kv): | ||
bs, q_len, n_heads, _ = q.shape | ||
assert self.dim == _ | ||
|
||
assert (bs, q_len, 2, n_heads, self.dim) == kv.shape | ||
|
||
if (self.rpe is None) or (self.max_seq_len != q_len): | ||
self.max_seq_len = q_len | ||
self.rpe = torchtune.modules.RotaryPositionalEmbeddings(dim=self.dim, | ||
max_seq_len=self.max_seq_len, | ||
base=self.base).to(device) | ||
q_out = self.rpe(q) | ||
kv_out = torch.stack((self.rpe(kv[:, :, 0]), kv[:, :, 1]), 2) | ||
return q_out, kv_out | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
device = torch.device(0) | ||
theta = 10000 | ||
|
||
batch_size = 3 | ||
dim_qk = 4 | ||
q_len = 256 | ||
kv_len = 256 | ||
n_heads = 4 | ||
|
||
max_seq_len = max(q_len, kv_len) | ||
|
||
print(max_seq_len) | ||
|
||
|
||
query_states = torch.rand(batch_size, q_len, n_heads, dim_qk, device=device) | ||
key_value_states = torch.rand(batch_size, kv_len, 2, n_heads, dim_qk, device=device).contiguous() | ||
|
||
|
||
interleaved = True | ||
# interleaved = False | ||
re1 = flash_attn.layers.rotary.RotaryEmbedding(dim=dim_qk, interleaved=interleaved, base=theta).to(device) | ||
re2 = torchtune.modules.RotaryPositionalEmbeddings(dim=dim_qk, max_seq_len=max_seq_len, base=theta).to(device) | ||
re3 = RotaryEmbeddingKyleLikeFA(dim=dim_qk, base=theta).to(device) | ||
|
||
|
||
|
||
print(key_value_states[:, :, 0].shape) | ||
|
||
out2 = re2(query_states) | ||
out3 = re2(key_value_states[:, :, 0]) | ||
# out4 = re2(key_value_states[:, :, 1]) | ||
|
||
out_eq = re3(query_states, kv=key_value_states) | ||
|
||
# torch.testing.assert_close(out2, query_states) | ||
out1 = re1(query_states, kv=key_value_states) | ||
|
||
torch.testing.assert_close(out_eq[0], out1[0]) | ||
torch.testing.assert_close(out_eq[1], out1[1]) | ||
|
||
|
||
# Do this second, since the computation is inplace | ||
torch.testing.assert_close(out1[0], query_states) | ||
|
||
test = torch.stack((out3, key_value_states[:, :, 1]), 2) | ||
torch.testing.assert_close(out1[1], test) | ||
# torch.testing.assert_close(out1[1][:, :, 0], out3) | ||
|
||
|
||
torch.testing.assert_close(out1[0], out2) | ||
|
||
print("done") | ||
|
||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say this one should either go somewhere to
tests
or be removed.