-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feature/finetune-default
- Loading branch information
Showing
8 changed files
with
321 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. | ||
from pathlib import Path | ||
from typing import Dict, Any, Optional, Literal | ||
from litgpt.utils import check_valid_checkpoint_dir | ||
|
||
import lightning as L | ||
import torch | ||
from litserve import LitAPI, LitServer | ||
|
||
from litgpt.model import GPT | ||
from litgpt.config import Config | ||
from litgpt.tokenizer import Tokenizer | ||
from litgpt.generate.base import generate | ||
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle | ||
from litgpt.utils import load_checkpoint, CLI, get_default_supported_precision | ||
|
||
|
||
class SimpleLitAPI(LitAPI): | ||
def __init__(self, | ||
checkpoint_dir: Path, | ||
precision: Optional[str] = None, | ||
temperature: float = 0.8, | ||
top_k: int = 50, | ||
max_new_tokens: int = 50) -> None: | ||
|
||
super().__init__() | ||
self.checkpoint_dir = checkpoint_dir | ||
self.precision = precision | ||
self.temperature = temperature | ||
self.top_k = top_k | ||
self.max_new_tokens = max_new_tokens | ||
|
||
def setup(self, device: str) -> None: | ||
# Setup the model so it can be called in `predict`. | ||
config = Config.from_file(self.checkpoint_dir / "model_config.yaml") | ||
device = torch.device(device) | ||
torch.set_float32_matmul_precision("high") | ||
|
||
precision = self.precision or get_default_supported_precision(training=False) | ||
|
||
fabric = L.Fabric( | ||
accelerator=device.type, | ||
devices=1 if device.type=="cpu" else [device.index], # TODO: Update once LitServe supports "auto" | ||
precision=precision, | ||
) | ||
checkpoint_path = self.checkpoint_dir / "lit_model.pth" | ||
self.tokenizer = Tokenizer(self.checkpoint_dir) | ||
self.prompt_style = ( | ||
load_prompt_style(self.checkpoint_dir) | ||
if has_prompt_style(self.checkpoint_dir) | ||
else PromptStyle.from_config(config) | ||
) | ||
with fabric.init_module(empty_init=True): | ||
model = GPT(config) | ||
with fabric.init_tensor(): | ||
# enable the kv cache | ||
model.set_kv_cache(batch_size=1) | ||
model.eval() | ||
|
||
self.model = fabric.setup_module(model) | ||
load_checkpoint(fabric, self.model, checkpoint_path) | ||
self.device = fabric.device | ||
|
||
def decode_request(self, request: Dict[str, Any]) -> Any: | ||
# Convert the request payload to your model input. | ||
prompt = request["prompt"] | ||
prompt = self.prompt_style.apply(prompt) | ||
encoded = self.tokenizer.encode(prompt, device=self.device) | ||
return encoded | ||
|
||
def predict(self, inputs: torch.Tensor) -> Any: | ||
# Run the model on the input and return the output. | ||
prompt_length = inputs.size(0) | ||
max_returned_tokens = prompt_length + self.max_new_tokens | ||
|
||
y = generate( | ||
self.model, | ||
inputs, | ||
max_returned_tokens, | ||
temperature=self.temperature, | ||
top_k=self.top_k, | ||
eos_id=self.tokenizer.eos_id | ||
) | ||
|
||
for block in self.model.transformer.h: | ||
block.attn.kv_cache.reset_parameters() | ||
return y | ||
|
||
def encode_response(self, output: torch.Tensor) -> Dict[str, Any]: | ||
# Convert the model output to a response payload. | ||
decoded_output = self.tokenizer.decode(output) | ||
return {"output": decoded_output} | ||
|
||
|
||
def run_server( | ||
checkpoint_dir: Path = Path("checkpoints"), | ||
precision: Optional[str] = None, | ||
temperature: float = 0.8, | ||
top_k: int = 200, | ||
max_new_tokens: int = 50, | ||
devices: int = 1, | ||
accelerator: str = "cuda", | ||
port: int = 8000 | ||
) -> None: | ||
"""Serve a LitGPT model using LitServe | ||
Arguments: | ||
checkpoint_dir: The checkpoint directory to load the model from. | ||
precision: Optional precision setting to instantiate the model weights in. By default, this will | ||
automatically be inferred from the metadata in the given ``checkpoint_dir`` directory. | ||
temperature: Temperature setting for the text generation. Value above 1 increase randomness. | ||
Values below 1 decrease randomness. | ||
top_k: The size of the pool of potential next tokens. Values larger than 1 result in more novel | ||
generated text but can also lead to more incoherent texts. | ||
max_new_tokens: The number of generation steps to take. | ||
devices: How many devices/GPUs to use. | ||
accelerator: The type of accelerator to use. For example, "cuda" or "cpu". | ||
port: The network port number on which the model is configured to be served. | ||
""" | ||
check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth") | ||
|
||
server = LitServer( | ||
SimpleLitAPI( | ||
checkpoint_dir=checkpoint_dir, | ||
precision=precision, | ||
temperature=temperature, | ||
top_k=top_k, | ||
max_new_tokens=max_new_tokens, | ||
), | ||
accelerator=accelerator, | ||
devices=devices) | ||
|
||
server.run(port=port) | ||
|
||
|
||
if __name__ == "__main__": | ||
CLI(run_server) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. | ||
from dataclasses import asdict | ||
import shutil | ||
|
||
from lightning.fabric import seed_everything | ||
from fastapi.testclient import TestClient | ||
from litserve.server import LitServer | ||
import torch | ||
import yaml | ||
|
||
|
||
from litgpt import GPT, Config | ||
from litgpt.deploy.serve import SimpleLitAPI | ||
from litgpt.scripts.download import download_from_hub | ||
|
||
|
||
def test_simple(tmp_path): | ||
|
||
# Create model checkpoint | ||
seed_everything(123) | ||
ours_config = Config.from_name("pythia-14m") | ||
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) | ||
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) | ||
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) | ||
ours_model = GPT(ours_config) | ||
checkpoint_path = tmp_path / "lit_model.pth" | ||
torch.save(ours_model.state_dict(), checkpoint_path) | ||
config_path = tmp_path / "model_config.yaml" | ||
with open(config_path, "w", encoding="utf-8") as fp: | ||
yaml.dump(asdict(ours_config), fp) | ||
|
||
accelerator = "cpu" | ||
server = LitServer( | ||
SimpleLitAPI(checkpoint_dir=tmp_path, temperature=1, top_k=1), | ||
accelerator=accelerator, devices=1, timeout=60 | ||
) | ||
|
||
with TestClient(server.app) as client: | ||
response = client.post("/predict", json={"prompt": "Hello world"}) | ||
# Model is a small random model, not trained, hence the gibberish. | ||
# We are just testing that the server works. | ||
assert response.json()["output"][:19] == "Hello world statues" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.