diff --git a/README.md b/README.md
index 3694d0cd4a..78fdff2b6d 100644
--- a/README.md
+++ b/README.md
@@ -155,6 +155,12 @@ For more information, refer to the [download](tutorials/download_model_weights.m
### Finetune an LLM
[Finetune](tutorials/finetune.md) a model to specialize it on your own custom dataset:
+
+
+
+
+
+
```bash
# 1) Download a pretrained model
litgpt download --repo_id microsoft/phi-2
@@ -174,9 +180,17 @@ litgpt chat \
--checkpoint_dir out/phi-2-lora/final
```
+
+
### Pretrain an LLM
Train an LLM from scratch on your own data via pretraining:
+
+
+
+
+
+
```bash
mkdir -p custom_texts
curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt
@@ -201,10 +215,19 @@ litgpt chat \
--checkpoint_dir out/custom-model/final
```
+
+
### Continue pretraining an LLM
This is another way of finetuning that specialize an already pretrained model by training on custom data:
-```
+
+
+
+
+
+
+
+```bash
mkdir -p custom_texts
curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt
curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt
@@ -215,6 +238,7 @@ litgpt download --repo_id EleutherAI/pythia-160m
# 2) Continue pretraining the model
litgpt pretrain \
--model_name pythia-160m \
+ --tokenizer_dir checkpoints/EleutherAI/pythia-160m \
--initial_checkpoint_dir checkpoints/EleutherAI/pythia-160m \
--data TextFiles \
--data.train_data_path "custom_texts/" \
@@ -228,6 +252,30 @@ litgpt chat \
+### Deploy an LLM
+
+This example illustrates how to deploy an LLM using LitGPT
+
+```bash
+# 1) Download a pretrained model (alternatively, use your own finetuned model)
+litgpt download --repo_id microsoft/phi-2
+
+# 2) Start the server
+litgpt serve --checkpoint_dir checkpoints/microsoft/phi-2
+```
+
+```python
+# 3) Use the server (in a separate session)
+import requests, json
+ response = requests.post(
+ "http://127.0.0.1:8000/predict",
+ json={"prompt": "Fix typos in the following sentence: Exampel input"}
+)
+print(response.json()["output"])
+```
+
+
+
> [!NOTE]
> **[Read the full docs](tutorials/0_to_litgpt.md)**.
diff --git a/litgpt/__main__.py b/litgpt/__main__.py
index 06a865f75e..821c1f5801 100644
--- a/litgpt/__main__.py
+++ b/litgpt/__main__.py
@@ -25,6 +25,8 @@
from litgpt.scripts.download import download_from_hub as download_fn
from litgpt.scripts.merge_lora import merge_lora as merge_lora_fn
from litgpt.eval.evaluate import convert_and_evaluate as evaluate_fn
+from litgpt.deploy.serve import run_server as serve_fn
+
if TYPE_CHECKING:
from jsonargparse import ArgumentParser
@@ -87,6 +89,7 @@ def main() -> None:
},
"merge_lora": {"help": "Merges the LoRA weights with the base model.", "fn": merge_lora_fn},
"evaluate": {"help": "Evaluate a model with the LM Evaluation Harness.", "fn": evaluate_fn},
+ "serve": {"help": "Serve and deploy a model with LitServe.", "fn": serve_fn},
}
from jsonargparse import set_config_read_mode, set_docstring_parse_options
diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py
new file mode 100644
index 0000000000..9df48ad98d
--- /dev/null
+++ b/litgpt/deploy/serve.py
@@ -0,0 +1,137 @@
+# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
+from pathlib import Path
+from typing import Dict, Any, Optional, Literal
+from litgpt.utils import check_valid_checkpoint_dir
+
+import lightning as L
+import torch
+from litserve import LitAPI, LitServer
+
+from litgpt.model import GPT
+from litgpt.config import Config
+from litgpt.tokenizer import Tokenizer
+from litgpt.generate.base import generate
+from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
+from litgpt.utils import load_checkpoint, CLI, get_default_supported_precision
+
+
+class SimpleLitAPI(LitAPI):
+ def __init__(self,
+ checkpoint_dir: Path,
+ precision: Optional[str] = None,
+ temperature: float = 0.8,
+ top_k: int = 50,
+ max_new_tokens: int = 50) -> None:
+
+ super().__init__()
+ self.checkpoint_dir = checkpoint_dir
+ self.precision = precision
+ self.temperature = temperature
+ self.top_k = top_k
+ self.max_new_tokens = max_new_tokens
+
+ def setup(self, device: str) -> None:
+ # Setup the model so it can be called in `predict`.
+ config = Config.from_file(self.checkpoint_dir / "model_config.yaml")
+ device = torch.device(device)
+ torch.set_float32_matmul_precision("high")
+
+ precision = self.precision or get_default_supported_precision(training=False)
+
+ fabric = L.Fabric(
+ accelerator=device.type,
+ devices=1 if device.type=="cpu" else [device.index], # TODO: Update once LitServe supports "auto"
+ precision=precision,
+ )
+ checkpoint_path = self.checkpoint_dir / "lit_model.pth"
+ self.tokenizer = Tokenizer(self.checkpoint_dir)
+ self.prompt_style = (
+ load_prompt_style(self.checkpoint_dir)
+ if has_prompt_style(self.checkpoint_dir)
+ else PromptStyle.from_config(config)
+ )
+ with fabric.init_module(empty_init=True):
+ model = GPT(config)
+ with fabric.init_tensor():
+ # enable the kv cache
+ model.set_kv_cache(batch_size=1)
+ model.eval()
+
+ self.model = fabric.setup_module(model)
+ load_checkpoint(fabric, self.model, checkpoint_path)
+ self.device = fabric.device
+
+ def decode_request(self, request: Dict[str, Any]) -> Any:
+ # Convert the request payload to your model input.
+ prompt = request["prompt"]
+ prompt = self.prompt_style.apply(prompt)
+ encoded = self.tokenizer.encode(prompt, device=self.device)
+ return encoded
+
+ def predict(self, inputs: torch.Tensor) -> Any:
+ # Run the model on the input and return the output.
+ prompt_length = inputs.size(0)
+ max_returned_tokens = prompt_length + self.max_new_tokens
+
+ y = generate(
+ self.model,
+ inputs,
+ max_returned_tokens,
+ temperature=self.temperature,
+ top_k=self.top_k,
+ eos_id=self.tokenizer.eos_id
+ )
+
+ for block in self.model.transformer.h:
+ block.attn.kv_cache.reset_parameters()
+ return y
+
+ def encode_response(self, output: torch.Tensor) -> Dict[str, Any]:
+ # Convert the model output to a response payload.
+ decoded_output = self.tokenizer.decode(output)
+ return {"output": decoded_output}
+
+
+def run_server(
+ checkpoint_dir: Path = Path("checkpoints"),
+ precision: Optional[str] = None,
+ temperature: float = 0.8,
+ top_k: int = 200,
+ max_new_tokens: int = 50,
+ devices: int = 1,
+ accelerator: str = "cuda",
+ port: int = 8000
+) -> None:
+ """Serve a LitGPT model using LitServe
+
+ Arguments:
+ checkpoint_dir: The checkpoint directory to load the model from.
+ precision: Optional precision setting to instantiate the model weights in. By default, this will
+ automatically be inferred from the metadata in the given ``checkpoint_dir`` directory.
+ temperature: Temperature setting for the text generation. Value above 1 increase randomness.
+ Values below 1 decrease randomness.
+ top_k: The size of the pool of potential next tokens. Values larger than 1 result in more novel
+ generated text but can also lead to more incoherent texts.
+ max_new_tokens: The number of generation steps to take.
+ devices: How many devices/GPUs to use.
+ accelerator: The type of accelerator to use. For example, "cuda" or "cpu".
+ port: The network port number on which the model is configured to be served.
+ """
+ check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth")
+
+ server = LitServer(
+ SimpleLitAPI(
+ checkpoint_dir=checkpoint_dir,
+ precision=precision,
+ temperature=temperature,
+ top_k=top_k,
+ max_new_tokens=max_new_tokens,
+ ),
+ accelerator=accelerator,
+ devices=devices)
+
+ server.run(port=port)
+
+
+if __name__ == "__main__":
+ CLI(run_server)
diff --git a/pyproject.toml b/pyproject.toml
index 1d3c89cfd9..b6fbec18b8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,6 +12,7 @@ dependencies = [
"torch>=2.2.0",
"lightning==2.3.0.dev20240328",
"jsonargparse[signatures]>=4.27.6",
+ "litserve==0.0.0.dev2", # imported by litgpt.deploy
]
[project.urls]
diff --git a/tests/test_cli.py b/tests/test_cli.py
index a45a1ca092..f95841ddc0 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -15,7 +15,7 @@ def test_cli():
main()
out = out.getvalue()
assert "usage: litgpt" in out
- assert "{download,chat,finetune,pretrain,generate,convert,merge_lora,evaluate}" in out
+ assert "{download,chat,finetune,pretrain,generate,convert,merge_lora,evaluate,serve}" in out
assert (
"""Available subcommands:
download Download weights or tokenizer data from the Hugging
@@ -24,7 +24,7 @@ def test_cli():
in out
)
assert """evaluate Evaluate a model with the LM Evaluation Harness.""" in out
-
+ assert """serve Serve and deploy a model with LitServe.""" in out
out = StringIO()
with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune", "lora", "-h"]):
main()
diff --git a/tests/test_serve.py b/tests/test_serve.py
new file mode 100644
index 0000000000..46a109c807
--- /dev/null
+++ b/tests/test_serve.py
@@ -0,0 +1,42 @@
+# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
+from dataclasses import asdict
+import shutil
+
+from lightning.fabric import seed_everything
+from fastapi.testclient import TestClient
+from litserve.server import LitServer
+import torch
+import yaml
+
+
+from litgpt import GPT, Config
+from litgpt.deploy.serve import SimpleLitAPI
+from litgpt.scripts.download import download_from_hub
+
+
+def test_simple(tmp_path):
+
+ # Create model checkpoint
+ seed_everything(123)
+ ours_config = Config.from_name("pythia-14m")
+ download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
+ shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
+ shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
+ ours_model = GPT(ours_config)
+ checkpoint_path = tmp_path / "lit_model.pth"
+ torch.save(ours_model.state_dict(), checkpoint_path)
+ config_path = tmp_path / "model_config.yaml"
+ with open(config_path, "w", encoding="utf-8") as fp:
+ yaml.dump(asdict(ours_config), fp)
+
+ accelerator = "cpu"
+ server = LitServer(
+ SimpleLitAPI(checkpoint_dir=tmp_path, temperature=1, top_k=1),
+ accelerator=accelerator, devices=1, timeout=60
+ )
+
+ with TestClient(server.app) as client:
+ response = client.post("/predict", json={"prompt": "Hello world"})
+ # Model is a small random model, not trained, hence the gibberish.
+ # We are just testing that the server works.
+ assert response.json()["output"][:19] == "Hello world statues"
diff --git a/tutorials/0_to_litgpt.md b/tutorials/0_to_litgpt.md
index 337bf37049..e5e1c7c765 100644
--- a/tutorials/0_to_litgpt.md
+++ b/tutorials/0_to_litgpt.md
@@ -464,6 +464,44 @@ litgpt evaluate \
(A list of supported tasks can be found [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md).)
+
+## Deploy LLMs
+
+You can deploy LitGPT LLMs using your tool of choice. Below is an example using LitGPT built-in serving capabilities:
+
+
+```bash
+# 1) Download a pretrained model (alternatively, use your own finetuned model)
+litgpt download --repo_id microsoft/phi-2
+
+# 2) Start the server
+litgpt serve --checkpoint_dir checkpoints/microsoft/phi-2
+```
+
+```python
+# 3) Use the server (in a separate session)
+import requests, json
+ response = requests.post(
+ "http://127.0.0.1:8000/predict",
+ json={"prompt": "Fix typos in the following sentence: Exampel input"}
+)
+print(response.json()["output"])
+```
+
+This prints:
+
+```
+Instruct: Fix typos in the following sentence: Exampel input
+Output: Example input.
+```
+
+
+
+**More information and additional resources**
+
+- [tutorials/deploy](deploy.md): A full deployment tutorial and example
+
+
## Converting LitGPT model weights to `safetensors` format
diff --git a/tutorials/deploy.md b/tutorials/deploy.md
new file mode 100644
index 0000000000..1b1495fde7
--- /dev/null
+++ b/tutorials/deploy.md
@@ -0,0 +1,49 @@
+# Serve and Deploy LLMs
+
+This document shows how you can serve a LitGPT for deployment.
+
+
+## Serve an LLM
+
+This section illustrates how we can set up an inference server for a phi-2 LLM using `litgpt serve` that is minimal and highly scalable.
+
+
+
+## Step 1: Start the inference server
+
+
+```bash
+# 1) Download a pretrained model (alternatively, use your own finetuned model)
+litgpt download --repo_id microsoft/phi-2
+
+# 2) Start the server
+litgpt serve --checkpoint_dir checkpoints/microsoft/phi-2
+```
+
+> [!TIP]
+> Use `litgpt serve --help` to display additional options, including the port, devices, LLM temperature setting, and more.
+
+
+
+## Step 2: Query the inference server
+
+You can now send requests to the inference server you started in step 2. For example, in a new Python session, we can send requests to the inference server as follows:
+
+
+```python
+import requests, json
+
+response = requests.post(
+ "http://127.0.0.1:8000/predict",
+ json={"prompt": "Fix typos in the following sentence: Exampel input"}
+)
+
+print(response.json()["output"])
+```
+
+Executing the code above prints the following output:
+
+```
+Instruct: Fix typos in the following sentence: Exampel input
+Output: Example input.
+```