Skip to content

Commit

Permalink
working deployment example
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Apr 12, 2024
1 parent 876b1b7 commit 286d626
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 36 deletions.
60 changes: 26 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,64 +125,56 @@ litgpt chat \
--checkpoint_dir out/phi-2-lora/final
```


 

### Deploy an LLM

Deploy a LitGPT model using [LitServe](https://github.com/Lightning-AI/litserve):

```bash
# pip install litserve

import litserve as ls
from litgpt.generate.base import main
from functools import partial
from pathlib import Path

```python
from litserve import LitAPI, LitServer
...

# STEP 1: DEFINE YOUR MODEL API
class SimpleAPIForLitGPT(ls.LitAPI):
class SimpleLitAPI(LitAPI):

def setup(self, device):
# Setup the model so it can be called in `predict`.
self.generate = partial(
main,
top_k=200,
temperature=0.8,
checkpoint_dir=Path("litgpt/checkpoints/microsoft/phi-2"),
precision="bf16-true",
quantize=None,
compile=False
)
repo_id = "microsoft/phi-2"
checkpoint_dir = Path(f"checkpoints/{repo_id}")
...


def decode_request(self, request):
# Convert the request payload to your model input.
return request["input"]
prompt = request["prompt"]
...
return encoded

def predict(self, x):
def predict(self, inputs):
# Run the model on the input and return the output.
return self.generate(prompt=x)
...
y = generate(...)
return y

def encode_response(self, output):
# Convert the model output to a response payload.
return {"output": output}
decoded_output = self.tokenizer.decode(output)
return {"output": decoded_output}


# STEP 2: START THE SERVER
api = SimpleAPIForLitGPT()
server = ls.LitServer(api, accelerator="gpu")
server.run(port=8000)
if __name__ == "__main__":
server = LitServer(SimpleLitAPI(), accelerator="cuda", devices=1)
server.run(port=8000)
```

In a new Python session:

```python
# STEP 3: USE THE SERVER
import requests

response = requests.post(
"http://127.0.0.1:8000/predict",
json={"prompt": "Fix typos in the following sentence: Exampel input"}
)
print(response.content)
import requests, json
response = requests.post("http://127.0.0.1:8000/predict",
json={"prompt": "Fix typos in the following sentence: Exampel input"})
```

 
Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor

from litgpt import Config
from litgpt.config import Config
from litgpt.utils import incremental_save, lazy_load, save_config


Expand Down
56 changes: 55 additions & 1 deletion tutorials/0_to_litgpt.md
Original file line number Diff line number Diff line change
Expand Up @@ -441,13 +441,67 @@ Time for inference: 1.14 sec total, 26.26 tokens/sec, 30 tokens
> [!TIP]
> Most model weights are already represented in an efficient bfloat16 format. However, if the model currently exceeds your GPU memory, you can try to pass the `--precision bf16-true` option. In addition, you can check the quantization documentation for further optimization, which is linked below.

 
**More information and additional resources**

- [tutorials/inference](inference.md): Chat and inference tutorial
- [tutorials/quantize](quantize.md): Quantizing models to reduce GPU memory requirements

 
## Deploy LLMs

