diff --git a/README.md b/README.md index f6e426b9fa..4bd3d8bf75 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@
-LitGPT -  # ⚡ LitGPT @@ -20,11 +18,11 @@ Uses the latest state-of-the-art techniques:

Lightning AIModels • - Install • - Get started • - Evaluate • + Quick start • + InferenceFinetunePretrain • + DeployFeaturesTraining recipes (YAML)

@@ -34,13 +32,13 @@ Uses the latest state-of-the-art techniques:   # Finetune, pretrain and deploy LLMs Lightning fast ⚡⚡ -LitGPT is a command-line tool designed to easily [finetune](#finetune-an-llm), [pretrain](#pretrain-an-llm), [evaluate](#use-an-llm), and deploy [20+ LLMs](#choose-from-20-llms) **on your own data**. It features highly-optimized [training recipes](#training-recipes) for the world's most powerful open-source large-language-models (LLMs). +LitGPT is a command-line tool designed to easily [finetune](#finetune-an-llm), [pretrain](#pretrain-an-llm), [evaluate](#use-an-llm), and [deploy](#deploy-an-llm) [20+ LLMs](#choose-from-20-llms) **on your own data**. It features highly-optimized [training recipes](#training-recipes) for the world's most powerful open-source large language models (LLMs). We reimplemented all model architectures and training recipes from scratch for 4 reasons: 1. Remove all abstraction layers and have single file implementations. 2. Guarantee Apache 2.0 compliance to enable enterprise use without limits. -3. Optimized each model architectural detail to maximize performance, reduce costs, and speed up training. +3. Optimized each model's architectural detail to maximize performance, reduce costs, and speed up training. 4. Highly-optimized [recipe configs](#training-recipes) we have tested at enterprise scale. --- @@ -50,6 +48,7 @@ LitGPT has 🤯 **custom, from-scratch implementations** of [20+ LLMs](tutorials | Model | Model size | Author | Reference | |----|----|----|----| +| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Mistral | 7B | Mistral AI | [Mistral website](https://mistral.ai/) | @@ -72,6 +71,7 @@ LitGPT has 🤯 **custom, from-scratch implementations** of [20+ LLMs](tutorials | Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | | Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | +| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | | Mistral | 7B | Mistral AI | [Mistral website](https://mistral.ai/) | | Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) | @@ -114,21 +114,22 @@ pip install -e '.[all]' --- -# Get started +# Quick start After installing LitGPT, select the model and action you want to take on that model (finetune, pretrain, evaluate, deploy, etc...): ```bash # ligpt [action] [model] -litgpt download mistralai/Mistral-7B-Instruct-v0.2 -litgpt chat mistralai/Mistral-7B-Instruct-v0.2 -litgpt finetune mistralai/Mistral-7B-Instruct-v0.2 -litgpt pretrain mistralai/Mistral-7B-Instruct-v0.2 -litgpt serve mistralai/Mistral-7B-Instruct-v0.2 +litgpt download meta-llama/Meta-Llama-3-8B-Instruct +litgpt chat meta-llama/Meta-Llama-3-8B-Instruct +litgpt finetune meta-llama/Meta-Llama-3-8B-Instruct +litgpt pretrain meta-llama/Meta-Llama-3-8B-Instruct +litgpt serve meta-llama/Meta-Llama-3-8B-Instruct ```   -### Use an LLM +### Use an LLM for inference +Use LLMs for inference to test its chatting capabilities, run evaluations, or extract embeddings, etc... Here's an example showing how to use the Mistral 7B LLM. @@ -155,14 +156,20 @@ For more information, refer to the [download](tutorials/download_model_weights.m ### Finetune an LLM [Finetune](tutorials/finetune.md) a model to specialize it on your own custom dataset: + + Open In Studio + + +  + ```bash # 1) Download a pretrained model litgpt download --repo_id microsoft/phi-2 # 2) Finetune the model -curl -L https://huggingface.co/datasets/medalpaca/medical_meadow_health_advice/raw/main/medical_meadow_health_advice.json -o my_custom_dataset.json +curl -L https://huggingface.co/datasets/ksaw008/finance_alpaca/resolve/main/finance_alpaca.json -o my_custom_dataset.json -litgpt finetune lora \ +litgpt finetune \ --checkpoint_dir checkpoints/microsoft/phi-2 \ --data JSON \ --data.json_path my_custom_dataset.json \ @@ -174,9 +181,17 @@ litgpt chat \ --checkpoint_dir out/phi-2-lora/final ``` +  + ### Pretrain an LLM Train an LLM from scratch on your own data via pretraining: + +Open In Studio + + +  + ```bash mkdir -p custom_texts curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt @@ -201,10 +216,19 @@ litgpt chat \ --checkpoint_dir out/custom-model/final ``` +  + ### Continue pretraining an LLM -This is another way of finetuning that specialize an already pretrained model by training on custom data: +This is another way of finetuning that specializes an already pretrained model by training on custom data: -``` + + +Open In Studio + + +  + +```bash mkdir -p custom_texts curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt @@ -215,6 +239,7 @@ litgpt download --repo_id EleutherAI/pythia-160m # 2) Continue pretraining the model litgpt pretrain \ --model_name pythia-160m \ + --tokenizer_dir checkpoints/EleutherAI/pythia-160m \ --initial_checkpoint_dir checkpoints/EleutherAI/pythia-160m \ --data TextFiles \ --data.train_data_path "custom_texts/" \ @@ -228,6 +253,37 @@ litgpt chat \   +### Deploy an LLM +Once you're ready to deploy a finetuned LLM, run this command: + + + Open In Studio + + +  + +```bash +# locate the checkpoint to your finetuned or pretrained model and call the `serve` command: +litgpt serve --checkpoint_dir path/to/your/checkpoint/microsoft/phi-2 + +# Alternative: if you haven't finetuned, download any checkpoint to deploy it: +litgpt download --repo_id microsoft/phi-2 +litgpt serve --checkpoint_dir checkpoints/microsoft/phi-2 +``` + +Test the server in a separate terminal and integrate the model API into your AI product: +```python +# 3) Use the server (in a separate session) +import requests, json + response = requests.post( + "http://127.0.0.1:8000/predict", + json={"prompt": "Fix typos in the following sentence: Exampel input"} +) +print(response.json()["output"]) +``` + +  + > [!NOTE] > **[Read the full docs](tutorials/0_to_litgpt.md)**. @@ -267,7 +323,7 @@ Browse all training recipes [here](config_hub). ### Example ```bash -litgpt finetune lora \ +litgpt finetune \ --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml ``` @@ -422,7 +478,7 @@ seed: 1337 Override any parameter in the CLI: ```bash -litgpt finetune lora \ +litgpt finetune \ --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml \ --lora_r 4 ``` diff --git a/config_hub/finetune/README.md b/config_hub/finetune/README.md index fc82e0854b..2aaa2d3b79 100644 --- a/config_hub/finetune/README.md +++ b/config_hub/finetune/README.md @@ -22,6 +22,10 @@ For more information, see the [Dealing with out-of-memory (OOM) errors](../../tu | llama-2-7b/qlora.yaml | 7B | Alpaca 2k | 4 | 0.814 | 13.68 GB | 512 | 2 | bfloat16 | 45.68 min (A10G) | | llama-2-7b/full.yaml | 7B | Alpaca 2k | 1 | 0.941 | 26.81 GB | 512 | 4 | bfloat16 | 1.78 min (4xA100) | | | | | | | | | | | | +| llama-3-8b/lora.yaml | 8B | Alpaca 2k | 2 | 0.890 | 19.73 GB | 512 | 1 | bfloat16 | 14.80 min (A10G) | +| llama-3-8b/qlora.yaml | 8B | Alpaca 2k | 2 | 0.941 | 17.41 GB | 512 | 2 | bfloat16 | 22.34 min (A10G) | +| llama-3-8b/full.yaml | 8B | Alpaca 2k | 1 | 1.451 | 35.48 GB | 512 | 4 | bfloat16 | 2.14 min (4xA100) | +| | | | | | | | | | | | mistral-7b/lora.yaml (v0.1) | 7B | Alpaca 2k | 4 | 0.796 | 20.65 GB | 512 | 2 | bfloat16 | 31.04 min (1xA10G) | | mistral-7b/qlora.yaml (v0.1) | 7B | Alpaca 2k | 4 | 0.803 | 14.29 GB | 512 | 2 | bfloat16 | 44.69 min (1xA10G) | | | | | | | | | | | | diff --git a/config_hub/finetune/llama-3-8b/full.yaml b/config_hub/finetune/llama-3-8b/full.yaml new file mode 100644 index 0000000000..11aebcb155 --- /dev/null +++ b/config_hub/finetune/llama-3-8b/full.yaml @@ -0,0 +1,95 @@ + +# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) +checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B + +# Directory in which to save checkpoints and logs. (type: , default: out/finetune/full) +out_dir: out/finetune/full-llama-3-8b + +# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-true + +# How many devices/GPUs to use (type: Union[int, str], default: 1) +devices: 4 + +# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume +# from the latest checkpoint in ``out_dir``. (type: Union[bool, Path], default: False) +resume: false + +# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. +data: + class_path: litgpt.data.Alpaca2k + init_args: + mask_prompt: false + prompt_style: alpaca + ignore_index: -100 + seed: 42 + num_workers: 4 + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 200 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 1 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64) + global_batch_size: 64 + + # Number of samples per data-parallel rank (type: int, default: 1) + micro_batch_size: 4 + + # Number of iterations with learning rate warmup active (type: int, default: 100) + lr_warmup_steps: 25 + + # Number of epochs to train on (type: Optional[int], default: 5) + epochs: 1 + + # Total number of tokens to train on (type: Optional[int], default: null) + max_tokens: + + # Limits the number of optimizer steps to run. (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples. Off by default (type: Optional[int], default: null) + max_seq_length: 512 + + # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) + tie_embeddings: + + # (type: float, default: 0.003) + learning_rate: 0.0002 + + # (type: float, default: 0.02) + weight_decay: 0.1 + + # (type: float, default: 0.9) + beta1: 0.9 + + # (type: float, default: 0.95) + beta2: 0.95 + + # (type: Optional[float], default: null) + max_norm: + + # (type: float, default: 6e-05) + min_lr: 6.0e-05 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + + # Number of optimizer steps between evaluation calls (type: int, default: 600) + interval: 25 + + # Number of tokens to generate (type: Optional[int], default: 100) + max_new_tokens: 100 + + # Number of iterations (type: int, default: 100) + max_iters: 100 + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) +logger_name: csv + +# The random seed to use for reproducibility. (type: int, default: 1337) +seed: 1337 diff --git a/config_hub/finetune/llama-3-8b/lora.yaml b/config_hub/finetune/llama-3-8b/lora.yaml new file mode 100644 index 0000000000..700a3b62f4 --- /dev/null +++ b/config_hub/finetune/llama-3-8b/lora.yaml @@ -0,0 +1,121 @@ + +# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) +checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B + +# Directory in which to save checkpoints and logs. (type: , default: out/lora) +out_dir: out/finetune/lora-llama-3-8b + +# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-true + +# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) +quantize: + +# How many devices/GPUs to use. (type: Union[int, str], default: 1) +devices: 1 + +# The LoRA rank. (type: int, default: 8) +lora_r: 32 + +# The LoRA alpha. (type: int, default: 16) +lora_alpha: 16 + +# The LoRA dropout value. (type: float, default: 0.05) +lora_dropout: 0.05 + +# Whether to apply LoRA to the query weights in attention. (type: bool, default: True) +lora_query: true + +# Whether to apply LoRA to the key weights in attention. (type: bool, default: False) +lora_key: false + +# Whether to apply LoRA to the value weights in attention. (type: bool, default: True) +lora_value: true + +# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) +lora_projection: false + +# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) +lora_mlp: false + +# Whether to apply LoRA to output head in GPT. (type: bool, default: False) +lora_head: false + +# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. +data: + class_path: litgpt.data.Alpaca2k + init_args: + mask_prompt: false + prompt_style: alpaca + ignore_index: -100 + seed: 42 + num_workers: 4 + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 200 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 1 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) + global_batch_size: 8 + + # Number of samples per data-parallel rank (type: int, default: 4) + micro_batch_size: 1 + + # Number of iterations with learning rate warmup active (type: int, default: 100) + lr_warmup_steps: 10 + + # Number of epochs to train on (type: Optional[int], default: 5) + epochs: 2 + + # Total number of tokens to train on (type: Optional[int], default: null) + max_tokens: + + # Limits the number of optimizer steps to run. (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples. Off by default (type: Optional[int], default: null) + max_seq_length: 512 + + # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) + tie_embeddings: + + # (type: float, default: 0.0003) + learning_rate: 0.0002 + + # (type: float, default: 0.02) + weight_decay: 0.0 + + # (type: float, default: 0.9) + beta1: 0.9 + + # (type: float, default: 0.95) + beta2: 0.95 + + # (type: Optional[float], default: null) + max_norm: + + # (type: float, default: 6e-05) + min_lr: 6.0e-05 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + + # Number of optimizer steps between evaluation calls (type: int, default: 100) + interval: 100 + + # Number of tokens to generate (type: Optional[int], default: 100) + max_new_tokens: 100 + + # Number of iterations (type: int, default: 100) + max_iters: 100 + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) +logger_name: csv + +# The random seed to use for reproducibility. (type: int, default: 1337) +seed: 1337 diff --git a/config_hub/finetune/llama-3-8b/qlora.yaml b/config_hub/finetune/llama-3-8b/qlora.yaml new file mode 100644 index 0000000000..1da95eaac5 --- /dev/null +++ b/config_hub/finetune/llama-3-8b/qlora.yaml @@ -0,0 +1,123 @@ + +# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) +checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B + +# Directory in which to save checkpoints and logs. (type: , default: out/lora) +out_dir: out/finetune/qlora-llama3-8b + +# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-true + +# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) +quantize: bnb.nf4 + +# How many devices/GPUs to use. (type: Union[int, str], default: 1) +devices: 1 + +# The LoRA rank. (type: int, default: 8) +lora_r: 32 + +# The LoRA alpha. (type: int, default: 16) +lora_alpha: 16 + +# The LoRA dropout value. (type: float, default: 0.05) +lora_dropout: 0.05 + +# Whether to apply LoRA to the query weights in attention. (type: bool, default: True) +lora_query: true + +# Whether to apply LoRA to the key weights in attention. (type: bool, default: False) +lora_key: false + +# Whether to apply LoRA to the value weights in attention. (type: bool, default: True) +lora_value: true + +# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) +lora_projection: false + +# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) +lora_mlp: false + +# Whether to apply LoRA to output head in GPT. (type: bool, default: False) +lora_head: false + +# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. +data: + class_path: litgpt.data.Alpaca2k + init_args: + mask_prompt: false + val_split_fraction: 0.05 + prompt_style: alpaca + ignore_index: -100 + seed: 42 + num_workers: 4 + download_dir: data/alpaca2k + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 200 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 1 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) + global_batch_size: 8 + + # Number of samples per data-parallel rank (type: int, default: 4) + micro_batch_size: 2 + + # Number of iterations with learning rate warmup active (type: int, default: 100) + lr_warmup_steps: 10 + + # Number of epochs to train on (type: Optional[int], default: 5) + epochs: 2 + + # Total number of tokens to train on (type: Optional[int], default: null) + max_tokens: + + # Limits the number of optimizer steps to run (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples (type: Optional[int], default: null) + max_seq_length: 512 + + # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) + tie_embeddings: + + # (type: float, default: 0.0003) + learning_rate: 0.0002 + + # (type: float, default: 0.02) + weight_decay: 0.0 + + # (type: float, default: 0.9) + beta1: 0.9 + + # (type: float, default: 0.95) + beta2: 0.95 + + # (type: Optional[float], default: null) + max_norm: + + # (type: float, default: 6e-05) + min_lr: 6.0e-05 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + + # Number of optimizer steps between evaluation calls (type: int, default: 100) + interval: 100 + + # Number of tokens to generate (type: Optional[int], default: 100) + max_new_tokens: 100 + + # Number of iterations (type: int, default: 100) + max_iters: 100 + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) +logger_name: csv + +# The random seed to use for reproducibility. (type: int, default: 1337) +seed: 1337 diff --git a/litgpt/__main__.py b/litgpt/__main__.py index 59d53ac904..821c1f5801 100644 --- a/litgpt/__main__.py +++ b/litgpt/__main__.py @@ -1,4 +1,5 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import sys from typing import TYPE_CHECKING, Any @@ -24,6 +25,8 @@ from litgpt.scripts.download import download_from_hub as download_fn from litgpt.scripts.merge_lora import merge_lora as merge_lora_fn from litgpt.eval.evaluate import convert_and_evaluate as evaluate_fn +from litgpt.deploy.serve import run_server as serve_fn + if TYPE_CHECKING: from jsonargparse import ArgumentParser @@ -39,6 +42,12 @@ def _new_parser(**kwargs: Any) -> "ArgumentParser": return parser +def _rewrite_argv_for_default_subcommand(parser_data: dict, command: str, subcommand: str) -> None: + """Rewrites the `sys.argv` such that `litgpt command` defaults to `litgpt command subcommand`.""" + if len(sys.argv) > 2 and sys.argv[1] == command and sys.argv[2] not in parser_data[command].keys(): + sys.argv.insert(2, subcommand) + + def main() -> None: parser_data = { "download": {"help": "Download weights or tokenizer data from the Hugging Face Hub.", "fn": download_fn}, @@ -80,6 +89,7 @@ def main() -> None: }, "merge_lora": {"help": "Merges the LoRA weights with the base model.", "fn": merge_lora_fn}, "evaluate": {"help": "Evaluate a model with the LM Evaluation Harness.", "fn": evaluate_fn}, + "serve": {"help": "Serve and deploy a model with LitServe.", "fn": serve_fn}, } from jsonargparse import set_config_read_mode, set_docstring_parse_options @@ -87,6 +97,8 @@ def main() -> None: set_docstring_parse_options(attribute_docstrings=True) set_config_read_mode(urls_enabled=True) + _rewrite_argv_for_default_subcommand(parser_data, "finetune", "lora") + root_parser = _new_parser(prog="litgpt") # register level 1 subcommands and level 2 subsubcommands. If there are more levels in the future we would want to diff --git a/litgpt/config.py b/litgpt/config.py index 0a4234222d..e03fa8ae34 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -836,6 +836,56 @@ def norm_class(self) -> Type: copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) + + +############### +# Meta LLaMA 3 +############### +llama_3 = [ + # https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json + dict( + name="Llama-3-8B{}", + hf_config=dict(org="meta-llama", name="Meta-Llama-3-8B{}"), + block_size=8192, + vocab_size=128000, + padded_vocab_size=128256, + n_layer=32, + n_head=32, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=14336, + rope_base=500000, + ), + # https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/config.json + dict( + name="Llama-3-70B{}", + hf_config=dict(org="meta-llama", name="Meta-Llama-3-70B{}"), + block_size=8192, + vocab_size=128000, + padded_vocab_size=128256, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=28672, + rope_base=500000, + ), +] +for c in llama_3: + for kind in ("", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) ############### diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py new file mode 100644 index 0000000000..9df48ad98d --- /dev/null +++ b/litgpt/deploy/serve.py @@ -0,0 +1,137 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +from pathlib import Path +from typing import Dict, Any, Optional, Literal +from litgpt.utils import check_valid_checkpoint_dir + +import lightning as L +import torch +from litserve import LitAPI, LitServer + +from litgpt.model import GPT +from litgpt.config import Config +from litgpt.tokenizer import Tokenizer +from litgpt.generate.base import generate +from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle +from litgpt.utils import load_checkpoint, CLI, get_default_supported_precision + + +class SimpleLitAPI(LitAPI): + def __init__(self, + checkpoint_dir: Path, + precision: Optional[str] = None, + temperature: float = 0.8, + top_k: int = 50, + max_new_tokens: int = 50) -> None: + + super().__init__() + self.checkpoint_dir = checkpoint_dir + self.precision = precision + self.temperature = temperature + self.top_k = top_k + self.max_new_tokens = max_new_tokens + + def setup(self, device: str) -> None: + # Setup the model so it can be called in `predict`. + config = Config.from_file(self.checkpoint_dir / "model_config.yaml") + device = torch.device(device) + torch.set_float32_matmul_precision("high") + + precision = self.precision or get_default_supported_precision(training=False) + + fabric = L.Fabric( + accelerator=device.type, + devices=1 if device.type=="cpu" else [device.index], # TODO: Update once LitServe supports "auto" + precision=precision, + ) + checkpoint_path = self.checkpoint_dir / "lit_model.pth" + self.tokenizer = Tokenizer(self.checkpoint_dir) + self.prompt_style = ( + load_prompt_style(self.checkpoint_dir) + if has_prompt_style(self.checkpoint_dir) + else PromptStyle.from_config(config) + ) + with fabric.init_module(empty_init=True): + model = GPT(config) + with fabric.init_tensor(): + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + self.model = fabric.setup_module(model) + load_checkpoint(fabric, self.model, checkpoint_path) + self.device = fabric.device + + def decode_request(self, request: Dict[str, Any]) -> Any: + # Convert the request payload to your model input. + prompt = request["prompt"] + prompt = self.prompt_style.apply(prompt) + encoded = self.tokenizer.encode(prompt, device=self.device) + return encoded + + def predict(self, inputs: torch.Tensor) -> Any: + # Run the model on the input and return the output. + prompt_length = inputs.size(0) + max_returned_tokens = prompt_length + self.max_new_tokens + + y = generate( + self.model, + inputs, + max_returned_tokens, + temperature=self.temperature, + top_k=self.top_k, + eos_id=self.tokenizer.eos_id + ) + + for block in self.model.transformer.h: + block.attn.kv_cache.reset_parameters() + return y + + def encode_response(self, output: torch.Tensor) -> Dict[str, Any]: + # Convert the model output to a response payload. + decoded_output = self.tokenizer.decode(output) + return {"output": decoded_output} + + +def run_server( + checkpoint_dir: Path = Path("checkpoints"), + precision: Optional[str] = None, + temperature: float = 0.8, + top_k: int = 200, + max_new_tokens: int = 50, + devices: int = 1, + accelerator: str = "cuda", + port: int = 8000 +) -> None: + """Serve a LitGPT model using LitServe + + Arguments: + checkpoint_dir: The checkpoint directory to load the model from. + precision: Optional precision setting to instantiate the model weights in. By default, this will + automatically be inferred from the metadata in the given ``checkpoint_dir`` directory. + temperature: Temperature setting for the text generation. Value above 1 increase randomness. + Values below 1 decrease randomness. + top_k: The size of the pool of potential next tokens. Values larger than 1 result in more novel + generated text but can also lead to more incoherent texts. + max_new_tokens: The number of generation steps to take. + devices: How many devices/GPUs to use. + accelerator: The type of accelerator to use. For example, "cuda" or "cpu". + port: The network port number on which the model is configured to be served. + """ + check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth") + + server = LitServer( + SimpleLitAPI( + checkpoint_dir=checkpoint_dir, + precision=precision, + temperature=temperature, + top_k=top_k, + max_new_tokens=max_new_tokens, + ), + accelerator=accelerator, + devices=devices) + + server.run(port=port) + + +if __name__ == "__main__": + CLI(run_server) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index d827413913..04a0551cd1 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -200,6 +200,24 @@ def apply(self, prompt: str, **kwargs: str) -> str: ) +class Llama3(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + # https://github.com/meta-llama/llama3/blob/359887376f0aaf30e433f23e25df858d8c2a9833/llama/tokenizer.py#L202-L229 + return ( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant.<|eot_id|>\n" # The system prompt is optional + "<|start_header_id|>user<|end_header_id|>\n\n" + f"{prompt}<|eot_id|>\n" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + return ( + [tokenizer.eos_id], + [tokenizer.token_to_id("<|eot_id|>")], + ) + + class FreeWilly2(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: return ( @@ -316,6 +334,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Llama2FunctionCalling() if re.search("Llama-2.*-chat", model_name): return Llama2() + if re.search("Llama-3.*-Instruct", model_name): + return Llama3() if re.search("FreeWilly2", model_name): return FreeWilly2() if re.search("Platypus", model_name): diff --git a/pyproject.toml b/pyproject.toml index 3b22a124ea..d2c26fc33e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "torch>=2.2.0", "lightning==2.3.0.dev20240328", "jsonargparse[signatures]>=4.27.6", + "litserve==0.0.0.dev2", # imported by litgpt.deploy ] [project.urls] diff --git a/tests/test_cli.py b/tests/test_cli.py index 2c994fcf96..f95841ddc0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -15,7 +15,7 @@ def test_cli(): main() out = out.getvalue() assert "usage: litgpt" in out - assert "{download,chat,finetune,pretrain,generate,convert,merge_lora,evaluate}" in out + assert "{download,chat,finetune,pretrain,generate,convert,merge_lora,evaluate,serve}" in out assert ( """Available subcommands: download Download weights or tokenizer data from the Hugging @@ -23,19 +23,8 @@ def test_cli(): chat Chat with a model.""" in out ) - assert ("""evaluate Evaluate a model with the LM Evaluation Harness.""") in out - - out = StringIO() - with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune", "-h"]): - main() - out = out.getvalue() - assert ( - """Available subcommands: - lora Finetune a model with LoRA. - full Finetune a model.""" - in out - ) - + assert """evaluate Evaluate a model with the LM Evaluation Harness.""" in out + assert """serve Serve and deploy a model with LitServe.""" in out out = StringIO() with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune", "lora", "-h"]): main() @@ -61,3 +50,13 @@ def test_cli(): Optional[int], default: 3000000000000)""" in out ) + + +def test_rewrite_finetune_command(): + out1 = StringIO() + with pytest.raises(SystemExit), redirect_stdout(out1), mock.patch("sys.argv", ["litgpt", "fineune", "-h"]): + main() + out2 = StringIO() + with pytest.raises(SystemExit), redirect_stdout(out2), mock.patch("sys.argv", ["litgpt", "fineune", "lora", "-h"]): + main() + assert out1.getvalue() == out2.getvalue() diff --git a/tests/test_model.py b/tests/test_model.py index 7bc0ccb5b4..0537098342 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -206,7 +206,13 @@ def test_against_original_open_llama_3b(device, dtype): @torch.inference_mode() @pytest.mark.parametrize( "ours_kwargs", - [{"name": "Llama-2-7b-hf"}, {"name": "CodeLlama-7b-hf"}, {"name": "Llama-2-70b-chat-hf", "n_query_groups": 1}], + [ + {"name": "Llama-2-7b-hf"}, + {"name": "CodeLlama-7b-hf"}, + {"name": "Llama-2-70b-chat-hf", "n_query_groups": 1}, + {"name": "Llama-3-8B"}, + {"name": "Llama-3-8B-Instruct"} + ], ) @pytest.mark.parametrize( ("device", "dtype"), @@ -224,7 +230,7 @@ def test_against_original_open_llama_3b(device, dtype): ), ], ) -def test_against_hf_llama2(ours_kwargs, device, dtype): +def test_against_hf_llama_2_and_3(ours_kwargs, device, dtype): torch.set_default_dtype(dtype) ours_config = Config.from_name( diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 3250ce4801..20f2c84e0c 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -50,6 +50,8 @@ def test_prompt_style_from_config(): "Llama-2-7b-chat-hf", "Llama-2-13b-chat-hf", "Llama-2-70b-chat-hf", + "Llama-3-8B-Instruct", + "Llama-3-70B-Instruct", "Gemma-2b-it", "Gemma-7b-it", "FreeWilly2", diff --git a/tests/test_serve.py b/tests/test_serve.py new file mode 100644 index 0000000000..46a109c807 --- /dev/null +++ b/tests/test_serve.py @@ -0,0 +1,42 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +from dataclasses import asdict +import shutil + +from lightning.fabric import seed_everything +from fastapi.testclient import TestClient +from litserve.server import LitServer +import torch +import yaml + + +from litgpt import GPT, Config +from litgpt.deploy.serve import SimpleLitAPI +from litgpt.scripts.download import download_from_hub + + +def test_simple(tmp_path): + + # Create model checkpoint + seed_everything(123) + ours_config = Config.from_name("pythia-14m") + download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) + ours_model = GPT(ours_config) + checkpoint_path = tmp_path / "lit_model.pth" + torch.save(ours_model.state_dict(), checkpoint_path) + config_path = tmp_path / "model_config.yaml" + with open(config_path, "w", encoding="utf-8") as fp: + yaml.dump(asdict(ours_config), fp) + + accelerator = "cpu" + server = LitServer( + SimpleLitAPI(checkpoint_dir=tmp_path, temperature=1, top_k=1), + accelerator=accelerator, devices=1, timeout=60 + ) + + with TestClient(server.app) as client: + response = client.post("/predict", json={"prompt": "Hello world"}) + # Model is a small random model, not trained, hence the gibberish. + # We are just testing that the server works. + assert response.json()["output"][:19] == "Hello world statues" diff --git a/tutorials/0_to_litgpt.md b/tutorials/0_to_litgpt.md index 337bf37049..e5e1c7c765 100644 --- a/tutorials/0_to_litgpt.md +++ b/tutorials/0_to_litgpt.md @@ -464,6 +464,44 @@ litgpt evaluate \ (A list of supported tasks can be found [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md).) +  +## Deploy LLMs + +You can deploy LitGPT LLMs using your tool of choice. Below is an example using LitGPT built-in serving capabilities: + + +```bash +# 1) Download a pretrained model (alternatively, use your own finetuned model) +litgpt download --repo_id microsoft/phi-2 + +# 2) Start the server +litgpt serve --checkpoint_dir checkpoints/microsoft/phi-2 +``` + +```python +# 3) Use the server (in a separate session) +import requests, json + response = requests.post( + "http://127.0.0.1:8000/predict", + json={"prompt": "Fix typos in the following sentence: Exampel input"} +) +print(response.json()["output"]) +``` + +This prints: + +``` +Instruct: Fix typos in the following sentence: Exampel input +Output: Example input. +``` + + +  +**More information and additional resources** + +- [tutorials/deploy](deploy.md): A full deployment tutorial and example + +   ## Converting LitGPT model weights to `safetensors` format diff --git a/tutorials/deploy.md b/tutorials/deploy.md new file mode 100644 index 0000000000..1b1495fde7 --- /dev/null +++ b/tutorials/deploy.md @@ -0,0 +1,49 @@ +# Serve and Deploy LLMs + +This document shows how you can serve a LitGPT for deployment. + +  +## Serve an LLM + +This section illustrates how we can set up an inference server for a phi-2 LLM using `litgpt serve` that is minimal and highly scalable. + + +  +## Step 1: Start the inference server + + +```bash +# 1) Download a pretrained model (alternatively, use your own finetuned model) +litgpt download --repo_id microsoft/phi-2 + +# 2) Start the server +litgpt serve --checkpoint_dir checkpoints/microsoft/phi-2 +``` + +> [!TIP] +> Use `litgpt serve --help` to display additional options, including the port, devices, LLM temperature setting, and more. + + +  +## Step 2: Query the inference server + +You can now send requests to the inference server you started in step 2. For example, in a new Python session, we can send requests to the inference server as follows: + + +```python +import requests, json + +response = requests.post( + "http://127.0.0.1:8000/predict", + json={"prompt": "Fix typos in the following sentence: Exampel input"} +) + +print(response.json()["output"]) +``` + +Executing the code above prints the following output: + +``` +Instruct: Fix typos in the following sentence: Exampel input +Output: Example input. +``` diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index d1c320ac33..45c9c7d50c 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -3,29 +3,30 @@ LitGPT supports a variety of LLM architectures with publicly available weights. You can download model weights and access a list of supported models using the LitGPT `download.py` script. -| Model | Model size | Reference | -|----------------------------------------------|------------------------------------------|------------------------------------------------------------------------------------------------------------------------------| -| CodeGemma by Google | 7B | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | -| Code Llama by Meta AI | 7B, 13B, 34B, 70B | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | -| Dolly by Databricks | 3B, 7B, 12B | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | -| Falcon by TII UAE | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) | -| FreeWilly2 (Stable Beluga 2) by Stability AI | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | -| Function Calling Llama 2 by Trelis | 7B | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | -| Gemma by Google | 2B, 7B | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | -| Llama 2 by Meta AI | 7B, 13B, 70B | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | -| LongChat by LMSYS | 7B, 13B | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | -| Mistral and Mixtral by Mistral AI | 7B | [Mistral website](https://mistral.ai/) | -| Nous-Hermes by NousResearch | 7B, 13B, 70B | [Org page](https://huggingface.co/NousResearch) | -| OpenLLaMA by OpenLM Research | 3B, 7B, 13B | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | -| Phi by Microsoft Research | 1.3B, 2.7B | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | -| Platypus by Lee at el. | 7B, 13B, 70B | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | -| Pythia by EleutherAI | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | -| RedPajama-INCITE by Together | 3B, 7B | [Together 2023](https://together.ai/blog/redpajama-models-v1) | -| StableCode by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | -| StableLM by Stability AI | 3B, 7B | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | -| StableLM Zephyr by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | -| TinyLlama by Zhang et al. | 1.1B | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | -| Vicuna by LMSYS | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) | +| Model | Model size | Reference | +|----------------------------------------------|-----------------------------------------|--------------------------------------------------------------------------------------------------------------------------| +| CodeGemma by Google | 7B | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | +| Code Llama by Meta AI | 7B, 13B, 34B, 70B | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | +| Dolly by Databricks | 3B, 7B, 12B | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | +| Falcon by TII UAE | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) | +| FreeWilly2 (Stable Beluga 2) by Stability AI | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | +| Function Calling Llama 2 by Trelis | 7B | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | +| Gemma by Google | 2B, 7B | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | +| Llama 2 by Meta AI | 7B, 13B, 70B | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | +| Llama 3 by Meta AI | 8B, 70B | [Meta AI 2024](https://github.com/meta-llama/llama3) | +| LongChat by LMSYS | 7B, 13B | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | +| Mistral and Mixtral by Mistral AI | 7B | [Mistral website](https://mistral.ai/) | +| Nous-Hermes by NousResearch | 7B, 13B, 70B | [Org page](https://huggingface.co/NousResearch) | +| OpenLLaMA by OpenLM Research | 3B, 7B, 13B | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | +| Phi by Microsoft Research | 1.3B, 2.7B | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | +| Platypus by Lee at el. | 7B, 13B, 70B | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | +| Pythia by EleutherAI | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | +| RedPajama-INCITE by Together | 3B, 7B | [Together 2023](https://together.ai/blog/redpajama-models-v1) | +| StableCode by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| StableLM by Stability AI | 3B, 7B | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | +| StableLM Zephyr by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| TinyLlama by Zhang et al. | 1.1B | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | +| Vicuna by LMSYS | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) | @@ -105,6 +106,10 @@ meta-llama/Llama-2-70b-chat-hf meta-llama/Llama-2-70b-hf meta-llama/Llama-2-7b-chat-hf meta-llama/Llama-2-7b-hf +meta-llama/Meta-Llama-3-70B +meta-llama/Meta-Llama-3-70B-Instruct +meta-llama/Meta-Llama-3-8B +meta-llama/Meta-Llama-3-8B-Instruct microsoft/phi-1_5 microsoft/phi-2 mistralai/Mistral-7B-Instruct-v0.1