Skip to content

Commit

Permalink
remove quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Apr 17, 2024
1 parent 090107c commit 694f94c
Showing 1 changed file with 0 additions and 17 deletions.
17 changes: 0 additions & 17 deletions litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from litgpt.utils import check_valid_checkpoint_dir

import lightning as L
from lightning.fabric.plugins import BitsandbytesPrecision
import torch
from litserve import LitAPI, LitServer

Expand All @@ -20,15 +19,13 @@ class SimpleLitAPI(LitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
temperature: float = 0.8,
top_k: int = 50,
max_new_tokens: int = 50) -> None:

super().__init__()
self.checkpoint_dir = checkpoint_dir
self.precision = precision
self.quantize = quantize
self.temperature = temperature
self.top_k = top_k
self.max_new_tokens = max_new_tokens
Expand All @@ -40,19 +37,11 @@ def setup(self, device: str) -> None:
torch.set_float32_matmul_precision("high")

precision = self.precision or get_default_supported_precision(training=False)
plugins = None
if self.quantize is not None and self.quantize.startswith("bnb."):
if "mixed" in self.precision:
raise ValueError("Quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(self.quantize[4:], dtype)
precision = None

fabric = L.Fabric(
accelerator=device.type,
devices=1 if device.type=="cpu" else [device.index], # TODO: Update once LitServe supports "auto"
precision=precision,
plugins=plugins,
)
checkpoint_path = self.checkpoint_dir / "lit_model.pth"
self.tokenizer = Tokenizer(self.checkpoint_dir)
Expand Down Expand Up @@ -106,7 +95,6 @@ def encode_response(self, output: torch.Tensor) -> Dict[str, Any]:
def run_server(
checkpoint_dir: Path = Path("checkpoints"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
temperature: float = 0.8,
top_k: int = 200,
max_new_tokens: int = 50,
Expand All @@ -120,10 +108,6 @@ def run_server(
checkpoint_dir: The checkpoint directory to load the model from.
precision: Optional precision setting to instantiate the model weights in. By default, this will
automatically be inferred from the metadata in the given ``checkpoint_dir`` directory.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
temperature: Temperature setting for the text generation. Value above 1 increase randomness.
Values below 1 decrease randomness.
top_k: The size of the pool of potential next tokens. Values larger than 1 result in more novel
Expand All @@ -139,7 +123,6 @@ def run_server(
SimpleLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
quantize=quantize,
temperature=temperature,
top_k=top_k,
max_new_tokens=max_new_tokens,
Expand Down

0 comments on commit 694f94c

Please sign in to comment.