You can deploy LitGPT LLMs using your tool of choice. Below is an abbreviated example using [LitServe](https://github.com/Lightning-AI/litserve) to deploy an LLM:

```python
from litserve import LitAPI, LitServer
...

# STEP 1: DEFINE YOUR MODEL API
class SimpleLitAPI(LitAPI):

def setup(self, device):
# Setup the model so it can be called in `predict`.
repo_id = "microsoft/phi-2"
checkpoint_dir = Path(f"checkpoints/{repo_id}")
...


def decode_request(self, request):
# Convert the request payload to your model input.
prompt = request["prompt"]
...
return encoded

def predict(self, inputs):
# Run the model on the input and return the output.
...
y = generate(...)
return y

def encode_response(self, output):
# Convert the model output to a response payload.
decoded_output = self.tokenizer.decode(output)
return {"output": decoded_output}


# STEP 2: START THE SERVER
if __name__ == "__main__":
server = LitServer(SimpleLitAPI(), accelerator="cuda", devices=1)
server.run(port=8000)
```

```python
# STEP 3: USE THE SERVER
import requests, json
response = requests.post("http://127.0.0.1:8000/predict",
json={"prompt": "Fix typos in the following sentence: Exampel input"})
```

 
**More information and additional resources**

- [tutorials/deploy](deploy.md): A full deployment tutorial and example


 
## Evaluating models
Expand Down
136 changes: 136 additions & 0 deletions tutorials/deploy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Deploy and Serve LLMs

This document shows how you can serve a LitGPT for deployment.

 
## Serve an LLM with LitServe

This section illustrates how we can set up an inference server for a phi-2 LLM using [LitServe](https://github.com/Lightning-AI/litserve).

[LitServe](https://github.com/Lightning-AI/litserve) is an inference server for AI/ML models that is minimal and highly scalable.

You can install LitServe as follows:

```bash
pip install litserve
```

 
### Step 1: Create a server.py file

First, copy the following code into a file called `server.py`:

```python
from pathlib import Path

import lightning as L
import torch
from litserve import LitAPI, LitServer

from litgpt.model import GPT
from litgpt.config import Config
from litgpt.tokenizer import Tokenizer
from litgpt.generate.base import generate
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
from litgpt.scripts.download import download_from_hub
from litgpt.utils import load_checkpoint


# DEFINE YOUR MODEL API
class SimpleLitAPI(LitAPI):

def setup(self, device):
# Setup the model so it can be called in `predict`.
repo_id = "microsoft/phi-2"
checkpoint_dir = Path(f"checkpoints/{repo_id}")

if not checkpoint_dir.exists():
download_from_hub(repo_id=repo_id)

config = Config.from_file(checkpoint_dir / "model_config.yaml")

device = torch.device(device)
torch.set_float32_matmul_precision("high")
fabric = L.Fabric(accelerator=device.type, devices=[device.index], precision="bf16-true")

checkpoint_path = checkpoint_dir / "lit_model.pth"

self.tokenizer = Tokenizer(checkpoint_dir)
self.prompt_style = (
load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
)

with fabric.init_module(empty_init=True):
model = GPT(config)
with fabric.init_tensor():
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()

self.model = fabric.setup_module(model)

load_checkpoint(fabric, self.model, checkpoint_path)

self.device = fabric.device

def decode_request(self, request):
# Convert the request payload to your model input.
prompt = request["prompt"]
prompt = self.prompt_style.apply(prompt)
encoded = self.tokenizer.encode(prompt, device=self.device)
return encoded

def predict(self, inputs):
# Run the model on the input and return the output.
prompt_length = inputs.size(0)
max_returned_tokens = prompt_length + 30

y = generate(self.model, inputs, max_returned_tokens, temperature=0.8, top_k=200, eos_id=self.tokenizer.eos_id)

for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()
return y

def encode_response(self, output):
# Convert the model output to a response payload.
decoded_output = self.tokenizer.decode(output)
return {"output": decoded_output}


# START THE SERVER
if __name__ == "__main__":
server = LitServer(SimpleLitAPI(), accelerator="cuda", devices=1)
server.run(port=8000)
```

 
## Step 2: Start the inference server

After you saved the code from step 1 in a `server.py` file, start the inference server from your command line terminal:

```bash
python server.py
```

 
## Step 3: Query the inference server

You can now send requests to the inference server you started in step 2. For example, in a new Python session, we can send requests to the inference server as follows:


```python
import requests, json
response = requests.post("http://127.0.0.1:8000/predict",
json={"prompt": "Fix typos in the following sentence: Exampel input"})

decoded_string = response.content.decode("utf-8")
output_str = json.loads(decoded_string)["output"]
print(output_str)
```

Executing the code above prints the following output:

```
Instruct:Fix typos in the following sentence: Exampel input
Output: Example input.
```

0 comments on commit 286d626

Please sign in to comment.