Skip to content

Commit

Permalink
Fix setting max_seq_length when using longlora
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Apr 25, 2024
1 parent 9facaf3 commit 2dfa7a5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 3 additions & 1 deletion litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def fit(
scheduler = state["scheduler"]
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset)
model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
if longlora.use_longlora:
longest_seq_length = find_multiple(longest_seq_length, longlora.n_groups)
model.max_seq_length = longest_seq_length
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
Expand Down
5 changes: 4 additions & 1 deletion litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,10 @@ def fit(
) -> None:
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset)
model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
longest_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
if longlora.use_longlora:
longest_seq_length = find_multiple(longest_seq_length, longlora.n_groups)
model.max_seq_length = longest_seq_length
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
Expand Down

0 comments on commit 2dfa7a5

Please sign in to comment.