Skip to content

Commit

Permalink
stream-with-chat
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 20, 2024
1 parent 5dabf5f commit e04667b
Showing 1 changed file with 95 additions and 24 deletions.
119 changes: 95 additions & 24 deletions litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
from pathlib import Path
from typing import Dict, Any, Optional
from litgpt.utils import check_valid_checkpoint_dir
Expand All @@ -11,7 +12,8 @@
from litgpt.model import GPT
from litgpt.config import Config
from litgpt.tokenizer import Tokenizer
from litgpt.generate.base import generate
from litgpt.generate.base import generate as plain_generate
from litgpt.chat.base import generate as stream_generate
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
from litgpt.utils import load_checkpoint, CLI, get_default_supported_precision

Expand All @@ -23,7 +25,7 @@
LitAPI, LitServer = object, object


class SimpleLitAPI(LitAPI):
class BaseLitAPI(LitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
Expand Down Expand Up @@ -53,7 +55,7 @@ def setup(self, device: str) -> None:

fabric = L.Fabric(
accelerator=device.type,
devices=1 if device.type=="cpu" else [device.index],
devices=1 if device.type == "cpu" else [device.index],
precision=precision,
)
checkpoint_path = self.checkpoint_dir / "lit_model.pth"
Expand Down Expand Up @@ -81,20 +83,34 @@ def decode_request(self, request: Dict[str, Any]) -> Any:
encoded = self.tokenizer.encode(prompt, device=self.device)
return encoded


class SimpleLitAPI(BaseLitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50):
super().__init__(checkpoint_dir, precision, temperature, top_k, top_p, max_new_tokens)

def setup(self, device: str):
super().setup(device)

def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.
prompt_length = inputs.size(0)
max_returned_tokens = prompt_length + self.max_new_tokens

y = generate(
self.model,
inputs,
max_returned_tokens,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
eos_id=self.tokenizer.eos_id
)
y = plain_generate(
self.model,
inputs,
max_returned_tokens,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
eos_id=self.tokenizer.eos_id
)

for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()
Expand All @@ -106,6 +122,42 @@ def encode_response(self, output: torch.Tensor) -> Dict[str, Any]:
return {"output": decoded_output}


class StreamLitAPI(BaseLitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50):
super().__init__(checkpoint_dir, precision, temperature, top_k, top_p, max_new_tokens)

def setup(self, device: str):
super().setup(device)

def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.
prompt_length = inputs.size(0)
max_returned_tokens = prompt_length + self.max_new_tokens

for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()

yield from stream_generate(
self.model,
inputs,
max_returned_tokens,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
stop_tokens=[self.tokenizer.decode(torch.tensor([self.tokenizer.eos_id]))]
)

def encode_response(self, output):
for out in output:
yield {"output": self.tokenizer.decode(out)}


def run_server(
checkpoint_dir: Path = Path("checkpoints"),
precision: Optional[str] = None,
Expand All @@ -115,7 +167,8 @@ def run_server(
max_new_tokens: int = 50,
devices: int = 1,
accelerator: str = "auto",
port: int = 8000
port: int = 8000,
stream: bool = False
) -> None:
"""Serve a LitGPT model using LitServe
Expand Down Expand Up @@ -146,19 +199,37 @@ def run_server(
accelerator: The type of accelerator to use. For example, "auto", "cuda", "cpu", or "mps".
The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
port: The network port number on which the model is configured to be served.
stream: Whether to stream the responses.
"""
check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth")

server = LitServer(
SimpleLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices)
if not stream:
server = LitServer(
SimpleLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices
)

else:
server = LitServer(
StreamLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices,
stream=True
)

server.run(port=port)

0 comments on commit e04667b

Please sign in to comment.