diff --git a/README.md b/README.md index c91e434f02..ddf733291e 100644 --- a/README.md +++ b/README.md @@ -220,13 +220,14 @@ litgpt chat \ ### Continue pretraining an LLM This is another way of finetuning that specialize an already pretrained model by training on custom data: + Open In Studio   -``` +```bash mkdir -p custom_texts curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt @@ -251,6 +252,30 @@ litgpt chat \   +### Deploy an LLM + +This example illustrates how to deploy an LLM using LitGPT + +```bash +# 1) Download a pretrained model (alternatively, use your own finetuned model) +litgpt download --repo_id microsoft/phi-2 + +# 2) Start the server +litgpt serve --checkpoint_dir checkpoints/microsoft/phi-2 +``` + +```python +# 3) Use the server (in a separate session) +import requests, json + response = requests.post( + "http://127.0.0.1:8000/predict", + json={"prompt": "Fix typos in the following sentence: Exampel input"} +) +print(response.json()["output"]) +``` + +  + > [!NOTE] > **[Read the full docs](tutorials/0_to_litgpt.md)**. diff --git a/litgpt/__main__.py b/litgpt/__main__.py index 59d53ac904..e88c6212c8 100644 --- a/litgpt/__main__.py +++ b/litgpt/__main__.py @@ -24,6 +24,8 @@ from litgpt.scripts.download import download_from_hub as download_fn from litgpt.scripts.merge_lora import merge_lora as merge_lora_fn from litgpt.eval.evaluate import convert_and_evaluate as evaluate_fn +from litgpt.deploy.serve import run_server as serve_fn + if TYPE_CHECKING: from jsonargparse import ArgumentParser @@ -80,6 +82,7 @@ def main() -> None: }, "merge_lora": {"help": "Merges the LoRA weights with the base model.", "fn": merge_lora_fn}, "evaluate": {"help": "Evaluate a model with the LM Evaluation Harness.", "fn": evaluate_fn}, + "serve": {"help": "Serve and deploy a model with LitServe.", "fn": serve_fn}, } from jsonargparse import set_config_read_mode, set_docstring_parse_options diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py new file mode 100644 index 0000000000..9df48ad98d --- /dev/null +++ b/litgpt/deploy/serve.py @@ -0,0 +1,137 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +from pathlib import Path +from typing import Dict, Any, Optional, Literal +from litgpt.utils import check_valid_checkpoint_dir + +import lightning as L +import torch +from litserve import LitAPI, LitServer + +from litgpt.model import GPT +from litgpt.config import Config +from litgpt.tokenizer import Tokenizer +from litgpt.generate.base import generate +from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle +from litgpt.utils import load_checkpoint, CLI, get_default_supported_precision + + +class SimpleLitAPI(LitAPI): + def __init__(self, + checkpoint_dir: Path, + precision: Optional[str] = None, + temperature: float = 0.8, + top_k: int = 50, + max_new_tokens: int = 50) -> None: + + super().__init__() + self.checkpoint_dir = checkpoint_dir + self.precision = precision + self.temperature = temperature + self.top_k = top_k + self.max_new_tokens = max_new_tokens + + def setup(self, device: str) -> None: + # Setup the model so it can be called in `predict`. + config = Config.from_file(self.checkpoint_dir / "model_config.yaml") + device = torch.device(device) + torch.set_float32_matmul_precision("high") + + precision = self.precision or get_default_supported_precision(training=False) + + fabric = L.Fabric( + accelerator=device.type, + devices=1 if device.type=="cpu" else [device.index], # TODO: Update once LitServe supports "auto" + precision=precision, + ) + checkpoint_path = self.checkpoint_dir / "lit_model.pth" + self.tokenizer = Tokenizer(self.checkpoint_dir) + self.prompt_style = ( + load_prompt_style(self.checkpoint_dir) + if has_prompt_style(self.checkpoint_dir) + else PromptStyle.from_config(config) + ) + with fabric.init_module(empty_init=True): + model = GPT(config) + with fabric.init_tensor(): + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + self.model = fabric.setup_module(model) + load_checkpoint(fabric, self.model, checkpoint_path) + self.device = fabric.device + + def decode_request(self, request: Dict[str, Any]) -> Any: + # Convert the request payload to your model input. + prompt = request["prompt"] + prompt = self.prompt_style.apply(prompt) + encoded = self.tokenizer.encode(prompt, device=self.device) + return encoded + + def predict(self, inputs: torch.Tensor) -> Any: + # Run the model on the input and return the output. + prompt_length = inputs.size(0) + max_returned_tokens = prompt_length + self.max_new_tokens + + y = generate( + self.model, + inputs, + max_returned_tokens, + temperature=self.temperature, + top_k=self.top_k, + eos_id=self.tokenizer.eos_id + ) + + for block in self.model.transformer.h: + block.attn.kv_cache.reset_parameters() + return y + + def encode_response(self, output: torch.Tensor) -> Dict[str, Any]: + # Convert the model output to a response payload. + decoded_output = self.tokenizer.decode(output) + return {"output": decoded_output} + + +def run_server( + checkpoint_dir: Path = Path("checkpoints"), + precision: Optional[str] = None, + temperature: float = 0.8, + top_k: int = 200, + max_new_tokens: int = 50, + devices: int = 1, + accelerator: str = "cuda", + port: int = 8000 +) -> None: + """Serve a LitGPT model using LitServe + + Arguments: + checkpoint_dir: The checkpoint directory to load the model from. + precision: Optional precision setting to instantiate the model weights in. By default, this will + automatically be inferred from the metadata in the given ``checkpoint_dir`` directory. + temperature: Temperature setting for the text generation. Value above 1 increase randomness. + Values below 1 decrease randomness. + top_k: The size of the pool of potential next tokens. Values larger than 1 result in more novel + generated text but can also lead to more incoherent texts. + max_new_tokens: The number of generation steps to take. + devices: How many devices/GPUs to use. + accelerator: The type of accelerator to use. For example, "cuda" or "cpu". + port: The network port number on which the model is configured to be served. + """ + check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth") + + server = LitServer( + SimpleLitAPI( + checkpoint_dir=checkpoint_dir, + precision=precision, + temperature=temperature, + top_k=top_k, + max_new_tokens=max_new_tokens, + ), + accelerator=accelerator, + devices=devices) + + server.run(port=port) + + +if __name__ == "__main__": + CLI(run_server) diff --git a/pyproject.toml b/pyproject.toml index 1d3c89cfd9..b6fbec18b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "torch>=2.2.0", "lightning==2.3.0.dev20240328", "jsonargparse[signatures]>=4.27.6", + "litserve==0.0.0.dev2", # imported by litgpt.deploy ] [project.urls] diff --git a/tests/test_cli.py b/tests/test_cli.py index 2c994fcf96..49a10a07ab 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -15,7 +15,7 @@ def test_cli(): main() out = out.getvalue() assert "usage: litgpt" in out - assert "{download,chat,finetune,pretrain,generate,convert,merge_lora,evaluate}" in out + assert "{download,chat,finetune,pretrain,generate,convert,merge_lora,evaluate,serve}" in out assert ( """Available subcommands: download Download weights or tokenizer data from the Hugging @@ -24,7 +24,7 @@ def test_cli(): in out ) assert ("""evaluate Evaluate a model with the LM Evaluation Harness.""") in out - + assert ("""serve Serve and deploy a model with LitServe.""") in out out = StringIO() with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune", "-h"]): main() diff --git a/tests/test_serve.py b/tests/test_serve.py new file mode 100644 index 0000000000..46a109c807 --- /dev/null +++ b/tests/test_serve.py @@ -0,0 +1,42 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +from dataclasses import asdict +import shutil + +from lightning.fabric import seed_everything +from fastapi.testclient import TestClient +from litserve.server import LitServer +import torch +import yaml + + +from litgpt import GPT, Config +from litgpt.deploy.serve import SimpleLitAPI +from litgpt.scripts.download import download_from_hub + + +def test_simple(tmp_path): + + # Create model checkpoint + seed_everything(123) + ours_config = Config.from_name("pythia-14m") + download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) + ours_model = GPT(ours_config) + checkpoint_path = tmp_path / "lit_model.pth" + torch.save(ours_model.state_dict(), checkpoint_path) + config_path = tmp_path / "model_config.yaml" + with open(config_path, "w", encoding="utf-8") as fp: + yaml.dump(asdict(ours_config), fp) + + accelerator = "cpu" + server = LitServer( + SimpleLitAPI(checkpoint_dir=tmp_path, temperature=1, top_k=1), + accelerator=accelerator, devices=1, timeout=60 + ) + + with TestClient(server.app) as client: + response = client.post("/predict", json={"prompt": "Hello world"}) + # Model is a small random model, not trained, hence the gibberish. + # We are just testing that the server works. + assert response.json()["output"][:19] == "Hello world statues" diff --git a/tutorials/0_to_litgpt.md b/tutorials/0_to_litgpt.md index 337bf37049..e5e1c7c765 100644 --- a/tutorials/0_to_litgpt.md +++ b/tutorials/0_to_litgpt.md @@ -464,6 +464,44 @@ litgpt evaluate \ (A list of supported tasks can be found [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md).) +  +## Deploy LLMs + +You can deploy LitGPT LLMs using your tool of choice. Below is an example using LitGPT built-in serving capabilities: + + +```bash +# 1) Download a pretrained model (alternatively, use your own finetuned model) +litgpt download --repo_id microsoft/phi-2 + +# 2) Start the server +litgpt serve --checkpoint_dir checkpoints/microsoft/phi-2 +``` + +```python +# 3) Use the server (in a separate session) +import requests, json + response = requests.post( + "http://127.0.0.1:8000/predict", + json={"prompt": "Fix typos in the following sentence: Exampel input"} +) +print(response.json()["output"]) +``` + +This prints: + +``` +Instruct: Fix typos in the following sentence: Exampel input +Output: Example input. +``` + + +  +**More information and additional resources** + +- [tutorials/deploy](deploy.md): A full deployment tutorial and example + +   ## Converting LitGPT model weights to `safetensors` format diff --git a/tutorials/deploy.md b/tutorials/deploy.md new file mode 100644 index 0000000000..1b1495fde7 --- /dev/null +++ b/tutorials/deploy.md @@ -0,0 +1,49 @@ +# Serve and Deploy LLMs + +This document shows how you can serve a LitGPT for deployment. + +  +## Serve an LLM + +This section illustrates how we can set up an inference server for a phi-2 LLM using `litgpt serve` that is minimal and highly scalable. + + +  +## Step 1: Start the inference server + + +```bash +# 1) Download a pretrained model (alternatively, use your own finetuned model) +litgpt download --repo_id microsoft/phi-2 + +# 2) Start the server +litgpt serve --checkpoint_dir checkpoints/microsoft/phi-2 +``` + +> [!TIP] +> Use `litgpt serve --help` to display additional options, including the port, devices, LLM temperature setting, and more. + + +  +## Step 2: Query the inference server + +You can now send requests to the inference server you started in step 2. For example, in a new Python session, we can send requests to the inference server as follows: + + +```python +import requests, json + +response = requests.post( + "http://127.0.0.1:8000/predict", + json={"prompt": "Fix typos in the following sentence: Exampel input"} +) + +print(response.json()["output"]) +``` + +Executing the code above prints the following output: + +``` +Instruct: Fix typos in the following sentence: Exampel input +Output: Example input. +```