From 2ac103877bf2332b6451aaf4efb913bd599bca71 Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 24 Apr 2024 19:11:39 +0200 Subject: [PATCH 1/7] Add LongLora --- litgpt/adapter_v2.py | 3 + litgpt/args.py | 10 +-- litgpt/config.py | 5 +- litgpt/data/alpaca.py | 16 +++- litgpt/data/base.py | 35 ++++++++- litgpt/data/deita.py | 16 +++- litgpt/data/dolly.py | 2 +- litgpt/data/flan.py | 12 ++- litgpt/data/json_data.py | 16 +++- litgpt/data/lima.py | 16 +++- litgpt/data/lit_data.py | 8 +- litgpt/data/longform.py | 12 ++- litgpt/data/openwebtext.py | 8 +- litgpt/data/text_files.py | 11 ++- litgpt/data/tinyllama.py | 8 +- litgpt/data/tinystories.py | 10 ++- litgpt/deploy/serve.py | 21 ++--- litgpt/eval/evaluate.py | 7 +- litgpt/finetune/full.py | 53 +++++++++++-- litgpt/finetune/lora.py | 132 ++++++++++++++++++++++++++++---- litgpt/generate/base.py | 18 +++++ litgpt/generate/full.py | 18 +++++ litgpt/generate/sequentially.py | 18 +++++ litgpt/generate/tp.py | 18 +++++ litgpt/lora.py | 26 +++++-- litgpt/model.py | 39 +++++++++- 26 files changed, 462 insertions(+), 76 deletions(-) diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 665527f053..4cda0d5167 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -168,6 +168,9 @@ def __init__(self, config: Config, block_idx: int) -> None: self.config = config + # LongLora + self._longlora_available = self.config.longlora_n_groups is not None and self.config.longlora_n_groups > 0 + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { diff --git a/litgpt/args.py b/litgpt/args.py index b227ffe3f6..f3e04a283a 100644 --- a/litgpt/args.py +++ b/litgpt/args.py @@ -40,13 +40,9 @@ class TrainArgs: max_norm: Optional[float] = None min_lr: float = 6e-5 - def __post_init__(self) -> None: - if self.lr_warmup_fraction and self.lr_warmup_steps: - raise ValueError( - "Can't provide both `--train.lr_warmup_fraction` and `--train.lr_warmup_steps`. Choose one." - ) - if self.lr_warmup_fraction and not (0 <= self.lr_warmup_fraction <= 1): - raise ValueError("`--train.lr_warmup_fraction` must be between 0 and 1.") + # Misc args + get_longest_seq_length: bool = True + """Whether to compute the longest sequence length in the dataset""" def gradient_accumulation_iters(self, devices: int) -> int: """Number of iterations between gradient synchronizations""" diff --git a/litgpt/config.py b/litgpt/config.py index e03fa8ae34..ce9cb37407 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -61,6 +61,9 @@ class Config: rope_base: int = 10000 n_expert: int = 0 n_expert_per_token: int = 0 + longlora_n_groups: Optional[int] = None + longlora_context_length: Optional[int] = None + longlora_trainable_params: str = "" def __post_init__(self): if not self.name: @@ -836,7 +839,7 @@ def norm_class(self) -> Type: copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) - + ############### # Meta LLaMA 3 diff --git a/litgpt/data/alpaca.py b/litgpt/data/alpaca.py index fc3d973848..05e3a62422 100644 --- a/litgpt/data/alpaca.py +++ b/litgpt/data/alpaca.py @@ -43,6 +43,7 @@ class Alpaca(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -51,11 +52,16 @@ def __post_init__(self) -> None: self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: self.download_dir.mkdir(parents=True, exist_ok=True) @@ -97,7 +103,9 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) def val_dataloader(self) -> DataLoader: @@ -106,7 +114,9 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) diff --git a/litgpt/data/base.py b/litgpt/data/base.py index 36ef33fb8a..b4d34aeab9 100644 --- a/litgpt/data/base.py +++ b/litgpt/data/base.py @@ -10,6 +10,7 @@ from litgpt import Tokenizer from litgpt.prompts import PromptStyle +from litgpt.utils import find_multiple class DataModule(LightningDataModule): @@ -17,7 +18,7 @@ class DataModule(LightningDataModule): @abstractmethod def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None, **kwargs ) -> None: """All settings that can't be determined at the time of instantiation need to be passed through here before any dataloaders can be accessed. @@ -44,6 +45,7 @@ class SFTDataset(Dataset): ignore_index: The index to use for elements to be ignored in the label. transform: An optional transform to apply to the sample before it gets tokenized. Use this to rename the keys in the dataset to the expected 'instruction' and 'output' keys. + pad_multiple_of: If set, sequences will be padded to a multiple of 'pad_multiple_of'. Returns a dict with two keys: input_ids: The encoded prompt + response @@ -93,18 +95,30 @@ def __getitem__(self, idx: int) -> Dict[str, Tensor]: return {"input_ids": encoded_prompt_and_response.type(torch.int64), "labels": labels.type(torch.int64)} -def get_sft_collate_fn(max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100): +def get_sft_collate_fn( + max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100, pad_multiple_of: Optional[int] = None +): """Returns the collate function for supervised finetuning (needed in the DataLoader). The collate function gets a list of dicts with keys `input_ids` and `labels`. It returns a dict with batched `input_ids` and `labels`. Also pads short sequences to the longest element in the batch. Optionally truncates all sequences to the specified maximum length. """ - return partial(_sft_collate_fn, max_seq_length=max_seq_length, pad_id=pad_id, ignore_index=ignore_index) + return partial( + _sft_collate_fn, + max_seq_length=max_seq_length, + pad_id=pad_id, + ignore_index=ignore_index, + pad_multiple_of=pad_multiple_of, + ) def _sft_collate_fn( - samples: List[Dict[str, Tensor]], max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100 + samples: List[Dict[str, Tensor]], + max_seq_length: int = -1, + pad_id: int = 0, + ignore_index: int = -100, + pad_multiple_of: Optional[int] = None, ) -> Dict[str, Tensor]: batched = {} @@ -116,6 +130,19 @@ def _sft_collate_fn( [sample[key] for sample in samples], batch_first=True, padding_value=pad_value ) + # Pad to multiple of 'pad_multiple_of' + if pad_multiple_of is not None and pad_multiple_of > 1: + pad_to = find_multiple(batched[key].shape[1], pad_multiple_of) + pad_to_add = pad_to - batched[key].shape[1] + if pad_to_add > 0: + batched[key] = torch.cat( + ( + batched[key], + torch.full((batched[key].shape[0], pad_to_add, *batched[key].shape[2:]), fill_value=pad_value), + ), + dim=1, + ) + # Truncate if needed if max_seq_length > 0: batched[key] = batched[key][:, :max_seq_length] diff --git a/litgpt/data/deita.py b/litgpt/data/deita.py index c0e52d24f0..b310e3de54 100644 --- a/litgpt/data/deita.py +++ b/litgpt/data/deita.py @@ -36,6 +36,7 @@ class Deita(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -44,11 +45,16 @@ def __post_init__(self) -> None: self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: from datasets import load_dataset @@ -86,7 +92,9 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) def val_dataloader(self) -> DataLoader: @@ -95,7 +103,9 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) diff --git a/litgpt/data/dolly.py b/litgpt/data/dolly.py index 03d973f9b2..891507e86d 100644 --- a/litgpt/data/dolly.py +++ b/litgpt/data/dolly.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass, field from pathlib import Path -from typing import Union +from typing import Optional, Union import torch from torch.utils.data import random_split diff --git a/litgpt/data/flan.py b/litgpt/data/flan.py index a2a5b443ac..dd2ee492e5 100644 --- a/litgpt/data/flan.py +++ b/litgpt/data/flan.py @@ -42,6 +42,7 @@ class FLAN(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -59,11 +60,16 @@ def __post_init__(self): self.subsets = list(supported_subsets) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: self.download_dir.mkdir(parents=True, exist_ok=True) @@ -100,7 +106,9 @@ def _dataloader(self, split: str) -> DataLoader: shuffle=(split == "train"), generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) diff --git a/litgpt/data/json_data.py b/litgpt/data/json_data.py index a40096486d..7d2878a599 100644 --- a/litgpt/data/json_data.py +++ b/litgpt/data/json_data.py @@ -38,6 +38,7 @@ class JSON(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) val_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -61,11 +62,16 @@ def __post_init__(self): self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def setup(self, stage: str = "") -> None: train_data, test_data = self.get_splits() @@ -94,7 +100,9 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) def val_dataloader(self) -> DataLoader: @@ -103,7 +111,9 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) def get_splits(self) -> Tuple: diff --git a/litgpt/data/lima.py b/litgpt/data/lima.py index 8ea3db5ebd..75a2b8d0a2 100644 --- a/litgpt/data/lima.py +++ b/litgpt/data/lima.py @@ -39,6 +39,7 @@ class LIMA(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -53,11 +54,16 @@ def __post_init__(self): self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: from datasets import load_dataset @@ -102,7 +108,9 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) def val_dataloader(self) -> DataLoader: @@ -111,7 +119,9 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) diff --git a/litgpt/data/lit_data.py b/litgpt/data/lit_data.py index 1350c16dbd..07ba98841e 100644 --- a/litgpt/data/lit_data.py +++ b/litgpt/data/lit_data.py @@ -29,16 +29,22 @@ class LitData(DataModule): batch_size: int = field(init=False, repr=False, default=1) seq_length: int = field(init=False, repr=False, default=2048) + pad_multiple_of: Optional[int] = field(init=False, repr=False, default=None) def __post_init__(self) -> None: if self.split_names is not None and len(self.split_names) != 2: raise ValueError("If provided `split_names` must be a tuple of two strings, for example: ('train', 'val').") def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.batch_size = batch_size self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + pad_multiple_of = pad_multiple_of def train_dataloader(self) -> DataLoader: input_dir = os.path.join(self.data_path, self.split_names[0]) if self.split_names else str(self.data_path) diff --git a/litgpt/data/longform.py b/litgpt/data/longform.py index 34fcd29906..fb02ca06fe 100644 --- a/litgpt/data/longform.py +++ b/litgpt/data/longform.py @@ -36,6 +36,7 @@ class LongForm(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -44,11 +45,16 @@ def __post_init__(self) -> None: self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: self.download_dir.mkdir(parents=True, exist_ok=True) @@ -80,7 +86,9 @@ def _dataloader(self, split: str) -> DataLoader: shuffle=(split == "train"), generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) diff --git a/litgpt/data/openwebtext.py b/litgpt/data/openwebtext.py index c6cc3151b3..6e5e3f1482 100644 --- a/litgpt/data/openwebtext.py +++ b/litgpt/data/openwebtext.py @@ -28,6 +28,7 @@ class OpenWebText(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, repr=False, init=False) batch_size: int = field(default=1, repr=False, init=False) seq_length: int = field(default=2048, repr=False, init=False) + pad_multiple_of: Optional[int] = field(default=None, repr=False, init=False) def __post_init__(self) -> None: # Could be a remote path (s3://) or a local path @@ -35,11 +36,16 @@ def __post_init__(self) -> None: self.data_path_val = str(self.data_path).rstrip("/") + "/val" def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = 2048 + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = 2048, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: from datasets import Dataset, load_dataset diff --git a/litgpt/data/text_files.py b/litgpt/data/text_files.py index 5989937669..8a954f109f 100644 --- a/litgpt/data/text_files.py +++ b/litgpt/data/text_files.py @@ -21,6 +21,7 @@ class TextFiles(DataModule): and provides training and validation dataloaders that return batches of tokens. Every sample is set to a fixed length. """ + train_data_path: Path """The path to the data directory used for training that contains .txt files""" val_data_path: Optional[Path] = None @@ -35,6 +36,7 @@ class TextFiles(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) def __post_init__(self) -> None: self.out_path_train = self.train_data_path / "train" @@ -43,10 +45,17 @@ def __post_init__(self) -> None: else: self.out_path_val = Path(self.val_data_path) / "val" - def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: + def connect( + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: int = -1, + pad_multiple_of: Optional[int] = None, + ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: from litdata import optimize diff --git a/litgpt/data/tinyllama.py b/litgpt/data/tinyllama.py index 0f32507aa2..7f5afedc3b 100644 --- a/litgpt/data/tinyllama.py +++ b/litgpt/data/tinyllama.py @@ -27,6 +27,7 @@ class TinyLlama(DataModule): batch_size: int = field(init=False, repr=False, default=1) seq_length: int = field(init=False, repr=False, default=2048) + pad_multiple_of: Optional[int] = field(init=False, repr=False, default=None) def __post_init__(self): # Could be a remote path (s3://) or a local path @@ -35,10 +36,15 @@ def __post_init__(self): self.starcoder_train = str(self.data_path).rstrip("/") + "/starcoder" def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.batch_size = batch_size self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: for path in (self.slimpajama_train, self.slimpajama_val, self.starcoder_train): diff --git a/litgpt/data/tinystories.py b/litgpt/data/tinystories.py index 632a015e44..bd13e23395 100644 --- a/litgpt/data/tinystories.py +++ b/litgpt/data/tinystories.py @@ -34,15 +34,23 @@ class TinyStories(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) def __post_init__(self) -> None: self.data_path_train = self.data_path / "train" self.data_path_val = self.data_path / "val" - def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: + def connect( + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: int = -1, + pad_multiple_of: Optional[int] = None, + ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: from litdata import optimize diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index 4a26e0b14f..ad9c863a79 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -16,12 +16,14 @@ class SimpleLitAPI(LitAPI): - def __init__(self, - checkpoint_dir: Path, - precision: Optional[str] = None, - temperature: float = 0.8, - top_k: int = 50, - max_new_tokens: int = 50) -> None: + def __init__( + self, + checkpoint_dir: Path, + precision: Optional[str] = None, + temperature: float = 0.8, + top_k: int = 50, + max_new_tokens: int = 50, + ) -> None: super().__init__() self.checkpoint_dir = checkpoint_dir @@ -79,7 +81,7 @@ def predict(self, inputs: torch.Tensor) -> Any: max_returned_tokens, temperature=self.temperature, top_k=self.top_k, - eos_id=self.tokenizer.eos_id + eos_id=self.tokenizer.eos_id, ) for block in self.model.transformer.h: @@ -127,9 +129,10 @@ def run_server( temperature=temperature, top_k=top_k, max_new_tokens=max_new_tokens, - ), + ), accelerator=accelerator, - devices=devices) + devices=devices, + ) server.run(port=port) diff --git a/litgpt/eval/evaluate.py b/litgpt/eval/evaluate.py index 78e0ed0f59..2612730ecc 100644 --- a/litgpt/eval/evaluate.py +++ b/litgpt/eval/evaluate.py @@ -19,9 +19,7 @@ def prepare_results(results, save_filepath, print_results=True): if "groups" in results: print(make_table(results, "groups")) - json_result = json.dumps( - results, indent=2, ensure_ascii=False - ) + json_result = json.dumps(results, indent=2, ensure_ascii=False) save_filepath.open("w", encoding="utf-8").write(json_result) @@ -62,6 +60,7 @@ def convert_and_evaluate( if tasks is None: from lm_eval.tasks import TaskManager + taskm = TaskManager() print("\n".join(taskm.task_index.keys())) print( @@ -84,7 +83,7 @@ def convert_and_evaluate( out_dir.mkdir(parents=True, exist_ok=True) save_filepath = out_dir / Path("results.json") if save_filepath is None else Path(save_filepath) - config_filepath = checkpoint_dir/"model_config.yaml" + config_filepath = checkpoint_dir / "model_config.yaml" with open(config_filepath, encoding="utf-8") as f: config_dict = yaml.safe_load(f) diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 23de9b622c..0db02997b5 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -26,6 +26,7 @@ choose_logger, chunked_cross_entropy, copy_config_files, + find_multiple, get_default_supported_precision, load_checkpoint, init_out_dir, @@ -42,6 +43,8 @@ def setup( devices: Union[int, str] = 1, resume: Union[bool, Path] = False, data: Optional[DataModule] = None, + longlora_n_groups: Optional[int] = None, + longlora_context_length: Optional[int] = None, train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, @@ -67,6 +70,8 @@ def setup( resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume from the latest checkpoint in ``out_dir``. data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. + longlora_n_groups: The number of groups to use for LongLora. + longlora_context_length: The increased context length to use for LongLora. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. logger_name: The name of the logger to send metrics to. @@ -78,8 +83,17 @@ def setup( devices = parse_devices(devices) out_dir = init_out_dir(out_dir) + # Check longlora params: if one is set, then all must be set + longlora_params = [longlora_n_groups is not None, longlora_context_length is not None] + if any(longlora_params) and not all(longlora_params): + raise ValueError("If any of 'longlora_n_groups' or 'longlora_context_length' are set," " then all must be set.") + check_valid_checkpoint_dir(checkpoint_dir) - config = Config.from_file(checkpoint_dir / "model_config.yaml") + config = Config.from_file( + checkpoint_dir / "model_config.yaml", + longlora_n_groups=longlora_n_groups, + longlora_context_length=longlora_context_length, + ) precision = precision or get_default_supported_precision(training=True) logger = choose_logger( @@ -116,7 +130,9 @@ def main( validate_args(train, eval) tokenizer = Tokenizer(checkpoint_dir) - train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) + train_dataloader, val_dataloader = get_dataloaders( + fabric, data, tokenizer, train, pad_multiple_of=config.longlora_n_groups + ) steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) @@ -127,6 +143,15 @@ def main( checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(devices > 1)): + if config.longlora_context_length is not None and config.longlora_context_length > config.block_size: + old_block_size = config.block_size + config.block_size = config.longlora_context_length + config.rope_condense_ratio = config.longlora_context_length / old_block_size + fabric.print( + f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" + ) + fabric.print(f"The 'rope_condense_ratio' has been adapted to {config.rope_condense_ratio}") + model = GPT(config) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") @@ -183,12 +208,21 @@ def fit( eval: EvalArgs, data: DataModule, ) -> None: - model = state["model"] + model: GPT = state["model"] optimizer = state["optimizer"] scheduler = state["scheduler"] tokenizer = Tokenizer(checkpoint_dir) - longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) - model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + pad_multiple_of = data.pad_multiple_of or 1 + if train.get_longest_seq_length: + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) + longest_seq_length = find_multiple( + min(longest_seq_length, train.max_seq_length or float("inf")), pad_multiple_of + ) + else: + longest_seq_length = find_multiple( + min(model.max_seq_length, train.max_seq_length or float("inf")), pad_multiple_of + ) + model.max_seq_length = longest_seq_length fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" @@ -331,9 +365,14 @@ def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): def get_dataloaders( - fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs + fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs, pad_multiple_of: Optional[int] = None ) -> Tuple[DataLoader, DataLoader]: - data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) + data.connect( + tokenizer=tokenizer, + batch_size=train.micro_batch_size, + max_seq_length=train.max_seq_length, + pad_multiple_of=pad_multiple_of, + ) with fabric.rank_zero_first(): data.prepare_data() data.setup() diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 39e805befe..dcead661c7 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -1,5 +1,6 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import dataclasses +from functools import partial import math import os import time @@ -18,7 +19,14 @@ from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate -from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable +from litgpt.lora import ( + GPT, + Block, + Config, + longlora_filter, + lora_filter, + mark_only_lora_as_trainable, +) from litgpt.prompts import save_prompt_style from litgpt.scripts.merge_lora import merge_lora from litgpt.tokenizer import Tokenizer @@ -29,6 +37,7 @@ choose_logger, chunked_cross_entropy, copy_config_files, + find_multiple, get_default_supported_precision, load_checkpoint, init_out_dir, @@ -53,6 +62,9 @@ def setup( lora_projection: bool = False, lora_mlp: bool = False, lora_head: bool = False, + longlora_n_groups: Optional[int] = None, + longlora_context_length: Optional[int] = None, + longlora_trainable_params: str = "", data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( save_interval=1000, @@ -86,6 +98,10 @@ def setup( lora_projection: Whether to apply LoRA to the output projection in the attention block. lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block. lora_head: Whether to apply LoRA to output head in GPT. + longlora_n_groups: The number of groups to use for LongLora. + longlora_context_length: The increased context length to use for LongLora. + longlora_trainable_params: The names of the parameters to make trainable for LongLora. + The parameters should be comma-separated, if any. data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. @@ -98,6 +114,18 @@ def setup( devices = parse_devices(devices) out_dir = init_out_dir(out_dir) + # Check longlora params: if one is set, then all must be set + longlora_params = [ + longlora_n_groups is not None, + longlora_context_length is not None, + longlora_trainable_params != "", + ] + if any(longlora_params) and not all(longlora_params[:-1]): + raise ValueError( + "If any of 'longlora_n_groups', 'longlora_context_length', or 'longlora_trainable_params' are set," + " then all must be set." + ) + check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file( checkpoint_dir / "model_config.yaml", @@ -110,16 +138,28 @@ def setup( lora_projection=lora_projection, lora_mlp=lora_mlp, lora_head=lora_head, + longlora_n_groups=longlora_n_groups, + longlora_context_length=longlora_context_length, + longlora_trainable_params=longlora_trainable_params, ) precision = precision or get_default_supported_precision(training=True) - logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval) + logger = choose_logger( + logger_name, + out_dir, + name=f"finetune-{config.name}", + log_interval=train.log_interval, + ) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") - dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None @@ -139,7 +179,13 @@ def setup( else: strategy = "auto" - fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins) + fabric = L.Fabric( + devices=devices, + strategy=strategy, + precision=precision, + loggers=logger, + plugins=plugins, + ) fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval) @@ -157,7 +203,9 @@ def main( validate_args(train, eval) tokenizer = Tokenizer(checkpoint_dir) - train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) + train_dataloader, val_dataloader = get_dataloaders( + fabric, data, tokenizer, train, pad_multiple_of=config.longlora_n_groups + ) steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) @@ -168,9 +216,28 @@ def main( checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(devices > 1)): + if config.longlora_context_length is not None and config.longlora_context_length > config.block_size: + old_block_size = config.block_size + config.block_size = config.longlora_context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = config.longlora_context_length / old_block_size + fabric.print( + f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" + ) + fabric.print( + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" + ) + model = GPT(config) mark_only_lora_as_trainable(model) + # Let other layers be trainable + if config.longlora_trainable_params != "": + trainable_params = set(config.longlora_trainable_params.strip().split(",")) + for n, p in model.named_parameters(): + if any(trainable_p_name in n for trainable_p_name in trainable_params): + p.requires_grad = True + fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") @@ -184,7 +251,10 @@ def main( else: optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - trainable_params, lr=train.learning_rate, weight_decay=train.weight_decay, betas=(train.beta1, train.beta2) + trainable_params, + lr=train.learning_rate, + weight_decay=train.weight_decay, + betas=(train.beta1, train.beta2), ) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) @@ -244,8 +314,17 @@ def fit( data: DataModule, ) -> None: tokenizer = Tokenizer(checkpoint_dir) - longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) - model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + pad_multiple_of = data.pad_multiple_of or 1 + if train.get_longest_seq_length: + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) + longest_seq_length = find_multiple( + min(longest_seq_length, train.max_seq_length or float("inf")), pad_multiple_of + ) + else: + longest_seq_length = find_multiple( + min(model.max_seq_length, train.max_seq_length or float("inf")), pad_multiple_of + ) + model.max_seq_length = longest_seq_length fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" @@ -292,7 +371,10 @@ def fit( loss = running_loss.compute().item() # expensive device-to-host synchronization t1 = time.perf_counter() throughput.update( - time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths + time=t1 - total_t0, + batches=iter_num, + samples=iter_num * train.micro_batch_size, + lengths=total_lengths, ) throughput.compute_and_log(step=iter_num) metrics = { @@ -367,7 +449,11 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( - model, encoded, max_returned_tokens=len(encoded) + eval.max_new_tokens, temperature=0.8, eos_id=tokenizer.eos_id + model, + encoded, + max_returned_tokens=len(encoded) + eval.max_new_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, ) model.clear_kv_cache() model.train() @@ -383,9 +469,14 @@ def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): def get_dataloaders( - fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs + fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs, pad_multiple_of: Optional[int] = None ) -> Tuple[DataLoader, DataLoader]: - data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) + data.connect( + tokenizer=tokenizer, + batch_size=train.micro_batch_size, + max_seq_length=train.max_seq_length, + pad_multiple_of=pad_multiple_of, + ) with fabric.rank_zero_first(): data.prepare_data() data.setup() @@ -403,9 +494,22 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: return longest_seq_length, longest_seq_ix -def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None: +def save_lora_checkpoint(fabric: L.Fabric, model: GPT, file_path: Path) -> None: fabric.print(f"Saving LoRA weights to {str(file_path)!r}") - fabric.save(file_path, {"model": model}, filter={"model": lora_filter}) + fabric.save( + file_path, + {"model": model}, + filter={ + "model": ( + lora_filter + if model.config.longlora_context_length is None + else partial( + longlora_filter, + additional_weights=model.config.longlora_trainable_params.strip().split(","), + ) + ) + }, + ) def validate_args(train: TrainArgs, eval: EvalArgs) -> None: diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 060604b43f..6928e9b5fb 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -9,6 +9,7 @@ import torch import torch._dynamo.config import torch._inductor.config +import yaml from lightning.fabric.plugins import BitsandbytesPrecision from litgpt import GPT, Config, PromptStyle, Tokenizer @@ -134,6 +135,12 @@ def main( check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") + if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): + with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: + hparams = yaml.safe_load(hparams_file) + longlora_context_length = hparams.get("longlora_context_length", config.block_size) + else: + longlora_context_length = config.block_size checkpoint_path = checkpoint_dir / "lit_model.pth" @@ -150,6 +157,17 @@ def main( fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): + if longlora_context_length is not None and longlora_context_length > config.block_size: + old_block_size = config.block_size + config.block_size = longlora_context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = longlora_context_length / old_block_size + fabric.print( + f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" + ) + fabric.print( + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" + ) model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) with fabric.init_tensor(): diff --git a/litgpt/generate/full.py b/litgpt/generate/full.py index c570e8dd2e..d8466fc802 100644 --- a/litgpt/generate/full.py +++ b/litgpt/generate/full.py @@ -8,6 +8,7 @@ import lightning as L import torch from lightning.fabric.plugins import BitsandbytesPrecision +import yaml from litgpt import GPT, Config, PromptStyle, Tokenizer from litgpt.generate.base import generate @@ -60,6 +61,12 @@ def main( check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") + if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): + with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: + hparams = yaml.safe_load(hparams_file) + longlora_context_length = hparams.get("longlora_context_length", config.block_size) + else: + longlora_context_length = config.block_size checkpoint_path = finetuned_path @@ -76,6 +83,17 @@ def main( fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): + if longlora_context_length is not None and longlora_context_length > config.block_size: + old_block_size = config.block_size + config.block_size = longlora_context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = longlora_context_length / old_block_size + fabric.print( + f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" + ) + fabric.print( + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" + ) model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) with fabric.init_tensor(): diff --git a/litgpt/generate/sequentially.py b/litgpt/generate/sequentially.py index 9f006ab47f..41d904c809 100644 --- a/litgpt/generate/sequentially.py +++ b/litgpt/generate/sequentially.py @@ -16,6 +16,7 @@ from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.utilities.init import _materialize_meta_tensors from typing_extensions import Type +import yaml import litgpt.generate.base as generate_base from litgpt import GPT, Config, Tokenizer @@ -159,6 +160,12 @@ def main( check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") + if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): + with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: + hparams = yaml.safe_load(hparams_file) + longlora_context_length = hparams.get("longlora_context_length", config.block_size) + else: + longlora_context_length = config.block_size checkpoint_path = checkpoint_dir / "lit_model.pth" @@ -173,6 +180,17 @@ def main( # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): + if longlora_context_length is not None and longlora_context_length > config.block_size: + old_block_size = config.block_size + config.block_size = longlora_context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = longlora_context_length / old_block_size + fabric.print( + f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" + ) + fabric.print( + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" + ) model = GPT(config) print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) diff --git a/litgpt/generate/tp.py b/litgpt/generate/tp.py index 41492f75b2..92da3bc43f 100644 --- a/litgpt/generate/tp.py +++ b/litgpt/generate/tp.py @@ -14,6 +14,7 @@ from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.utilities import rank_zero_only from torch.distributed._functional_collectives import all_reduce +import yaml import litgpt.generate.base as generate_base from litgpt import GPT, Config, Tokenizer @@ -138,6 +139,12 @@ def main( check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") + if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): + with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: + hparams = yaml.safe_load(hparams_file) + longlora_context_length = hparams.get("longlora_context_length", config.block_size) + else: + longlora_context_length = config.block_size model_file = "lit_model.pth" checkpoint_path = checkpoint_dir / model_file @@ -153,6 +160,17 @@ def main( # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): + if longlora_context_length is not None and longlora_context_length > config.block_size: + old_block_size = config.block_size + config.block_size = longlora_context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = longlora_context_length / old_block_size + fabric.print( + f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" + ) + fabric.print( + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" + ) model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) diff --git a/litgpt/lora.py b/litgpt/lora.py index 51fd66713d..60c9db2c60 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -45,7 +45,7 @@ import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union import torch import torch.nn as nn @@ -335,7 +335,9 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384) result = result.view(-1, self.linear.out_features) # (4096, 384) result = result.index_copy( - 1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes)) + 1, + torch.tensor(self.lora_ind, device=result.device), + x.reshape(-1, sum(self.qkv_shapes)), ) # (4096, 256) return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384) @@ -371,7 +373,8 @@ def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T) weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1) return torch.cat( - [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T) + [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], + dim=1, # (B, C_output', T) ) # (B, C_output, T) def get_lora_AB(self) -> torch.Tensor: @@ -468,6 +471,10 @@ def lora_filter(key: str, value: Any) -> bool: return "lora_" in key +def longlora_filter(key: str, value: Any, additional_weights: Sequence[str] = ["lora_"]) -> bool: + return any(x in key for x in additional_weights + ["lora_"]) + + @dataclass class Config(BaseConfig): """ @@ -521,7 +528,10 @@ def __init__(self, config: Config) -> None: self.mask_cache: Optional[torch.Tensor] = None def forward( - self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0 + self, + idx: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + lm_head_chunk_size: int = 0, ) -> Union[torch.Tensor, List[torch.Tensor]]: T = idx.size(1) if self.max_seq_length < T: @@ -561,7 +571,10 @@ def _init_weights(self, module: nn.Module) -> None: def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" - mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"} + mapping = { + "lm_head.weight": "lm_head.linear.weight", + "lm_head.bias": "lm_head.linear.bias", + } state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) @@ -613,6 +626,9 @@ def __init__(self, config: Config) -> None: self.config = config + # LongLora + self._longlora_available = self.config.longlora_n_groups is not None and self.config.longlora_n_groups > 0 + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { diff --git a/litgpt/model.py b/litgpt/model.py index fe71c60b80..1c25666de1 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -202,6 +202,9 @@ def __init__(self, config: Config) -> None: self.config = config + # LongLora + self._longlora_available = self.config.longlora_n_groups is not None and self.config.longlora_n_groups > 0 + def forward( self, x: torch.Tensor, @@ -212,6 +215,13 @@ def forward( ) -> torch.Tensor: B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + if input_pos is None and self._longlora_available: + if T % self.config.longlora_n_groups != 0: + raise ValueError(f"sequence length {T} should be divisible by group size {longlora_group_size}.") + longlora_group_size = T // self.config.longlora_n_groups + else: + longlora_group_size = 0 + qkv = self.attn(x) # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) @@ -243,16 +253,33 @@ def forward( if not isinstance(self.kv_cache, KVCache): raise TypeError("You need to call `gpt.set_kv_cache()`") k, v = self.kv_cache(input_pos, k, v) + elif longlora_group_size > 0: + q = roll_and_group(q, B, T, longlora_group_size, q.shape[1], self.config.head_size) + k = roll_and_group(k, B, T, longlora_group_size, k.shape[1], self.config.head_size) + v = roll_and_group(v, B, T, longlora_group_size, v.shape[1], self.config.head_size) y = self.scaled_dot_product_attention(q, k, v, mask) + y_cloned = y + + if input_pos is None and longlora_group_size > 0: + # shift back and unroll + n_heads = y.shape[2] + y_cloned = y.clone() + y_cloned = y_cloned.reshape(B, T, n_heads, self.config.head_size) # (B, T, nh, hs) + y_cloned[:, :, n_heads // 2 :] = y_cloned[:, :, n_heads // 2 :].roll(longlora_group_size // 2, dims=1) - y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side + # re-assemble all head outputs side by side + y_cloned = y_cloned.reshape(B, T, self.config.head_size * self.config.n_head) # output projection - return self.proj(y) + return self.proj(y_cloned) def scaled_dot_product_attention( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: scale = 1.0 / math.sqrt(self.config.head_size) y = torch.nn.functional.scaled_dot_product_attention( @@ -284,6 +311,12 @@ def build_kv_cache( return KVCache(k_shape, v_shape, device=device, dtype=dtype) +def roll_and_group(qkv, bsz, q_len, group_size, num_heads, head_dim): + qkv[:, num_heads // 2 :] = qkv[:, num_heads // 2 :].roll(-group_size // 2, dims=2) + qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2) + return qkv + + class GptNeoxMLP(nn.Module): def __init__(self, config: Config) -> None: super().__init__() From 95a6539afeca947d07def653dc6051710df10291 Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 25 Apr 2024 13:55:48 +0200 Subject: [PATCH 2/7] Add LongLoraArgs --- litgpt/adapter_v2.py | 3 -- litgpt/args.py | 18 +++++-- litgpt/config.py | 5 +- litgpt/finetune/full.py | 95 +++++++++++++++++++++------------ litgpt/finetune/lora.py | 91 +++++++++++++++---------------- litgpt/generate/base.py | 12 +++-- litgpt/generate/full.py | 12 +++-- litgpt/generate/sequentially.py | 12 +++-- litgpt/generate/tp.py | 12 +++-- litgpt/lora.py | 3 -- litgpt/model.py | 34 +++++++----- 11 files changed, 171 insertions(+), 126 deletions(-) diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 4cda0d5167..665527f053 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -168,9 +168,6 @@ def __init__(self, config: Config, block_idx: int) -> None: self.config = config - # LongLora - self._longlora_available = self.config.longlora_n_groups is not None and self.config.longlora_n_groups > 0 - def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { diff --git a/litgpt/args.py b/litgpt/args.py index f3e04a283a..bbeaaf738c 100644 --- a/litgpt/args.py +++ b/litgpt/args.py @@ -40,10 +40,6 @@ class TrainArgs: max_norm: Optional[float] = None min_lr: float = 6e-5 - # Misc args - get_longest_seq_length: bool = True - """Whether to compute the longest sequence length in the dataset""" - def gradient_accumulation_iters(self, devices: int) -> int: """Number of iterations between gradient synchronizations""" gradient_accumulation_iters = self.batch_size(devices) // self.micro_batch_size @@ -75,3 +71,17 @@ class EvalArgs: """Number of tokens to generate""" max_iters: int = 100 """Number of iterations""" + + +@dataclass +class LongLoraArgs: + """GaLore-related arguments""" + + use_longlora: bool = False + """Whether to enable LongLora.""" + n_groups: int = 4 + """Number of groups to divide the sequence length into.""" + context_length: int = 8192 + """Length of the enlarged context window.""" + trainable_params: str = "wte,norm,ln" + """List of comma-separated parameters to train in LongLora.""" diff --git a/litgpt/config.py b/litgpt/config.py index ce9cb37407..5bc0907f72 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -61,9 +61,8 @@ class Config: rope_base: int = 10000 n_expert: int = 0 n_expert_per_token: int = 0 - longlora_n_groups: Optional[int] = None - longlora_context_length: Optional[int] = None - longlora_trainable_params: str = "" + use_longlora: bool = False + longlora_n_groups: int = 4 def __post_init__(self): if not self.name: diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 0db02997b5..3c7c325021 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -6,14 +6,16 @@ from pathlib import Path from pprint import pprint from typing import Dict, List, Literal, Optional, Tuple, Union +import warnings import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy from torch.utils.data import DataLoader from torchmetrics import RunningMean +import yaml -from litgpt.args import EvalArgs, TrainArgs +from litgpt.args import EvalArgs, TrainArgs, LongLoraArgs from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.model import GPT, Block, Config @@ -43,8 +45,6 @@ def setup( devices: Union[int, str] = 1, resume: Union[bool, Path] = False, data: Optional[DataModule] = None, - longlora_n_groups: Optional[int] = None, - longlora_context_length: Optional[int] = None, train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, @@ -56,6 +56,7 @@ def setup( max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), + longlora: LongLoraArgs = LongLoraArgs(use_longlora=False, n_groups=4, context_length=8192, trainable_params=""), logger_name: Literal["wandb", "tensorboard", "csv"] = "csv", seed: int = 1337, ) -> None: @@ -70,10 +71,9 @@ def setup( resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume from the latest checkpoint in ``out_dir``. data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. - longlora_n_groups: The number of groups to use for LongLora. - longlora_context_length: The increased context length to use for LongLora. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. + longlora: LongLoRA-related arguments. See ``litgpt.args.LongLoraArgs`` for details. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. """ @@ -83,16 +83,9 @@ def setup( devices = parse_devices(devices) out_dir = init_out_dir(out_dir) - # Check longlora params: if one is set, then all must be set - longlora_params = [longlora_n_groups is not None, longlora_context_length is not None] - if any(longlora_params) and not all(longlora_params): - raise ValueError("If any of 'longlora_n_groups' or 'longlora_context_length' are set," " then all must be set.") - check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file( - checkpoint_dir / "model_config.yaml", - longlora_n_groups=longlora_n_groups, - longlora_context_length=longlora_context_length, + checkpoint_dir / "model_config.yaml", use_longlora=longlora.use_longlora, longlora_n_groups=longlora.n_groups ) precision = precision or get_default_supported_precision(training=True) @@ -112,7 +105,7 @@ def setup( strategy = "auto" fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger) - fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval) + fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, longlora) def main( @@ -126,12 +119,26 @@ def main( out_dir: Path, train: TrainArgs, eval: EvalArgs, + longlora: LongLoraArgs, ) -> None: validate_args(train, eval) + if resume is True: + resume = max(out_dir.rglob("step-*/*.pth"), key=(lambda p: int(p.parent.name.split("-")[1]))) + if resume: + with open(resume.parent / "hyperparameters.yaml", "r") as f: + hyperparams = yaml.safe_load(f) + longlora_cfg = hyperparams.get("longlora", None) + if longlora_cfg is not None: + longlora.use_longlora = longlora_cfg.get("use_longlora", False) + longlora.n_groups = longlora_cfg.get("n_groups", longlora.n_groups) + longlora.context_length = longlora_cfg.get("context_length", longlora.context_length) + config.use_longlora = longlora.use_longlora + config.longlora_n_groups = longlora.n_groups + validate_longlora_args(config, longlora) tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders( - fabric, data, tokenizer, train, pad_multiple_of=config.longlora_n_groups + fabric, data, tokenizer, train, pad_multiple_of=longlora.n_groups if longlora.use_longlora else None ) steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) @@ -143,15 +150,15 @@ def main( checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(devices > 1)): - if config.longlora_context_length is not None and config.longlora_context_length > config.block_size: + if longlora.use_longlora and longlora.context_length > config.block_size: old_block_size = config.block_size - config.block_size = config.longlora_context_length - config.rope_condense_ratio = config.longlora_context_length / old_block_size + config.block_size = longlora.context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = longlora.context_length / old_block_size + fabric.print(f"The model context length has been increased from {old_block_size} to {config.block_size}") fabric.print( - f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" ) - fabric.print(f"The 'rope_condense_ratio' has been adapted to {config.rope_condense_ratio}") - model = GPT(config) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") @@ -164,8 +171,6 @@ def main( scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) state = {"model": model, "optimizer": optimizer, "scheduler": scheduler, "iter_num": 0, "step_count": 0} - if resume is True: - resume = max(out_dir.rglob("step-*/*.pth"), key=(lambda p: int(p.parent.name.split("-")[1]))) if resume: fabric.print(f"Resuming training from {resume}") fabric.load(resume, state) @@ -173,7 +178,20 @@ def main( load_checkpoint(fabric, state["model"], checkpoint_path) train_time = time.perf_counter() - fit(fabric, state, train_dataloader, val_dataloader, devices, resume, checkpoint_dir, out_dir, train, eval, data) + fit( + fabric, + state, + train_dataloader, + val_dataloader, + devices, + resume, + checkpoint_dir, + out_dir, + train, + eval, + longlora, + data, + ) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") @@ -206,23 +224,15 @@ def fit( out_dir: Path, train: TrainArgs, eval: EvalArgs, + longlora: LongLoraArgs, data: DataModule, ) -> None: model: GPT = state["model"] optimizer = state["optimizer"] scheduler = state["scheduler"] tokenizer = Tokenizer(checkpoint_dir) - pad_multiple_of = data.pad_multiple_of or 1 - if train.get_longest_seq_length: - longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) - longest_seq_length = find_multiple( - min(longest_seq_length, train.max_seq_length or float("inf")), pad_multiple_of - ) - else: - longest_seq_length = find_multiple( - min(model.max_seq_length, train.max_seq_length or float("inf")), pad_multiple_of - ) - model.max_seq_length = longest_seq_length + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) + model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" @@ -408,6 +418,21 @@ def validate_args(train: TrainArgs, eval: EvalArgs) -> None: raise ValueError("\n".join(issues)) +def validate_longlora_args(config: Config, longlora: LongLoraArgs): + if longlora.use_longlora: + if longlora.context_length <= config.block_size: + warnings.warn( + f"LongLora is disabled because the LongLora context length ({longlora.context_length}) " + f"is less than the model original block size {config.block_size}. " + ) + longlora.use_longlora = False + elif longlora.context_length % longlora.n_groups != 0: + raise ValueError( + f"LongLora context length ({longlora.context_length}) must be a multiple of the number of groups " + f"({longlora.n_groups})." + ) + + if __name__ == "__main__": torch.set_float32_matmul_precision("high") diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index dcead661c7..56f131cb2d 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -7,6 +7,7 @@ from pathlib import Path from pprint import pprint from typing import Dict, List, Literal, Optional, Tuple, Union +import warnings import lightning as L import torch @@ -16,7 +17,7 @@ from torch.utils.data import DataLoader from torchmetrics import RunningMean -from litgpt.args import EvalArgs, TrainArgs +from litgpt.args import EvalArgs, LongLoraArgs, TrainArgs from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.lora import ( @@ -62,9 +63,6 @@ def setup( lora_projection: bool = False, lora_mlp: bool = False, lora_head: bool = False, - longlora_n_groups: Optional[int] = None, - longlora_context_length: Optional[int] = None, - longlora_trainable_params: str = "", data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( save_interval=1000, @@ -77,6 +75,9 @@ def setup( max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), + longlora: LongLoraArgs = LongLoraArgs( + use_longlora=False, n_groups=4, context_length=8192, trainable_params="wte,norm,ln" + ), logger_name: Literal["wandb", "tensorboard", "csv"] = "csv", seed: int = 1337, ) -> None: @@ -98,13 +99,10 @@ def setup( lora_projection: Whether to apply LoRA to the output projection in the attention block. lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block. lora_head: Whether to apply LoRA to output head in GPT. - longlora_n_groups: The number of groups to use for LongLora. - longlora_context_length: The increased context length to use for LongLora. - longlora_trainable_params: The names of the parameters to make trainable for LongLora. - The parameters should be comma-separated, if any. data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. + longlora: LongLoRA-related arguments. See ``litgpt.args.LongLoraArgs`` for details. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. """ @@ -114,18 +112,6 @@ def setup( devices = parse_devices(devices) out_dir = init_out_dir(out_dir) - # Check longlora params: if one is set, then all must be set - longlora_params = [ - longlora_n_groups is not None, - longlora_context_length is not None, - longlora_trainable_params != "", - ] - if any(longlora_params) and not all(longlora_params[:-1]): - raise ValueError( - "If any of 'longlora_n_groups', 'longlora_context_length', or 'longlora_trainable_params' are set," - " then all must be set." - ) - check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file( checkpoint_dir / "model_config.yaml", @@ -138,9 +124,8 @@ def setup( lora_projection=lora_projection, lora_mlp=lora_mlp, lora_head=lora_head, - longlora_n_groups=longlora_n_groups, - longlora_context_length=longlora_context_length, - longlora_trainable_params=longlora_trainable_params, + use_longlora=longlora.use_longlora, + longlora_n_groups=longlora.n_groups, ) precision = precision or get_default_supported_precision(training=True) @@ -186,7 +171,7 @@ def setup( loggers=logger, plugins=plugins, ) - fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval) + fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, longlora) def main( @@ -199,12 +184,14 @@ def main( out_dir: Path, train: TrainArgs, eval: EvalArgs, + longlora: LongLoraArgs, ) -> None: validate_args(train, eval) + validate_longlora_args(config, longlora) tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders( - fabric, data, tokenizer, train, pad_multiple_of=config.longlora_n_groups + fabric, data, tokenizer, train, pad_multiple_of=longlora.n_groups if longlora.use_longlora else None ) steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) @@ -216,14 +203,12 @@ def main( checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(devices > 1)): - if config.longlora_context_length is not None and config.longlora_context_length > config.block_size: + if longlora.use_longlora and longlora.context_length > config.block_size: old_block_size = config.block_size - config.block_size = config.longlora_context_length + config.block_size = longlora.context_length old_rope_condense_ratio = config.rope_condense_ratio - config.rope_condense_ratio = config.longlora_context_length / old_block_size - fabric.print( - f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" - ) + config.rope_condense_ratio = longlora.context_length / old_block_size + fabric.print(f"The model context length has been increased from {old_block_size} to {config.block_size}") fabric.print( f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" ) @@ -232,8 +217,8 @@ def main( mark_only_lora_as_trainable(model) # Let other layers be trainable - if config.longlora_trainable_params != "": - trainable_params = set(config.longlora_trainable_params.strip().split(",")) + if longlora.use_longlora and longlora.trainable_params != "": + trainable_params = set(longlora.trainable_params.strip().split(",")) for n, p in model.named_parameters(): if any(trainable_p_name in n for trainable_p_name in trainable_params): p.requires_grad = True @@ -275,6 +260,7 @@ def main( out_dir, train, eval, + longlora, data, ) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") @@ -290,7 +276,7 @@ def main( # Save the final LoRA checkpoint at the end of training save_path = out_dir / "final" / "lit_model.pth.lora" save_path.parent.mkdir(parents=True, exist_ok=True) - save_lora_checkpoint(fabric, model, save_path) + save_lora_checkpoint(fabric, model, save_path, longlora=longlora) if fabric.global_rank == 0: # Copy checkpoint files from original checkpoint dir copy_config_files(checkpoint_dir, save_path.parent) @@ -311,20 +297,12 @@ def fit( out_dir: Path, train: TrainArgs, eval: EvalArgs, + longlora: LongLoraArgs, data: DataModule, ) -> None: tokenizer = Tokenizer(checkpoint_dir) - pad_multiple_of = data.pad_multiple_of or 1 - if train.get_longest_seq_length: - longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) - longest_seq_length = find_multiple( - min(longest_seq_length, train.max_seq_length or float("inf")), pad_multiple_of - ) - else: - longest_seq_length = find_multiple( - min(model.max_seq_length, train.max_seq_length or float("inf")), pad_multiple_of - ) - model.max_seq_length = longest_seq_length + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) + model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" @@ -411,7 +389,7 @@ def fit( if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0: checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) - save_lora_checkpoint(fabric, model, checkpoint_file) + save_lora_checkpoint(fabric, model, checkpoint_file, longlora=longlora) if fabric.global_rank == 0: copy_config_files(checkpoint_dir, checkpoint_file.parent) save_hyperparameters(setup, checkpoint_file.parent) @@ -494,7 +472,7 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: return longest_seq_length, longest_seq_ix -def save_lora_checkpoint(fabric: L.Fabric, model: GPT, file_path: Path) -> None: +def save_lora_checkpoint(fabric: L.Fabric, model: GPT, file_path: Path, longlora: LongLoraArgs) -> None: fabric.print(f"Saving LoRA weights to {str(file_path)!r}") fabric.save( file_path, @@ -502,10 +480,10 @@ def save_lora_checkpoint(fabric: L.Fabric, model: GPT, file_path: Path) -> None: filter={ "model": ( lora_filter - if model.config.longlora_context_length is None + if not longlora.use_longlora else partial( longlora_filter, - additional_weights=model.config.longlora_trainable_params.strip().split(","), + additional_weights=longlora.trainable_params.strip().split(","), ) ) }, @@ -530,6 +508,21 @@ def validate_args(train: TrainArgs, eval: EvalArgs) -> None: raise ValueError("\n".join(issues)) +def validate_longlora_args(config: Config, longlora: LongLoraArgs): + if longlora.use_longlora: + if longlora.context_length <= config.block_size: + warnings.warn( + f"LongLora is disabled because the LongLora context length ({longlora.context_length}) " + f"is less than the model original block size {config.block_size}. " + ) + longlora.use_longlora = False + elif longlora.context_length % longlora.n_groups != 0: + raise ValueError( + f"LongLora context length ({longlora.context_length}) must be a multiple of the number of groups " + f"({longlora.n_groups})." + ) + + if __name__ == "__main__": torch.set_float32_matmul_precision("high") diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 6928e9b5fb..ef7ac3b0b4 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -138,9 +138,13 @@ def main( if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: hparams = yaml.safe_load(hparams_file) - longlora_context_length = hparams.get("longlora_context_length", config.block_size) + longlora_cfg = hparams.get("longlora", None) + use_longlora = False + if longlora_cfg is not None: + use_longlora = longlora_cfg.get("use_longlora", False) + longlora_context_length = longlora_cfg.get("context_length", config.block_size) else: - longlora_context_length = config.block_size + use_longlora = False checkpoint_path = checkpoint_dir / "lit_model.pth" @@ -157,13 +161,13 @@ def main( fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): - if longlora_context_length is not None and longlora_context_length > config.block_size: + if use_longlora and longlora_context_length > config.block_size: old_block_size = config.block_size config.block_size = longlora_context_length old_rope_condense_ratio = config.rope_condense_ratio config.rope_condense_ratio = longlora_context_length / old_block_size fabric.print( - f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" + f"The model context length has been increased from {old_block_size} to {config.block_size}" ) fabric.print( f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" diff --git a/litgpt/generate/full.py b/litgpt/generate/full.py index d8466fc802..1d12b23a95 100644 --- a/litgpt/generate/full.py +++ b/litgpt/generate/full.py @@ -64,9 +64,13 @@ def main( if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: hparams = yaml.safe_load(hparams_file) - longlora_context_length = hparams.get("longlora_context_length", config.block_size) + longlora_cfg = hparams.get("longlora", None) + use_longlora = False + if longlora_cfg is not None: + use_longlora = longlora_cfg.get("use_longlora", False) + longlora_context_length = longlora_cfg.get("context_length", config.block_size) else: - longlora_context_length = config.block_size + use_longlora = False checkpoint_path = finetuned_path @@ -83,13 +87,13 @@ def main( fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): - if longlora_context_length is not None and longlora_context_length > config.block_size: + if use_longlora and longlora_context_length > config.block_size: old_block_size = config.block_size config.block_size = longlora_context_length old_rope_condense_ratio = config.rope_condense_ratio config.rope_condense_ratio = longlora_context_length / old_block_size fabric.print( - f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" + f"The model context length has been increased from {old_block_size} to {config.block_size}" ) fabric.print( f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" diff --git a/litgpt/generate/sequentially.py b/litgpt/generate/sequentially.py index 41d904c809..5c5fa7d726 100644 --- a/litgpt/generate/sequentially.py +++ b/litgpt/generate/sequentially.py @@ -163,9 +163,13 @@ def main( if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: hparams = yaml.safe_load(hparams_file) - longlora_context_length = hparams.get("longlora_context_length", config.block_size) + longlora_cfg = hparams.get("longlora", None) + use_longlora = False + if longlora_cfg is not None: + use_longlora = longlora_cfg.get("use_longlora", False) + longlora_context_length = longlora_cfg.get("context_length", config.block_size) else: - longlora_context_length = config.block_size + use_longlora = False checkpoint_path = checkpoint_dir / "lit_model.pth" @@ -180,13 +184,13 @@ def main( # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): - if longlora_context_length is not None and longlora_context_length > config.block_size: + if use_longlora and longlora_context_length > config.block_size: old_block_size = config.block_size config.block_size = longlora_context_length old_rope_condense_ratio = config.rope_condense_ratio config.rope_condense_ratio = longlora_context_length / old_block_size fabric.print( - f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" + f"The model context length has been increased from {old_block_size} to {config.block_size}" ) fabric.print( f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" diff --git a/litgpt/generate/tp.py b/litgpt/generate/tp.py index 92da3bc43f..778391744f 100644 --- a/litgpt/generate/tp.py +++ b/litgpt/generate/tp.py @@ -142,9 +142,13 @@ def main( if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: hparams = yaml.safe_load(hparams_file) - longlora_context_length = hparams.get("longlora_context_length", config.block_size) + longlora_cfg = hparams.get("longlora", None) + use_longlora = False + if longlora_cfg is not None: + use_longlora = longlora_cfg.get("use_longlora", False) + longlora_context_length = longlora_cfg.get("context_length", config.block_size) else: - longlora_context_length = config.block_size + use_longlora = False model_file = "lit_model.pth" checkpoint_path = checkpoint_dir / model_file @@ -160,13 +164,13 @@ def main( # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): - if longlora_context_length is not None and longlora_context_length > config.block_size: + if use_longlora and longlora_context_length > config.block_size: old_block_size = config.block_size config.block_size = longlora_context_length old_rope_condense_ratio = config.rope_condense_ratio config.rope_condense_ratio = longlora_context_length / old_block_size fabric.print( - f"The model context length has been increased from {old_block_size} to {config.longlora_context_length}" + f"The model context length has been increased from {old_block_size} to {config.block_size}" ) fabric.print( f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" diff --git a/litgpt/lora.py b/litgpt/lora.py index 60c9db2c60..bba6bb2ea8 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -626,9 +626,6 @@ def __init__(self, config: Config) -> None: self.config = config - # LongLora - self._longlora_available = self.config.longlora_n_groups is not None and self.config.longlora_n_groups > 0 - def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { diff --git a/litgpt/model.py b/litgpt/model.py index 1c25666de1..7b4f419142 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -202,9 +202,6 @@ def __init__(self, config: Config) -> None: self.config = config - # LongLora - self._longlora_available = self.config.longlora_n_groups is not None and self.config.longlora_n_groups > 0 - def forward( self, x: torch.Tensor, @@ -215,9 +212,11 @@ def forward( ) -> torch.Tensor: B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - if input_pos is None and self._longlora_available: + if input_pos is None and self.config.use_longlora: if T % self.config.longlora_n_groups != 0: - raise ValueError(f"sequence length {T} should be divisible by group size {longlora_group_size}.") + raise ValueError( + f"sequence length {T} should be divisible by the number of groups {self.config.longlora_n_groups}." + ) longlora_group_size = T // self.config.longlora_n_groups else: longlora_group_size = 0 @@ -259,20 +258,20 @@ def forward( v = roll_and_group(v, B, T, longlora_group_size, v.shape[1], self.config.head_size) y = self.scaled_dot_product_attention(q, k, v, mask) - y_cloned = y if input_pos is None and longlora_group_size > 0: # shift back and unroll n_heads = y.shape[2] - y_cloned = y.clone() - y_cloned = y_cloned.reshape(B, T, n_heads, self.config.head_size) # (B, T, nh, hs) - y_cloned[:, :, n_heads // 2 :] = y_cloned[:, :, n_heads // 2 :].roll(longlora_group_size // 2, dims=1) + y = y.reshape(B, T, n_heads, self.config.head_size) # (B, T, nh, hs) + y0, y1 = y.split(n_heads // 2, dim=2) + y1 = y1.roll(longlora_group_size // 2, dims=1) + y = torch.cat((y0, y1), dim=2) # re-assemble all head outputs side by side - y_cloned = y_cloned.reshape(B, T, self.config.head_size * self.config.n_head) + y = y.reshape(B, T, self.config.head_size * self.config.n_head) # output projection - return self.proj(y_cloned) + return self.proj(y) def scaled_dot_product_attention( self, @@ -311,8 +310,17 @@ def build_kv_cache( return KVCache(k_shape, v_shape, device=device, dtype=dtype) -def roll_and_group(qkv, bsz, q_len, group_size, num_heads, head_dim): - qkv[:, num_heads // 2 :] = qkv[:, num_heads // 2 :].roll(-group_size // 2, dims=2) +def roll_and_group( + qkv: torch.Tensor, bsz: int, q_len: int, group_size: int, num_heads: int, head_dim: int +) -> torch.Tensor: + # Split, roll and recompose to avoid the following error: + # RuntimeError: Output 0 of SliceBackward0 is a view and is being modified inplace. + # This view is the output of a function that returns multiple views. + # Such functions do not allow the output views to be modified inplace. + # You should replace the inplace operation by an out-of-place one. + qkv0, qkv1 = qkv.split(num_heads // 2, dim=1) + qkv1 = qkv1.roll(-group_size // 2, dims=2) + qkv = torch.cat((qkv0, qkv1), dim=1) qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2) return qkv From b38f0cec8dc34bcad159c0a2135c3d7f6b2fbbbb Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 25 Apr 2024 13:55:55 +0200 Subject: [PATCH 3/7] Add tests --- tests/test_full.py | 88 ++++++++++++++++++++++++++++++++++++++++++++++ tests/test_lora.py | 66 +++++++++++++++++++++++++++++++++- 2 files changed, 153 insertions(+), 1 deletion(-) diff --git a/tests/test_full.py b/tests/test_full.py index 74bc10f22e..993b50bb7a 100644 --- a/tests/test_full.py +++ b/tests/test_full.py @@ -12,6 +12,7 @@ import litgpt.finetune.full as module from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca +from litgpt.utils import CLI @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) @@ -69,3 +70,90 @@ def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): assert f"Resuming training from {out_dir / 'step-000006' / 'lit_model.pth'}" in logs assert logs.count("(step)") == 2 assert out_dir / "step-000008" in set(out_dir.iterdir()) + + +@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) +def test_full_longlora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): + model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) + (fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config)) + monkeypatch.setattr(module, "load_checkpoint", Mock()) + + tokenizer_mock = Mock() + tokenizer_mock.return_value = tokenizer_mock + tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) + monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) + + out_dir = tmp_path / "out" + setup_kwargs = dict( + data=Alpaca(download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0), + checkpoint_dir=fake_checkpoint_dir, + out_dir=out_dir, + precision="32-true", + train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), + eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), + ) + stdout = StringIO() + with redirect_stdout(stdout), mock.patch( + # Needed to save_hyperparameters function saves correctly LongLora params to be used when resuming + "sys.argv", + [ + "full.py", + "--data=litgpt.data.Alpaca", + "--data.download_dir=" + str(alpaca_path.parent), + "--data.file_name=" + str(alpaca_path.name), + "--data.val_split_fraction=0.5", + "--data.num_workers=0", + "--checkpoint_dir=" + str(fake_checkpoint_dir), + "--out_dir=" + str(out_dir), + "--precision=32-true", + "--train.global_batch_size=1", + "--train.save_interval=2", + "--train.epochs=1", + "--train.max_steps=6", + "--train.micro_batch_size=1", + "--eval.interval=2", + "--eval.max_iters=2", + "--eval.max_new_tokens=1", + "--longlora.use_longlora=True", + "--longlora.n_groups=4", + "--longlora.context_length=256", + ], + ): + CLI(module.setup) + + out_dir_contents = set(os.listdir(out_dir)) + checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"} + assert checkpoint_dirs.issubset(out_dir_contents) + assert all((out_dir / p).is_dir() for p in checkpoint_dirs) + for checkpoint_dir in checkpoint_dirs: + assert set(os.listdir(out_dir / checkpoint_dir)) == { + "lit_model.pth", + "model_config.yaml", + "tokenizer_config.json", + "tokenizer.json", + "hyperparameters.yaml", + "prompt_style.yaml", + } + assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file() + + logs = stdout.getvalue() + assert logs.count("(step)") == 6 + assert logs.count("val loss") == 4 # 3 validations + 1 final validation + assert logs.count("Final evaluation") == 1 + assert "of trainable parameters: 1,888" in logs + assert "The model context length has been increased from 128 to 256" in logs + assert "The 'rope_condense_ratio' has been adapted from 1 to 2.0" in logs + + # Resume training and do 2 steps more + setup_kwargs["train"].max_steps = 8 + setup_kwargs["resume"] = True + stdout = StringIO() + with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py"]): + module.setup(**setup_kwargs) + logs = stdout.getvalue() + assert f"Resuming training from {out_dir / 'step-000006' / 'lit_model.pth'}" in logs + assert logs.count("(step)") == 2 + assert out_dir / "step-000008" in set(out_dir.iterdir()) + assert "The model context length has been increased from 128 to 256" in logs + assert "The 'rope_condense_ratio' has been adapted from 1 to 2.0" in logs + diff --git a/tests/test_lora.py b/tests/test_lora.py index f8764c39bb..6a47defc2a 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -20,7 +20,7 @@ import litgpt.config as config_module import litgpt.finetune.lora as module -from litgpt.args import EvalArgs, TrainArgs +from litgpt.args import EvalArgs, LongLoraArgs, TrainArgs from litgpt.data import Alpaca from litgpt.lora import GPT as LoRAGPT from litgpt.lora import CausalSelfAttention as LoRACausalSelfAttention @@ -226,6 +226,70 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): assert "of trainable parameters: 512" in logs +@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) +def test_longlora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): + model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) + (fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config)) + monkeypatch.setattr(module, "load_checkpoint", Mock()) + monkeypatch.setattr(module, "merge_lora", Mock()) + + tokenizer_mock = Mock() + tokenizer_mock.return_value = tokenizer_mock + tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) + monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) + + out_dir = tmp_path / "out" + stdout = StringIO() + with redirect_stdout(stdout), mock.patch("sys.argv", ["lora.py"]): + module.setup( + data=Alpaca( + download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0 + ), + checkpoint_dir=fake_checkpoint_dir, + out_dir=out_dir, + precision="32-true", + train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), + eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), + longlora=LongLoraArgs(use_longlora=True, context_length=256), + ) + + out_dir_contents = set(os.listdir(out_dir)) + checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"} + assert checkpoint_dirs.issubset(out_dir_contents) + assert all((out_dir / p).is_dir() for p in checkpoint_dirs) + for checkpoint_dir in checkpoint_dirs: + assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == { + "lit_model.pth.lora", + "model_config.yaml", + "tokenizer_config.json", + "tokenizer.json", + "hyperparameters.yaml", + "prompt_style.yaml", + } + lora_ckpt = torch.load(out_dir / checkpoint_dir / "lit_model.pth.lora")["model"] + lora_ckpt_keys = lora_ckpt.keys() + assert all( + param in lora_ckpt_keys + for param in [ + "transformer.wte.weight", + "transformer.h.0.norm_1.weight", + "transformer.h.0.norm_2.weight", + "transformer.h.1.norm_1.weight", + "transformer.h.1.norm_2.weight", + "transformer.ln_f.weight", + ] + ) + assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file() + + logs = stdout.getvalue() + assert logs.count("(step)") == 6 + assert logs.count("val loss") == 4 # 3 validations + 1 final validation + assert logs.count("Final evaluation") == 1 + assert "of trainable parameters: 656" in logs + assert "The model context length has been increased from 128 to 256" in logs + assert "The 'rope_condense_ratio' has been adapted from 1 to 2.0" in logs + + def test_lora_init_when_linear_overridden(): class MyLinear(torch.nn.Linear): def __init__(self, *args, **kwargs): From 7bfe9efdc98067179e71792e1000730362d77aa7 Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 25 Apr 2024 13:56:00 +0200 Subject: [PATCH 4/7] Add configs --- config_hub/finetune/llama-2-7b/longlora.yaml | 135 +++++++++++++++++++ config_hub/finetune/mistral-7b/longlora.yaml | 135 +++++++++++++++++++ 2 files changed, 270 insertions(+) create mode 100644 config_hub/finetune/llama-2-7b/longlora.yaml create mode 100644 config_hub/finetune/mistral-7b/longlora.yaml diff --git a/config_hub/finetune/llama-2-7b/longlora.yaml b/config_hub/finetune/llama-2-7b/longlora.yaml new file mode 100644 index 0000000000..28cf3a8f46 --- /dev/null +++ b/config_hub/finetune/llama-2-7b/longlora.yaml @@ -0,0 +1,135 @@ + +# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) +checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf + +# Directory in which to save checkpoints and logs. (type: , default: out/lora) +out_dir: out/finetune/lora-llama2-7b + +# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-true + +# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) +quantize: + +# How many devices/GPUs to use. (type: Union[int, str], default: 1) +devices: 1 + +# The LoRA rank. (type: int, default: 8) +lora_r: 8 + +# The LoRA alpha. (type: int, default: 16) +lora_alpha: 16 + +# The LoRA dropout value. (type: float, default: 0.05) +lora_dropout: 0.0 + +# Whether to apply LoRA to the query weights in attention. (type: bool, default: True) +lora_query: true + +# Whether to apply LoRA to the key weights in attention. (type: bool, default: False) +lora_key: true + +# Whether to apply LoRA to the value weights in attention. (type: bool, default: True) +lora_value: true + +# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) +lora_projection: true + +# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) +lora_mlp: false + +# Whether to apply LoRA to output head in GPT. (type: bool, default: False) +lora_head: false + +# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. +data: + class_path: litgpt.data.Alpaca2k + init_args: + mask_prompt: false + prompt_style: alpaca + ignore_index: -100 + seed: 42 + num_workers: 4 + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 200 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 1 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) + global_batch_size: 8 + + # Number of samples per data-parallel rank (type: int, default: 4) + micro_batch_size: 2 + + # Number of iterations with learning rate warmup active (type: int, default: 100) + lr_warmup_steps: 10 + + # Number of epochs to train on (type: Optional[int], default: 5) + epochs: 4 + + # Total number of tokens to train on (type: Optional[int], default: null) + max_tokens: + + # Limits the number of optimizer steps to run. (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples. Off by default (type: Optional[int], default: null) + max_seq_length: 512 + + # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) + tie_embeddings: + + # (type: float, default: 0.0003) + learning_rate: 0.0002 + + # (type: float, default: 0.02) + weight_decay: 0.0 + + # (type: float, default: 0.9) + beta1: 0.9 + + # (type: float, default: 0.95) + beta2: 0.95 + + # (type: Optional[float], default: null) + max_norm: + + # (type: float, default: 6e-05) + min_lr: 6.0e-05 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + + # Number of optimizer steps between evaluation calls (type: int, default: 100) + interval: 100 + + # Number of tokens to generate (type: Optional[int], default: 100) + max_new_tokens: 100 + + # Number of iterations (type: int, default: 100) + max_iters: 100 + +# LongLoRA-related arguments. See ``litgpt.args.LongLoRAArgs`` for details +longlora: + # Whether to use LongLoRA. (type: bool, default: false) + use_longlora: true + + # The enlarged context length for LongLoRA. (type: int, default: 8192) + context_length: 8192 + + # The number of groups to split the sequence into. (type: int, default: 4) + n_groups: 4 + + # The additional trainable parameters for LongLoRA. (type: str, default: "wte,norm,ln") + trainable_params: "wte,norm,ln" + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) +logger_name: csv + +# The random seed to use for reproducibility. (type: int, default: 1337) +seed: 1337 diff --git a/config_hub/finetune/mistral-7b/longlora.yaml b/config_hub/finetune/mistral-7b/longlora.yaml new file mode 100644 index 0000000000..6bcc98f600 --- /dev/null +++ b/config_hub/finetune/mistral-7b/longlora.yaml @@ -0,0 +1,135 @@ + +# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) +checkpoint_dir: checkpoints/mistralai/Mistral-7B-v0.1 + +# Directory in which to save checkpoints and logs. (type: , default: out/lora) +out_dir: out/finetune/lora-mistral-7b + +# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-true + +# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) +quantize: + +# How many devices/GPUs to use. (type: Union[int, str], default: 1) +devices: 1 + +# The LoRA rank. (type: int, default: 8) +lora_r: 8 + +# The LoRA alpha. (type: int, default: 16) +lora_alpha: 16 + +# The LoRA dropout value. (type: float, default: 0.05) +lora_dropout: 0.0 + +# Whether to apply LoRA to the query weights in attention. (type: bool, default: True) +lora_query: true + +# Whether to apply LoRA to the key weights in attention. (type: bool, default: False) +lora_key: true + +# Whether to apply LoRA to the value weights in attention. (type: bool, default: True) +lora_value: true + +# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) +lora_projection: true + +# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) +lora_mlp: false + +# Whether to apply LoRA to output head in GPT. (type: bool, default: False) +lora_head: false + +# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. +data: + class_path: litgpt.data.Alpaca2k + init_args: + mask_prompt: false + prompt_style: alpaca + ignore_index: -100 + seed: 42 + num_workers: 4 + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 200 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 1 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) + global_batch_size: 8 + + # Number of samples per data-parallel rank (type: int, default: 4) + micro_batch_size: 2 + + # Number of iterations with learning rate warmup active (type: int, default: 100) + lr_warmup_steps: 10 + + # Number of epochs to train on (type: Optional[int], default: 5) + epochs: 4 + + # Total number of tokens to train on (type: Optional[int], default: null) + max_tokens: + + # Limits the number of optimizer steps to run. (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples. Off by default (type: Optional[int], default: null) + max_seq_length: 512 + + # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) + tie_embeddings: + + # (type: float, default: 0.0003) + learning_rate: 0.0002 + + # (type: float, default: 0.02) + weight_decay: 0.0 + + # (type: float, default: 0.9) + beta1: 0.9 + + # (type: float, default: 0.95) + beta2: 0.95 + + # (type: Optional[float], default: null) + max_norm: + + # (type: float, default: 6e-05) + min_lr: 6.0e-05 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + + # Number of optimizer steps between evaluation calls (type: int, default: 100) + interval: 100 + + # Number of tokens to generate (type: Optional[int], default: 100) + max_new_tokens: 100 + + # Number of iterations (type: int, default: 100) + max_iters: 100 + +# LongLoRA-related arguments. See ``litgpt.args.LongLoRAArgs`` for details +longlora: + # Whether to use LongLoRA. (type: bool, default: false) + use_longlora: true + + # The enlarged context length for LongLoRA. (type: int, default: 8192) + context_length: 8192 + + # The number of groups to split the sequence into. (type: int, default: 4) + n_groups: 4 + + # The additional trainable parameters for LongLoRA. (type: str, default: "wte,norm,ln") + trainable_params: "wte,norm,ln" + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) +logger_name: csv + +# The random seed to use for reproducibility. (type: int, default: 1337) +seed: 1337 From 2dfa7a5676371cac95740ad2c1a939ffb93a85dc Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 25 Apr 2024 14:17:30 +0200 Subject: [PATCH 5/7] Fix setting max_seq_length when using longlora --- litgpt/finetune/full.py | 4 +++- litgpt/finetune/lora.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 3c7c325021..64dd0e6187 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -232,7 +232,9 @@ def fit( scheduler = state["scheduler"] tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) - model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + if longlora.use_longlora: + longest_seq_length = find_multiple(longest_seq_length, longlora.n_groups) + model.max_seq_length = longest_seq_length fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 56f131cb2d..bd1f044787 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -302,7 +302,10 @@ def fit( ) -> None: tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) - model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + longest_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + if longlora.use_longlora: + longest_seq_length = find_multiple(longest_seq_length, longlora.n_groups) + model.max_seq_length = longest_seq_length fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" From f7c6971a39febee3cb9b1a0b468e00bf2480be83 Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 25 Apr 2024 18:51:15 +0200 Subject: [PATCH 6/7] Fix formatting --- litgpt/args.py | 2 +- litgpt/deploy/serve.py | 21 +++++++++------------ litgpt/eval/evaluate.py | 7 ++++--- litgpt/finetune/full.py | 24 +++--------------------- litgpt/finetune/lora.py | 39 +++++++-------------------------------- litgpt/lora.py | 17 ++++------------- litgpt/model.py | 6 +----- 7 files changed, 29 insertions(+), 87 deletions(-) diff --git a/litgpt/args.py b/litgpt/args.py index bbeaaf738c..4924875922 100644 --- a/litgpt/args.py +++ b/litgpt/args.py @@ -75,7 +75,7 @@ class EvalArgs: @dataclass class LongLoraArgs: - """GaLore-related arguments""" + """LongLora-related arguments""" use_longlora: bool = False """Whether to enable LongLora.""" diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index ad9c863a79..4a26e0b14f 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -16,14 +16,12 @@ class SimpleLitAPI(LitAPI): - def __init__( - self, - checkpoint_dir: Path, - precision: Optional[str] = None, - temperature: float = 0.8, - top_k: int = 50, - max_new_tokens: int = 50, - ) -> None: + def __init__(self, + checkpoint_dir: Path, + precision: Optional[str] = None, + temperature: float = 0.8, + top_k: int = 50, + max_new_tokens: int = 50) -> None: super().__init__() self.checkpoint_dir = checkpoint_dir @@ -81,7 +79,7 @@ def predict(self, inputs: torch.Tensor) -> Any: max_returned_tokens, temperature=self.temperature, top_k=self.top_k, - eos_id=self.tokenizer.eos_id, + eos_id=self.tokenizer.eos_id ) for block in self.model.transformer.h: @@ -129,10 +127,9 @@ def run_server( temperature=temperature, top_k=top_k, max_new_tokens=max_new_tokens, - ), + ), accelerator=accelerator, - devices=devices, - ) + devices=devices) server.run(port=port) diff --git a/litgpt/eval/evaluate.py b/litgpt/eval/evaluate.py index 2612730ecc..78e0ed0f59 100644 --- a/litgpt/eval/evaluate.py +++ b/litgpt/eval/evaluate.py @@ -19,7 +19,9 @@ def prepare_results(results, save_filepath, print_results=True): if "groups" in results: print(make_table(results, "groups")) - json_result = json.dumps(results, indent=2, ensure_ascii=False) + json_result = json.dumps( + results, indent=2, ensure_ascii=False + ) save_filepath.open("w", encoding="utf-8").write(json_result) @@ -60,7 +62,6 @@ def convert_and_evaluate( if tasks is None: from lm_eval.tasks import TaskManager - taskm = TaskManager() print("\n".join(taskm.task_index.keys())) print( @@ -83,7 +84,7 @@ def convert_and_evaluate( out_dir.mkdir(parents=True, exist_ok=True) save_filepath = out_dir / Path("results.json") if save_filepath is None else Path(save_filepath) - config_filepath = checkpoint_dir / "model_config.yaml" + config_filepath = checkpoint_dir/"model_config.yaml" with open(config_filepath, encoding="utf-8") as f: config_dict = yaml.safe_load(f) diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 64dd0e6187..ccca8f7297 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -178,20 +178,7 @@ def main( load_checkpoint(fabric, state["model"], checkpoint_path) train_time = time.perf_counter() - fit( - fabric, - state, - train_dataloader, - val_dataloader, - devices, - resume, - checkpoint_dir, - out_dir, - train, - eval, - longlora, - data, - ) + fit(fabric, state, train_dataloader, val_dataloader, devices, resume, checkpoint_dir,out_dir, train, eval, longlora, data) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") @@ -227,7 +214,7 @@ def fit( longlora: LongLoraArgs, data: DataModule, ) -> None: - model: GPT = state["model"] + model = state["model"] optimizer = state["optimizer"] scheduler = state["scheduler"] tokenizer = Tokenizer(checkpoint_dir) @@ -379,12 +366,7 @@ def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): def get_dataloaders( fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs, pad_multiple_of: Optional[int] = None ) -> Tuple[DataLoader, DataLoader]: - data.connect( - tokenizer=tokenizer, - batch_size=train.micro_batch_size, - max_seq_length=train.max_seq_length, - pad_multiple_of=pad_multiple_of, - ) + data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length, pad_multiple_of=pad_multiple_of) with fabric.rank_zero_first(): data.prepare_data() data.setup() diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index bd1f044787..b3a063ac65 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -129,22 +129,13 @@ def setup( ) precision = precision or get_default_supported_precision(training=True) - logger = choose_logger( - logger_name, - out_dir, - name=f"finetune-{config.name}", - log_interval=train.log_interval, - ) + logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") - dtype = { - "16-true": torch.float16, - "bf16-true": torch.bfloat16, - "32-true": torch.float32, - }[precision] + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None @@ -164,13 +155,7 @@ def setup( else: strategy = "auto" - fabric = L.Fabric( - devices=devices, - strategy=strategy, - precision=precision, - loggers=logger, - plugins=plugins, - ) + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins) fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, longlora) @@ -236,10 +221,7 @@ def main( else: optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - trainable_params, - lr=train.learning_rate, - weight_decay=train.weight_decay, - betas=(train.beta1, train.beta2), + trainable_params, lr=train.learning_rate, weight_decay=train.weight_decay, betas=(train.beta1, train.beta2) ) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) @@ -352,10 +334,7 @@ def fit( loss = running_loss.compute().item() # expensive device-to-host synchronization t1 = time.perf_counter() throughput.update( - time=t1 - total_t0, - batches=iter_num, - samples=iter_num * train.micro_batch_size, - lengths=total_lengths, + time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths ) throughput.compute_and_log(step=iter_num) metrics = { @@ -430,11 +409,7 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( - model, - encoded, - max_returned_tokens=len(encoded) + eval.max_new_tokens, - temperature=0.8, - eos_id=tokenizer.eos_id, + model, encoded, max_returned_tokens=len(encoded) + eval.max_new_tokens, temperature=0.8, eos_id=tokenizer.eos_id ) model.clear_kv_cache() model.train() @@ -456,7 +431,7 @@ def get_dataloaders( tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length, - pad_multiple_of=pad_multiple_of, + pad_multiple_of=pad_multiple_of ) with fabric.rank_zero_first(): data.prepare_data() diff --git a/litgpt/lora.py b/litgpt/lora.py index bba6bb2ea8..30f8a6b2c2 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -335,9 +335,7 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384) result = result.view(-1, self.linear.out_features) # (4096, 384) result = result.index_copy( - 1, - torch.tensor(self.lora_ind, device=result.device), - x.reshape(-1, sum(self.qkv_shapes)), + 1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes)) ) # (4096, 256) return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384) @@ -373,8 +371,7 @@ def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T) weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1) return torch.cat( - [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], - dim=1, # (B, C_output', T) + [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T) ) # (B, C_output, T) def get_lora_AB(self) -> torch.Tensor: @@ -528,10 +525,7 @@ def __init__(self, config: Config) -> None: self.mask_cache: Optional[torch.Tensor] = None def forward( - self, - idx: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, - lm_head_chunk_size: int = 0, + self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0 ) -> Union[torch.Tensor, List[torch.Tensor]]: T = idx.size(1) if self.max_seq_length < T: @@ -571,10 +565,7 @@ def _init_weights(self, module: nn.Module) -> None: def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" - mapping = { - "lm_head.weight": "lm_head.linear.weight", - "lm_head.bias": "lm_head.linear.bias", - } + mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"} state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/litgpt/model.py b/litgpt/model.py index 7b4f419142..9e0c131101 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -274,11 +274,7 @@ def forward( return self.proj(y) def scaled_dot_product_attention( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: Optional[torch.Tensor] = None, + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: scale = 1.0 / math.sqrt(self.config.head_size) y = torch.nn.functional.scaled_dot_product_attention( From 073b0275893a2b686f198888f0b8f7fde3323bc3 Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 29 May 2024 12:06:44 +0200 Subject: [PATCH 7/7] Update LongLora configs --- config_hub/finetune/llama-2-7b/longlora.yaml | 30 ++++++++++++-------- config_hub/finetune/mistral-7b/longlora.yaml | 30 ++++++++++++-------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/config_hub/finetune/llama-2-7b/longlora.yaml b/config_hub/finetune/llama-2-7b/longlora.yaml index 28cf3a8f46..13a0dd0b16 100644 --- a/config_hub/finetune/llama-2-7b/longlora.yaml +++ b/config_hub/finetune/llama-2-7b/longlora.yaml @@ -84,18 +84,6 @@ train: # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: - # (type: float, default: 0.0003) - learning_rate: 0.0002 - - # (type: float, default: 0.02) - weight_decay: 0.0 - - # (type: float, default: 0.9) - beta1: 0.9 - - # (type: float, default: 0.95) - beta2: 0.95 - # (type: Optional[float], default: null) max_norm: @@ -133,3 +121,21 @@ logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 + +# Optimizer-related arguments +optimizer: + + class_path: torch.optim.AdamW + + init_args: + + # (type: float, default: 0.001) + lr: 0.0002 + + # (type: float, default: 0.01) + weight_decay: 0.0 + + # (type: tuple, default: (0.9,0.999)) + betas: + - 0.9 + - 0.95 \ No newline at end of file diff --git a/config_hub/finetune/mistral-7b/longlora.yaml b/config_hub/finetune/mistral-7b/longlora.yaml index 6bcc98f600..1369e1c459 100644 --- a/config_hub/finetune/mistral-7b/longlora.yaml +++ b/config_hub/finetune/mistral-7b/longlora.yaml @@ -84,18 +84,6 @@ train: # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: - # (type: float, default: 0.0003) - learning_rate: 0.0002 - - # (type: float, default: 0.02) - weight_decay: 0.0 - - # (type: float, default: 0.9) - beta1: 0.9 - - # (type: float, default: 0.95) - beta2: 0.95 - # (type: Optional[float], default: null) max_norm: @@ -133,3 +121,21 @@ logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 + +# Optimizer-related arguments +optimizer: + + class_path: torch.optim.AdamW + + init_args: + + # (type: float, default: 0.001) + lr: 0.0002 + + # (type: float, default: 0.01) + weight_decay: 0.0 + + # (type: tuple, default: (0.9,0.999)) + betas: + - 0.9 + - 0.95 \ No newline at end of file