Skip to content

Commit

Permalink
add streaming support in litgpt serve
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 17, 2024
1 parent 089345c commit c890334
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 16 deletions.
97 changes: 84 additions & 13 deletions litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
from pathlib import Path
from typing import Dict, Any, Optional
from litgpt.utils import check_valid_checkpoint_dir
Expand All @@ -23,7 +24,7 @@
LitAPI, LitServer = object, object


class SimpleLitAPI(LitAPI):
class BaseLitAPI(LitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
Expand Down Expand Up @@ -81,6 +82,20 @@ def decode_request(self, request: Dict[str, Any]) -> Any:
encoded = self.tokenizer.encode(prompt, device=self.device)
return encoded


class SimpleLitAPI(BaseLitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50):
super().__init__(checkpoint_dir, precision, temperature, top_k, top_p, max_new_tokens)

def setup(self, device: str):
super().setup(device)

def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.
prompt_length = inputs.size(0)
Expand Down Expand Up @@ -108,6 +123,42 @@ def encode_response(self, output: torch.Tensor) -> Dict[str, Any]:
return {"output": decoded_output}


class StreamLitAPI(BaseLitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50):
super().__init__(checkpoint_dir, precision, temperature, top_k, top_p, max_new_tokens)

def setup(self, device: str):
super().setup(device)

def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.
prompt_length = inputs.size(0)
max_returned_tokens = prompt_length + self.max_new_tokens

for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()

yield generate(
self.model,
inputs,
max_returned_tokens,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
eos_id=self.tokenizer.eos_id
)

def encode_response(self, output_stream):
for outputs in output_stream:
yield [json.dumps({"output": self.tokenizer.decode(output)}) for output in outputs]


def run_server(
checkpoint_dir: Path = Path("checkpoints"),
precision: Optional[str] = None,
Expand All @@ -117,7 +168,8 @@ def run_server(
max_new_tokens: int = 50,
devices: int = 1,
accelerator: str = "auto",
port: int = 8000
port: int = 8000,
stream: bool = False
) -> None:
"""Serve a LitGPT model using LitServe
Expand Down Expand Up @@ -148,19 +200,38 @@ def run_server(
accelerator: The type of accelerator to use. For example, "auto", "cuda", "cpu", or "mps".
The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
port: The network port number on which the model is configured to be served.
stream: Whether to stream the responses.
"""
check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth")

server = LitServer(
SimpleLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices)
if not stream:

server = LitServer(
SimpleLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices
)

else:
server = LitServer(
StreamLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices,
stream=True
)

server.run(port=port)
20 changes: 17 additions & 3 deletions tests/test_serve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import asdict
import json
import shutil

from lightning.fabric import seed_everything
Expand All @@ -10,7 +11,7 @@


from litgpt import GPT, Config
from litgpt.deploy.serve import SimpleLitAPI
from litgpt.deploy.serve import SimpleLitAPI, StreamLitAPI
from litgpt.scripts.download import download_from_hub


Expand All @@ -37,6 +38,19 @@ def test_simple(tmp_path):

with TestClient(server.app) as client:
response = client.post("/predict", json={"prompt": "Hello world"})
# Model is a small random model, not trained, hence the gibberish.
# We are just testing that the server works.
assert response.json()["output"][:19] == "Hello world statues"

# Test with streaming enabled
server = LitServer(
StreamLitAPI(checkpoint_dir=tmp_path, temperature=1, top_k=1),
accelerator=accelerator, devices=1, timeout=60, stream=True
)
with TestClient(server.app) as client:
response = client.post("/predict", json={"prompt": "Hello world"})
response_list = response.json()
parsed_response = []

for item in response_list:
parsed_dict = json.loads(item)
parsed_response.append(parsed_dict)
assert parsed_response[0]["output"][:19] == "Hello world statues"

0 comments on commit c890334

Please sign in to comment.