diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index af2f8af7f0..60d6472b36 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -22,7 +22,7 @@ def __init__(self, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, temperature: float = 0.8, top_k: int = 200, - max_generated_tokens: int = 30) -> None: + max_new_tokens: int = 50) -> None: super().__init__() self.checkpoint_dir = checkpoint_dir @@ -30,7 +30,7 @@ def __init__(self, self.quantize = quantize self.temperature = temperature self.top_k = top_k - self.max_generated_tokens = max_generated_tokens + self.max_new_tokens = max_new_tokens def setup(self, device: str) -> None: # Setup the model so it can be called in `predict`. @@ -80,7 +80,7 @@ def decode_request(self, request: Dict[str, Any]) -> Any: def predict(self, inputs: torch.Tensor) -> Any: # Run the model on the input and return the output. prompt_length = inputs.size(0) - max_returned_tokens = prompt_length + self.max_generated_tokens + max_returned_tokens = prompt_length + self.max_new_tokens y = generate( self.model, @@ -107,7 +107,7 @@ def run_server( quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, temperature: float = 0.8, top_k: int = 200, - max_generated_tokens: int = 50, + max_new_tokens: int = 50, devices: int = 1, accelerator: str = "cuda", port: int = 8000 @@ -126,7 +126,7 @@ def run_server( Values below 1 decrease randomness. top_k: The size of the pool of potential next tokens. Values larger than 1 result in more novel generated text but can also lead to more incoherent texts. - max_generated_tokens: How many new tokens, in addition to the prompt length, to generate. + max_new_tokens: The number of generation steps to take. devices: How many devices/GPUs to use. accelerator: The type of accelerator to use. For example, "cuda" or "cpu". port: The network port number on which the model is configured to be served. @@ -140,7 +140,7 @@ def run_server( quantize=quantize, temperature=temperature, top_k=top_k, - max_generated_tokens=max_generated_tokens, + max_new_tokens=max_new_tokens, ), accelerator=accelerator, devices=devices)