diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 3c7c325021..64dd0e6187 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -232,7 +232,9 @@ def fit( scheduler = state["scheduler"] tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) - model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + if longlora.use_longlora: + longest_seq_length = find_multiple(longest_seq_length, longlora.n_groups) + model.max_seq_length = longest_seq_length fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 56f131cb2d..bd1f044787 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -302,7 +302,10 @@ def fit( ) -> None: tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) - model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + longest_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + if longlora.use_longlora: + longest_seq_length = find_multiple(longest_seq_length, longlora.n_groups) + model.max_seq_length = longest_seq_length fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}"